├── .gitignore ├── README.md ├── benchmark.py ├── contiguous_params ├── __init__.py └── params.py ├── setup.py ├── test.py └── visualizations ├── adam_gradnorm_trace_comparison.png └── sgd_trace_comparison.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Editor 2 | .vscode 3 | 4 | # Profiling timelines. 5 | *timeline.json 6 | 7 | # Python 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | *.pytest_cache 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # Jupyter Notebook 38 | .ipynb_checkpoints 39 | 40 | # IPython 41 | profile_default/ 42 | ipython_config.py 43 | 44 | 45 | # Environments 46 | .env 47 | .venv 48 | env/ 49 | venv/ 50 | ENV/ 51 | env.bak/ 52 | venv.bak/ 53 | 54 | 55 | # OS generated files # 56 | ###################### 57 | .DS_Store 58 | .DS_Store? 59 | ._* 60 | .Spotlight-V100 61 | .Trashes 62 | ehthumbs.db 63 | Thumbs.db 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Contiguous Parameters for Pytorch 2 | 3 | Accelerate training by storing parameters in one contiguous chunk of memory. 4 | 5 | ## Speed up your optimizer with 3 lines of code! 6 | This graphic shows a GPU step trace comparison with and without contiguous params for a Resnet50 on Cifar10, using *Adam and gradient clipping*. 7 | The upper trace is with the default optimizer, the trace below is with the parameter wrapper. 8 | ![Gradient norm + Adam](visualizations/adam_gradnorm_trace_comparison.png) 9 | 10 | Step trace comparison for a Resnet50 on Cifar10, using *SGD*. 11 | ![Gradient norm + Adam](visualizations/sgd_trace_comparison.png) 12 | 13 | 14 | ## What's the difference to Apex? 15 | Apex implements the full optimizer update in C++ and is limited to the supported 16 | optimizers. This wrapper allows to use any optimizer as long as it updates the 17 | parameters inplace. 18 | 19 | 20 | ## How does it work? 21 | Launching Cuda kernels comes with a small overhead, resulting in low GPU utilization 22 | when launching numerous fast-returning kernels. A typical example for this is the 23 | optimizer step. 24 | This package accelerates training by copying all parameters into one contiguous 25 | buffer, resetting the parameters to be views into the buffer, and applying 26 | optimizer updates on the contiguous representation. Depending on the model, the 27 | optimizer, the type of GPU used, etc, this can drastically reduce the time required for the optimizer's step function, resulting in speedups from anywhere between 7x to 100x. 28 | 29 | 30 | For this to work, two requirements need to be fulfilled: 31 | 1. The computation graph may only alter the parameters and gradients inplace 32 | and should not replace the parameter/gradient tensors with new ones. 33 | Make sure to call `parameters.assert_buffer_is_valid()` to detect any buffer 34 | invalidation. 35 | 2. All operations executed on `parameters.contiguous()` must not rely on shape 36 | information or statistics of the parameter as these would be computed on the 37 | full buffer instead of each of the original parameters. For such operations, 38 | keep using `parameters.original()`. 39 | 40 | ## Disclaimer 41 | This is still a rather new project and considered experimental. If you encounter 42 | a bug, please file an issue if there is no matching existing issue! Also, if you 43 | find this project helpful, consider leaving a star to keep me motivated or spread 44 | the word and help people to train their models faster :) 45 | 46 | ## Install 47 | To get the most recent version, it's easiest to install the package directly from 48 | github: 49 | ``` 50 | pip install git+https://github.com/philjd/contiguous_pytorch_params.git 51 | ``` 52 | Alternatively, [frgfm](https://github.com/frgfm) has kindly created a pip package (`pip install contiguous-params`) 53 | and a conda package ` conda install -c frgfm contiguous_params`. 54 | 55 | 56 | ## Example Usage 57 | ```python 58 | import torch 59 | from torch import nn 60 | from contiguous_params import ContiguousParams 61 | 62 | data = torch.randn(5, 1, 8) 63 | model = nn.Sequential(nn.Linear(8, 8), nn.Linear(8, 8)) 64 | 65 | # Create the contiguous parameters. 66 | parameters = ContiguousParams(model.parameters()) # <--- (1) Wrap parameters. 67 | 68 | # Use parameters.contiguous() instead of model.parameters() to initialize 69 | # the optimizer. Note that the optimizer must update the parameters inplace. 70 | optimizer = torch.optim.Adam(parameters.contiguous()) # <--- (2) Optimize view. 71 | 72 | # Run the training loop as usual. 73 | for x in data: 74 | loss = model(x).sum() 75 | loss.backward() 76 | # Gradient clipping also profits from contiguous memory. 77 | nn.utils.clip_grad_norm_(parameters.contiguous(), 0.1) 78 | optimizer.step() 79 | optimizer.zero_grad() 80 | # !!!!!!! 81 | # Always make sure to call buffer_is_valid() at least once, to detect 82 | # if operations invalidated the buffer by overwriting/copying parameters. 83 | # (Except when running in DDP mode, there the buffer check doesn't work.) 84 | # !!!!!!! 85 | parameters.assert_buffer_is_valid() # <--- (3) Check that the optimizer only applies valid ops. 86 | ``` 87 | 88 | ## Debugging 89 | Common Problems that might occur: 90 | - The loss is not going down. One reason for this could be that gradients are 91 | disconnected and don't use the contiguous grad buffer. This can happen 92 | when the optimizer with the contiguous params is created before moving the 93 | model to its device. A good check is to verify that the gradient_buffer 94 | tensor is non-zero. 95 | - A function updates a parameter with an operation that is not inplace (inplace 96 | ops have an underscore suffix). This can be catched with the 97 | `ContiguousParams.assert_buffer_is_valid()` function, so make sure to use it 98 | at least once per forward pass. 99 | - Operations try to change the parameter views inplace. This happens for 100 | example when `nn.Module.zero_grad()` is used instead of 101 | `optimizer.zero_grad()`. Either override your module's zero_grad function 102 | to link to the optmizer's zero_grad or manually `zero_` the contiguous grad 103 | buffer. 104 | 105 | 106 | ## Testing 107 | ``` 108 | pytest test.py 109 | ``` 110 | 111 | ## Benchmarking 112 | Run `python benchmark.py`. This applies several updates with the original method 113 | as well as using contiguous parameters. You should see a speed up of ~100x. 114 | To take a look at the timeline, open chromium, navigate to `chrome://tracing/`, 115 | click load, and select the `*timeline.json` file. 116 | 117 | ## Distributed Data Parallel Training 118 | Training with DDP is also easy, we just need to make sure that the parameters for each replica are contiguous. 119 | To understand where we should insert the ContiguousParams into our `nn.Module`, let's first recap how DDP 120 | works: 121 | 1. Create the reference model. 122 | 2. Replicate the model onto the respective devices. 123 | 3. Wrap as DDP module. This creates hooks between gradients, ensuring that they 124 | get synced across devices during `backward`. Note: DDP does not allow Parameters to change 125 | after this step. 126 | 4. Initialize an optimizer for each device with the device's parameters. Each 127 | device calls `optimizer.step` for its own parameters but with the same 128 | gradients, due to syncing. This means we perform the same update on each 129 | device and end up with the same set of parameters, saving the round of 130 | syncing of parameters before the forward pass, which would be necessary if 131 | we would use only one device for computing `step`. 132 | 133 | This means, the contiguous parameters need to be created after step 2 but 134 | before step 3. The easiest way to do this is to create your optimizer after 135 | moving the model to the desired device, otherwise you need to wrap the `Module.cuda` 136 | and `Module.cpu` functions and recreate the contiguous parameters there. 137 | Note: the buffer invalidation check currently doesn't work with DDP. 138 | 139 | Contiguous params work with pytorch_lightning's DDP implementation for versions > 0.9 140 | or on master after [this commit](https://github.com/PyTorchLightning/pytorch-lightning/commit/e3528afae3f178cf9d5d8ea6bc3f8a876646054a). 141 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | """Code to compare optimizer step times and generate tracing timelines.""" 2 | import torch 3 | from torch import nn 4 | from torch.autograd import profiler 5 | import time 6 | 7 | from copy import deepcopy 8 | 9 | from contiguous_params import ContiguousParams 10 | 11 | 12 | def benchmark_model(model, optimizer, parameters, name): 13 | # Run 14 | step_times = [] 15 | # Autograd profiler adds some overhead, so we time the forward pass with 16 | # and without enabling it. 17 | for profile_autograd in [False, True]: 18 | with profiler.profile(enabled=profile_autograd, use_cuda=(device == "cuda")) as prof: 19 | for i in range(15): 20 | # Warm up for five steps, reset step_times after this. 21 | if i == 5: 22 | step_times = [] 23 | with profiler.record_function("forward"): 24 | loss = model(x).sum() 25 | with profiler.record_function("backward"): 26 | loss.backward() 27 | torch.cuda.synchronize() 28 | start = time.time() 29 | with profiler.record_function("gradient_norm"): 30 | torch.nn.utils.clip_grad_norm_(parameters, 0.1) 31 | with profiler.record_function("step"): 32 | optimizer.step() 33 | with profiler.record_function("zero_grad"): 34 | optimizer.zero_grad() 35 | torch.cuda.synchronize() 36 | step_times.append(time.time() - start) 37 | print(f"Mean step time: {sum(step_times) / 10} seconds. " 38 | f"(Autograd profiler enabled: {profile_autograd})") 39 | prof.export_chrome_trace(f"{name}_timeline.json") 40 | 41 | 42 | if __name__ == "__main__": 43 | device = "cuda" 44 | model = nn.Sequential(*[nn.Linear(128, 128) for i in range(100)]).to(device) 45 | print("Number of parameters: ", sum(p.numel() for p in model.parameters())) 46 | x = torch.randn(1, 128).to(device) 47 | 48 | model_copies = [deepcopy(model) for _ in range(2)] 49 | 50 | # Benchmark original. 51 | parameters = list(model_copies[0].parameters()) 52 | optimizer = torch.optim.Adam(parameters) 53 | benchmark_model(model_copies[0], optimizer, parameters, "original_params") 54 | 55 | # Benchmark contiguous. 56 | parameters = ContiguousParams(model_copies[1].parameters()) 57 | optimizer = torch.optim.Adam(parameters.contiguous()) 58 | benchmark_model(model_copies[1], optimizer, parameters.contiguous(), 59 | "contiguous_params") 60 | # Ensure the parameter buffers are still valid. 61 | parameters.assert_buffer_is_valid() 62 | 63 | -------------------------------------------------------------------------------- /contiguous_params/__init__.py: -------------------------------------------------------------------------------- 1 | from .params import ContiguousParams -------------------------------------------------------------------------------- /contiguous_params/params.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ContiguousParams: 6 | 7 | def __init__(self, parameters): 8 | # Create a list of the parameters to prevent emptying an iterator. 9 | self._parameters = list(parameters) 10 | self._param_buffer = None 11 | self._grad_buffer = None 12 | self._init_buffers() 13 | # Store the data pointers for each parameter into the buffer. These 14 | # can be used to check if an operation overwrites the gradient/data 15 | # tensor (invalidating the assumption of a contiguous buffer). 16 | self.data_pointers = [] 17 | self.grad_pointers = [] 18 | self.make_params_contiguous() 19 | 20 | def _init_buffers(self): 21 | dtype = self._parameters[0].dtype 22 | device = self._parameters[0].device 23 | if not all(p.dtype == dtype for p in self._parameters): 24 | raise ValueError("All parameters must be of the same dtype.") 25 | if not all(p.device == device for p in self._parameters): 26 | raise ValueError("All parameters must be on the same device.") 27 | size = sum(p.numel() for p in self._parameters) 28 | self._param_buffer = torch.zeros(size, dtype=dtype, device=device) 29 | self._grad_buffer = torch.zeros(size, dtype=dtype, device=device) 30 | 31 | def make_params_contiguous(self): 32 | """Create a buffer to hold all params and update the params to be views of the buffer. 33 | 34 | Args: 35 | parameters: An iterable of parameters. 36 | """ 37 | index = 0 38 | for p in self._parameters: 39 | size = p.numel() 40 | self._param_buffer[index:index + size] = p.data.view(-1) 41 | p.data = self._param_buffer[index:index + size].view(p.data.shape) 42 | p.grad = self._grad_buffer[index:index + size].view(p.data.shape) 43 | self.data_pointers.append(p.data.data_ptr()) 44 | self.grad_pointers.append(p.grad.data.data_ptr()) 45 | index += size 46 | # Bend the param_buffer to use grad_buffer to track its gradients. 47 | self._param_buffer.grad = self._grad_buffer 48 | 49 | def contiguous(self): 50 | """Return all parameters as one contiguous buffer.""" 51 | return [self._param_buffer] 52 | 53 | def original(self): 54 | """Return the non-flattened parameters.""" 55 | return self._parameters 56 | 57 | def buffer_is_valid(self): 58 | """Verify that all parameters and gradients still use the buffer.""" 59 | params_and_pointers = zip(self._parameters, 60 | self.data_pointers, 61 | self.grad_pointers) 62 | return all((p.data.data_ptr() == data_ptr) and 63 | (p.grad.data.data_ptr() == grad_ptr) 64 | for p, data_ptr, grad_ptr in params_and_pointers) 65 | 66 | def assert_buffer_is_valid(self): 67 | if not self.buffer_is_valid(): 68 | raise ValueError( 69 | "The data or gradient buffer has been invalidated. Please make " 70 | "sure to use inplace operations only when updating parameters " 71 | "or gradients.") 72 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("README.md", "r") as f: 4 | long_description = f.read() 5 | 6 | setup( 7 | name="contiguous_params", 8 | version="1.0", 9 | description="Make pytorch parameters contiguous to speed up training by 100x.", 10 | license="Apache 2.0", 11 | long_description=long_description, 12 | author="Philipp Jund", 13 | author_email="ijund.phil@gmail.com", 14 | url="http://www.github.com/philjd/contiguous_pytorch_params", 15 | packages=["contiguous_params"], 16 | keywords="pytorch contiguous parameters speed up accelerate", 17 | ) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """Test contiguous parameter functions.""" 2 | from copy import deepcopy 3 | import os 4 | 5 | import pytest 6 | import torch 7 | import numpy as np 8 | from torch import nn 9 | from torch.utils.data import TensorDataset, DataLoader 10 | import pytorch_lightning 11 | 12 | from contiguous_params import ContiguousParams 13 | 14 | 15 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 16 | def test_equal_optimizer_update(device): 17 | if device == "cuda" and not torch.cuda.is_available(): 18 | print("No GPU available, skipping GPU test.") 19 | return 20 | """Verify that the parameters are the same after a few updates.""" 21 | x = torch.randn(1, 8).to(device) 22 | 23 | model_ref = nn.Sequential(*[nn.Linear(8, 8) for i in range(10)]) 24 | model_ref = model_ref.to(device) 25 | optimizer = torch.optim.SGD(model_ref.parameters(), lr=1e-3) 26 | 27 | model_c = deepcopy(model_ref) 28 | parameters_c = ContiguousParams(model_c.parameters()) 29 | optimizer_c = torch.optim.SGD(parameters_c.contiguous(), lr=1e-3) 30 | 31 | for model, optimizer in zip([model_ref, model_c], [optimizer, optimizer_c]): 32 | for step in range(5): 33 | loss = model(x).sum() 34 | loss.backward() 35 | optimizer.step() 36 | optimizer.zero_grad() 37 | # Verify that the model/optimizer did not modify the data or grad handle. 38 | parameters_c.assert_buffer_is_valid() 39 | 40 | # Verify that both models applied the same parameter updates. 41 | for p1, p2 in zip(model_ref.parameters(), model_c.parameters()): 42 | assert torch.allclose(p1.data, p2.data, atol=1e-06) 43 | 44 | 45 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 46 | def test_buffer_invalidation_detection(device): 47 | """Verify that we recognize an invalidated buffer.""" 48 | if device == "cuda" and not torch.cuda.is_available(): 49 | print("No GPU available, skipping GPU test.") 50 | return 51 | model = nn.Linear(8, 8) 52 | parameters = ContiguousParams(model.parameters()) 53 | assert parameters.buffer_is_valid() 54 | # Invalidate the buffer. 55 | model.weight.data = model.weight + 4 56 | assert not parameters.buffer_is_valid() 57 | with pytest.raises(ValueError): 58 | parameters.assert_buffer_is_valid() 59 | 60 | 61 | def test_distributed_data_parallel(): 62 | """Verify that this works in the distributed data paralllel setting.""" 63 | np.random.seed(0) 64 | # Create 20 samples with 10 features, label one out of 5 classes. 65 | data_X = torch.as_tensor(np.random.randn(20, 10), dtype=torch.float32) 66 | data_y = torch.as_tensor(np.random.choice(5, (20)), dtype=torch.int64) 67 | 68 | class Model(pytorch_lightning.LightningModule): 69 | 70 | def __init__(self, use_contiguous): 71 | super().__init__() 72 | self.model = nn.Sequential(nn.Linear(10, 10), 73 | nn.ReLU(), 74 | nn.Linear(10, 5)) 75 | self.use_contiguous = use_contiguous 76 | self.loss_fn = torch.nn.CrossEntropyLoss() 77 | self.dataset = TensorDataset(data_X, data_y) 78 | self.contiguous_params = None 79 | self.optimizer = None 80 | 81 | def forward(self, x): 82 | return self.model(x) 83 | 84 | def training_step(self, batch, batch_idx): 85 | x, target = batch 86 | prediction = self(x) 87 | loss_value = self.loss_fn(prediction, target) 88 | return {'loss': loss_value} 89 | 90 | def train_dataloader(self): 91 | return torch.utils.data.DataLoader(self.dataset, 92 | batch_size=2, 93 | shuffle=False) 94 | 95 | def configure_optimizers(self): 96 | if self.use_contiguous: 97 | self.contiguous_params = ContiguousParams(self.parameters()) 98 | params = self.contiguous_params.contiguous() 99 | else: 100 | params = self.model.parameters() 101 | self.optimizer = torch.optim.SGD(params, lr=1e-3) 102 | return self.optimizer 103 | 104 | 105 | model_ref = Model(use_contiguous=False) 106 | initial_configuration = deepcopy(model_ref.state_dict()) 107 | 108 | model_c = Model(use_contiguous=True) 109 | model_c.load_state_dict(initial_configuration) 110 | 111 | port = 1234 112 | for i, model in enumerate([model_ref, model_c]): 113 | # Choose different ports to prevent 114 | # RuntimeError("Address already in use."). 115 | os.environ['MASTER_PORT'] = str(port + i) 116 | trainer = pytorch_lightning.Trainer(distributed_backend="ddp", max_epochs=1, gpus=[0]) 117 | trainer.fit(model) 118 | # Make sure the optimizer did update the weights. 119 | for p1, p2 in zip(model.parameters(), initial_configuration.values()): 120 | assert not torch.allclose(p1.data, p2.data, atol=1e-06) 121 | 122 | 123 | for p1, p2 in zip(model_ref.parameters(), model_c.parameters()): 124 | assert torch.allclose(p1.data, p2.data, atol=1e-06) 125 | -------------------------------------------------------------------------------- /visualizations/adam_gradnorm_trace_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhilJd/contiguous_pytorch_params/e683709d516159084957f4aaf678b4e283cc58b7/visualizations/adam_gradnorm_trace_comparison.png -------------------------------------------------------------------------------- /visualizations/sgd_trace_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhilJd/contiguous_pytorch_params/e683709d516159084957f4aaf678b4e283cc58b7/visualizations/sgd_trace_comparison.png --------------------------------------------------------------------------------