├── data └── .gitkeep ├── shallowspeed ├── __init__.py ├── optimizer.py ├── utils.py ├── functional.py ├── dataset.py ├── layers.py └── pipe.py ├── .github └── assets │ ├── title_picture.jpg │ └── PP_pebble_graph.gif ├── setup.py ├── environment.yml ├── .pre-commit-config.yaml ├── tests ├── test_dataset.py ├── test_layers.py ├── test_schedules.py └── test_functional.py ├── download_dataset.py ├── README.md ├── .gitignore ├── train.py └── scripts └── DDP_PyTorch_MNIST.py /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /shallowspeed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/assets/title_picture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siboehm/ShallowSpeed/HEAD/.github/assets/title_picture.jpg -------------------------------------------------------------------------------- /.github/assets/PP_pebble_graph.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siboehm/ShallowSpeed/HEAD/.github/assets/PP_pebble_graph.gif -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="shallowspeed", 5 | version="0.0.1", 6 | author="siboehm", 7 | author_email="", 8 | packages=setuptools.find_packages(), 9 | ) 10 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: shallowspeed 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python>=3.10 6 | - mpi4py 7 | - pandas 8 | - scikit-learn 9 | - matplotlib 10 | - numpy 11 | - jupyter 12 | - pyarrow 13 | - pytest 14 | - pytest-mpi 15 | - pre-commit 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 23.10.0 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | - repo: https://github.com/PyCQA/isort 8 | rev: 5.12.0 9 | hooks: 10 | - id: isort 11 | args: ["--profile", "black", "--filter-files"] 12 | verbose: true 13 | -------------------------------------------------------------------------------- /shallowspeed/optimizer.py: -------------------------------------------------------------------------------- 1 | from shallowspeed.layers import Parameter 2 | 3 | 4 | class SGD: 5 | def __init__(self, parameters: list[Parameter], lr: float): 6 | # Boring stateless optimizer is boring 7 | self._params = parameters 8 | self._lr = lr 9 | 10 | def step(self): 11 | for param in self._params: 12 | if param.requires_grad: 13 | param.data -= self._lr * param.grad 14 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from shallowspeed.dataset import Dataset 7 | 8 | 9 | def test_dataset(): 10 | save_path = Path("data/mnist_784") 11 | dataset = Dataset(save_path, 128, 8) 12 | input_X = pd.read_parquet(save_path / f"x_train.parquet").to_numpy() 13 | 14 | num_sample = 59500 15 | num_sample_no_tile_quantization = num_sample - (num_sample % 128) 16 | dataset.load(DP_rank=1, DP_size=4) 17 | assert len(dataset) == num_sample_no_tile_quantization // 4 18 | assert dataset.load_micro_batch_input(0, 0).dtype == np.float32 19 | -------------------------------------------------------------------------------- /download_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.datasets import fetch_openml 6 | from sklearn.model_selection import train_test_split 7 | 8 | 9 | def download_MNIST(save_dir): 10 | x, y = fetch_openml("mnist_784", version=1, data_home="data_cache", return_X_y=True) 11 | 12 | x /= 255.0 13 | x -= x.mean() 14 | y = pd.get_dummies(y) 15 | 16 | x_train, x_val, y_train, y_val = train_test_split( 17 | x, y, test_size=0.15, random_state=42 18 | ) 19 | save_dir.mkdir() 20 | x_train.to_parquet(save_dir / "x_train.parquet") 21 | x_val.to_parquet(save_dir / "x_val.parquet") 22 | np.save(save_dir / "y_train.npy", y_train) 23 | np.save(save_dir / "y_val.npy", y_val) 24 | 25 | 26 | if __name__ == "__main__": 27 | save_dir = Path("../data/mnist_784/") 28 | print(f"Downloading MNIST dataset at {save_dir.resolve()}") 29 | download_MNIST(save_dir) 30 | -------------------------------------------------------------------------------- /shallowspeed/utils.py: -------------------------------------------------------------------------------- 1 | from hashlib import sha1 2 | 3 | from mpi4py import MPI 4 | 5 | from shallowspeed.layers import Parameter 6 | 7 | 8 | def rprint(*args, **kwargs): 9 | if MPI.COMM_WORLD.Get_rank() == 0: 10 | print(*args, **kwargs) 11 | 12 | 13 | def get_model_hash(model): 14 | # this is probably not the most efficient way to do this, but it's 15 | # not straightforward to get a deterministic, content-based hash of a model's parameters 16 | hash_str = "" 17 | for param in model.parameters(): 18 | if isinstance(param, Parameter): 19 | param = param.data 20 | 21 | # concat the strings to form a single hash later 22 | hash_str += sha1(param).hexdigest() 23 | # hash to concatenated strings 24 | return sha1(hash_str.encode("utf-8")).hexdigest() 25 | 26 | 27 | def assert_sync(comm, model_hash): 28 | # check that all processes have the same model hash 29 | model_hash_all = comm.gather(model_hash, root=0) 30 | if comm.rank == 0 and len(set(model_hash_all)) > 1: 31 | raise ValueError("Model hash mismatch") 32 | -------------------------------------------------------------------------------- /shallowspeed/functional.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def relu(input): 5 | return input.clip(min=0.0) 6 | 7 | 8 | def relu_grad(grad_output, bitmask): 9 | assert bitmask.dtype == bool 10 | return grad_output * bitmask 11 | 12 | 13 | def linear(input, weight, bias): 14 | """ 15 | y = x@A^T + b 16 | """ 17 | return input @ weight.T + bias 18 | 19 | 20 | def linear_grad(grad_output, input, weight): 21 | return grad_output @ weight, grad_output.T @ input, grad_output.sum(axis=0) 22 | 23 | 24 | def softmax(input): 25 | # logsumexp trick 26 | input_exp = np.exp(input - np.max(input)) 27 | return input_exp / (input_exp.sum(axis=1, keepdims=True) + 1e-7) 28 | 29 | 30 | def softmax_grad(grad_output, input): 31 | # ideally we would cache the output instead of the input during FW, 32 | # avoiding the recomputation 33 | output = softmax(input) 34 | new_grad = output * grad_output 35 | return new_grad - output * new_grad.sum(axis=-1, keepdims=True) 36 | 37 | 38 | def mse_loss(input, target, batch_size: int): 39 | assert input.shape == target.shape 40 | return ((target - input) ** 2).sum() / batch_size 41 | 42 | 43 | def mse_loss_grad(input, target, batch_size: int): 44 | return -2 * (target - input) / batch_size 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shallowspeed 2 | ![stability-wip](https://img.shields.io/badge/stability-work_in_progress-lightgrey.svg) 3 | 4 | A tiny POC implementation of distributed training for sequential deep learning models. 5 | Implemented using plain Numpy & mpi4py. 6 | 7 | ![](.github/assets/title_picture.jpg) 8 | 9 | 10 | Currently implements: 11 | - Sequential models / deep MLPs, training using SGD. 12 | - Data parallel training with interleaved communication & computation, similar to PyTorch's [DistributedDataParallel](https://arxiv.org/abs/2006.15704). 13 | - Pipeline parallel training: 14 | - Naive schedule without interleaved stages. 15 | - [Gpipe](https://arxiv.org/abs/1811.06965) schedule with interleaved FWD & interleaved BWD. 16 | - (soon) [PipeDream Flush](https://arxiv.org/abs/2006.09503) schedule with additional inter-FWD & BWD interleaving. 17 | - Any combination of DP & PP algorithms. 18 | 19 | ## Setup 20 | ```bash 21 | conda env create 22 | pip install -e . 23 | # M1 Macs: conda install "libblas=*=*accelerate" 24 | python download_dataset.py 25 | pytest 26 | ``` 27 | 28 | ## Usage 29 | ```bash 30 | # Sequential training 31 | python train.py 32 | # Data parallel distributed training 33 | mpirun -n 4 python train.py --dp 4 34 | # Pipeline parallel distributed training 35 | mpirun -n 4 python train.py --pp 4 --schedule naive 36 | # Data & pipeline parallel distributed training 37 | mpirun -n 8 python train.py --dp 2 --pp 4 --schedule gpipe 38 | ``` 39 | 40 | ## Internals 41 | ![](.github/assets/PP_pebble_graph.gif) 42 | -------------------------------------------------------------------------------- /tests/test_layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from shallowspeed.functional import mse_loss_grad 4 | from shallowspeed.layers import MLP, Linear, ReLU, Sequential, Softmax 5 | 6 | 7 | def test_MLP_basic(): 8 | layer_sizes = [132, 40, 11, 9] 9 | layers = [ 10 | Linear( 11 | layer_sizes[i], 12 | layer_sizes[i + 1], 13 | activation="relu" if i < len(layer_sizes) - 2 else None, 14 | ) 15 | for i in range(len(layer_sizes) - 1) 16 | ] 17 | layers.append(Softmax()) 18 | dnn = Sequential(layers) 19 | assert len(dnn.parameters()) == 6 20 | x = np.ones((13, 132), dtype=np.float32) 21 | 22 | dnn.eval() 23 | output = dnn(x) 24 | assert output.shape == (13, 9) 25 | assert output.dtype == np.float32 26 | assert np.allclose(output.sum(), 13.0) 27 | 28 | dnn.train() 29 | output = dnn(x) 30 | target = np.diag(np.ones(9, dtype=np.float32)) 31 | target = np.concatenate((target, target[:4])) 32 | assert target.shape == (13, 9) 33 | 34 | dout = dnn.backward(mse_loss_grad(output, target, 13)) 35 | # TODO: Make sure the last layer doesn't return the gradients wrt the input 36 | assert dout.shape == (13, 132) 37 | assert dout.dtype == np.float32 38 | # check if parameters were updated 39 | for param in dnn.parameters(): 40 | assert param.requires_grad 41 | assert np.abs(param.grad).sum() > 0 42 | assert param.grad.shape == param.data.shape 43 | 44 | dnn.zero_grad() 45 | for param in dnn.parameters(): 46 | assert np.abs(param.grad).sum() == 0 47 | assert param.grad.shape == param.data.shape 48 | 49 | assert len(dnn.parameters()) == 6 50 | 51 | 52 | def test_distributed_MLP_init(): 53 | layer_sizes = [1, 22, 98, 14, 132, 40, 11, 9, 33] 54 | n_stages = 3 55 | batch_size = 13 56 | 57 | # first in pipeline 58 | dnn = MLP(layer_sizes, 0, n_stages, batch_size) 59 | assert len(dnn.parameters()) == 2 * 3 60 | assert len(dnn.layers) == 3 61 | assert all(isinstance(l.activation, ReLU) for l in dnn.layers) 62 | assert dnn.in_dim == 1 and dnn.out_dim == 14 63 | 64 | # last in pipeline 65 | dnn = MLP(layer_sizes, 2, n_stages, batch_size) 66 | assert len(dnn.parameters()) == 2 * 2 67 | assert len(dnn.layers) == 4 68 | assert isinstance(dnn.layers[0].activation, ReLU) 69 | assert dnn.layers[1].activation is None 70 | assert dnn.in_dim == 11 and dnn.out_dim == 33 71 | -------------------------------------------------------------------------------- /shallowspeed/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | class Dataset: 6 | """ 7 | TODO this is a pretty awkward implementation. 8 | How to improve: 9 | - Turn microbatch loading into an iterator 10 | - We only ever need to access μBatches sequentially, even during PipeDream 11 | - Figure out a clean interface for not having every process load the dataset 12 | - Maybe by injecting a load_dataset() function into __init__, which returns the numpy slices? 13 | - Write some tests to ensure equal μBatches are loaded during Sequential & distributed training 14 | """ 15 | 16 | input_X = None 17 | target_y = None 18 | 19 | def __init__( 20 | self, 21 | save_dir, 22 | global_batch_size, 23 | mubatch_size, 24 | validation=False, 25 | ): 26 | assert save_dir.is_dir(), "Download the dataset first!" 27 | self.save_dir = save_dir 28 | self.global_batch_size = global_batch_size 29 | self.local_batch_size = None 30 | self.mubatch_size = mubatch_size 31 | self._val = validation 32 | 33 | def load(self, DP_rank, DP_size): 34 | assert DP_rank < DP_size 35 | assert self.global_batch_size % DP_size == 0 36 | assert ( 37 | self.global_batch_size // DP_size 38 | ) % self.mubatch_size == 0, "μBatchsize must divide batchsize!" 39 | self.local_batch_size = self.global_batch_size // DP_size 40 | 41 | # each process loads the whole dataset 42 | # this is inefficient for large datasets, but fine for tiny MNIST 43 | suffix = "val" if self._val else "train" 44 | input_X = pd.read_parquet(self.save_dir / f"x_{suffix}.parquet").to_numpy( 45 | dtype=np.float32 46 | ) 47 | target_y = np.load(self.save_dir / f"y_{suffix}.npy").astype(np.float32) 48 | assert len(input_X) == len(target_y) 49 | 50 | # drop last few samples such that each batch is exactly `global_batch_size` long 51 | # this is important to ensure equivalence when changing the number of μBatches 52 | full_tiles_length = len(input_X) - (len(input_X) % self.global_batch_size) 53 | 54 | # each DP process selects its subset of the datasets by a `rank`-offset and `size`-strides 55 | # the copy() is super important, else the array is not continuous in memory 56 | # which results in horrible matmul performance 57 | self.input_X = input_X[DP_rank:full_tiles_length:DP_size].copy() 58 | self.target_y = target_y[DP_rank:full_tiles_length:DP_size].copy() 59 | 60 | assert len(self.input_X) % self.mubatch_size == 0 61 | assert len(self.input_X) % self.local_batch_size == 0 62 | 63 | def __len__(self): 64 | return len(self.input_X) 65 | 66 | def load_micro_batch_input(self, batch_id, mubatch_id): 67 | assert batch_id < self.get_num_batches() 68 | assert mubatch_id < self.get_num_mubatches() 69 | start_idx = batch_id * self.local_batch_size + mubatch_id * self.mubatch_size 70 | end_idx = start_idx + self.mubatch_size 71 | assert end_idx <= len(self.input_X) 72 | return self.input_X[start_idx:end_idx] 73 | 74 | def load_micro_batch_target(self, batch_id, mubatch_id): 75 | assert batch_id < self.get_num_batches() 76 | assert mubatch_id < self.get_num_mubatches() 77 | start_idx = batch_id * self.local_batch_size + mubatch_id * self.mubatch_size 78 | end_idx = start_idx + self.mubatch_size 79 | assert end_idx <= len(self.input_X) 80 | return self.target_y[start_idx:end_idx] 81 | 82 | def get_num_batches(self): 83 | return len(self) // self.local_batch_size 84 | 85 | def get_num_mubatches(self): 86 | return self.local_batch_size // self.mubatch_size 87 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | !data/.gitkeep 3 | 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | .idea/ 165 | -------------------------------------------------------------------------------- /tests/test_schedules.py: -------------------------------------------------------------------------------- 1 | import shallowspeed.pipe as pipe 2 | from shallowspeed.pipe import GPipeSchedule, NaiveParallelSchedule, PipeInstr 3 | 4 | """ 5 | TODO these tests are fairly useless. 6 | How to improve: 7 | - define a "happens before" predicate and eg test: 8 | - happens_before(BWD_MuBatch1, FWD_MuBatch2) in Naive schedule 9 | - happens_before(FWD_MuBatch, BWD_MuBatch1) in GPipe 10 | """ 11 | 12 | 13 | def flatten(sched): 14 | if not isinstance(sched[0], list): 15 | return sched 16 | 17 | flat_sched = [] 18 | for cmds in sched: 19 | for cmd in cmds: 20 | flat_sched.append(cmd) 21 | return flat_sched 22 | 23 | 24 | def cmd_is_in(cmd_t, sched): 25 | flat_sched = flatten(sched) 26 | return any(isinstance(x, cmd_t) for x in flat_sched) 27 | 28 | 29 | def test_naive_pp_schedule_dp_only(): 30 | # naive scheduler without any pipeline parallelism (num_stages=1) 31 | sched = NaiveParallelSchedule(num_micro_batches=5, num_stages=1, stage_id=0) 32 | cmds = list(sched.steps()) 33 | assert cmd_is_in(pipe.ZeroGrad, cmds[0]) 34 | assert not cmd_is_in(pipe.ZeroGrad, cmds[1:]) 35 | 36 | # for the final microbatch we AllReduce and step 37 | assert cmd_is_in(pipe.BackwardGradAllReduce, cmds[-2]) 38 | assert cmd_is_in(pipe.OptimizerStep, cmds[-1]) 39 | assert not cmd_is_in(pipe.BackwardGradAllReduce, cmds[:-2]) 40 | assert not cmd_is_in(pipe.OptimizerStep, cmds[:-1]) 41 | 42 | 43 | def test_naive_pp_schedule_pp_only(): 44 | # naive scheduler without any μBatches parallelism 45 | first_sched = NaiveParallelSchedule(num_micro_batches=1, num_stages=2, stage_id=0) 46 | first_cmds = list(first_sched.steps()) 47 | # init 48 | assert cmd_is_in(pipe.ZeroGrad, first_cmds[0]) 49 | assert cmd_is_in(pipe.LoadMuBatchInput, first_cmds[1]) 50 | assert not cmd_is_in(pipe.LoadMuBatchTarget, first_cmds) 51 | # finish 52 | assert cmd_is_in(pipe.BackwardGradAllReduce, first_cmds[-2]) 53 | assert not cmd_is_in(pipe.BackwardGradAllReduce, first_cmds[:-2]) 54 | assert cmd_is_in(pipe.OptimizerStep, first_cmds[-1]) 55 | assert not cmd_is_in(pipe.OptimizerStep, first_cmds[:-1]) 56 | 57 | second_sched = NaiveParallelSchedule(num_micro_batches=1, num_stages=2, stage_id=1) 58 | second_cmds = list(second_sched.steps()) 59 | # init 60 | assert cmd_is_in(pipe.ZeroGrad, second_cmds[0]) 61 | assert cmd_is_in(pipe.RecvActivations, second_cmds[1]) 62 | assert not cmd_is_in(pipe.LoadMuBatchInput, second_cmds) 63 | # processing 64 | assert cmd_is_in(pipe.LoadMuBatchTarget, second_cmds[1]) 65 | # finish 66 | assert cmd_is_in(pipe.BackwardGradAllReduce, second_cmds[-2]) 67 | assert not cmd_is_in(pipe.BackwardGradAllReduce, second_cmds[:-2]) 68 | assert cmd_is_in(pipe.OptimizerStep, first_cmds[-1]) 69 | assert not cmd_is_in(pipe.OptimizerStep, second_cmds[:-1]) 70 | 71 | 72 | def test_gpipe_schedule(): 73 | first_sched = GPipeSchedule(num_micro_batches=2, num_stages=3, stage_id=0) 74 | first_cmds = list(first_sched.steps()) 75 | 76 | # init 77 | assert cmd_is_in(pipe.ZeroGrad, first_cmds[0]) 78 | assert cmd_is_in(pipe.LoadMuBatchInput, first_cmds[1]) 79 | assert not cmd_is_in(pipe.LoadMuBatchTarget, first_cmds) 80 | # finish 81 | assert cmd_is_in(pipe.BackwardGradAllReduce, first_cmds[-2]) 82 | assert not cmd_is_in(pipe.BackwardGradAllReduce, first_cmds[:-2]) 83 | assert cmd_is_in(pipe.OptimizerStep, first_cmds[-1]) 84 | assert not cmd_is_in(pipe.OptimizerStep, first_cmds[:-1]) 85 | 86 | second_sched = GPipeSchedule(num_micro_batches=2, num_stages=3, stage_id=1) 87 | second_cmds = list(second_sched.steps()) 88 | # init 89 | assert cmd_is_in(pipe.ZeroGrad, second_cmds[0]) 90 | assert cmd_is_in(pipe.RecvActivations, second_cmds[1]) 91 | assert not cmd_is_in(pipe.LoadMuBatchInput, second_cmds) 92 | assert not cmd_is_in(pipe.LoadMuBatchTarget, second_cmds) 93 | # processing 94 | assert cmd_is_in(pipe.SendActivations, second_cmds[1:]) 95 | assert cmd_is_in(pipe.RecvActivations, second_cmds[1:]) 96 | assert cmd_is_in(pipe.SendInputGrad, second_cmds[1:]) 97 | assert cmd_is_in(pipe.RecvOutputGrad, second_cmds[1:]) 98 | # finish 99 | assert cmd_is_in(pipe.BackwardGradAllReduce, second_cmds[-2]) 100 | assert not cmd_is_in(pipe.BackwardGradAllReduce, second_cmds[:-2]) 101 | assert cmd_is_in(pipe.OptimizerStep, first_cmds[-1]) 102 | assert not cmd_is_in(pipe.OptimizerStep, second_cmds[:-1]) 103 | -------------------------------------------------------------------------------- /tests/test_functional.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from shallowspeed.functional import ( 4 | linear, 5 | linear_grad, 6 | mse_loss, 7 | relu, 8 | relu_grad, 9 | softmax, 10 | softmax_grad, 11 | ) 12 | 13 | EPS = 10e-6 14 | 15 | 16 | def test_shapes(): 17 | # relu 18 | x = np.empty((2, 3)) 19 | y = relu(x) 20 | dinput = relu_grad(y, x > 0) 21 | assert x.shape == dinput.shape 22 | 23 | # linear 24 | weight = np.empty((3, 10)) 25 | bias = np.empty((3,)) 26 | x = np.empty((20, 10)) 27 | y = linear(x, weight, bias) 28 | assert y.shape == (20, 3) 29 | dinput, dweight, dbias = linear_grad(np.empty((20, 3)), x, weight) 30 | assert dinput.shape == x.shape 31 | assert dweight.shape == weight.shape 32 | assert dbias.shape == bias.shape 33 | 34 | # softmax 35 | x = np.empty((20, 10)) 36 | y = softmax(x) 37 | assert y.shape == (20, 10) 38 | dinput = softmax_grad(y, x) 39 | assert dinput.shape == x.shape 40 | 41 | 42 | def test_relu(): 43 | x = np.array([[-1, 2, -3], [4, -5, 6]]) 44 | y = relu(x) 45 | assert np.allclose(y, np.array([[0, 2, 0], [4, 0, 6]])) 46 | 47 | 48 | def test_relu_grad(): 49 | x = np.array([[-1, -2, -3], [0.1, 5, 6]]) 50 | finite_diff = (relu(x + EPS / 2) - relu(x - EPS / 2)) / EPS 51 | assert np.allclose(relu_grad(np.ones_like(x), x > 0), finite_diff) 52 | 53 | 54 | def I_ij(i, j, n, m): 55 | # 1 at position i,j zero otherwise 56 | result = np.zeros((n, m)) 57 | result[i, j] += 1 58 | return result 59 | 60 | 61 | def test_linear_grad(): 62 | x = np.array([[-1, -2, -3]], dtype=float) 63 | W = np.array([[2, 3, -1], [1, 0, 4], [9, -9, 1], [1, -3, 5]], dtype=float) 64 | b = np.array([[1, -1, 1, 3]], dtype=float) 65 | grad_out = np.arange(4, dtype=float).reshape((1, 4)) 66 | 67 | # TODO get rid of some duplication by introducing a function 68 | # TODO calculate the Jacobian using vectorized operations 69 | # calculating Jacobian for input using finite differences method 70 | jacobian_i_fd = np.zeros((3, 4), dtype=float) 71 | for i in range(3): 72 | for o in range(4): 73 | jacobian_i_fd[i, o] += ( 74 | ( 75 | linear(x + EPS / 2 * I_ij(0, i, 1, 3), W, b) 76 | - linear(x - EPS / 2 * I_ij(0, i, 1, 3), W, b) 77 | ) 78 | / EPS 79 | )[0][o] 80 | jvp = jacobian_i_fd @ grad_out[0] 81 | real = linear_grad(grad_out, x, W)[0] 82 | assert np.allclose(jvp, real) 83 | 84 | # calculating Jacobian for weights using finite differences method 85 | jacobian_W_fd = np.zeros((4, 3, 4), dtype=float) 86 | for r in range(4): 87 | for c in range(3): 88 | for o in range(4): 89 | jacobian_W_fd[r, c, o] += ( 90 | ( 91 | linear(x, W + EPS / 2 * I_ij(r, c, 4, 3), b) 92 | - linear(x, W - EPS / 2 * I_ij(r, c, 4, 3), b) 93 | ) 94 | / EPS 95 | )[0][o] 96 | jvp = jacobian_W_fd @ grad_out[0] 97 | real = linear_grad(grad_out, x, W)[1] 98 | assert np.allclose(jvp, real) 99 | 100 | # calculating Jacobian for bias using finite differences method 101 | jacobian_b_fd = np.zeros((4, 4), dtype=float) 102 | for b in range(4): 103 | for o in range(4): 104 | jacobian_b_fd[b, o] += ( 105 | ( 106 | linear(x, W, b + EPS / 2 * I_ij(0, b, 1, 4)) 107 | - linear(x, W, b - EPS / 2 * I_ij(0, b, 1, 4)) 108 | ) 109 | / EPS 110 | )[0][o] 111 | jvp = jacobian_b_fd @ grad_out[0] 112 | real = linear_grad(grad_out, x, W)[2] 113 | assert np.allclose(jvp, real) 114 | 115 | 116 | def test_softmax(): 117 | x = np.array([[-1, 2, -3], [4, 5, 6]]) 118 | y = softmax(x) 119 | assert np.allclose(y.sum(axis=1), np.ones(y.shape[0])) 120 | assert (y > 0).all() 121 | # softmax is invariant to shifts 122 | assert np.allclose(softmax(x), softmax(x - 6)) 123 | 124 | 125 | def test_softmax_grad(): 126 | # TODO test against batch size > 1 127 | x = np.array([[-1, -2, -3]], dtype=float) 128 | grad_out = np.array([[1, 9, 11]], dtype=float) 129 | 130 | # calculating Jacobian for input using finite differences method 131 | jacobian_i_fd = np.zeros((3, 3), dtype=float) 132 | for i_i in range(3): 133 | for o_i in range(3): 134 | jacobian_i_fd[i_i, o_i] += ( 135 | ( 136 | softmax(x + EPS / 2 * I_ij(0, i_i, 1, 3)) 137 | - softmax(x - EPS / 2 * I_ij(0, i_i, 1, 3)) 138 | ) 139 | / EPS 140 | )[0][o_i] 141 | 142 | jvp = jacobian_i_fd @ grad_out[0] 143 | real = softmax_grad(grad_out, x) 144 | assert np.allclose(jvp, real) 145 | 146 | 147 | def test_mse(): 148 | # TODO write a grad test for MSE 149 | input = np.array([[1, 0, 0], [0, 1, 0]]) 150 | target = np.array([[1, 0, 0], [0, 1, 0]]) 151 | mse = mse_loss(input, target, input.shape[0]) 152 | assert np.allclose(mse, 0) 153 | 154 | input = np.array([[0.25, 0.5, 0.25], [0.5, 0.5, 0.0]]) 155 | mse = mse_loss(input, target, input.shape[0]) 156 | assert np.allclose(mse, (0.625 + 0.75) / 2) 157 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from mpi4py import MPI 7 | 8 | from shallowspeed.dataset import Dataset 9 | from shallowspeed.layers import MLP 10 | from shallowspeed.optimizer import SGD 11 | from shallowspeed.pipe import ( 12 | GPipeSchedule, 13 | InferenceSchedule, 14 | NaiveParallelSchedule, 15 | PipeDreamSchedule, 16 | Worker, 17 | ) 18 | from shallowspeed.utils import assert_sync, get_model_hash 19 | 20 | 21 | def compute_accuracy(model, worker, dataset): 22 | """ 23 | This function does a forward pass of x, then checks if the indices 24 | of the maximum value in the output equals the indices in the label 25 | y. Then it sums over each prediction and calculates the accuracy. 26 | """ 27 | model.eval() 28 | 29 | correct = 0 30 | total = 0 31 | for batch_id in range(dataset.get_num_batches()): 32 | schedule = InferenceSchedule( 33 | num_micro_batches=1, 34 | num_stages=worker.pipeline_depth, 35 | stage_id=worker.stage_id, 36 | ) 37 | worker.execute(schedule, batch_id) 38 | 39 | if worker.stage_id == worker.pipeline_depth - 1: 40 | pred = np.argmax(worker.output_buffers[0], axis=-1) 41 | target = np.argmax(dataset.load_micro_batch_target(batch_id, 0), axis=-1) 42 | correct += np.sum(pred == target) 43 | total += pred.shape[0] 44 | 45 | model.train() 46 | if worker.stage_id == worker.pipeline_depth - 1: 47 | return correct / total 48 | 49 | 50 | SCHEDULE_NAME_TO_CLS = { 51 | "naive": NaiveParallelSchedule, 52 | "gpipe": GPipeSchedule, 53 | "pipedream": PipeDreamSchedule, 54 | } 55 | 56 | EPOCHS = 20 57 | # We use a big batch size, to make training more amenable to parallelization 58 | GLOBAL_BATCH_SIZE = 128 59 | N_MUBATCHES = 4 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument( 65 | "--dp", 66 | type=int, 67 | default=1, 68 | help="Degree of data parallelism (=number of full model replicas)", 69 | ) 70 | parser.add_argument("--pp", type=int, default=1, help="Number of pipeline stages") 71 | parser.add_argument( 72 | "--schedule", type=str, choices=["pipedream", "gpipe", "naive"], default="naive" 73 | ) 74 | args = parser.parse_args() 75 | DP_tile_factor = args.dp 76 | PP_tile_factor = args.pp 77 | 78 | assert DP_tile_factor >= 1 and PP_tile_factor >= 1 79 | assert DP_tile_factor * PP_tile_factor == MPI.COMM_WORLD.size, ( 80 | f"Number of started workers is {MPI.COMM_WORLD.size}, " 81 | f"but should be {DP_tile_factor * PP_tile_factor} (DP * PP)" 82 | ) 83 | assert ( 84 | GLOBAL_BATCH_SIZE % DP_tile_factor == 0 85 | ), "Batch size must be properly divisible by DP" 86 | 87 | # create MPI communicators for data parallel AllReduce & pipeline parallel send & recv 88 | # if the `color=` parameter is the same, then those two workers end up in the same communicator 89 | dp_comm = MPI.COMM_WORLD.Split(color=MPI.COMM_WORLD.Get_rank() % PP_tile_factor) 90 | # to run it truly distributed (like on a RaspberryPi cluster) you'd use comm.Split_type 91 | # instead of this color splitting, eg TYPE_SOCKET for PP 92 | pp_comm = MPI.COMM_WORLD.Split(color=MPI.COMM_WORLD.Get_rank() // PP_tile_factor) 93 | # sanity check 94 | assert dp_comm.Get_size() == DP_tile_factor and pp_comm.Get_size() == PP_tile_factor 95 | 96 | # Set up the local model. 97 | # Layer_sizes is the total model size, which we split into PP-many stages 98 | layer_sizes = [784, 128, 127, 126, 125, 124, 123, 10] 99 | model = MLP( 100 | layer_sizes, 101 | stage_idx=pp_comm.rank, 102 | n_stages=PP_tile_factor, 103 | batch_size=GLOBAL_BATCH_SIZE, 104 | ) 105 | model.train() 106 | 107 | optimizer = SGD(model.parameters(), lr=0.006) 108 | 109 | # Each DP-worker gets a slice of the global batch-size 110 | # TODO not every worker needs the dataset 111 | save_dir = Path("data/mnist_784/") 112 | local_batch_size = GLOBAL_BATCH_SIZE // DP_tile_factor 113 | dataset = Dataset( 114 | save_dir, 115 | global_batch_size=GLOBAL_BATCH_SIZE, 116 | mubatch_size=local_batch_size // N_MUBATCHES, 117 | validation=False, 118 | ) 119 | dataset.load(dp_comm.Get_rank(), dp_comm.Get_size()) 120 | worker = Worker(dp_comm, pp_comm, model, dataset, optimizer) 121 | 122 | val_dataset = Dataset( 123 | save_dir, 124 | global_batch_size=GLOBAL_BATCH_SIZE, 125 | mubatch_size=GLOBAL_BATCH_SIZE, 126 | validation=True, 127 | ) 128 | val_dataset.load(DP_rank=0, DP_size=1) 129 | val_worker = Worker(None, pp_comm, model, val_dataset, None) 130 | 131 | start_time = time.time() 132 | for iteration in range(EPOCHS): 133 | accuracy = compute_accuracy(model, val_worker, val_dataset) 134 | if accuracy: 135 | print( 136 | f"Epoch: {iteration}, Time Spent: {time.time() - start_time:.2f}s, Accuracy: {accuracy * 100:.2f}%", 137 | ) 138 | 139 | for batch_id in range(0, dataset.get_num_batches()): 140 | schedule = SCHEDULE_NAME_TO_CLS[args.schedule]( 141 | num_micro_batches=N_MUBATCHES, 142 | num_stages=PP_tile_factor, 143 | stage_id=pp_comm.rank, 144 | ) 145 | # do the actual work 146 | worker.execute(schedule, batch_id) 147 | 148 | accuracy = compute_accuracy(model, val_worker, val_dataset) 149 | if accuracy is not None: 150 | print( 151 | f"Epoch: {EPOCHS}, Time Spent: {time.time() - start_time:.2f}s, Accuracy: {accuracy * 100:.2f}%", 152 | ) 153 | 154 | # Sanity check: Make sure data parallel replicas have the same model weights 155 | assert_sync(dp_comm, get_model_hash(model)) 156 | -------------------------------------------------------------------------------- /scripts/DDP_PyTorch_MNIST.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a self-contained example of non-interleaved data parallel training 3 | using PyTorch and MPI. 4 | """ 5 | 6 | import time 7 | from hashlib import sha1 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from mpi4py import MPI 13 | from torch.utils.data import DataLoader, Subset 14 | from torchvision import datasets, transforms 15 | 16 | # Make sure that all kernel's used by PyTorch run in a single thread 17 | # Eg the matrix multiplication kernel by default will use multiple threads (this is called intra-op parallelism) 18 | torch.set_num_threads(1) 19 | torch.manual_seed(0) 20 | 21 | 22 | # Define an MLP classifier 23 | class MLP(nn.Module): 24 | def __init__(self, input_size: int, hidden_size: int, output_size: int): 25 | super(MLP, self).__init__() 26 | self.fc1 = nn.Linear(input_size, hidden_size) 27 | self.fc2 = nn.Linear(hidden_size, hidden_size) 28 | self.fc3 = nn.Linear(hidden_size, output_size) 29 | 30 | def forward(self, x): 31 | x = F.relu(self.fc1(x)) 32 | x = F.relu(self.fc2(x)) 33 | x = self.fc3(x) 34 | return x 35 | 36 | 37 | def get_model_hash(model): 38 | # this is probably not the most efficient way to do this, but it's 39 | # not straightforward to get a deterministic, content-based hash of a model's parameters 40 | hash_str = "" 41 | for param in model.parameters(): 42 | numpy_param = param.data.cpu().numpy() 43 | # concat the strings to form a single hash later 44 | hash_str += sha1(numpy_param).hexdigest() 45 | # hash to concatenated strings 46 | return sha1(hash_str.encode("utf-8")).hexdigest() 47 | 48 | 49 | def rprint(*args, **kwargs): 50 | if MPI.COMM_WORLD.Get_rank() == 0: 51 | print(*args, **kwargs) 52 | 53 | 54 | def assert_sync(model_hash): 55 | # check that all processes have the same model hash 56 | model_hash_all = comm.gather(model_hash, root=0) 57 | if MPI.COMM_WORLD.rank == 0 and len(set(model_hash_all)) > 1: 58 | raise ValueError("Model hash mismatch") 59 | 60 | 61 | NUM_EPOCHS = 5 62 | BATCH_SIZE = 128 63 | 64 | if __name__ == "__main__": 65 | # init MPI 66 | comm = MPI.COMM_WORLD 67 | rank = comm.Get_rank() 68 | size = comm.Get_size() 69 | 70 | # download the MNIST dataset 71 | transform = transforms.Compose( 72 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 73 | ) 74 | dataset = datasets.MNIST("data/", train=True, download=True, transform=transform) 75 | # get a distinct subset of the dataset for each process by using strides 76 | dataset = Subset( 77 | dataset, torch.arange(start=rank, end=len(dataset), step=size, dtype=torch.long) 78 | ) 79 | # Note: to make distributed training as similar as possible to serial training, 80 | # we need to turn of shuffling and make sure that `size` evenly divides the batch size 81 | assert BATCH_SIZE % size == 0 82 | train_loader = DataLoader( 83 | dataset=dataset, 84 | batch_size=BATCH_SIZE // size, 85 | shuffle=False, 86 | ) 87 | test_loader = DataLoader( 88 | datasets.MNIST("data/", train=False, transform=transform), 89 | batch_size=64, 90 | shuffle=True, 91 | ) 92 | 93 | # define the model 94 | model = MLP(input_size=28 * 28, hidden_size=64, output_size=10) 95 | # make sure the initialization is the same on all processes 96 | assert_sync(get_model_hash(model)) 97 | 98 | # define the loss function 99 | criterion = nn.CrossEntropyLoss() 100 | # define the optimizer 101 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 102 | 103 | # train the model 104 | start_time = time.time() 105 | for epoch in range(NUM_EPOCHS): 106 | epoch_start_time = time.time() 107 | for i, (images, labels) in enumerate(train_loader): 108 | images = images.view(-1, 28 * 28) 109 | outputs = model(images) 110 | loss = criterion(outputs, labels) 111 | # rescale the loss to be a mean over the global batch size instead of 112 | # just the local batch size (= global batch size / comm.size, see above) 113 | loss /= comm.size 114 | optimizer.zero_grad() 115 | # compute the gradients locally 116 | loss.backward() 117 | 118 | if size > 1: 119 | # todo: gather the activations instead of gradients 120 | # todo: do this in a single communication step by allocating a large tensor 121 | for param in model.parameters(): 122 | comm.Allreduce(MPI.IN_PLACE, param.grad, op=MPI.SUM) 123 | 124 | optimizer.step() 125 | if (i + 1) % 100 == 0: 126 | rprint( 127 | "Epoch [{:2}/{}], Step [{}/{}], Loss: {:.4f}".format( 128 | epoch + 1, 129 | NUM_EPOCHS, 130 | i + 1, 131 | len(train_loader), 132 | loss.item() * 4, 133 | ) 134 | ) 135 | rprint("Time(epoch) {:4.1f}s".format(time.time() - epoch_start_time)) 136 | rprint("Total training time: {:4.1f}s".format(time.time() - start_time)) 137 | 138 | # test the model 139 | with torch.no_grad(): 140 | correct = 0 141 | total = 0 142 | for images, labels in test_loader: 143 | images = images.view(-1, 28 * 28) 144 | outputs = model(images) 145 | _, predicted = torch.max(outputs.data, 1) 146 | total += labels.size(0) 147 | correct += (predicted == labels).sum().item() 148 | rprint( 149 | "Accuracy of the model on the {} test images: {} %".format( 150 | total, 100 * correct / total 151 | ), 152 | ) 153 | 154 | # make sure the final model is the same on all processes 155 | assert_sync(get_model_hash(model)) 156 | 157 | torch.save(model.state_dict(), f"data/models/model_p{size}.pkl") 158 | 159 | # compare the absolute divergence between the two models 160 | sequential_model = MLP(input_size=28 * 28, hidden_size=64, output_size=10) 161 | sequential_model.load_state_dict(torch.load(f"data/models/model_p1.pkl")) 162 | divergence = 0 163 | for param1, param2 in zip(model.parameters(), sequential_model.parameters()): 164 | divergence += torch.abs(param1 - param2).sum().item() 165 | rprint( 166 | "Total absolute divergence cmp'd to serial weights: {:.8f}".format(divergence), 167 | ) 168 | -------------------------------------------------------------------------------- /shallowspeed/layers.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | from numpy.random import MT19937, RandomState, SeedSequence 5 | 6 | from shallowspeed.functional import ( 7 | linear, 8 | linear_grad, 9 | mse_loss_grad, 10 | relu, 11 | relu_grad, 12 | softmax, 13 | softmax_grad, 14 | ) 15 | 16 | 17 | class Parameter(ABC): 18 | """ 19 | Encapsulates a numpy array and keeps track of its gradient. 20 | """ 21 | 22 | def __init__(self, data: np.array, requires_grad: bool = True): 23 | self.data = data 24 | self.grad = np.zeros_like(data, dtype=np.float32) 25 | self.requires_grad = requires_grad 26 | 27 | def __repr__(self): 28 | return f"Parameter(shape={self.data.shape}, requires_grad={self.requires_grad})" 29 | 30 | 31 | class Module(ABC): 32 | """ 33 | A module is a stateful object, encapsulating a function and keeping track 34 | of the trainable parameters. It also keeps track of cached activations. 35 | """ 36 | 37 | def __init__(self): 38 | self._params = {} 39 | self._cache = {} 40 | self._training = True 41 | 42 | def __call__(self, inputs, mubatch_id=0): 43 | return self.forward(inputs, mubatch_id=mubatch_id) 44 | 45 | @abstractmethod 46 | def forward(self, inputs: np.array, mubatch_id=0): 47 | raise NotImplementedError 48 | 49 | @abstractmethod 50 | def backward(self, dout: np.array, mubatch_id=0): 51 | raise NotImplementedError 52 | 53 | def train(self): 54 | self._training = True 55 | 56 | def eval(self): 57 | self._training = False 58 | 59 | def zero_grad(self): 60 | for param in self.parameters(): 61 | param.grad.fill(0.0) 62 | 63 | def parameters(self): 64 | return list(self._params.values()) 65 | 66 | 67 | class ReLU(Module): 68 | def forward(self, inputs, mubatch_id=0): 69 | if self._training: 70 | self._cache[f"bitmask_{mubatch_id}"] = inputs > 0 71 | return relu(inputs) 72 | 73 | def backward(self, dout, mubatch_id=0): 74 | assert self._training 75 | dout = relu_grad(dout, self._cache[f"bitmask_{mubatch_id}"]) 76 | del self._cache[f"bitmask_{mubatch_id}"] 77 | return dout 78 | 79 | def __repr__(self): 80 | return "ReLU()" 81 | 82 | 83 | class Softmax(Module): 84 | def forward(self, inputs, mubatch_id=0): 85 | if self._training: 86 | self._cache[f"input_{mubatch_id}"] = inputs 87 | return softmax(inputs) 88 | 89 | def backward(self, dout, mubatch_id=0): 90 | assert self._training 91 | dout = softmax_grad(dout, self._cache[f"input_{mubatch_id}"]) 92 | del self._cache[f"input_{mubatch_id}"] 93 | return dout 94 | 95 | def __repr__(self): 96 | return "Softmax()" 97 | 98 | 99 | class Linear(Module): 100 | def __init__(self, in_dims, out_dims, activation="relu"): 101 | super().__init__() 102 | assert activation is None or activation == "relu" 103 | 104 | # we want to get the same initial weights, no matter 105 | # if the model is distributed across workers or not 106 | rs = RandomState(MT19937(SeedSequence(in_dims + out_dims * 1337))) 107 | 108 | self.activation = ReLU() if activation == "relu" else None 109 | self._params["W"] = Parameter( 110 | rs.normal(0.0, 1.0, (out_dims, in_dims)).astype(np.float32) 111 | / np.sqrt(in_dims) 112 | ) 113 | self._params["b"] = Parameter(np.zeros((1, out_dims), dtype=np.float32)) 114 | 115 | def forward(self, inputs, mubatch_id=0): 116 | if self._training: 117 | self._cache[f"input_{mubatch_id}"] = inputs 118 | result = linear(inputs, self._params["W"].data, self._params["b"].data) 119 | 120 | if self.activation: 121 | return self.activation(result, mubatch_id) 122 | return result 123 | 124 | def backward(self, dout, mubatch_id=0): 125 | assert self._training 126 | 127 | if self.activation: 128 | dout = self.activation.backward(dout, mubatch_id) 129 | 130 | dout, dW, db = linear_grad( 131 | dout, self._cache[f"input_{mubatch_id}"], self._params["W"].data 132 | ) 133 | 134 | # accumulate gradients 135 | self._params["W"].grad += dW 136 | self._params["b"].grad += db 137 | 138 | del self._cache[f"input_{mubatch_id}"] 139 | return dout 140 | 141 | def __repr__(self): 142 | return f"Linear({self._params['W'].data.shape[1]}->{self._params['W'].data.shape[0]}, act: {self.activation})" 143 | 144 | 145 | class MSELoss(Module): 146 | def __init__(self, batch_size: int): 147 | super().__init__() 148 | self.batch_size = batch_size 149 | 150 | # You don't need to calculate the loss to compute the gradient 151 | # so we just don't do it 152 | def forward(self, input: np.array, mubatch_id=0): 153 | if self._training: 154 | self._cache[f"input_{mubatch_id}"] = input 155 | return input 156 | 157 | def backward(self, target, mubatch_id=0): 158 | assert self._training 159 | dout = mse_loss_grad( 160 | self._cache[f"input_{mubatch_id}"], target, self.batch_size 161 | ) 162 | del self._cache[f"input_{mubatch_id}"] 163 | return dout 164 | 165 | def __repr__(self): 166 | return f"MSELoss()" 167 | 168 | 169 | class Sequential(Module): 170 | def __init__(self, layers: list[Module]): 171 | super().__init__() 172 | self.layers = layers 173 | self._grad_hooks = [] 174 | self._post_grad_hooks = [] 175 | 176 | def forward(self, inputs, mubatch_id=0): 177 | result = inputs 178 | for layer in self.layers: 179 | result = layer(result, mubatch_id) 180 | return result 181 | 182 | def register_grad_hook(self, hook): 183 | """ 184 | Register a hook to be run when the gradient for a parameter has been calculated 185 | """ 186 | assert id not in self._grad_hooks 187 | self._grad_hooks.append(hook) 188 | 189 | def reset_grad_hooks(self): 190 | self._grad_hooks = [] 191 | 192 | def register_post_grad_hook(self, hook): 193 | """ 194 | Register a hook to be run before returning from the backwards()-function 195 | """ 196 | self._post_grad_hooks.append(hook) 197 | 198 | def reset_post_grad_hooks(self): 199 | self._post_grad_hooks = [] 200 | 201 | def backward(self, dout, mubatch_id=0): 202 | result = dout 203 | for layer in reversed(self.layers): 204 | result = layer.backward(result, mubatch_id) 205 | 206 | for hook in self._grad_hooks: 207 | for param in layer.parameters(): 208 | hook(param) 209 | 210 | for hook in self._post_grad_hooks: 211 | hook(self.parameters()) 212 | 213 | return result 214 | 215 | def train(self): 216 | self._training = True 217 | for l in self.layers: 218 | l.train() 219 | 220 | def eval(self): 221 | self._training = False 222 | for l in self.layers: 223 | l.eval() 224 | 225 | def zero_grad(self): 226 | for l in self.layers: 227 | l.zero_grad() 228 | 229 | def parameters(self): 230 | result = [] 231 | for l in self.layers: 232 | result += l.parameters() 233 | return result 234 | 235 | 236 | class MLP(Sequential): 237 | def __init__(self, sizes: list[int], stage_idx, n_stages, batch_size): 238 | """ 239 | :param batch_size: The total batch size. This is necessary for rescaling the 240 | loss while operating on DP-slices & μBatches 241 | """ 242 | assert len(sizes) % n_stages == 0 243 | stage_size = len(sizes) // n_stages 244 | 245 | # construct & init local layers 246 | is_last_stage = stage_idx == n_stages - 1 247 | local_sizes = sizes[ 248 | stage_idx 249 | * stage_size : min(len(sizes), stage_size * stage_idx + stage_size + 1) 250 | ] 251 | layers = [ 252 | Linear( 253 | local_sizes[i], 254 | local_sizes[i + 1], 255 | activation=None 256 | if i == len(local_sizes) - 2 and is_last_stage 257 | else "relu", 258 | ) 259 | for i in range(len(local_sizes) - 1) 260 | ] 261 | if is_last_stage: 262 | layers.append(Softmax()) 263 | layers.append(MSELoss(batch_size=batch_size)) 264 | super().__init__(layers) 265 | 266 | print(layers) 267 | 268 | self.in_dim = local_sizes[0] 269 | # softmax & losses don't change output dimensions 270 | self.out_dim = local_sizes[-1] 271 | -------------------------------------------------------------------------------- /shallowspeed/pipe.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from abc import ABC, abstractmethod 3 | from dataclasses import dataclass 4 | 5 | import numpy as np 6 | from mpi4py import MPI 7 | 8 | from shallowspeed.dataset import Dataset 9 | from shallowspeed.layers import MLP 10 | 11 | 12 | class PipeInstr: 13 | pass 14 | 15 | 16 | @dataclass 17 | class ZeroGrad(PipeInstr): 18 | """ 19 | Set param.grad to zero for all trainable parameters in model 20 | This starts a new phase of gradient accumulation 21 | """ 22 | 23 | pass 24 | 25 | 26 | @dataclass 27 | class OptimizerStep(PipeInstr): 28 | """ 29 | Update the trainable parameters of the model using parameter.grad 30 | """ 31 | 32 | pass 33 | 34 | 35 | @dataclass 36 | class BufferPipeInstr(PipeInstr): 37 | buffer_id: int 38 | 39 | 40 | @dataclass 41 | class RecvActivations(BufferPipeInstr): 42 | """ 43 | Recv activations from the previous pipeline stage and store them into 44 | a buffer. Currently, this is a blocking Op. 45 | """ 46 | 47 | pass 48 | 49 | 50 | @dataclass 51 | class SendActivations(BufferPipeInstr): 52 | """ 53 | Send the results of the local FWD-pass to the next pipeline stage 54 | Currently, this is a blocking Op. 55 | """ 56 | 57 | pass 58 | 59 | 60 | @dataclass 61 | class RecvOutputGrad(BufferPipeInstr): 62 | """ 63 | Recv gradients wrt the outputs of the local FWD-pass from the next pipeline stage 64 | and store them into a buffer. Currently, this is a blocking Op. 65 | """ 66 | 67 | pass 68 | 69 | 70 | @dataclass 71 | class SendInputGrad(BufferPipeInstr): 72 | """ 73 | Send gradients wrt the inputs of the local FWD-pass to the previous pipeline stage 74 | and store them into a buffer. Currently, this is a blocking Op. 75 | """ 76 | 77 | pass 78 | 79 | 80 | @dataclass 81 | class MuBatchPipeInstr(PipeInstr): 82 | buffer_id: int 83 | mubatch_id: int 84 | 85 | 86 | @dataclass 87 | class Forward(MuBatchPipeInstr): 88 | """ 89 | Perform a local FWD-pass on the given `mubatch_id`-μBatch and store the 90 | result in the given buffer. 91 | """ 92 | 93 | pass 94 | 95 | 96 | @dataclass 97 | class BackwardGradAcc(MuBatchPipeInstr): 98 | """ 99 | Perform a local BWD-pass on the given `mubatch_id`-μBatch and store the 100 | result in the given buffer. Accumulate the gradients for each parameter 101 | in param.grad (ie param.grad += ). 102 | """ 103 | 104 | pass 105 | 106 | 107 | @dataclass 108 | class BackwardGradAllReduce(MuBatchPipeInstr): 109 | """ 110 | Like `BackwardGradAcc`, but start non-blocking AllReduce for each param.grad 111 | once it has been computed. This interleaves communication & computation while 112 | performing the local BWDs pass. 113 | """ 114 | 115 | pass 116 | 117 | 118 | @dataclass 119 | class LoadInstruction(MuBatchPipeInstr): 120 | pass 121 | 122 | 123 | @dataclass 124 | class LoadMuBatchInput(LoadInstruction): 125 | """ 126 | Load the inputs X of a new μBatch into the given buffer. 127 | """ 128 | 129 | pass 130 | 131 | 132 | @dataclass 133 | class LoadMuBatchTarget(LoadInstruction): 134 | """ 135 | Load the targets y of a new μBatch into the given buffer. 136 | """ 137 | 138 | pass 139 | 140 | 141 | class Schedule(ABC): 142 | def __init__(self, num_micro_batches, num_stages, stage_id): 143 | assert stage_id < num_stages 144 | self.num_stages = num_stages 145 | self.stage_id = stage_id 146 | self.num_micro_batches = num_micro_batches 147 | 148 | @abstractmethod 149 | def steps(self): 150 | """ 151 | This returns a generator, which contains all the operations to 152 | process a single batch 153 | """ 154 | pass 155 | 156 | @property 157 | @abstractmethod 158 | def num_buffers(self): 159 | """ 160 | The number of buffers necessary for sending & receiving data. 161 | This should always be a multiple of 2, since we have input buffers and 162 | corresponding output buffers (at least during training) 163 | """ 164 | pass 165 | 166 | @property 167 | def is_first_stage(self): 168 | return self.stage_id == 0 169 | 170 | @property 171 | def is_last_stage(self): 172 | return self.stage_id == self.num_stages - 1 173 | 174 | def is_first_mubatch(self, mubatch_id): 175 | return mubatch_id == 0 176 | 177 | def is_last_mubatch(self, mubatch_id): 178 | return mubatch_id == self.num_micro_batches - 1 179 | 180 | def is_valid_stage_id(self, stage_id): 181 | return 0 <= stage_id < self.num_stages 182 | 183 | 184 | class NaiveParallelSchedule(Schedule): 185 | """ 186 | A pipeline schedule without any interleaving of μBatches. 187 | Only one pipeline stage is activate at any given time 188 | """ 189 | 190 | def steps(self): 191 | yield [ZeroGrad()] 192 | for mubatch_id in range(self.num_micro_batches): 193 | yield self.steps_mubatch(mubatch_id) 194 | # updating the weights is the last step of processing a batch 195 | yield [OptimizerStep()] 196 | 197 | def steps_mubatch(self, mubatch_id): 198 | cmds = [] 199 | if self.is_first_stage: 200 | cmds.append(LoadMuBatchInput(mubatch_id=mubatch_id, buffer_id=0)) 201 | else: 202 | cmds.append((RecvActivations(buffer_id=0))) 203 | cmds.append(Forward(buffer_id=0, mubatch_id=mubatch_id)) 204 | if self.is_last_stage: 205 | cmds.append(LoadMuBatchTarget(mubatch_id=mubatch_id, buffer_id=0)) 206 | else: 207 | cmds.append(SendActivations(buffer_id=0)) 208 | cmds.append(RecvOutputGrad(buffer_id=0)) 209 | if self.is_last_mubatch(mubatch_id): 210 | cmds.append(BackwardGradAllReduce(buffer_id=0, mubatch_id=mubatch_id)) 211 | else: 212 | cmds.append(BackwardGradAcc(buffer_id=0, mubatch_id=mubatch_id)) 213 | if not self.is_first_stage: 214 | cmds.append(SendInputGrad(buffer_id=0)) 215 | return cmds 216 | 217 | @property 218 | def num_buffers(self): 219 | # need 1 Buffer for receiving input and 1 buffer for sending output 220 | # since this is naive PP, there's only ever one μB in flight at the 221 | # same time 222 | return 2 223 | 224 | 225 | class GPipeSchedule(Schedule): 226 | def steps(self): 227 | yield [ZeroGrad()] 228 | 229 | # STAGE 1: FWD all μBatches 230 | for mubatch_id in range(self.num_micro_batches): 231 | yield self.steps_FWD_mubatch(mubatch_id) 232 | 233 | # STAGE 2: BWD all μBatches 234 | for mubatch_id in reversed(range(self.num_micro_batches)): 235 | yield from self.steps_BWD_mubatch(mubatch_id) 236 | 237 | # updating the weights is the last step of processing any batch 238 | yield [OptimizerStep()] 239 | 240 | def steps_BWD_mubatch(self, mubatch_id): 241 | cmds = [] 242 | if self.is_last_stage: 243 | cmds.append(LoadMuBatchTarget(mubatch_id=mubatch_id, buffer_id=0)) 244 | else: 245 | cmds.append(RecvOutputGrad(buffer_id=0)) 246 | if self.is_first_mubatch(mubatch_id): 247 | # interleaved backprop & AllReduce during last μBatch of BWD 248 | cmds.append(BackwardGradAllReduce(buffer_id=0, mubatch_id=mubatch_id)) 249 | else: 250 | cmds.append(BackwardGradAcc(buffer_id=0, mubatch_id=mubatch_id)) 251 | if not self.is_first_stage: 252 | cmds.append(SendInputGrad(buffer_id=0)) 253 | yield cmds 254 | 255 | def steps_FWD_mubatch(self, mubatch_id): 256 | cmds = [] 257 | if self.is_first_stage: 258 | cmds.append(LoadMuBatchInput(buffer_id=0, mubatch_id=mubatch_id)) 259 | else: 260 | cmds.append(RecvActivations(buffer_id=0)) 261 | cmds.append(Forward(buffer_id=0, mubatch_id=mubatch_id)) 262 | # the last stage just discards the output of its `forward()` pass since 263 | # it's not necessary for running BWD. The last stage just needs the target values 264 | # (loaded from disk) and the activations (cached inside the `Module`s) for BWD. 265 | if not self.is_last_stage: 266 | cmds.append(SendActivations(buffer_id=0)) 267 | return cmds 268 | 269 | @property 270 | def num_buffers(self): 271 | # TODO should keep more buffers around and make the sending & receiving async 272 | return 2 273 | 274 | 275 | class InferenceSchedule(Schedule): 276 | def steps(self): 277 | for mubatch_id in range(self.num_micro_batches): 278 | cmds = [] 279 | 280 | if self.is_first_stage: 281 | cmds.append(LoadMuBatchInput(mubatch_id=mubatch_id, buffer_id=0)) 282 | else: 283 | cmds.append(RecvActivations(buffer_id=0)) 284 | 285 | cmds.append(Forward(buffer_id=0, mubatch_id=mubatch_id)) 286 | 287 | if not self.is_last_stage: 288 | cmds.append(SendActivations(buffer_id=0)) 289 | yield cmds 290 | 291 | @property 292 | def num_buffers(self): 293 | # Could be done with 1 buffer (by doing the FWD inplace) 294 | return 2 295 | 296 | 297 | class PipeDreamSchedule: 298 | def __init__(self): 299 | raise NotImplementedError() 300 | 301 | 302 | def backprop_allreduce_gradient(comm, param): 303 | """ 304 | start a non-blocking AllReduce for the parameters for which we just 305 | calculated the final gradient. 306 | This interleaves communication of this layer's gradients with 307 | computation of the next layers gradients 308 | 309 | Starting a new communication for each parameter is quite wasteful, particularly if 310 | the parameters are small. PyTorch's DDP implementation uses bucketing to get around this. 311 | """ 312 | if param.requires_grad: 313 | # we won't be touching param.grad until the Op is done, so we do it inplace 314 | param._request = comm.Iallreduce( 315 | sendbuf=MPI.IN_PLACE, recvbuf=param.grad, op=MPI.SUM 316 | ) 317 | 318 | 319 | def backprop_block_for_comms(params): 320 | # after the full backwards pass we wait for all communication to finish 321 | # only then can we be certain that the gradients are the same on all processes 322 | requests = [ 323 | param._request 324 | for param in params 325 | if param.requires_grad and param._request is not None 326 | ] 327 | MPI.Request.Waitall(requests) 328 | 329 | 330 | class Worker: 331 | """ 332 | Executes all stages in a schedule, during each batch 333 | The buffers don't keep any state between batches. 334 | """ 335 | 336 | input_buffers = None 337 | output_buffers = None 338 | 339 | def __init__( 340 | self, 341 | dp_comm: MPI.Comm, 342 | pp_comm: MPI.Comm, 343 | model: MLP, 344 | dataset: Dataset, 345 | optimizer, 346 | ): 347 | self.stage_id = pp_comm.Get_rank() 348 | self.pipeline_depth = pp_comm.Get_size() 349 | self.dp_comm = dp_comm 350 | self.pp_comm = pp_comm 351 | self.model = model 352 | self.dataset = dataset 353 | self.optimizer = optimizer 354 | 355 | def load_micro_batch_input(self, batch_id, mubatch_id, buffer_id): 356 | data = self.dataset.load_micro_batch_input(batch_id, mubatch_id) 357 | assert ( 358 | data.shape == self.input_buffers[buffer_id].shape 359 | ), f"shape is {data.shape} but should be {self.input_buffers[buffer_id].shape}" 360 | self.input_buffers[buffer_id] = data 361 | 362 | def load_micro_batch_target(self, batch_id, mubatch_id, buffer_id): 363 | data = self.dataset.load_micro_batch_target(batch_id, mubatch_id) 364 | assert self.output_buffers[buffer_id].shape == data.shape 365 | self.output_buffers[buffer_id] = data 366 | 367 | def send_activations(self, buffer_id): 368 | # send forwards 369 | self.pp_comm.Send(self.output_buffers[buffer_id], self.get_successor()) 370 | 371 | def recv_activations(self, buffer_id): 372 | # receive from previous 373 | self.pp_comm.Recv(self.input_buffers[buffer_id], self.get_predecessor()) 374 | 375 | def send_grad(self, buffer_id): 376 | # send backwards 377 | self.pp_comm.Send(self.input_buffers[buffer_id], self.get_predecessor()) 378 | 379 | def recv_grad(self, buffer_id): 380 | # receive from next 381 | self.pp_comm.Recv(self.output_buffers[buffer_id], self.get_successor()) 382 | 383 | def forward(self, buffer_id, mubatch_id): 384 | # FWD pass transforms input buffer into output buffer 385 | self.output_buffers[buffer_id] = self.model.forward( 386 | inputs=self.input_buffers[buffer_id], mubatch_id=mubatch_id 387 | ) 388 | 389 | def backward_and_reduce(self, buffer_id, mubatch_id): 390 | # hooks for AllReducing-ing the gradient across all dp_workers 391 | self.model.register_grad_hook( 392 | lambda param: backprop_allreduce_gradient(self.dp_comm, param) 393 | ) 394 | self.model.register_post_grad_hook(backprop_block_for_comms) 395 | 396 | # regular backwards pass will trigger the AR hooks 397 | self.backward_accumulate(buffer_id, mubatch_id=mubatch_id) 398 | 399 | self.model.reset_grad_hooks() 400 | self.model.reset_post_grad_hooks() 401 | 402 | def backward_accumulate(self, buffer_id, mubatch_id): 403 | # BWD pass transforms output buffer into input buffer 404 | self.input_buffers[buffer_id] = self.model.backward( 405 | dout=self.output_buffers[buffer_id], mubatch_id=mubatch_id 406 | ) 407 | 408 | def optimizer_step(self): 409 | self.optimizer.step() 410 | 411 | def zero_grad(self): 412 | self.model.zero_grad() 413 | 414 | def get_predecessor(self): 415 | return self.stage_id - 1 416 | 417 | def get_successor(self): 418 | return self.stage_id + 1 419 | 420 | _INSTRUCTION_MAP = { 421 | LoadMuBatchInput: load_micro_batch_input, 422 | LoadMuBatchTarget: load_micro_batch_target, 423 | Forward: forward, 424 | BackwardGradAllReduce: backward_and_reduce, 425 | BackwardGradAcc: backward_accumulate, 426 | OptimizerStep: optimizer_step, 427 | ZeroGrad: zero_grad, 428 | RecvActivations: recv_activations, 429 | SendActivations: send_activations, 430 | RecvOutputGrad: recv_grad, 431 | SendInputGrad: send_grad, 432 | } 433 | 434 | def execute(self, sched, batch_id): 435 | """ 436 | Setup buffers and use the configured schedule to execute a full batch 437 | 438 | Basically it'll just call the right function, given whatever the scheduler 439 | tells it to do 440 | """ 441 | 442 | # The buffers hold activations during FWD passes and gradients during BWD 443 | # activation.shape == grad.shape, hence we can reuse the same buffers for FWD & BWD 444 | # TODO make buffers persistent for the whole training run by setting μBatch-size 445 | # and schedule during __init__. Implement a worker.teardown() for free'ing the buffers. 446 | assert sched.num_buffers % 2 == 0 447 | self.input_buffers = [ 448 | np.empty((self.dataset.mubatch_size, self.model.in_dim), dtype=np.float32) 449 | for _ in range(sched.num_buffers // 2) 450 | ] 451 | self.output_buffers = [ 452 | np.empty((self.dataset.mubatch_size, self.model.out_dim), dtype=np.float32) 453 | for _ in range(sched.num_buffers // 2) 454 | ] 455 | 456 | for commands in sched.steps(): 457 | for command in commands: 458 | if isinstance(command, LoadInstruction): 459 | # data loaders need the current batch_id, every other instruction doesn't. 460 | self._INSTRUCTION_MAP[type(command)]( 461 | self, batch_id, **dataclasses.asdict(command) 462 | ) 463 | else: 464 | self._INSTRUCTION_MAP[type(command)]( 465 | self, **dataclasses.asdict(command) 466 | ) 467 | --------------------------------------------------------------------------------