├── config.py ├── models └── unet │ ├── __init__.py │ ├── unet_model.py │ └── unet_parts.py ├── .gitignore ├── .gitmodules ├── figs └── monet_concept_fig.jpeg ├── monet ├── lm_ops │ ├── __init__.py │ ├── lravg.cpp │ ├── lrrelu.cu │ ├── lrhardtanh.cpp │ ├── lrrelu.cpp │ ├── softmax.py │ ├── cat.py │ ├── base.py │ ├── softmax.cpp │ ├── compress.py │ ├── compress.cu │ ├── lrconv.cpp │ ├── lrln.cu │ ├── linear.py │ ├── conv_fwd.py │ ├── pack.cu │ ├── compress.cpp │ ├── pack.py │ ├── hardtanh.py │ ├── ln.py │ ├── elementary.py │ ├── pool.py │ ├── lrln.cpp │ ├── pack.cpp │ ├── bn.py │ ├── lrfuncs.cpp │ ├── relu.py │ ├── defaultconv.py │ ├── maxpool.cu │ ├── lrbn.cpp │ ├── funcs.py │ └── greedyconv.py ├── monet_wrapper.py └── graph.py ├── checkmate ├── readme.md └── utils │ ├── timer.py │ └── solver_common.py ├── setup.py ├── examples ├── training.py ├── check_runtime_original.py └── dist_training.py ├── install.sh ├── gist ├── README.md ├── gist_graph.py ├── gist_solver_info.py └── gist_schedule.py ├── LICENSE.txt └── README.md /config.py: -------------------------------------------------------------------------------- 1 | budget = 0 -------------------------------------------------------------------------------- /models/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_model import UNet 2 | 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | monet_memory_optimized_training.egg-info 3 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "data"] 2 | path = data 3 | url = https://github.com/utsaslab/monet-schedules 4 | -------------------------------------------------------------------------------- /figs/monet_concept_fig.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utsaslab/MONeT/HEAD/figs/monet_concept_fig.jpeg -------------------------------------------------------------------------------- /monet/lm_ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import list_ops, InputStorage, OutputStorage, IntermediateStorage 2 | from . import conv, relu, elementary, pool, linear, bn, cat, hardtanh, funcs 3 | from .elementary import NativeOP 4 | -------------------------------------------------------------------------------- /monet/lm_ops/lravg.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 4 | #define CHECK_INPUT(x) CHECK_CONTIGUOUS(x) 5 | 6 | 7 | torch::Tensor lr_adaptive_avg_pool_backward(const torch::Tensor& grad, const torch::Tensor& self) { 8 | return at::native::adaptive_avg_pool2d_backward_cuda(grad, self); 9 | } 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("lr_adaptive_avg_pool_backward", &lr_adaptive_avg_pool_backward, "Adaptive Avg Pool Backward"); 13 | } 14 | -------------------------------------------------------------------------------- /checkmate/readme.md: -------------------------------------------------------------------------------- 1 | This folder contains the code for our implementation of [Checkmate](https://github.com/parasj/checkmate) in PyTorch, used for comparison with MONeT. 2 | 3 | `checkmate_solver.py` has been copied almost as-in and is under Apache-2.0 license. 4 | 5 | `checkmate_schedule.py` runs a Checkmate schedule. 6 | 7 | They can be run using: 8 | 9 | ``` 10 | python checkmate_solver.py [MODEL] [BATCH_SIZE] [BUDGET] 11 | python checkmate_schedule.py [MODEL] [BATCH_SIZE] [BUDGET] normal GUROBI --solution_file [PATH_TO_SOLUTION_FILE] --check_runtime 12 | ``` 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="monet_memory_optimized_training", 5 | version="0.0.1", 6 | description="Memory Optimized Network Training Framework", 7 | url="https://github.com/philkr/lowrank_conv", 8 | packages=setuptools.find_packages(include = ['monet', 'monet.*', 'models', 'checkmate', 'gist']), 9 | classifiers=[ 10 | "Programming Language :: Python :: 3", 11 | "License :: OSI Approved :: MIT License", 12 | "Operating System :: OS Independent", 13 | ], 14 | python_requires='>=3.6', 15 | ) 16 | -------------------------------------------------------------------------------- /monet/lm_ops/lrrelu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | __global__ void threshold_kernel(float* grad, const float* self, float threshold, float value, int N){ 5 | const int j = blockIdx.x * blockDim.x + threadIdx.x; 6 | if (j threshold ? grad[j] : value; 8 | } 9 | } 10 | 11 | void threshold_bwd(float* grad, const float* self, float threshold, float value, int N) { 12 | const int threads = 1024; 13 | const int blocks = (N+threads-1)/threads; 14 | threshold_kernel<<>>(grad, self, threshold, value, N); 15 | } 16 | -------------------------------------------------------------------------------- /examples/training.py: -------------------------------------------------------------------------------- 1 | import torch, torchvision 2 | from monet.cvxpy_solver import Solution 3 | from monet.monet_wrapper import MONeTWrapper 4 | import time 5 | 6 | input = torch.randn(184,3,224,224).cuda() 7 | model = torchvision.models.resnet50() 8 | input_shape = (3,224,224) 9 | 10 | # Can change to use absolute path instead of relative 11 | sol_file = "../data/monet_r50_184_24hr/solution_resnet50_184_inplace_conv_multiway_newnode_10.00.pkl" 12 | train_model = MONeTWrapper(model, sol_file, (3,224,224)).cuda() 13 | output = train_model(input) 14 | output.sum().backward() 15 | print("Memory used: %6.2f MB" % (torch.cuda.max_memory_allocated()/1024/1024)) 16 | 17 | -------------------------------------------------------------------------------- /monet/lm_ops/lrhardtanh.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 4 | #define CHECK_INPUT(x) CHECK_CONTIGUOUS(x) 5 | 6 | #include 7 | #include 8 | 9 | torch::Tensor hardtanh_backward(torch::Tensor& grad, const torch::Tensor& self, float min, float max) { 10 | if (grad.type().is_cuda()) { 11 | return at::hardtanh_backward(grad, self, min, max); 12 | } else { 13 | exit(-1); 14 | // Not implementing CPU definition 15 | } 16 | } 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("hardtanh_backward", &hardtanh_backward, "HardTanh Backward"); 20 | } -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | # Install anaconda 2 | wget https://repo.anaconda.com/archive/Anaconda3-2020.02-Linux-x86_64.sh 3 | bash ~/Anaconda3-2020.02-Linux-x86_64.sh -b -p 4 | source ~/anaconda3/bin/activate 5 | 6 | # Create python 3.7 env 7 | conda create -n monetenv -q python=3.7 -y 8 | conda activate monetenv 9 | 10 | # Install pytorch and utils 11 | conda install pytorch==1.5.1 torchvision==0.6.1 cudatoolkit=10.1 -c pytorch -y 12 | 13 | # Install required packages 14 | sudo apt install ninja-build coinor-cbc coinor-libcbc-dev -y 15 | conda config --add channels http://conda.anaconda.org/gurobi 16 | conda install -c conda-forge cvxpy gurobi -y 17 | conda install -c anaconda pandas -y 18 | yes | pip install cylp 19 | -------------------------------------------------------------------------------- /gist/README.md: -------------------------------------------------------------------------------- 1 | This folder contains the code for our implementation of [Gist](https://www.microsoft.com/en-us/research/uploads/prod/2018/04/fiddle-gist-isca18.pdf) in PyTorch, used for comparison with MONeT. 2 | 3 | `gist_graph.py` creates the modified forward pass graph for Gist by adding the intermediate encodings for ReLU->MaxPool layers and annotating ReLU->Conv layers for the SSDC technique in Gist. 4 | 5 | `gist_solver_info.py` adds the backward pass information to the graph. 6 | 7 | `gist_schedule.py` runs a Gist schedule. 8 | 9 | 10 | It can be run using: 11 | 12 | ``` 13 | python gist_schedule.py [MODEL] [BATCH_SIZE] [BUDGET in GB] --check_runtime 14 | ``` 15 | 16 | Example, 17 | ``` 18 | python gist_schedule.py "torchvision.models.googlenet()" 320 14 --check_runtime 19 | ``` 20 | -------------------------------------------------------------------------------- /monet/lm_ops/lrrelu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 4 | #define CHECK_INPUT(x) CHECK_CONTIGUOUS(x) 5 | 6 | #include 7 | 8 | 9 | void threshold_bwd(float* , const float* , float , float , int); 10 | 11 | torch::Tensor relu_backward(torch::Tensor& grad, const torch::Tensor& self, float threshold, int N) { 12 | float value = 0; 13 | if (grad.type().is_cuda()) { 14 | threshold_bwd(grad.data_ptr(), self.data_ptr(), threshold, value, N); 15 | return grad; 16 | } else { 17 | // Not implementing CPU definition 18 | return torch::zeros(N, torch::dtype(torch::kFloat32).device(grad.device())); 19 | } 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("relu_backward", &relu_backward, "ReLU Backward"); 24 | } -------------------------------------------------------------------------------- /monet/lm_ops/softmax.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .pack import * 3 | import torch 4 | 5 | softmax_cpp = load(name="softmax_cpp", sources=[this_dir/"softmax.cpp", this_dir/"softmax.cu"], extra_cflags=['-std=c++17']) 6 | 7 | @implements(['aten::softmax'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal']) 8 | class Softmax(OP): 9 | params = None 10 | backward_storage = OutputStorage() 11 | 12 | def forward(self, input_, dim, dtype): 13 | with torch.no_grad(): 14 | assert dtype is None 15 | self.params = input_.type(), dim 16 | return softmax_cpp.softmax(input_, dim) 17 | 18 | def backward(self, grad_output, stores, nodel=False): 19 | with torch.no_grad(): 20 | output = stores[0] 21 | if not nodel: 22 | del stores[0] 23 | input_type, dim = self.params 24 | return softmax_cpp.softmax_backward(grad_output, output, input_type, dim) -------------------------------------------------------------------------------- /monet/lm_ops/cat.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | import torch 3 | 4 | 5 | @implements(['aten::cat'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 6 | class Cat(OP): 7 | params = None 8 | 9 | def forward(self, *input_): 10 | tensors, dim = input_[:-1], input_[-1] 11 | self.params = dim, [x.shape[dim] for x in tensors] 12 | 13 | with torch.no_grad(): 14 | if len(input_) == 2: 15 | return input_[0] 16 | return torch.cat(tensors, dim=dim) 17 | 18 | def backward(self, *grad_output, nodel=False): 19 | dim, chunk_sizes = self.params 20 | with torch.no_grad(): 21 | if len(chunk_sizes) == 1: 22 | return grad_output[0] 23 | outs = list(torch.split( 24 | grad_output[0], chunk_sizes, dim=dim)) 25 | grad_inputs = [] 26 | for out in outs: 27 | grad_inputs.append(out.contiguous()) 28 | del outs, grad_output 29 | return grad_inputs 30 | 31 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Aashaka Shah 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/unet/unet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts import * 6 | 7 | 8 | class UNet(nn.Module): 9 | def __init__(self, n_channels, n_classes, height, width): 10 | super(UNet, self).__init__() 11 | self.n_channels = n_channels 12 | self.n_classes = n_classes 13 | 14 | self.inc = DoubleConv(n_channels, 64) 15 | self.down1 = Down(64, 128) 16 | self.down2 = Down(128, 256) 17 | self.down3 = Down(256, 512) 18 | factor = 2 19 | self.down4 = Down(512, 1024 // factor) 20 | self.up1 = Up(1024, 512 // factor, height // 8, width // 8) 21 | self.up2 = Up(512, 256 // factor, height // 4, width // 4) 22 | self.up3 = Up(256, 128 // factor, height // 2, width // 2) 23 | self.up4 = Up(128, 64, height, width) 24 | self.outc = OutConv(64, n_classes) 25 | 26 | def forward(self, x): 27 | x1 = self.inc(x) 28 | x2 = self.down1(x1) 29 | x3 = self.down2(x2) 30 | x4 = self.down3(x3) 31 | x5 = self.down4(x4) 32 | x = self.up1(x5, x4) 33 | x = self.up2(x, x3) 34 | x = self.up3(x, x2) 35 | x = self.up4(x, x1) 36 | logits = self.outc(x) 37 | return logits 38 | -------------------------------------------------------------------------------- /monet/lm_ops/base.py: -------------------------------------------------------------------------------- 1 | 2 | class StorageInfo: 3 | pass 4 | 5 | 6 | class InputStorage(StorageInfo): 7 | def __init__(self, *ids): 8 | self.ids = ids 9 | 10 | 11 | class OutputStorage(StorageInfo): 12 | pass 13 | 14 | 15 | class IntermediateStorage(StorageInfo): 16 | def __init__(self, size_fn): 17 | self.size = size_fn 18 | 19 | 20 | class OP: 21 | names = [] 22 | backward_storage: StorageInfo = None 23 | 24 | def intermediate(self): 25 | return None 26 | 27 | def forward(self, *args, **kwargs): 28 | raise NotImplementedError 29 | 30 | def backward(self, *args, **kwargs): 31 | raise NotImplementedError 32 | 33 | 34 | registry = {} 35 | all_modes = ["none", "normal", "multiway", "newnode", "multiway_newnode", "conv_multiway_newnode", "conv_normal", "gist"] 36 | 37 | def implements(names, modes): 38 | def _wrapper(C): 39 | for m in modes: 40 | assert m in all_modes 41 | if m not in registry: 42 | registry[m] = {} 43 | for n in names: 44 | if n not in registry[m]: 45 | registry[m][n] = [] 46 | registry[m][n].append(C) 47 | C.names = names 48 | return C 49 | return _wrapper 50 | 51 | 52 | def list_ops(mode, name): 53 | return registry[mode][name] 54 | -------------------------------------------------------------------------------- /monet/lm_ops/softmax.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 4 | #define CHECK_INPUT(x) CHECK_CONTIGUOUS(x) 5 | 6 | using namespace at; 7 | using namespace at::native; 8 | 9 | torch::Tensor do_softmax_backward1(torch::Tensor &tmp, torch::Tensor &output, int64_t dim, bool half_to_float); 10 | 11 | torch::Tensor softmax_forward(const torch::Tensor& self, int64_t dim) { 12 | torch::Tensor out = at::native::softmax_cuda(self, dim, false); 13 | namedinference::propagate_names(out, self); 14 | return out; 15 | } 16 | 17 | torch::Tensor softmax_backward1(const torch::Tensor& grad, torch::Tensor output, string input_type, int64_t dim) { 18 | ScalarType iptype; 19 | if (input_type == "torch.cuda.FloatTensor") 20 | iptype = ScalarType::Float; 21 | else if (input_type == "torch.cuda.HalfTensor") 22 | iptype = ScalarType::Half; 23 | else 24 | exit(-1); 25 | 26 | bool half_to_float = grad.scalar_type() != iptype; 27 | if (half_to_float) { 28 | TORCH_CHECK((grad.scalar_type() == ScalarType::Float && iptype == ScalarType::Half), 29 | "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); 30 | } 31 | torch::Tensor tmp = grad * output; 32 | return do_softmax_backward1(tmp, output, dim, half_to_float); 33 | } 34 | 35 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 36 | m.def("softmax", &softmax_forward, "Softmax forward"); 37 | m.def("softmax_backward", &softmax_backward1, "Softmax backward"); 38 | } -------------------------------------------------------------------------------- /monet/lm_ops/compress.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.cpp_extension import load 3 | from pathlib import Path 4 | 5 | this_dir = Path(__file__).parent 6 | compress_cpp = load(name="compress_cpp", sources=[this_dir / "compress.cpp", this_dir / "compress.cu"], extra_cflags=['-std=c++17', '-lcusparse'], extra_cuda_cflags=['-lcusparse'],extra_ldflags=['-lcusparse']) 7 | 8 | def compress_csr_256(ip: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor): 9 | """ 10 | Return a list of dense vectors and their indices 11 | 12 | :param ip: 1D float32 tensor - sparse form activation 13 | :return0: 1D float32 tensor - dense form CSR activation 14 | :return1: 1D int32 tensor - CSR indices 15 | 16 | """ 17 | d1,d2,d3,d4=ip.shape 18 | ip_dim = d1*d2*d3*d4 19 | pad_val = (256 - ip_dim%256)%256 20 | if pad_val != 0: 21 | ip_new = torch.nn.functional.pad(ip.view(ip_dim), (0,pad_val)) 22 | else: 23 | ip_new = ip 24 | del ip 25 | ip_new = ip_new.view(-1,256) 26 | nzrow = (ip_new!=0).sum(dim=1) 27 | nz = nzrow.sum().item() 28 | nzrow = nzrow.to(torch.int32) 29 | # print(ip_new) 30 | # print(ip_new.shape, nzrow, nz) 31 | return compress_cpp.compress_csr_256(ip_new, nzrow, nz) 32 | 33 | def uncompress_csr_256(compip: torch.Tensor, indx: torch.Tensor, row: torch.Tensor, N: int) -> torch.Tensor: 34 | """ 35 | Returns an uncompressed dense vector 36 | 37 | :param compip: 1D float32 tensor - CSR compressed input 38 | :param indx: 1D int32 tensor - CSR indices 39 | :return: 1D float32 tensor - uncompressed activation 40 | """ 41 | return compress_cpp.uncompress_csr_256(compip, indx, row, N) 42 | -------------------------------------------------------------------------------- /monet/monet_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | from monet.cvxpy_solver import Solution 4 | from monet.schedule import Schedule, create_schedule 5 | 6 | 7 | class MONeTWrapperF(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, input_, schedule, state_dict): 10 | input_detached = input_.detach() 11 | ctx.schedule = schedule 12 | ctx.save_for_backward(torch.ones(1)) 13 | output = schedule.forward(input_detached, *state_dict) 14 | output.requires_grad_(True) 15 | return output 16 | 17 | @staticmethod 18 | def backward(ctx, grad_output): 19 | schedule = ctx.schedule 20 | schedule.backward(grad_output) 21 | return None, None, None 22 | 23 | 24 | class MONeTWrapper(torch.nn.Module): 25 | def __init__(self, model, sol_file, input_shape): 26 | super(MONeTWrapper, self).__init__() 27 | if model._get_name() == 'MobileNetV2': 28 | model = torch.nn.Sequential( 29 | model.features, 30 | torch.nn.AdaptiveAvgPool2d((1, 1)), torch.nn.Flatten(start_dim=1), 31 | model.classifier[0], model.classifier[1]) 32 | self.model = model 33 | self.sol_file = sol_file 34 | self.input_shape = input_shape 35 | self.schedule = [create_schedule(model, sol_file, input_shape)] 36 | self.recreate_state_dict = False 37 | 38 | def train(self, mode=True): 39 | self.model.train(mode) 40 | 41 | def forward(self, x): 42 | x.requires_grad_(True) 43 | state_dict = self.state_dict(keep_vars=True) 44 | return MONeTWrapperF.apply(x, self.schedule[0], list(state_dict.values())) 45 | -------------------------------------------------------------------------------- /monet/lm_ops/compress.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define ERR_NE(X,Y) do { if ((X) != (Y)) { \ 7 | fprintf(stderr,"Error in %s at %s:%d\n",__func__,__FILE__,__LINE__); \ 8 | exit(-1);}} while(0) 9 | #define CUDA_CALL(X) ERR_NE((X),cudaSuccess) 10 | #define CUSPARSE_CALL(X) ERR_NE((X),CUSPARSE_STATUS_SUCCESS) 11 | 12 | // void get_nz_cuda(const float* ip, int *nnzPerRow, int nz_ptr, int lda) { 13 | // cusparseHandle_t handle; 14 | // CUSPARSE_CALL( cusparseCreate(&handle) ); 15 | // cusparseMatDescr_t descrX; 16 | // CUSPARSE_CALL(cusparseCreateMatDescr(&descrX)); 17 | // CUSPARSE_CALL( cusparseSnnz(handle, CUSPARSE_DIRECTION_ROW, lda, 256, descrX, ip, 18 | // lda, nnzPerRow, &nz_ptr)); 19 | // } 20 | 21 | void compress_csr_256_gpu(const float* ip, float* cip, int* idx, int * rowidx, const int * nnzPerRow, size_t N) { 22 | cusparseHandle_t handle; 23 | CUSPARSE_CALL( cusparseCreate(&handle) ); 24 | cusparseMatDescr_t descrX; 25 | CUSPARSE_CALL(cusparseCreateMatDescr(&descrX)); 26 | CUSPARSE_CALL( cusparseSdense2csr( handle, (N+255)/256, 256, descrX, ip, 27 | (N+255)/256, nnzPerRow, cip, 28 | rowidx, idx )) ; 29 | } 30 | 31 | void uncompress_csr_256_gpu(const float* compIP, const int * csrIdx, const int* csrRowIdx, float* op, size_t N){ 32 | cusparseHandle_t handle; 33 | CUSPARSE_CALL( cusparseCreate(&handle) ); 34 | cusparseMatDescr_t descrX; 35 | CUSPARSE_CALL(cusparseCreateMatDescr(&descrX)); 36 | CUSPARSE_CALL( cusparseScsr2dense( handle, (N+255)/256, 256, descrX, compIP, 37 | csrRowIdx, csrIdx, 38 | op,(N+255)/256 )) ; 39 | } 40 | -------------------------------------------------------------------------------- /examples/check_runtime_original.py: -------------------------------------------------------------------------------- 1 | # Get runtime of original PyTorch model 2 | import argparse 3 | import torch, torchvision 4 | from models.unet import UNet 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('model') 8 | parser.add_argument('bs') 9 | args = parser.parse_args() 10 | 11 | bs = int(args.bs) 12 | model_name = args.model.split(".")[-1][:-2] 13 | print("Batch size ", bs) 14 | print("Model", model_name) 15 | 16 | if args.model == 'unet': 17 | height, width = 416, 608 18 | model = UNet(n_channels=3, n_classes=1, height=height, width=width) 19 | else: 20 | height, width = 224, 224 21 | model = eval(args.model, {'torch': torch, 'torchvision': torchvision}) 22 | 23 | if 'mobilenet_v2' in args.model: 24 | model = torch.nn.Sequential( 25 | model.features, 26 | torch.nn.AdaptiveAvgPool2d((1, 1)), torch.nn.Flatten(start_dim=1), 27 | model.classifier[0], model.classifier[1]) 28 | model.cuda() 29 | input_ = torch.randn((bs, 3, height, width)).cuda() 30 | torch.cuda.reset_max_memory_allocated() 31 | torch.cuda.synchronize() 32 | 33 | torch.backends.cudnn.benchmark = True 34 | 35 | start_event = torch.cuda.Event(enable_timing=True) 36 | end_event = torch.cuda.Event(enable_timing=True) 37 | if 'googlenet' in args.model: 38 | for i in range(120): 39 | if i==100: 40 | start_event.record() 41 | x0 = model(input_) 42 | (x0[0] + x0[1] + x0[2]).sum().backward() 43 | del x0 44 | end_event.record() 45 | torch.cuda.synchronize() 46 | else: 47 | for i in range(120): 48 | if i==100: 49 | start_event.record() 50 | x0 = model(input_) 51 | x0.sum().backward() 52 | del x0 53 | end_event.record() 54 | torch.cuda.synchronize() 55 | orig_maxmem = torch.cuda.max_memory_allocated() / 2**20 56 | print("original: %fms avg, %8.2f MB" % (start_event.elapsed_time(end_event)/20, orig_maxmem)) 57 | del model -------------------------------------------------------------------------------- /monet/lm_ops/lrconv.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 4 | #define CHECK_INPUT(x) CHECK_CONTIGUOUS(x) 5 | 6 | 7 | torch::Tensor forward(const torch::Tensor& input, const torch::Tensor& weight, 8 | torch::IntArrayRef stride, torch::IntArrayRef padding, torch::IntArrayRef dilation, int64_t groups) { 9 | // torch::NoGradGuard no_grad_guard; 10 | static torch::Tensor undefined; 11 | // return torch::conv2d(input, weight, undefined, stride, padding, dilation, groups); 12 | return torch::cudnn_convolution(input, weight, undefined, padding, stride, dilation, groups, false, true); //benchmark, deterministic 13 | } 14 | torch::Tensor backward_input(torch::IntArrayRef input_sizes, const torch::Tensor& grad_output_t, const torch::Tensor& weight, 15 | torch::IntArrayRef stride, torch::IntArrayRef padding, torch::IntArrayRef dilation, int64_t groups) { 16 | // torch::NoGradGuard no_grad_guard; 17 | // torch::Tensor grad_output = grad_output_t.contiguous(weight.suggest_memory_format()); 18 | return torch::cudnn_convolution_backward_input(input_sizes, grad_output_t, weight, padding, stride, dilation, groups, false, true); 19 | } 20 | 21 | torch::Tensor backward_weight(torch::IntArrayRef weight_sizes,const torch::Tensor& grad_output_t, const torch::Tensor& input, 22 | torch::IntArrayRef stride, torch::IntArrayRef padding, torch::IntArrayRef dilation, int64_t groups) { 23 | // torch::NoGradGuard no_grad_guard; 24 | // torch::Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format()); 25 | return torch::cudnn_convolution_backward_weight(weight_sizes, grad_output_t, input, padding, stride, dilation, groups, false, true); 26 | } 27 | 28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 29 | m.def("forward", &forward, "Conv forward"); 30 | m.def("backward_input", &backward_input, "Conv backward input"); 31 | m.def("backward_weight", &backward_weight, "Conv backward weight"); 32 | } 33 | -------------------------------------------------------------------------------- /models/unet/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, height, width): 46 | super().__init__() 47 | 48 | self.up = torch.nn.AdaptiveAvgPool2d((height, width)) 49 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 50 | 51 | def forward(self, x1, x2): 52 | x1 = self.up(x1) 53 | x = torch.cat([x2, x1], dim=1) 54 | return self.conv(x) 55 | 56 | 57 | class OutConv(nn.Module): 58 | def __init__(self, in_channels, out_channels): 59 | super(OutConv, self).__init__() 60 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 61 | 62 | def forward(self, x): 63 | return self.conv(x) 64 | -------------------------------------------------------------------------------- /examples/dist_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import torch.multiprocessing as mp 5 | import torchvision 6 | import torch 7 | import torch.distributed as dist 8 | from monet.cvxpy_solver import Solution 9 | from monet.monet_wrapper import MONeTWrapper 10 | 11 | def train(gpu, args): 12 | rank = args.nr * args.gpus + gpu 13 | dist.init_process_group( 14 | backend='nccl', 15 | init_method='env://', 16 | world_size=args.world_size, 17 | rank=rank 18 | ) 19 | 20 | torch.manual_seed(0) 21 | sol_file = "../data/monet_r50_184_24hr/solution_resnet50_184_inplace_conv_multiway_newnode_10.00.pkl" 22 | model = MONeTWrapper(torchvision.models.resnet50(), sol_file, (3,224,224)) 23 | torch.cuda.set_device(gpu) 24 | model.cuda(gpu) 25 | batch_size = 184 26 | input_ = torch.randn((batch_size,3,224,224)).cuda(gpu) 27 | 28 | # Wrap the model 29 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) 30 | for i in range(100): 31 | if i == 80: 32 | torch.cuda.synchronize() 33 | t0 = time.time() 34 | output = model(input_) 35 | output.sum().backward() 36 | torch.cuda.synchronize() 37 | print(gpu, "time:", time.time()-t0, "memory:", torch.cuda.max_memory_allocated(gpu)/1024/1024) 38 | 39 | 40 | def main(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('-n', '--nodes', default=1, 43 | type=int, metavar='N') 44 | parser.add_argument('-g', '--gpus', default=1, type=int, 45 | help='number of gpus per node') 46 | parser.add_argument('-nr', '--nr', default=0, type=int, 47 | help='ranking within the nodes') 48 | parser.add_argument('--epochs', default=2, type=int, 49 | metavar='N', 50 | help='number of total epochs to run') 51 | args = parser.parse_args() 52 | 53 | args.world_size = args.gpus * args.nodes 54 | os.environ['MASTER_ADDR'] = 'localhost' 55 | os.environ['MASTER_PORT'] = '12345' 56 | mp.spawn(train, nprocs=args.gpus, args=(args,)) 57 | 58 | if __name__ == '__main__': 59 | main() -------------------------------------------------------------------------------- /monet/lm_ops/lrln.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | template 12 | __global__ void LayerNormForwardCUDAKernel1( 13 | int64_t N, 14 | const T* X, 15 | const T* mean, 16 | const T* rstd, 17 | const T* gamma, 18 | const T* beta, 19 | T* Y) { 20 | using T_ACC = at::acc_type; 21 | const int64_t i = blockIdx.x; 22 | for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { 23 | const int64_t index = i * N + j; 24 | const T_ACC gamma_v = 25 | gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); 26 | const T_ACC beta_v = 27 | beta == nullptr ? T_ACC(0) : static_cast(beta[j]); 28 | Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * 29 | static_cast(rstd[i]) * gamma_v + 30 | beta_v; 31 | } 32 | } 33 | 34 | template 35 | void LayerNormForwardCUDA1( 36 | long M, 37 | long N, 38 | const T* X, 39 | const T* mean, 40 | const T* rstd, 41 | const T* gamma, 42 | const T* beta, 43 | T* Y) { 44 | cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); 45 | LayerNormForwardCUDAKernel1<<>>( 46 | N, X, mean, rstd, gamma, beta, Y); 47 | } 48 | 49 | template void LayerNormForwardCUDA1( 50 | long M, 51 | long N, 52 | const double* X, 53 | const double* mean, 54 | const double* rstd, 55 | const double* gamma, 56 | const double* beta, 57 | double* Y); 58 | 59 | template void LayerNormForwardCUDA1( 60 | long M, 61 | long N, 62 | const float* X, 63 | const float* mean, 64 | const float* rstd, 65 | const float* gamma, 66 | const float* beta, 67 | float* Y); 68 | 69 | // void LayerNormForwardCUDA1( 70 | // int64_t M, 71 | // int64_t N, 72 | // const c10::Half* X, 73 | // const c10::Half* mean, 74 | // const c10::Half* rstd, 75 | // const c10::Half* gamma, 76 | // const c10::Half* beta, 77 | // c10::Half* Y); -------------------------------------------------------------------------------- /checkmate/utils/timer.py: -------------------------------------------------------------------------------- 1 | from decimal import Decimal 2 | from math import ceil, log10 3 | from timeit import default_timer 4 | 5 | 6 | class Timer: 7 | def __init__(self, name, extra_data=None, print_results=False, niters=None): 8 | self.niters = niters 9 | self._elapsed = Decimal() 10 | self._name = name 11 | if extra_data: 12 | self._name += "; " + str(extra_data) 13 | self._print_results = print_results 14 | self._start_time = None 15 | self._children = {} 16 | self._count = 0 17 | 18 | @property 19 | def elapsed(self): 20 | return float(self._elapsed) 21 | 22 | def __enter__(self): 23 | self.start() 24 | return self 25 | 26 | def __exit__(self, *_): 27 | self.stop() 28 | if self._print_results: 29 | self.print_results() 30 | 31 | def child(self, name): 32 | try: 33 | return self._children[name] 34 | except KeyError: 35 | result = Timer(name, print_results=False) 36 | self._children[name] = result 37 | return result 38 | 39 | def start(self): 40 | self._count += 1 41 | self._start_time = self._get_time() 42 | 43 | def stop(self): 44 | self._elapsed += self._get_time() - self._start_time 45 | 46 | def print_results(self): 47 | print(self._format_results()) 48 | 49 | def _format_results(self, indent=" "): 50 | children = self._children.values() 51 | elapsed = self._elapsed or sum(c._elapsed for c in children) 52 | result = "%s: %.3fs" % (self._name, elapsed) 53 | max_count = max(c._count for c in children) if children else 0 54 | count_digits = 0 if max_count <= 1 else int(ceil(log10(max_count + 1))) 55 | for child in sorted(children, key=lambda c: c._elapsed, reverse=True): 56 | lines = child._format_results(indent).split("\n") 57 | child_percent = child._elapsed / elapsed * 100 58 | lines[0] += " (%d%%)" % child_percent 59 | if count_digits: 60 | # `+2` for the 'x' and the space ' ' after it: 61 | lines[0] = ("%dx " % child._count).rjust(count_digits + 2) + lines[0] 62 | for line in lines: 63 | result += "\n" + indent + line 64 | return result 65 | 66 | @staticmethod 67 | def _get_time(): 68 | return Decimal(default_timer()) -------------------------------------------------------------------------------- /monet/lm_ops/linear.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | import torch 3 | 4 | 5 | @implements(['aten::addmm'], ['none', 'gist']) 6 | class AddMM(OP): 7 | backward_storage = InputStorage(1, 2) 8 | params = None 9 | 10 | def forward(self, bias, input_, weight, beta, alpha): 11 | # Just don't want to deal with this right now 12 | assert alpha == 1 13 | assert beta == 1 14 | with torch.no_grad(): 15 | return torch.addmm(bias, input_, weight.t(), beta=beta, alpha=alpha) 16 | 17 | def backward(self, output_grad, stores): 18 | with torch.no_grad(): 19 | input_ = stores[0] 20 | weight = stores[1] 21 | del stores[1] 22 | del stores[0] 23 | return output_grad.sum(0), output_grad.mm(weight), output_grad.t().mm(input_) 24 | 25 | @implements(['aten::addmm'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal']) 26 | class LowMemAddMM(OP): 27 | backward_storage = InputStorage(1, 2) 28 | params = None 29 | algorithm = -1 30 | 31 | def forward(self, bias, input_, weight, beta, alpha): 32 | # NOTE: bias will never be None because otherwise the trace would have generated matmul instead of addmm 33 | self.params = [input_.requires_grad, weight.requires_grad, bias is not None and bias.requires_grad] 34 | # Just don't want to deal with this right now 35 | assert alpha == 1 36 | assert beta == 1 37 | with torch.no_grad(): 38 | return torch.addmm(bias, input_, weight.t(), beta=beta, alpha=alpha) 39 | 40 | def backward(self, output_grad, stores, nodel=False): 41 | with torch.no_grad(): 42 | addmmtype = self.algorithm // 10 43 | input_ = stores[0] 44 | weight = stores[1] 45 | if not nodel: 46 | del stores[1] 47 | del stores[0] 48 | di = dw = db = None 49 | if addmmtype == 0 or addmmtype == 2: 50 | if input_ is not None: 51 | if self.params[1]: 52 | dw = output_grad.t().mm(input_) 53 | if self.params[2]: 54 | db = output_grad.sum(0) 55 | del input_ 56 | 57 | if addmmtype == 1 or addmmtype == 2: 58 | if weight is not None: 59 | if self.params[0]: 60 | di = output_grad.mm(weight) 61 | 62 | return db, di, dw -------------------------------------------------------------------------------- /monet/lm_ops/conv_fwd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.cpp_extension import load 3 | from pathlib import Path 4 | from time import time 5 | 6 | this_dir = Path(__file__).parent 7 | conv_fwd_cpp = load(name="conv_fwd_cpp", sources=[this_dir/"conv.cpp"], extra_cflags=['-std=c++17'], extra_include_paths=[str(this_dir)], with_cuda=True) 8 | cudnn_convolution = conv_fwd_cpp.cudnn_convolution 9 | cudnn_convolution_backward_input = conv_fwd_cpp.cudnn_convolution_backward_input 10 | cudnn_convolution_backward_weight = conv_fwd_cpp.cudnn_convolution_backward_weight 11 | 12 | 13 | if __name__ == '__main__': 14 | i = torch.zeros((256, 64, 56, 56), device='cuda') 15 | w = torch.zeros((64, 64, 3, 3), device='cuda') 16 | print('fwd') 17 | for t in range(conv_fwd_cpp.n_fwd_algos()): 18 | torch.cuda.reset_peak_memory_stats() 19 | t0 = time() 20 | try: 21 | for it in range(10): 22 | o = cudnn_convolution(i, w, (1, 1), (1, 1), (1, 1), 1, t) 23 | except Exception as e: 24 | print('%02d failed' % t, e) 25 | else: 26 | torch.cuda.synchronize() 27 | M = torch.cuda.memory_stats() 28 | print('%02d %6.3f s %0.3f GB' % (t, time() - t0, M["allocated_bytes.all.peak"] / 1024. / 1024. / 1024.)) 29 | 30 | print('bwd') 31 | for t in range(conv_fwd_cpp.n_bwd_ip_algos()): 32 | torch.cuda.reset_peak_memory_stats() 33 | t0 = time() 34 | try: 35 | for it in range(10): 36 | cudnn_convolution_backward_input(i.shape, o, w, (1, 1), (1, 1), (1, 1), 1, t) 37 | except Exception as e: 38 | print('%02d failed' % t, e) 39 | else: 40 | torch.cuda.synchronize() 41 | M = torch.cuda.memory_stats() 42 | print('%02d %6.3f s %0.3f GB' % (t, time() - t0, M["allocated_bytes.all.peak"] / 1024. / 1024. / 1024.)) 43 | 44 | print('bwd weight') 45 | for t in range(conv_fwd_cpp.n_bwd_wt_algos()): 46 | torch.cuda.reset_peak_memory_stats() 47 | t0 = time() 48 | try: 49 | for it in range(10): 50 | cudnn_convolution_backward_weight(w.shape, o, i, (1, 1), (1, 1), (1, 1), 1, t) 51 | except Exception as e: 52 | print('%02d failed' % t, e) 53 | else: 54 | torch.cuda.synchronize() 55 | M = torch.cuda.memory_stats() 56 | print('%02d %6.3f s %0.3f GB' % (t, time() - t0, M["allocated_bytes.all.peak"] / 1024. / 1024. / 1024.)) 57 | -------------------------------------------------------------------------------- /monet/lm_ops/pack.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | __global__ void pack_two_kernel(const uint8_t * input, size_t N, uint8_t * r) { 4 | const int j = blockIdx.x * blockDim.x + threadIdx.x; 5 | int i = 2*j; 6 | uint8_t v = 0; 7 | if (i < N) { 8 | v |= (input[i] << 4); 9 | if (i + 1 < N) { v |= (input[i+1]);} 10 | r[j] = v; 11 | } 12 | } 13 | __global__ void unpack_two_kernel(const uint8_t * input, size_t N, uint8_t * r) { 14 | const int j = blockIdx.x * blockDim.x + threadIdx.x; 15 | int i = 2*j; 16 | if (i < N) { 17 | const uint8_t v = input[j]; 18 | if (i+1> 4) & 15; 20 | } 21 | } 22 | __global__ void pack_kernel(const bool * input, size_t N, uint8_t * r) { 23 | const size_t j = (size_t)blockIdx.x * blockDim.x + threadIdx.x; 24 | size_t i = 8*j; 25 | if (i < N) { 26 | uint8_t v = 0; 27 | for(int k=7; k >= 0 && i < N; i++, k--) 28 | v |= (input[i] << k); 29 | r[j] = v; 30 | } 31 | } 32 | __global__ void unpack_kernel(const uint8_t * input, size_t N, bool * r) { 33 | const size_t j = (size_t)blockIdx.x * blockDim.x + threadIdx.x; 34 | size_t i = 8*j; 35 | if (i < N) { 36 | const uint8_t v = input[j]; 37 | for(int k=7; k >= 0 && i < N; i++, k--) 38 | r[i] = (v >> k) & 1; 39 | } 40 | } 41 | template 42 | __global__ void unpack_multiply_kernel(const uint8_t * input, const T * v, size_t N, T * r) { 43 | const size_t j = (size_t)blockIdx.x * blockDim.x + threadIdx.x; 44 | size_t i = 8*j; 45 | if (i < N) { 46 | const uint8_t t = input[j]; 47 | for(int k=7; k >= 0 && i < N; i++, k--) 48 | r[i] = ((t >> k) & 1) * v[i]; 49 | } 50 | } 51 | 52 | void pack_two_gpu(const uint8_t * input, size_t N, uint8_t * r) { 53 | const int threads = 1024; 54 | const int blocks = (N + threads - 1) / threads; 55 | pack_two_kernel<<>>(input, N, r); 56 | } 57 | void unpack_two_gpu(const uint8_t * input, size_t N, uint8_t * r) { 58 | const int threads = 1024; 59 | const int blocks = (N + threads - 1) / threads; 60 | unpack_two_kernel<<>>(input, N, r); 61 | } 62 | void pack_gpu(const bool * input, size_t N, uint8_t * r) { 63 | const int threads = 1024; 64 | const int blocks = (N + threads - 1) / threads; 65 | pack_kernel<<>>(input, N, r); 66 | } 67 | void unpack_gpu(const uint8_t * input, size_t N, bool * r) { 68 | const int threads = 1024; 69 | const int blocks = (N + threads - 1) / threads; 70 | unpack_kernel<<>>(input, N, r); 71 | } 72 | template 73 | void unpack_multiply_gpu(const uint8_t * input, const T * v, size_t N, T * r) { 74 | const int threads = 1024; 75 | const int blocks = (N + threads - 1) / threads; 76 | unpack_multiply_kernel<<>>(input, v, N, r); 77 | } 78 | template void unpack_multiply_gpu(const uint8_t *, const float *, size_t, float *); 79 | template void unpack_multiply_gpu(const uint8_t *, const double *, size_t, double *); 80 | template void unpack_multiply_gpu(const uint8_t *, const int *, size_t, int *); 81 | template void unpack_multiply_gpu(const uint8_t *, const uint8_t *, size_t, uint8_t *); 82 | -------------------------------------------------------------------------------- /monet/lm_ops/compress.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | 13 | void compress_csr_256_gpu(const float* ip, float* cip, int* idx, int * rowidx, const int * nnzPerRow, size_t N); 14 | void uncompress_csr_256_gpu(const float* compIP, const int * csrIdx, const int* csrRowIdx, float* op, size_t N); 15 | // void get_nz_cuda(const float* ip, int *nnzPerRow, int nz_ptr, int lda); 16 | 17 | std::tuple compress_csr_256(const torch::Tensor& ip, const torch::Tensor& nnzPerRow, size_t nz) { 18 | TORCH_CHECK(ip.dim() == 2 && ip.size(1) == 256, "Only N/256 x 256 tensors supported."); 19 | CHECK_CONTIGUOUS(ip); 20 | // size_t N = ip.size(0) * ip.size(1); 21 | // int lda = (N+255)/256; 22 | // int nz1; 23 | // int *nnzPerRow1; 24 | // get_nz_cuda(ip.data_ptr(), nnzPerRow1, nz1, lda); 25 | // std::cout << "nz: "<< nz << std::endl; 26 | TORCH_CHECK(nnzPerRow.scalar_type() == torch::kInt32, "nnzPerRow should be int32."); 27 | size_t N = ip.size(0) * ip.size(1); 28 | torch::Tensor cip = torch::zeros(nz, torch::dtype(torch::kFloat32).device(ip.device())); 29 | torch::Tensor idx = torch::zeros(nz, torch::dtype(torch::kInt32).device(ip.device())); 30 | torch::Tensor rowidx = torch::zeros(ip.size(0)+1, torch::dtype(torch::kInt32).device(ip.device())); 31 | if (ip.type().is_cuda()) { 32 | // compress_csr_256_gpu(ip.data_ptr(), cip.data_ptr(), idx.data_ptr(), rowidx.data_ptr(), nnzPerRow.data_ptr(), N, nz); 33 | compress_csr_256_gpu(ip.data_ptr(), cip.data_ptr(), idx.data_ptr(), rowidx.data_ptr(), nnzPerRow.data_ptr(), N); 34 | TORCH_CHECK(cudaGetLastError() == cudaSuccess, 35 | "compress_csr_256 failed with error code ", 36 | cudaGetLastError()); 37 | } else { 38 | std::cout<<"CPU type\n"; 39 | } 40 | return {cip, idx, rowidx}; 41 | } 42 | 43 | torch::Tensor uncompress_csr_256(const torch::Tensor& compip, const torch::Tensor& indx, const torch::Tensor& rowidx, size_t N) { 44 | TORCH_CHECK(compip.dim() == 1 , "Only 1D tensors supported."); 45 | CHECK_CONTIGUOUS(compip); 46 | size_t nz = compip.size(0); 47 | torch::Tensor op = torch::zeros({(N+255)/256,256}, torch::dtype(torch::kFloat32).device(compip.device())); 48 | torch::Tensor nnzPerRow = torch::full({1}, (int)nz, torch::dtype(torch::kInt32).device(compip.device())); 49 | 50 | if (compip.type().is_cuda()) { 51 | uncompress_csr_256_gpu(compip.data_ptr(), indx.data_ptr(), rowidx.data_ptr(), op.data_ptr(), N); 52 | TORCH_CHECK(cudaGetLastError() == cudaSuccess, 53 | "uncompress_csr_256 failed with error code ", 54 | cudaGetLastError()); 55 | } else { 56 | std::cout<<"CPU type\n"; 57 | } 58 | return op; 59 | } 60 | 61 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 62 | m.def("compress_csr_256", &compress_csr_256, "Convert values to a csr form"); 63 | m.def("uncompress_csr_256", &uncompress_csr_256, "reconstruct vector from nonzero values and nonzero indices"); 64 | } 65 | 66 | -------------------------------------------------------------------------------- /monet/lm_ops/pack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.cpp_extension import load 3 | from pathlib import Path 4 | 5 | this_dir = Path(__file__).parent 6 | pack_cpp = load(name="pack_cpp", sources=[this_dir / "pack.cpp", this_dir / "pack.cu"], extra_cflags=['-std=c++17']) 7 | 8 | 9 | def pack(v: torch.Tensor) -> torch.Tensor: 10 | """ 11 | Pack 8 boolean values into a uint8. 12 | 13 | :param v: A 1D bool tensor 14 | :return: A 1D uint8 tensor of 8x smaller size 15 | """ 16 | return pack_cpp.pack(v) 17 | 18 | def unpack(p: torch.Tensor, n: int = 0) -> torch.Tensor: 19 | """ 20 | Unpack a uint8 into 8 boolean values. 21 | 22 | :param p: A 1D uint8 tensor 23 | :param n: The output size (default v.size(0)*8) 24 | :return: A 1D bool tensor of size N 25 | """ 26 | return pack_cpp.unpack(p, n) 27 | 28 | 29 | def unpack_multiply(p: torch.Tensor, v: torch.Tensor) -> torch.Tensor: 30 | """ 31 | Unpack a uint8 into 8 boolean values and multiply them with another tensor. 32 | 33 | :param p: A 1D uint8 tensor 34 | :param v: A 1D tensor 35 | :return: A 1D tensor of the same type and shape as v 36 | """ 37 | return pack_cpp.unpack_multiply(p, v) 38 | 39 | def pack_two(v: torch.Tensor) -> torch.Tensor: 40 | """ 41 | Pack 2 uint8 values (<16) into a single uint8. 42 | :param v: A 1D uint8 tensor 43 | :return: A 1D uint8 tensor of 2x smaller size 44 | """ 45 | return pack_cpp.pack_two(v) 46 | 47 | def unpack_two(v: torch.Tensor, n: int = 0) -> torch.Tensor: 48 | """ 49 | Unpack a uint8 into 2 uint8 values <16. 50 | :param v: A 1D uint8 tensor 51 | :param n: The output size (default v.size(0)*2) 52 | :return: A 1D uint8 tensor of size N*2 53 | """ 54 | return pack_cpp.unpack_two(v, n) 55 | 56 | 57 | if __name__ == "__main__": 58 | rnd = torch.rand(10) 59 | t = rnd > 0.5 60 | p = pack(t) 61 | print(t) 62 | print(p) 63 | print(unpack(p, t.size(0))) 64 | print(unpack_multiply(p, rnd)) 65 | 66 | print('*'*10, 'benchmarking', '*'*10) 67 | rnd = torch.rand(128*64*64*64) 68 | v = rnd > 0.5 69 | from time import time 70 | for d in [torch.device('cuda')]: # torch.device('cpu'), 71 | rnd = rnd.to(d) 72 | v = v.to(d) 73 | print(d) 74 | t0 = time() 75 | for i in range(1000): 76 | p = pack(v) 77 | print('pack', time()-t0) 78 | t0 = time() 79 | for i in range(1000): 80 | vv = unpack(p, v.size(0)) 81 | print('unpack', time()-t0) 82 | print('correct', bool((v == vv).all())) 83 | 84 | # Just in case pytorch or cuda optimize a bit too aggressively 85 | t0 = time() 86 | for i in range(1000): 87 | p = pack(v) 88 | v = unpack(p, v.size(0)) 89 | print('pack and unpack', time()-t0) 90 | 91 | t0 = time() 92 | r2 = rnd 93 | for i in range(1000): 94 | r2 = unpack_multiply(p, r2) 95 | print('unpack multiply', time()-t0) 96 | 97 | # Compare to thresholding and relu 98 | t0 = time() 99 | for i in range(1000): 100 | rnd = torch.relu(rnd) 101 | print('relu', time()-t0) 102 | 103 | # Compare to thresholding and relu 104 | t0 = time() 105 | for i in range(1000): 106 | v = rnd > 0.5 107 | print('ge', time()-t0) -------------------------------------------------------------------------------- /monet/lm_ops/hardtanh.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .pack import * 3 | import torch 4 | import numpy as np 5 | 6 | lrhardtanh_cpp = load(name="lrhardtanh_cpp", sources=[this_dir/"lrhardtanh.cpp"], extra_cflags=['-std=c++17']) 7 | 8 | 9 | @implements(['aten::hardtanh'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 10 | class InputActHardTanh(OP): 11 | backward_storage = InputStorage(0) 12 | params = None 13 | inplace = False 14 | 15 | def forward(self, x, lower, upper): 16 | with torch.no_grad(): 17 | self.params = lower, upper 18 | if self.inplace: 19 | return torch.nn.functional.hardtanh_(x, lower, upper) 20 | return torch.nn.functional.hardtanh(x, lower, upper) 21 | 22 | def backward(self, x, stored, nodel=False): 23 | with torch.no_grad(): 24 | lower, upper = self.params 25 | ip = stored[0] 26 | if not nodel: 27 | del stored[0] 28 | return lrhardtanh_cpp.hardtanh_backward(x, ip, lower, upper) 29 | 30 | 31 | @implements(['aten::hardtanh'], ['multiway', 'multiway_newnode', 'conv_multiway_newnode']) 32 | class OutputActHardTanh(OP): 33 | backward_storage = OutputStorage() 34 | params = None 35 | inplace = False 36 | 37 | def forward(self, x, lower, upper): 38 | with torch.no_grad(): 39 | self.params = lower, upper 40 | if self.inplace: 41 | return torch.nn.functional.hardtanh_(x, lower, upper) 42 | return torch.nn.functional.hardtanh(x, lower, upper) 43 | 44 | def backward(self, x, stored, nodel=False): 45 | with torch.no_grad(): 46 | op = stored[0] 47 | lower, upper = self.params 48 | if not nodel: 49 | del stored[0] 50 | return lrhardtanh_cpp.hardtanh_backward(x, op, lower, upper) 51 | 52 | @implements(['aten::hardtanh_'], ['multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode']) 53 | class InputActHardTanh(OP): 54 | backward_storage = InputStorage(0) 55 | params = None 56 | inplace = False 57 | 58 | def forward(self, x, lower, upper): 59 | with torch.no_grad(): 60 | self.params = lower, upper 61 | if self.inplace: 62 | return torch.nn.functional.hardtanh_(x, lower, upper) 63 | return torch.nn.functional.hardtanh(x, lower, upper) 64 | 65 | def backward(self, x, stored, nodel=False): 66 | with torch.no_grad(): 67 | lower, upper = self.params 68 | ip = stored[0] 69 | if not nodel: 70 | del stored[0] 71 | return lrhardtanh_cpp.hardtanh_backward(x, ip, lower, upper) 72 | 73 | 74 | @implements(['aten::hardtanh_'], ['normal','multiway', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 75 | class OutputActHardTanh(OP): 76 | backward_storage = OutputStorage() 77 | params = None 78 | inplace = False 79 | 80 | def forward(self, x, lower, upper): 81 | with torch.no_grad(): 82 | self.params = lower, upper 83 | if self.inplace: 84 | return torch.nn.functional.hardtanh_(x, lower, upper) 85 | return torch.nn.functional.hardtanh(x, lower, upper) 86 | 87 | def backward(self, x, stored, nodel=False): 88 | with torch.no_grad(): 89 | op = stored[0] 90 | lower, upper = self.params 91 | if not nodel: 92 | del stored[0] 93 | return lrhardtanh_cpp.hardtanh_backward(x, op, lower, upper) -------------------------------------------------------------------------------- /monet/lm_ops/ln.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base import * 3 | from .pack import * 4 | import numpy as np 5 | from torch.utils.cpp_extension import load 6 | from pathlib import Path 7 | 8 | this_dir = Path(__file__).parent 9 | myln_cpp = load(name="lrln_cpp", sources=[this_dir / "lrln.cpp", this_dir / "lrln.cu"], extra_cflags=['-std=c++17']) 10 | 11 | # TODO - 1) recompute using precomputed statistics - done, 2) calculate output-activated backward 12 | 13 | @implements(['aten::layer_norm'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 14 | class InputActLayerNorm(OP): 15 | backward_storage = InputStorage(0, 2, 3) 16 | params = None # params memory is unaccounted for => save_mean and save_var will occupy memory over the expected memory 17 | 18 | def forward(self, input_, shape, weight, bias, 19 | eps, use_cudnn=True, *args, **kwargs): 20 | with torch.no_grad(): 21 | assert use_cudnn 22 | if self.params == None: # if this is the first forward pass, save the values 23 | (out, mean, rstd), M, N = myln_cpp.forward(input_, shape, weight, bias, eps, use_cudnn) 24 | self.params = eps, mean, rstd, M, N, (input_.requires_grad, weight.requires_grad, bias.requires_grad) 25 | else: 26 | eps, mean, rstd, M, N, do_grad = self.params 27 | out = myln_cpp.forward_recompute(input_, shape, weight, bias, mean, rstd, M, N, eps) 28 | return out 29 | 30 | def backward(self, grad_output, stored, nodel=False): 31 | with torch.no_grad(): 32 | input_, weight, bias = stored 33 | eps, mean, rstd, M, N, do_grad = self.params 34 | if not nodel: 35 | del stored[2], stored[1], stored[0] 36 | if grad_output.is_contiguous(): 37 | di_app, dw_app, db_app = myln_cpp.cudnn_backward(grad_output, input_, mean, rstd, weight, M, N, do_grad) 38 | else: 39 | di_app, dw_app, db_app = myln_cpp.cudnn_backward(grad_output.contiguous(), input_, mean, rstd, weight, M, N, do_grad) 40 | return di_app, dw_app, db_app 41 | 42 | 43 | # @implements(['aten::layer_norm'], ['multiway', 'multiway_newnode', 'conv_multiway_newnode']) 44 | # class OutputActLayerNorm(OP): 45 | # backward_storage = [ 46 | # InputStorage(2, 3), 47 | # OutputStorage(), 48 | # ] 49 | # params = None 50 | 51 | # def forward(self, input_, shape, weight, bias, 52 | # eps, use_cudnn, *args, **kwargs): 53 | # with torch.no_grad(): 54 | # assert use_cudnn 55 | # if self.params == None: # if this is the first forward pass, save the values 56 | # (out, mean, rstd), M, N = myln_cpp.forward(input_, shape, weight, bias, eps, use_cudnn) 57 | # self.params = eps, mean, rstd, M, N, (input_.requires_grad, weight.requires_grad, bias.requires_grad) 58 | # else: 59 | # # TODO - reuse stats 60 | # eps, mean, rstd, M, N, do_grad = self.params 61 | # (out, mean, rstd), M, N = myln_cpp.forward(input_, shape, weight, bias, eps, use_cudnn) 62 | # return out 63 | 64 | # def backward(self, grad_output, stored, nodel=False): 65 | # with torch.no_grad(): 66 | # weight, bias, output = stored 67 | # eps, mean, rstd, M, N, do_grad = self.params 68 | # if not nodel: 69 | # del stored[2], stored[1], stored[0] 70 | # # TODO 71 | # di_app, dw_app, db_app = myln_cpp.output_activate_lnorm_backward(grad_output, output, mean, rstd, weight, M, N, do_grad) 72 | # return di_app, dw_app, db_app -------------------------------------------------------------------------------- /monet/lm_ops/elementary.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .pack import * 3 | import torch 4 | 5 | class NativeOP(OP): 6 | op = None 7 | grad_fn = None 8 | backward_storage = None 9 | 10 | def forward(self, input_, *args, **kwargs): 11 | with torch.enable_grad(): 12 | r = self.op(input_.requires_grad_(True), *args, **kwargs) 13 | assert r.grad_fn is not None 14 | self.grad_fn = r.grad_fn 15 | rcopy = r.detach() 16 | del r, input_ 17 | for arg in args: 18 | del arg 19 | return rcopy 20 | 21 | def backward(self, grad_output, stored, nodel=False): 22 | assert len(stored) == 0 23 | return self.grad_fn(grad_output) 24 | 25 | 26 | @implements(['aten::flatten'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 27 | class Flatten(NativeOP): 28 | @staticmethod 29 | def op(x, start_dim, end_dim): 30 | return x.flatten(start_dim, end_dim) 31 | 32 | @implements(['aten::permute'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 33 | class Permute(NativeOP): 34 | @staticmethod 35 | def op(x, order): 36 | return x.permute(order) 37 | 38 | @implements(['aten::transpose'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 39 | class Transpose(NativeOP): 40 | @staticmethod 41 | def op(x, dim0, dim1): 42 | return torch.transpose(x,dim0,dim1) 43 | 44 | @implements(['aten::view'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 45 | class View(NativeOP): 46 | @staticmethod 47 | def op(x, dim_list): 48 | return x.view(dim_list) 49 | 50 | @implements(['aten::div'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 51 | class Div(NativeOP): 52 | @staticmethod 53 | def op(x, divisor): 54 | return torch.div(x, divisor) 55 | 56 | @implements(['aten::contiguous'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 57 | class Contiguous(NativeOP): 58 | @staticmethod 59 | def op(x, memory_format): 60 | assert memory_format == 0 61 | return x.contiguous() 62 | 63 | @implements(['aten::unsqueeze'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 64 | class Unsqueeze(OP): 65 | def forward(self, x, dim): 66 | assert not x.requires_grad 67 | return torch.unsqueeze(x, dim) 68 | 69 | @implements(['aten::mul'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 70 | class Mul(NativeOP): 71 | @staticmethod 72 | def op(x, alpha): 73 | return x * alpha 74 | 75 | @implements(['aten::t'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 76 | class T(NativeOP): 77 | @staticmethod 78 | def op(x): 79 | return x.t() 80 | 81 | @implements(['aten::avg_pool2d'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 82 | class AvgPool2D(NativeOP): 83 | @staticmethod 84 | def op(*a): 85 | assert a[4] in [0, 1] and a[5] in [0, 1] 86 | a = list(a) 87 | a[4] = (a[4] == 1) 88 | a[5] = (a[5] == 1) 89 | return torch._C._nn.avg_pool2d(*a) 90 | 91 | @implements(['aten::dropout'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 92 | class Dropout(NativeOP): 93 | @staticmethod 94 | def op(input_, p, training=False): 95 | return torch.nn.functional.dropout(input_, p, bool(training)) -------------------------------------------------------------------------------- /checkmate/utils/solver_common.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | import numpy as np 4 | 5 | from checkmate.core.dfgraph import DFGraph 6 | from checkmate.core.utils.definitions import Vertex 7 | 8 | SOLVER_DTYPE = np.int 9 | 10 | 11 | def setup_implied_s_backwards(g: DFGraph, s: np.ndarray = None): 12 | """ 13 | Given a backward graph, this function will set the appropriate items in S to 1 in order 14 | to satisfy no-recompute rules during backwards optimization. 15 | """ 16 | s = s if s is not None else np.zeros((g.size, g.size), dtype=SOLVER_DTYPE) 17 | for (start, end) in g.induce_subgraph(g.vbwd): 18 | for t in range(start + 1, end + 1): 19 | s[t, start] = 1 20 | return s 21 | 22 | 23 | def gen_s_matrix_fixed_checkpoints(g: DFGraph, segment_set: Set[Vertex]): 24 | """ 25 | Given a list of checkpoint locations, this function will generate 26 | as output S matrices denoting checkpoint schedule, given a set of 27 | fixed segments (only recompute once). 28 | """ 29 | T = len(g.vfwd) 30 | Ttotal = g.size 31 | segment_set = list(sorted(segment_set)) 32 | S = np.zeros((g.size, g.size), dtype=SOLVER_DTYPE) 33 | # set minimum input requirements 34 | for v in g.v: 35 | for u in g.predecessors(v): 36 | for t in range(u + 1, v): 37 | S[t, u] = 1 38 | 39 | # stripe every k nodes 40 | for t in range(1, Ttotal): 41 | for i in segment_set: 42 | if i < t: 43 | S[t, i] = 1 44 | 45 | # checkpoint ladders 46 | starts = [0] + list(map(lambda x: x, segment_set)) 47 | ends = segment_set + [T + 1] 48 | for start, end in zip(starts, ends): 49 | for t in filter(lambda t: t < Ttotal, map(lambda x: Ttotal - x - 1, range(start, end))): 50 | for i in range(start, min(t, end)): 51 | S[t, i] = 1 52 | 53 | # forward checkpoint block 54 | for start, end in zip(starts, ends): 55 | for t in filter(lambda t: t < Ttotal, range(start, end + 1)): 56 | for i in range(start, min(t, end)): 57 | S[t, i] = 1 58 | 59 | # backward checkpoint block 60 | # originally used as baselines will checkpoint whole blocks (e.g. Chen 2016 checkpoints entire backwards blocks), 61 | # but removed in public release as schedules are faster without this. 62 | # for start, end in zip(starts, ends): 63 | # for t in filter(lambda _t: _t < Ttotal, range(start, end + 1)): 64 | # back_t = Ttotal - 1 - t 65 | # for i in range(start, end): 66 | # back_i = g.forward_to_backward(i) 67 | # if back_i is not None and back_i < back_t: 68 | # S[back_t, back_i] = 1 69 | 70 | S = setup_implied_s_backwards(g, S) 71 | return S 72 | 73 | 74 | def solve_r_opt(g: DFGraph, s: np.ndarray): 75 | """Find the optimal recomputation pattern given caching decisions. 76 | Given S, E = [(i, j)] where node j depends on the result of node i, 77 | find R that minimizes cost, satisfies constraints. Assumes recomputation 78 | costs are nonnegative. 79 | NOTE: Does NOT check if memory limits are exceeded. 80 | Enforcing R[t,i] != S[t,i] does not seem to be necessary. 81 | """ 82 | T = s.shape[0] 83 | assert s.shape[1] == T 84 | 85 | R = np.eye(T, dtype=s.dtype) # Enforce R_t,t = 1 86 | # Enforce S_{t+1,v} <= S_{t,v} + R_{t,v}, 87 | # i.e. R_{t,v} >= S_{t+1,v} - S_{t,v} 88 | sdiff = s[1:] - s[:-1] 89 | R[:-1] = R[:-1] | (R[:-1] < sdiff) 90 | # Create reverse adjacency list (child -> parents, i.e. node -> dependencies) 91 | adj = [[] for _ in range(T)] 92 | for (u, v) in g.edge_list: 93 | adj[v].append(u) 94 | # Enforce R_{t,v} <= R_{t,u} + S_{t,u} for all (u, v) \in E 95 | for t in range(T): 96 | for v in range(t, -1, -1): 97 | for u in adj[v]: 98 | if R[t, v] > R[t, u] + s[t, u]: 99 | R[t, u] = 1 100 | return R -------------------------------------------------------------------------------- /monet/lm_ops/pool.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .pack import * 3 | import torch 4 | import numpy as np 5 | 6 | maxpool_cpp = load(name="maxpool_cpp", sources=[this_dir / "maxpool.cpp", this_dir / "maxpool.cu"], extra_cflags=['-std=c++17']) 7 | 8 | @implements(['aten::max_pool2d'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 9 | class SimpleMaxPool2D(OP): 10 | backward_storage = InputStorage(0) 11 | params = None 12 | 13 | # NOTE Because fwd calculated indices, the workspace memory will include the 14 | # memory of indice. Which means we are being slightly conservative in memory savings 15 | # for IndicesActMaxPool2D 16 | def forward(self, input_, kernel_size, stride, padding, dilation, ceil_mode): 17 | ceil_mode = bool(ceil_mode) 18 | self.params = kernel_size, stride, padding, dilation, ceil_mode 19 | with torch.no_grad(): 20 | return torch.max_pool2d(input_, kernel_size, stride, padding, dilation, ceil_mode) 21 | 22 | def backward(self, grad_output, stored, nodel=False): 23 | ip = stored[0] 24 | if not nodel: 25 | del stored[0] 26 | kernel_size, stride, padding, dilation, ceil_mode = self.params 27 | with torch.enable_grad(): 28 | tmp = torch.max_pool2d(ip.requires_grad_(True), kernel_size=kernel_size, stride=stride, 29 | padding=padding, dilation=dilation, ceil_mode=ceil_mode) 30 | gradfn = tmp.grad_fn 31 | assert gradfn 32 | del tmp, ip 33 | with torch.no_grad(): 34 | out = gradfn(grad_output) 35 | del gradfn 36 | return out 37 | 38 | @implements(['aten::max_pool2d'], ['newnode', 'multiway_newnode', 'conv_multiway_newnode']) 39 | class IndicesActMaxPool2D(OP): 40 | backward_storage = IntermediateStorage(lambda shape: (np.prod(shape)+1)//2) # Shape is output shape. Two indices are merged in 1 byte 41 | params = None 42 | 43 | def forward(self, input_, kernel_size, stride, padding, dilation, ceil_mode): 44 | with torch.no_grad(): 45 | y, indices, input_size, input_stride, input_ndim, input_numel = maxpool_cpp.max_pool2d_with_indices_cuda(input_, 46 | kernel_size, stride, padding, dilation, ceil_mode) 47 | index_shape = indices.shape 48 | index = indices.view(-1) 49 | del input_ 50 | packed_indices = pack_two(index) 51 | del indices, index 52 | self.params = kernel_size, stride, padding, dilation, bool(ceil_mode), input_size, input_stride, input_ndim, input_numel, index_shape 53 | return y, packed_indices 54 | 55 | def backward(self, grad_output, stored, nodel=False): 56 | with torch.no_grad(): 57 | packed_index = stored[0] 58 | if not nodel: 59 | del stored[0] 60 | kernel_size, stride, padding, dilation, ceil_mode, input_size, input_stride, input_ndim, input_numel, indices_shape = self.params 61 | index = unpack_two(packed_index) 62 | index.resize_(indices_shape) 63 | del packed_index 64 | gradInput = torch.zeros(input_size, device=grad_output.device) 65 | gradInput = maxpool_cpp.max_pool2d_with_indices_backward_out_cuda(gradInput, grad_output, 66 | input_size, input_stride, index, kernel_size, stride, padding, 67 | dilation, input_numel, input_ndim, ceil_mode) 68 | del index 69 | return gradInput 70 | 71 | @implements(['aten::gist_max_pool2d'], ['gist']) 72 | class IndicesActMaxPool2D(OP): 73 | backward_storage = IntermediateStorage(lambda shape: (np.prod(shape)+1)//2) # Shape is output shape. Two indices are merged in 1 byte 74 | params = None 75 | 76 | def forward(self, input_, kernel_size, stride, padding, dilation, ceil_mode): 77 | with torch.no_grad(): 78 | y, indices, input_size, input_stride, input_ndim, input_numel = maxpool_cpp.max_pool2d_with_indices_cuda(input_, 79 | kernel_size, stride, padding, dilation, ceil_mode) 80 | index_shape = indices.shape 81 | index = indices.view(-1) 82 | del input_ 83 | packed_indices = pack_two(index) 84 | del indices, index 85 | self.params = kernel_size, stride, padding, dilation, bool(ceil_mode), input_size, input_stride, input_ndim, input_numel, index_shape 86 | return y, packed_indices 87 | 88 | def backward(self, grad_output, stored, nodel=False): 89 | with torch.no_grad(): 90 | packed_index = stored[0] 91 | if not nodel: 92 | del stored[0] 93 | kernel_size, stride, padding, dilation, ceil_mode, input_size, input_stride, input_ndim, input_numel, indices_shape = self.params 94 | index = unpack_two(packed_index) 95 | index.resize_(indices_shape) 96 | del packed_index 97 | gradInput = torch.zeros(input_size, device=grad_output.device) 98 | gradInput = maxpool_cpp.max_pool2d_with_indices_backward_out_cuda(gradInput, grad_output, 99 | input_size, input_stride, index, kernel_size, stride, padding, 100 | dilation, input_numel, input_ndim, ceil_mode) 101 | del index 102 | return gradInput -------------------------------------------------------------------------------- /monet/lm_ops/lrln.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | // #include 4 | 5 | // #include 6 | // #include 7 | // #include 8 | // #include 9 | // #include 10 | // #include 11 | // #include 12 | // #include 13 | // #include 14 | // #include "lrbn.cuh" 15 | 16 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 17 | #define CHECK_INPUT(x) CHECK_CONTIGUOUS(x) 18 | 19 | using namespace at; 20 | using namespace at::native; 21 | 22 | template 23 | void LayerNormForwardCUDA1( 24 | int64_t M, 25 | int64_t N, 26 | const T* X, 27 | const T* mean, 28 | const T* rstd, 29 | const T* gamma, 30 | const T* beta, 31 | T* Y); 32 | 33 | std::tuple, int64_t, int64_t> do_forward( 34 | const torch::Tensor& input, 35 | IntArrayRef normalized_shape, 36 | const torch::Tensor& weight /* optional */, 37 | const torch::Tensor& bias /* optional */, 38 | double eps, 39 | bool cudnn_enable) { 40 | torch::NoGradGuard no_grad_guard; 41 | 42 | const int normalized_ndim = normalized_shape.size(); 43 | TORCH_CHECK( 44 | normalized_ndim >= 1, 45 | "Expected normalized_shape to be at least 1-dimensional, i.e., ", 46 | "containing at least one element, but got normalized_shape = ", 47 | normalized_shape); 48 | TORCH_CHECK( 49 | !weight.defined() || weight.sizes().equals(normalized_shape), 50 | "Expected weight to be of same shape as normalized_shape, but got ", 51 | "weight of shape ", 52 | weight.sizes(), 53 | " and normalized_shape = ", 54 | normalized_shape); 55 | TORCH_CHECK( 56 | !bias.defined() || bias.sizes().equals(normalized_shape), 57 | "Expected bias to be of same shape as normalized_shape, but got ", 58 | "bias of shape ", 59 | bias.sizes(), 60 | " and normalized_shape = ", 61 | normalized_shape); 62 | 63 | const auto input_shape = input.sizes(); 64 | const auto input_ndim = input.dim(); 65 | 66 | if (input_ndim < normalized_ndim || 67 | !input_shape.slice(input_ndim - normalized_ndim) 68 | .equals(normalized_shape)) { 69 | std::stringstream ss; 70 | ss << "Given normalized_shape=" << normalized_shape 71 | << ", expected input with shape [*"; 72 | for (auto size : normalized_shape) { 73 | ss << ", " << size; 74 | } 75 | ss << "], but got input of size" << input_shape; 76 | AT_ERROR(ss.str()); 77 | } 78 | 79 | const int axis = input_ndim - normalized_ndim; 80 | const int64_t M = std::accumulate( 81 | input_shape.cbegin(), 82 | input_shape.cbegin() + axis, 83 | 1LL, 84 | std::multiplies()); 85 | const int64_t N = std::accumulate( 86 | input_shape.cbegin() + axis, 87 | input_shape.cend(), 88 | 1LL, 89 | std::multiplies()); 90 | 91 | const auto& X = input.is_contiguous() ? input : input.contiguous(); 92 | const auto& gamma = weight.is_contiguous() ? weight : weight.contiguous(); 93 | const auto& beta = bias.is_contiguous() ? bias : bias.contiguous(); 94 | 95 | return {at::native_layer_norm(X, gamma, beta, M, N, eps), M, N}; 96 | } 97 | 98 | std::tuple do_cudnn_backward( 99 | const Tensor& dY, 100 | const Tensor& X, 101 | const Tensor& mean, 102 | const Tensor& rstd, 103 | const Tensor& gamma, 104 | int64_t M, 105 | int64_t N, 106 | std::array grad_input_mask) { 107 | 108 | return at::native::layer_norm_backward_cuda(dY, X, mean, rstd, gamma, M, N, grad_input_mask); 109 | } 110 | 111 | Tensor layer_norm_recompute_cuda( 112 | const Tensor& input, 113 | IntArrayRef normalized_shape, 114 | const Tensor& weight /* optional */, 115 | const Tensor& bias /* optional */, 116 | const Tensor& mean1, 117 | const Tensor& rstd1, 118 | int64_t M, int64_t N, 119 | double eps) { 120 | 121 | const auto& X = input.is_contiguous() ? input : input.contiguous(); 122 | const auto& gamma = weight.is_contiguous() ? weight : weight.contiguous(); 123 | const auto& beta = bias.is_contiguous() ? bias : bias.contiguous(); 124 | const auto& mean = mean1.is_contiguous() ? mean1 : mean1.contiguous(); 125 | const auto& rstd = rstd1.is_contiguous() ? rstd1 : rstd1.contiguous(); 126 | 127 | Tensor Y = at::native::empty_like(X, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 128 | if (M > 0) { 129 | AT_DISPATCH_FLOATING_TYPES( 130 | X.scalar_type(), "LayerNormKernelImpl1", [&]() { 131 | 132 | DCHECK_EQ(X.numel(), M * N); 133 | DCHECK(!gamma.defined() || gamma.numel() == N); 134 | DCHECK(!beta.defined() || beta.numel() == N); 135 | const scalar_t* X_data = X.data_ptr(); 136 | const scalar_t* gamma_data = gamma.defined() ? gamma.data_ptr() : nullptr; 137 | const scalar_t* beta_data = beta.defined() ? beta.data_ptr() : nullptr; 138 | scalar_t* Y_data = Y.data_ptr(); 139 | scalar_t* mean_data = mean.data_ptr(); 140 | scalar_t* rstd_data = rstd.data_ptr(); 141 | LayerNormForwardCUDA1(M, N, X_data, mean_data, 142 | rstd_data, gamma_data, beta_data, Y_data); 143 | AT_CUDA_CHECK(cudaGetLastError()); 144 | }); 145 | } 146 | return std::move(Y); 147 | } 148 | 149 | 150 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 151 | m.def("forward", &do_forward, "LN forward"); 152 | m.def("cudnn_backward", &do_cudnn_backward, "LN backward"); 153 | m.def("forward_recompute", &layer_norm_recompute_cuda, "LN forward recompute"); 154 | } -------------------------------------------------------------------------------- /monet/lm_ops/pack.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 5 | 6 | void pack_cpu(const bool * u, size_t N, uint8_t * p) { 7 | for(size_t i=0, j=0; i= 0 && i= 0 && i> k) & 1; 20 | } 21 | } 22 | 23 | template 24 | void unpack_multiply_cpu(const uint8_t * p, const T * v, size_t N, T * r) { 25 | for(size_t i=0, j=0; i= 0 && i> k) & 1) * v[i]; 29 | } 30 | } 31 | void pack_gpu(const bool *, size_t, uint8_t *); 32 | void pack_two_gpu(const uint8_t *, size_t, uint8_t *); 33 | void unpack_gpu(const uint8_t *, size_t, bool *); 34 | void unpack_two_gpu(const uint8_t *, size_t, uint8_t *); 35 | template 36 | void unpack_multiply_gpu(const uint8_t *, const T *, size_t N, T *); 37 | 38 | torch::Tensor pack_two(const torch::Tensor& u) { 39 | TORCH_CHECK(u.dim() == 1, "Only 1D tensors supported."); 40 | TORCH_CHECK(u.scalar_type() == torch::kUInt8, "Only UInt8 tensors supported."); 41 | CHECK_CONTIGUOUS(u); 42 | size_t N = u.size(0); 43 | torch::Tensor p = torch::zeros((N+1) / 2, torch::dtype(torch::kUInt8).device(u.device())); 44 | if (u.type().is_cuda()) { 45 | pack_two_gpu(u.data_ptr(), N, p.data_ptr()); 46 | TORCH_CHECK(cudaGetLastError() == cudaSuccess, 47 | "pack_two failed with error code ", 48 | cudaGetLastError()); 49 | } else { 50 | std::cout<<"Unimplemented\n"; 51 | } 52 | return p; 53 | } 54 | torch::Tensor unpack_two(const torch::Tensor& p, size_t N=0) { 55 | TORCH_CHECK(p.dim() == 1, "Only 1D tensors supported."); 56 | TORCH_CHECK(p.scalar_type() == torch::kUInt8, "Only uint8 tensors supported."); 57 | CHECK_CONTIGUOUS(p); 58 | if (!N) 59 | N = p.size(0) * 2; 60 | torch::Tensor u = torch::zeros(N, torch::dtype(torch::kUInt8).device(p.device())); 61 | if (p.type().is_cuda()) { 62 | unpack_two_gpu(p.data_ptr(), N, u.data_ptr()); 63 | TORCH_CHECK(cudaGetLastError() == cudaSuccess, 64 | "unpack_two failed with error code ", 65 | cudaGetLastError()); 66 | } else { 67 | std::cout<<"Unimplemented\n"; 68 | } 69 | return u; 70 | } 71 | 72 | torch::Tensor pack(const torch::Tensor& u) { 73 | TORCH_CHECK(u.dim() == 1, "Only 1D tensors supported."); 74 | TORCH_CHECK(u.scalar_type() == torch::kBool, "Only bool tensors supported."); 75 | CHECK_CONTIGUOUS(u); 76 | size_t N = u.size(0); 77 | torch::Tensor p = torch::zeros((N-1) / 8 + 1, torch::dtype(torch::kUInt8).device(u.device())); 78 | if (u.type().is_cuda()) 79 | pack_gpu(u.data_ptr(), N, p.data_ptr()); 80 | else 81 | pack_cpu(u.data_ptr(), N, p.data_ptr()); 82 | return p; 83 | } 84 | void pack_(const torch::Tensor& u, torch::Tensor& p) { 85 | TORCH_CHECK(u.dim() == 1, "Only 1D tensors supported."); 86 | TORCH_CHECK(u.scalar_type() == torch::kBool, "Only bool tensors supported."); 87 | CHECK_CONTIGUOUS(u); 88 | size_t N = u.size(0); 89 | TORCH_CHECK(p.dim() == 1, "Only 1D tensors supported."); 90 | TORCH_CHECK(p.scalar_type() == torch::kByte, "Only byte tensors supported."); 91 | TORCH_CHECK(p.dim() == 1, "Only 1D tensors supported."); 92 | TORCH_CHECK(p.size(0) == (N-1) / 8 + 1, "Only tensors size mismatch."); 93 | CHECK_CONTIGUOUS(p); 94 | TORCH_CHECK(u.type().is_cuda() == p.type().is_cuda(), "Device mismatch."); 95 | if (u.type().is_cuda()) 96 | pack_gpu(u.data_ptr(), N, p.data_ptr()); 97 | else 98 | pack_cpu(u.data_ptr(), N, p.data_ptr()); 99 | } 100 | torch::Tensor unpack(const torch::Tensor& p, size_t N=0) { 101 | TORCH_CHECK(p.dim() == 1, "Only 1D tensors supported."); 102 | TORCH_CHECK(p.scalar_type() == torch::kUInt8, "Only uint8 tensors supported."); 103 | CHECK_CONTIGUOUS(p); 104 | if (!N) 105 | N = p.size(0) * 8; 106 | torch::Tensor u = torch::zeros(N, torch::dtype(torch::kBool).device(p.device())); 107 | if (p.type().is_cuda()) 108 | unpack_gpu(p.data_ptr(), N, u.data_ptr()); 109 | else 110 | unpack_cpu(p.data_ptr(), N, u.data_ptr()); 111 | return u; 112 | } 113 | torch::Tensor unpack_multiply(const torch::Tensor& p, const torch::Tensor& v) { 114 | TORCH_CHECK(p.dim() == 1, "Only 1D tensors supported."); 115 | TORCH_CHECK(v.dim() == 1, "Only 1D tensors supported."); 116 | TORCH_CHECK(v.device() == p.device(), "All input need to be located on the same device."); 117 | TORCH_CHECK(p.scalar_type() == torch::kUInt8, "Only uint8 tensors supported."); 118 | CHECK_CONTIGUOUS(p); 119 | CHECK_CONTIGUOUS(v); 120 | torch::Tensor r = torch::empty_like(v); 121 | size_t N = v.size(0); 122 | if (p.type().is_cuda()) 123 | AT_DISPATCH_FLOATING_TYPES(v.type(), "unpack_multiply_gpu", ([&]{ 124 | unpack_multiply_gpu(p.data_ptr(), v.data_ptr(), N, r.data_ptr()); 125 | })); 126 | else 127 | AT_DISPATCH_FLOATING_TYPES(v.type(), "unpack_multiply_cpu", ([&]{ 128 | unpack_multiply_cpu(p.data_ptr(), v.data_ptr(), N, r.data_ptr()); 129 | })); 130 | return r; 131 | } 132 | 133 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 134 | m.def("pack_two", &pack_two, "Convert a uint8 tensor to a uint8 packed to half size"); 135 | m.def("unpack_two", &unpack_two, "Convert a uint8 tensor to a double sized uint8 unpacked", py::arg("p"), py::arg("N") = 0); 136 | m.def("pack", &pack, "Convert a boolean tensor to a uint8 packed"); 137 | m.def("pack_", &pack_, "Convert a boolean tensor to a uint8 packed"); 138 | m.def("unpack", &unpack, "Convert a uint8 tensor to a boolean unpacked", py::arg("p"), py::arg("N") = 0); 139 | m.def("unpack_multiply", &unpack_multiply, "Convert a uint8 tensor to a boolean unpacked and multiply it with a second tensor", py::arg("p"), py::arg("v") = 0); 140 | } 141 | -------------------------------------------------------------------------------- /monet/lm_ops/bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base import * 3 | from .pack import * 4 | import numpy as np 5 | from torch.utils.cpp_extension import load 6 | from pathlib import Path 7 | 8 | this_dir = Path(__file__).parent 9 | mybn_cpp = load(name="lrbn_cpp", sources=[this_dir / "lrbn.cpp", this_dir / "lrbn.cu"], extra_cflags=['-std=c++17']) 10 | 11 | 12 | @implements(['aten::batch_norm'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal']) 13 | class InputActBatchNorm(OP): 14 | backward_storage = InputStorage(0, 1, 2, 3, 4) 15 | params = None # params memory is unaccounted for => save_mean and save_var will occupy memory over the expected memory 16 | running_mean = None 17 | running_var = None 18 | 19 | def forward(self, input_, weight, bias, 20 | running_mean, running_var, 21 | training=0, momentum=0.1, eps=1e-5, *args, **kwargs): 22 | with torch.no_grad(): 23 | if self.running_mean is None: 24 | self.running_mean = running_mean.requires_grad_(False) 25 | self.running_var = running_var.requires_grad_(False) 26 | 27 | if self.params == None: # if this is the first forward pass, save the values 28 | out, save_mean, save_var, res_space = mybn_cpp.forward( 29 | input_, self.running_mean, self.running_var, weight, bias, 30 | training, momentum, eps) 31 | self.params = training, eps, save_mean, save_var, res_space 32 | else: 33 | training, eps, batch_mean, batch_var, res_space = self.params 34 | out = torch.nn.functional.batch_norm(input_, batch_mean, 1/torch.square(batch_var), weight, bias, False, momentum, eps) 35 | return out 36 | 37 | def backward(self, grad_output, stored, nodel=False): 38 | with torch.no_grad(): 39 | input_, weight, bias, running_mean, running_var = stored 40 | training, eps, save_mean, save_var, res_space = self.params 41 | 42 | if training: 43 | di_app, dw_app, db_app = mybn_cpp.cudnn_backward( 44 | input_, grad_output, weight, 45 | running_mean, running_var, 46 | save_mean, save_var, 47 | eps, res_space) 48 | else: 49 | raise NotImplementedError 50 | 51 | if not nodel: 52 | del stored[4], stored[3], stored[2], stored[1], stored[0] 53 | 54 | return di_app, dw_app, db_app, None, None 55 | 56 | @implements(['aten::batch_norm'], ['gist']) 57 | class InputActBatchNorm(OP): 58 | backward_storage = InputStorage(0, 1, 2, 3, 4) 59 | params = None # params memory is unaccounted for => save_mean and save_var will occupy memory over the expected memory 60 | running_mean = None 61 | running_var = None 62 | 63 | def forward(self, input_, weight, bias, 64 | running_mean, running_var, 65 | training=0, momentum=0.1, eps=1e-5, *args, **kwargs): 66 | with torch.no_grad(): 67 | if self.running_mean is None: 68 | self.running_mean = running_mean.requires_grad_(False) 69 | self.running_var = running_var.requires_grad_(False) 70 | 71 | out, save_mean, save_var, res_space = mybn_cpp.forward( 72 | input_, self.running_mean, self.running_var, weight, bias, 73 | training, momentum, eps) 74 | self.params = training, eps, save_mean, save_var, res_space 75 | return out 76 | 77 | def backward(self, grad_output, stored, nodel=False): 78 | with torch.no_grad(): 79 | input_, weight, bias, running_mean, running_var = stored 80 | training, eps, save_mean, save_var, res_space = self.params 81 | 82 | if training: 83 | di_app, dw_app, db_app = mybn_cpp.cudnn_backward( 84 | input_, grad_output, weight, 85 | running_mean, running_var, 86 | save_mean, save_var, 87 | eps, res_space) 88 | else: 89 | raise NotImplementedError 90 | 91 | if not nodel: 92 | del stored[4], stored[3], stored[2], stored[1], stored[0] 93 | 94 | return di_app, dw_app, db_app, None, None 95 | 96 | @implements(['aten::batch_norm'], ['multiway', 'multiway_newnode', 'conv_multiway_newnode']) 97 | class OutputActBatchNorm(OP): 98 | backward_storage = [ 99 | InputStorage(1, 2, 3, 4), 100 | OutputStorage(), 101 | ] 102 | params = None 103 | running_mean = None 104 | running_var = None 105 | 106 | def forward(self, input_, weight, bias, 107 | running_mean, running_var, 108 | training=False, momentum=0.1, eps=1e-5, *args, **kwargs): 109 | with torch.no_grad(): 110 | if self.running_mean is None: 111 | self.running_mean = running_mean.requires_grad_(False) 112 | self.running_var = running_var.requires_grad_(False) 113 | 114 | if self.params == None: # if this is the first forward pass, save the values 115 | out, save_mean, save_var, res_space = mybn_cpp.forward( 116 | input_, self.running_mean, self.running_var, weight, bias, 117 | training, momentum, eps) 118 | self.params = training, eps, save_mean, save_var, res_space 119 | else: 120 | training, eps, batch_mean, batch_var, res_space = self.params 121 | out = torch.nn.functional.batch_norm(input_, batch_mean, 1/torch.square(batch_var), weight, bias, False, momentum, eps) 122 | return out 123 | 124 | def backward(self, grad_output, stored, nodel=False): 125 | with torch.no_grad(): 126 | weight, bias, running_mean, running_var, output = stored 127 | training, eps, save_mean, save_var, res_space = self.params 128 | if training: 129 | di_app, dw_app, db_app = mybn_cpp.output_activated_bn_backward( 130 | grad_output, output, weight, bias, 131 | running_mean, running_var, 132 | save_mean, save_var, training, 133 | eps, [True, True,True]) 134 | else: 135 | raise NotImplementedError 136 | 137 | if not nodel: 138 | del stored[4], stored[3], stored[2], stored[1], stored[0] 139 | 140 | return di_app, dw_app, db_app, None, None 141 | -------------------------------------------------------------------------------- /monet/lm_ops/lrfuncs.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 5 | #define CHECK_INPUT(x) CHECK_CONTIGUOUS(x) 6 | 7 | using namespace at; 8 | using namespace at::native; 9 | 10 | torch::Tensor maybe_multiply1(const torch::Tensor & t, const Scalar & s) { 11 | bool is_one = false; 12 | if (s.isFloatingPoint()) { 13 | is_one = s.toDouble() == 1; 14 | } else if(s.isIntegral(true)) { 15 | is_one = s.toLong() == 1; 16 | } 17 | 18 | if (is_one) { 19 | return t; 20 | } else { 21 | return t * s; 22 | } 23 | } 24 | 25 | torch::Tensor adaptive_avg_pool_backward(const torch::Tensor& grad, const torch::Tensor& self) { 26 | return at::native::adaptive_avg_pool2d_backward_cuda(grad, self); 27 | } 28 | 29 | torch::Tensor embedding_forward(const torch::Tensor & weight, const torch::Tensor & indices, 30 | int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { 31 | return at::native::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse); 32 | } 33 | 34 | torch::Tensor embedding_backward1(const torch::Tensor & grad_, const torch::Tensor & indices, 35 | int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { 36 | return at::native::embedding_dense_backward_cuda(grad_, indices, num_weights, padding_idx, scale_grad_by_freq); 37 | } 38 | 39 | torch::Tensor mm_mat1_backward_t(const torch::Tensor & grad, const torch::Tensor & mat2, at::IntArrayRef mat1_sizes, at::IntArrayRef mat1_strides) { 40 | // Not using at call because it requires mat1 also (in pytorch v1.5.1) 41 | // return at::native::mm_mat1_backward(grad, mat2, mat1_sizes, mat1_strides, 1); 42 | // Assuming mat1 won't be sparse 43 | if (mat1_strides[0] == 1 && mat1_strides[1] == mat1_sizes[0]) { 44 | return maybe_multiply1((mat2.t()).mm(grad.t()).t(), 1); 45 | } else { 46 | return maybe_multiply1(grad.mm(mat2), 1); 47 | } 48 | } 49 | 50 | torch::Tensor mm_mat1_backward(const torch::Tensor & grad, const torch::Tensor & mat2, at::IntArrayRef mat1_sizes, at::IntArrayRef mat1_strides) { 51 | // Not using at call because it requires mat1 also (in pytorch v1.5.1) 52 | // return at::native::mm_mat1_backward(grad, mat2, mat1_sizes, mat1_strides, 1); 53 | // Assuming mat1 won't be sparse 54 | if (mat1_strides[0] == 1 && mat1_strides[1] == mat1_sizes[0]) { 55 | return maybe_multiply1(mat2.mm(grad.t()).t(), 1); 56 | } else { 57 | return maybe_multiply1(grad.mm(mat2.t()), 1); 58 | } 59 | } 60 | 61 | torch::Tensor mm_mat2_backward(const torch::Tensor & grad, const torch::Tensor & mat1, at::IntArrayRef sizes, at::IntArrayRef strides) { 62 | // return torch::autograd::generated::mm_mat2_backward(grad, mat1, sizes, strides, 1); 63 | if (strides[0] == 1 && strides[1] == sizes[0]) { 64 | if (mat1.is_sparse()) { 65 | // Since mm(dense, sparse) doesn't exist, 66 | // pass a transposed output matrix to the underlying "addmm" 67 | // function directly. 68 | int64_t out_rows = mat1.size(1); 69 | int64_t out_cols = grad.size(1); 70 | Tensor t = at::zeros({}, grad.options()).expand({out_rows, out_cols}, true); 71 | Tensor r = at::empty({out_cols, out_rows}, grad.options()).t(); 72 | at::addmm_out(r, t, mat1.t(), grad, 1, 1); 73 | return r; 74 | } 75 | return maybe_multiply1(grad.t().mm(mat1).t(), 1); 76 | } else { 77 | return maybe_multiply1(mat1.t().mm(grad), 1); 78 | } 79 | } 80 | 81 | torch::Tensor slice_forward(const torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) { 82 | return at::native::slice(self, dim, start, end, step); 83 | } 84 | 85 | torch::Tensor slice_backward1(const torch::Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { 86 | // return at::native::slice_backward(grad, input_sizes, dim, start, end, step); // pytorch v1.7 87 | auto grad_input = at::zeros(input_sizes, grad.options()); 88 | grad_input.slice(dim, start, end, step).copy_(grad); 89 | return grad_input; 90 | } 91 | 92 | torch::Tensor select_forward(const torch::Tensor& self, int64_t dim, int64_t index) { 93 | return at::native::select(self, dim, index); 94 | } 95 | 96 | torch::Tensor select_backward1(const torch::Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { 97 | // return at::native::select_backward(grad, input_sizes, dim, index); // pytorch v1.7 98 | auto grad_input = at::zeros(input_sizes, grad.options()); 99 | grad_input.select(dim, index).copy_(grad); 100 | return grad_input; 101 | } 102 | 103 | torch::Tensor tanh_backward1(const torch::Tensor& grad_output, const torch::Tensor& output) { 104 | return at::tanh_backward(grad_output, output); 105 | } 106 | 107 | torch::Tensor gelu_backward1(const torch::Tensor& grad_output, const torch::Tensor& input_) { 108 | return at::native::gelu_backward_cuda(grad_output, input_); 109 | } 110 | 111 | torch::Tensor rsub_const_forward(const torch::Tensor& input_, float other, int alpha) { 112 | return at::native::rsub(input_, at::Scalar(other), at::Scalar(alpha)); 113 | } 114 | 115 | // torch::Tensor zeros(int size[], c10::ScalarType dtype, c10::Layout layout, c10::Device device, bool pin) { 116 | // return at::zeros(size, dtype, layout, device, pin); 117 | // } 118 | 119 | torch::Tensor tofwd(torch::Tensor input_, uint8_t dtype, bool b1, bool b2) { 120 | return at::native::to(input_, static_cast(dtype), b1, b2, c10::nullopt); 121 | } 122 | 123 | torch::Tensor lr_upsample_nearest_3d_backward(const Tensor& grad_output, IntArrayRef output_size, IntArrayRef input_size, 124 | c10::optional scales_d, c10::optional scales_h, c10::optional scales_w) { 125 | return at::native::upsample_nearest3d_backward_cuda(grad_output, output_size, input_size, scales_d, scales_h, scales_w); 126 | } 127 | 128 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 129 | m.def("adaptive_avg_pool_backward", &adaptive_avg_pool_backward, "Adaptive Avg Pool Backward"); 130 | m.def("embedding", &embedding_forward, "Embedding forward"); 131 | m.def("embedding_backward", &embedding_backward1, "Embedding backward"); 132 | m.def("mm_mat1_backward", &mm_mat1_backward, "mm matrix1 backward"); 133 | m.def("mm_mat1_backward_t", &mm_mat1_backward, "mm matrix1 backward whose saved input2 had been transposed before mm"); 134 | m.def("mm_mat2_backward", &mm_mat2_backward, "mm matrix2 backward"); 135 | m.def("slice", &slice_forward, "Slice forward"); 136 | m.def("slice_backward", &slice_backward1, "Slice backward"); 137 | m.def("select", &select_forward, "Select forward"); 138 | m.def("select_backward", &select_backward1, "Select backward"); 139 | // m.def("softmax", &softmax_forward, "Softmax forward"); 140 | // m.def("softmax_backward", &softmax_backward1, "Softmax backward"); 141 | m.def("tanh_backward", &tanh_backward1, "Tanh backward"); 142 | m.def("gelu_backward", &gelu_backward1, "GeLU backward"); 143 | m.def("rsub_const", &rsub_const_forward, "Rsub forward (other has type Scalar)"); 144 | // m.def("zeros", &zeros, "Create zeros"); 145 | m.def("tofwd", &tofwd, "To fwd impl"); 146 | m.def("lr_upsample_nearest_3d_backward", &lr_upsample_nearest_3d_backward, "Upsample Nearest 3D Backward"); 147 | } -------------------------------------------------------------------------------- /monet/graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import lru_cache 3 | 4 | 5 | class Node: 6 | def __init__(self, shape): 7 | self.shape = shape 8 | 9 | 10 | class Input(Node): 11 | def __repr__(self): 12 | return '' % str(list(self.shape)) 13 | 14 | 15 | class Param(Node): 16 | def __repr__(self): 17 | return '' % str(list(self.shape)) 18 | 19 | 20 | class ComputeNode(Node): 21 | class Arg: 22 | pass 23 | 24 | class V(Arg): 25 | def __init__(self, v): 26 | self.value = v 27 | 28 | def __repr__(self): 29 | return '' % str(self.value) 30 | 31 | class D(Arg): 32 | def __init__(self, index, requires_grad=False): 33 | self.index = index 34 | self.requires_grad = requires_grad 35 | 36 | def __repr__(self): 37 | return '' % (self.index, self.requires_grad) 38 | 39 | def __init__(self, shape, nodeid, op, args, has_backward, is_depthwise=False): 40 | super().__init__(shape) 41 | self._op = op 42 | self._args = args 43 | self.id = nodeid 44 | self._has_backward = has_backward 45 | self._is_depthwise = is_depthwise 46 | 47 | @property 48 | @lru_cache(maxsize=512) 49 | def op(self): 50 | return self._op 51 | 52 | @property 53 | def args(self): 54 | return self._args 55 | 56 | @property 57 | def has_backward(self): 58 | return self._has_backward 59 | 60 | @property 61 | @lru_cache(maxsize=512) 62 | def is_depthwise(self): 63 | return self._is_depthwise 64 | 65 | @property 66 | @lru_cache(maxsize=128) 67 | def dependencies(self): 68 | return [(a.index, a.requires_grad) for a in self._args if isinstance(a, self.D)] 69 | 70 | def __repr__(self): 71 | return '' % (str(self._op), str(list(self.shape))) 72 | 73 | 74 | class Graph: 75 | def __init__(self): 76 | self._nodes = [] 77 | self._outputs = [] 78 | 79 | def _add_node(self, node): 80 | self._nodes.append(node) 81 | return len(self._nodes)-1 82 | 83 | def _add_input(self, shape): 84 | return self._add_node(Input(shape)) 85 | 86 | def _add_param(self, shape): 87 | return self._add_node(Param(shape)) 88 | 89 | def _add_op(self, shape, op, args, has_backward=False, is_depthwise=False): 90 | nodeid = len(self._nodes) 91 | return self._add_node(ComputeNode(shape, nodeid, op, args, has_backward, is_depthwise)) 92 | 93 | def _add_output(self, output_id): 94 | self._outputs.append(output_id) 95 | 96 | @property 97 | def nodes(self): 98 | return self._nodes 99 | 100 | @classmethod 101 | def create(cls, model, input_shape=(3, 224, 224)): 102 | # create a graph of the forward pass 103 | # JIT trace the model 104 | args = (torch.ones((23,) + input_shape),) 105 | graph, torch_out = torch.jit._get_trace_graph(model, args, _force_outplace=False, _return_inputs_states=False) 106 | torch._C._jit_pass_constant_propagation(graph) 107 | torch._C._jit_pass_inline(graph) 108 | torch._C._jit_pass_dce(graph) 109 | torch._C._jit_pass_lint(graph) 110 | params = torch.jit._unique_state_dict(model) 111 | 112 | assert len(list(graph.inputs())) == len(args) + len(params) 113 | node_id = {} 114 | r = cls() 115 | arg_and_param_shape = [list(a.shape) for a in args] + [list(p.shape) for p in params.values()] 116 | for k, i in enumerate(graph.inputs()): 117 | if k < len(args): 118 | node_id[i.unique()] = r._add_input([-1]+arg_and_param_shape[k][1:]) 119 | else: 120 | node_id[i.unique()] = r._add_param(arg_and_param_shape[k]) 121 | 122 | const = {} 123 | 124 | # Track connected nodes in the graph 125 | track = set() 126 | track.add("input.1") 127 | for node in graph.nodes(): 128 | if node.kind()!="aten::size": 129 | for ip in node.inputs(): 130 | if ip.debugName() in track or "input" in ip.debugName(): 131 | track.add(node.output().debugName()) 132 | if "input" in node.output().debugName(): 133 | track.add(node.output().debugName()) 134 | 135 | list_contents = {} 136 | for n in graph.nodes(): 137 | assert n.kind() != 'prim::GetAttr' 138 | if n.kind() == 'prim::Constant': 139 | const[n.output().unique()] = n['value'] if n.hasAttribute('value') else None 140 | elif len(n.kind()) > 6 and n.kind()[:6] == 'aten::': 141 | args = [] 142 | for i in n.inputs(): 143 | iu = i.unique() 144 | if iu in list_contents: 145 | iu_list = list_contents[iu] 146 | else: 147 | iu_list = [iu] 148 | for iu in iu_list: 149 | if iu in const: 150 | args.append(ComputeNode.V(const[iu])) 151 | elif iu in node_id: 152 | if i.debugName() not in track and (not isinstance(r._nodes[node_id[iu]], Input)) and (not isinstance(r._nodes[node_id[iu]], Param)): # Doing this for addmm and transpose 153 | for ii in i.node().inputs(): 154 | iiu = ii.unique() 155 | assert (isinstance(r._nodes[node_id[iiu]], Input) or isinstance(r._nodes[node_id[iiu]], Param)) == True 156 | args.append(ComputeNode.D(node_id[iiu], ii.requires_grad())) 157 | else: 158 | args.append(ComputeNode.D(node_id[iu], i.requires_grad())) 159 | else: 160 | raise ValueError('Nodes %s disconnected' % repr(i)) 161 | has_backward = False 162 | if n.output().debugName() in track: 163 | has_backward = True 164 | # Identify depthwise conv 165 | is_depthwise = False 166 | if n.kind() == "aten::_convolution": 167 | assert isinstance(args[8], ComputeNode.V) 168 | if args[8].value > 1 and args[8].value == r.nodes[args[0].index].shape[1]: 169 | is_depthwise = True 170 | # Add node to graph 171 | node_id[n.output().unique()] = r._add_op([s if s != 23 else -1 for s in n.output().type().sizes()], 172 | n.kind(), args, has_backward, is_depthwise) 173 | elif n.kind() in ['prim::ListConstruct', 'prim::TupleConstruct']: 174 | list_contents[n.output().unique()] = [i.unique() for i in n.inputs()] 175 | else: 176 | print('Unknown OP', n.kind(), n) 177 | # Identify outputs 178 | for op in graph.outputs(): 179 | if op.node().kind()[:6] == 'aten::': 180 | r._add_output(node_id[op.unique()]) 181 | elif op.node().kind() == 'prim::TupleConstruct': 182 | for i in op.node().inputs(): 183 | r._add_output(node_id[i.unique()]) 184 | return r 185 | -------------------------------------------------------------------------------- /monet/lm_ops/relu.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .compress import * 3 | from .pack import * 4 | import torch 5 | import numpy as np 6 | 7 | lrrelu_cpp = load(name="lrrelu_cpp", sources=[this_dir/"lrrelu.cpp", this_dir/"lrrelu.cu"], extra_cflags=['-std=c++17']) 8 | 9 | @implements(['aten::relu', 'aten::relu_'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode']) 10 | class InputActReLU(OP): 11 | backward_storage = InputStorage(0) 12 | params = None 13 | inplace = False 14 | 15 | def forward(self, x): 16 | with torch.no_grad(): 17 | if self.inplace: 18 | return torch.relu_(x) 19 | return torch.relu(x) 20 | 21 | def backward(self, x, stored, nodel=False): 22 | with torch.no_grad(): 23 | ip = stored[0] 24 | if not nodel: 25 | del stored[0] 26 | d = list(ip.shape) 27 | N = np.prod(d) 28 | y = lrrelu_cpp.relu_backward(x.view(N), ip.view(N), 0, N) 29 | return y.view(d) 30 | 31 | 32 | @implements(['aten::relu', 'aten::relu_'], ['multiway', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal']) 33 | class OutputActReLU(OP): 34 | backward_storage = OutputStorage() 35 | params = None 36 | inplace = False 37 | 38 | def forward(self, x): 39 | with torch.no_grad(): 40 | if self.inplace: 41 | return torch.relu_(x) 42 | return torch.relu(x) 43 | 44 | def backward(self, x, stored, nodel=False): 45 | with torch.no_grad(): 46 | op = stored[0] 47 | if not nodel: 48 | del stored[0] 49 | d = list(op.shape) 50 | N = np.prod(d) 51 | y = lrrelu_cpp.relu_backward(x.view(N), op.view(N), 0, N) 52 | return y.view(d) 53 | 54 | @implements(['aten::relu'], ['gist']) 55 | class IPActReLU(OP): 56 | backward_storage = InputStorage(0) 57 | params = None 58 | inplace = False 59 | 60 | def forward(self, x): 61 | with torch.no_grad(): 62 | return torch.relu(x) 63 | 64 | def backward(self, x, stored, nodel=False): 65 | with torch.no_grad(): 66 | op = stored[0] 67 | if not nodel: 68 | del stored[0] 69 | d = list(op.shape) 70 | N = np.prod(d) 71 | y = lrrelu_cpp.relu_backward(x.view(N), op.view(N), 0, N) 72 | return y.view(d) 73 | 74 | @implements(['aten::relu_'], ['gist']) 75 | class OutputActReLU(OP): 76 | backward_storage = OutputStorage() 77 | params = None 78 | inplace = True 79 | 80 | def forward(self, x): 81 | with torch.no_grad(): 82 | return torch.relu_(x) 83 | 84 | def backward(self, x, stored, nodel=False): 85 | with torch.no_grad(): 86 | op = stored[0] 87 | if not nodel: 88 | del stored[0] 89 | d = list(op.shape) 90 | N = np.prod(d) 91 | y = lrrelu_cpp.relu_backward(x.view(N), op.view(N), 0, N) 92 | return y.view(d) 93 | 94 | @implements(['aten::relu', 'aten::relu_'], ['newnode', 'multiway_newnode', 'conv_multiway_newnode']) 95 | class BinaryActReLU(OP): 96 | backward_storage = IntermediateStorage(lambda shape: (np.prod(shape)+7)//8) 97 | params = None 98 | inplace = False 99 | 100 | def forward(self, x): 101 | with torch.no_grad(): 102 | if self.inplace: 103 | return x.clamp_min_(0), pack((x > 0).view(-1)) 104 | return x.clamp_min(0), pack((x > 0).view(-1)) 105 | 106 | def intermediate(self, x): 107 | return pack((x > 0).view(-1)) 108 | 109 | def backward(self, x, stored, nodel=False): 110 | with torch.no_grad(): 111 | sign_pack = stored[0] 112 | if not nodel: 113 | del stored[0] 114 | sign = unpack(sign_pack) 115 | del sign_pack 116 | shape = x.shape 117 | x *= sign.view(shape) 118 | return x 119 | 120 | @implements(['aten::nosave_relu_'], ['gist']) 121 | class NoSaveReLUIP(OP): 122 | backward_storage = OutputStorage() 123 | params = None 124 | inplace = True 125 | 126 | def forward(self, x): 127 | with torch.no_grad(): 128 | out = torch.relu_(x) 129 | cip, col, row = compress_csr_256(out) 130 | intcol = col.to(torch.uint8) 131 | del col 132 | params = x.shape 133 | return out, (cip, intcol, row) 134 | 135 | def backward(self, x, stored, nodel=False): 136 | with torch.no_grad(): 137 | ip = stored[0] 138 | if not nodel: 139 | del stored[0] 140 | d = list(ip.shape) 141 | N = np.prod(d) 142 | y = lrrelu_cpp.relu_backward(x.view(N), ip.view(N), 0, N) 143 | return y.view(d) 144 | 145 | @implements(['aten::nosave_relu'], ['gist']) 146 | class NoSaveReLU(OP): 147 | backward_storage = InputStorage(0) 148 | params = None 149 | inplace = False 150 | 151 | def forward(self, x): 152 | with torch.no_grad(): 153 | out = torch.relu(x) 154 | cip, col, row = compress_csr_256(out) 155 | intcol = col.to(torch.uint8) 156 | # params = x.shape 157 | return out, [cip, intcol, row] 158 | 159 | def backward(self, x, stored, nodel=False): 160 | with torch.no_grad(): 161 | ip = stored[0] 162 | if not nodel: 163 | del stored[0] 164 | d = list(ip.shape) 165 | N = np.prod(d) 166 | y = lrrelu_cpp.relu_backward(x.view(N), ip.view(N), 0, N) 167 | return y.view(d) 168 | 169 | 170 | @implements(['aten::savesign_relu'], ['gist']) 171 | class GistBinaryActReLU(OP): 172 | backward_storage = IntermediateStorage(lambda shape: (np.prod(shape)+7)//8) 173 | params = None 174 | inplace = False 175 | 176 | def forward(self, x): 177 | with torch.no_grad(): 178 | if self.inplace: 179 | return x.clamp_min_(0), pack((x > 0).view(-1)) 180 | return x.clamp_min(0), pack((x > 0).view(-1)) 181 | 182 | def intermediate(self, x): 183 | return pack((x > 0).view(-1)) 184 | 185 | def backward(self, x, stored, nodel=False): 186 | with torch.no_grad(): 187 | sign_pack = stored[0] 188 | if not nodel: 189 | del stored[0] 190 | sign = unpack(sign_pack) 191 | del sign_pack 192 | shape = x.shape 193 | x *= sign.view(shape) 194 | return x 195 | 196 | @implements(['aten::savesign_relu_'], ['gist']) 197 | class GistBinaryActReLU(OP): 198 | backward_storage = IntermediateStorage(lambda shape: (np.prod(shape)+7)//8) 199 | params = None 200 | inplace = True 201 | 202 | def forward(self, x): 203 | with torch.no_grad(): 204 | if self.inplace: 205 | return x.clamp_min_(0), pack((x > 0).view(-1)) 206 | return x.clamp_min(0), pack((x > 0).view(-1)) 207 | 208 | def intermediate(self, x): 209 | return pack((x > 0).view(-1)) 210 | 211 | def backward(self, x, stored, nodel=False): 212 | with torch.no_grad(): 213 | sign_pack = stored[0] 214 | if not nodel: 215 | del stored[0] 216 | sign = unpack(sign_pack) 217 | del sign_pack 218 | shape = x.shape 219 | x *= sign.view(shape) 220 | return x -------------------------------------------------------------------------------- /monet/lm_ops/defaultconv.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .pack import * 3 | import torch 4 | 5 | conv_cpp = load(name="conv_cpp", sources=[this_dir/"conv.cpp"], extra_cflags=['-std=c++17'], extra_include_paths=[str(this_dir)], with_cuda=True) 6 | 7 | 8 | @implements(['aten::_convolution'], ['none']) 9 | class PytorchConvolution(OP): 10 | backward_storage = InputStorage(0, 1) 11 | 12 | def forward(self, input_, weight, bias, stride, padding, dilation, transposed, output_padding, groups, *args, **kwargs): 13 | assert not transposed 14 | self.params = stride, padding, dilation, groups, input_.shape, weight.shape 15 | self.do_grad = [input_.requires_grad, weight.requires_grad, bias is not None and bias.requires_grad] 16 | with torch.no_grad(): 17 | output = conv_cpp.forward_normal(input_, weight, stride, padding, dilation, groups) 18 | return output + bias.view((1, -1, 1, 1)) 19 | 20 | def backward(self, grad_output, stored): 21 | input_, weight = stored 22 | stride, padding, dilation, groups, input_shape, weight_shape = self.params 23 | di = dw = db = None 24 | if input_ is not None: 25 | if self.do_grad[1]: 26 | dw = conv_cpp.backward_weight_normal(weight_shape, grad_output, input_, stride, padding, dilation, groups) 27 | 28 | if weight is not None: 29 | if self.do_grad[0]: 30 | di = conv_cpp.backward_input_normal(input_shape, grad_output, weight, stride, padding, dilation, groups) 31 | 32 | if self.do_grad[2]: 33 | db = grad_output.sum([0, 2, 3]) 34 | 35 | return di, dw, db 36 | 37 | 38 | @implements(['aten::_convolution'], ['normal', 'multiway_newnode', 'multiway', 'newnode']) 39 | class DefaultAlgoConvolution(OP): 40 | backward_storage = InputStorage(0, 1) 41 | params = None 42 | algorithm = -1 43 | is_depthwise = False 44 | 45 | def forward(self, input_, weight, bias, stride, padding, dilation, transposed, output_padding, groups, *args, **kwargs): 46 | with torch.no_grad(): 47 | assert not transposed 48 | self.params = stride, padding, dilation, groups, input_.shape, weight.shape, [input_.requires_grad, weight.requires_grad, bias is not None and bias.requires_grad] 49 | algorithm = -1 50 | if self.is_depthwise: 51 | if bias == None: 52 | return conv_cpp.convolution_main(input_, weight, torch.tensor(1), stride, padding, dilation, transposed, output_padding, groups) 53 | return conv_cpp.convolution_main(input_, weight, bias, stride, padding, dilation, transposed, output_padding, groups) 54 | else: 55 | input_ = input_.detach() 56 | weight = weight.detach() 57 | output = conv_cpp.cudnn_convolution(input_, weight, padding, stride, dilation, groups, algorithm) 58 | if bias is not None: 59 | output[:] += bias.view((1, -1, 1, 1)) 60 | return output 61 | 62 | def backward(self, grad_output, stored, nodel=False): 63 | with torch.no_grad(): 64 | input_, weight = stored 65 | stride, padding, dilation, groups, input_shape, weight_shape, do_grad = self.params 66 | algorithm = self.algorithm 67 | algo = -1 68 | convtype = algorithm // 10 69 | # Delete the stored inputs 70 | if not nodel: 71 | del stored[1] 72 | del stored[0] 73 | di = dw = db = None 74 | 75 | if self.is_depthwise: 76 | di, dw = conv_cpp.backward_depthwise(grad_output, input_, weight, (weight.shape[2], weight.shape[3]), stride, padding, dilation, do_grad[:2]) 77 | else: 78 | if convtype == 0 or convtype == 2: 79 | if input_ is not None: 80 | if do_grad[1]: 81 | input_detached = input_.detach() 82 | dw = conv_cpp.cudnn_convolution_backward_weight(weight_shape, grad_output, input_detached, padding, stride, dilation, groups, algo) 83 | if do_grad[2]: 84 | db = grad_output.sum([0, 2, 3]) 85 | del input_, input_detached 86 | 87 | if convtype == 1 or convtype == 2: 88 | if weight is not None: 89 | if do_grad[0]: 90 | weight_detached = weight.detach() 91 | di = conv_cpp.cudnn_convolution_backward_input(input_shape, grad_output, weight_detached, padding, stride, dilation, groups, algo) 92 | 93 | if do_grad[2]: 94 | return di, dw, db 95 | return di, dw 96 | 97 | @implements(['aten::_convolution'], ['conv_multiway_newnode', 'conv_normal']) 98 | class SpecificAlgoConvolution(OP): 99 | backward_storage = InputStorage(0, 1) 100 | params = None 101 | algorithm = -1 102 | is_depthwise = False 103 | 104 | def forward(self, input_, weight, bias, stride, padding, dilation, transposed, output_padding, groups, *args, **kwargs): 105 | with torch.no_grad(): 106 | assert not transposed 107 | self.params = stride, padding, dilation, groups, input_.shape, weight.shape, [input_.requires_grad, weight.requires_grad, bias is not None and bias.requires_grad] 108 | if self.is_depthwise: 109 | if bias == None: 110 | return conv_cpp.convolution_main(input_, weight, torch.tensor(1), stride, padding, dilation, transposed, output_padding, groups) 111 | return conv_cpp.convolution_main(input_, weight, bias, stride, padding, dilation, transposed, output_padding, groups) 112 | else: 113 | algorithm = self.algorithm 114 | input_ = input_.detach() 115 | weight = weight.detach() 116 | output = conv_cpp.cudnn_convolution(input_, weight, padding, stride, dilation, groups, algorithm) 117 | if bias is not None: 118 | output[:] += bias.view((1, -1, 1, 1)) 119 | return output 120 | 121 | def backward(self, grad_output, stored, nodel=False): 122 | with torch.no_grad(): 123 | input_, weight = stored 124 | stride, padding, dilation, groups, input_shape, weight_shape, do_grad = self.params 125 | algorithm = self.algorithm 126 | algo = algorithm % 10 127 | convtype = algorithm // 10 128 | # Delete the stored inputs 129 | if not nodel: 130 | del stored[1] 131 | del stored[0] 132 | di = dw = db = None 133 | 134 | if self.is_depthwise: 135 | di, dw = conv_cpp.backward_depthwise(grad_output, input_, weight, (weight.shape[2], weight.shape[3]), stride, padding, dilation, do_grad[:2]) 136 | else: 137 | if convtype == 0 or convtype == 2: 138 | if input_ is not None: 139 | if do_grad[1]: 140 | input_detached = input_.detach() 141 | dw = conv_cpp.cudnn_convolution_backward_weight(weight_shape, grad_output, input_detached, padding, stride, dilation, groups, algo) 142 | if do_grad[2]: 143 | db = grad_output.sum([0, 2, 3]) 144 | del input_, input_detached 145 | 146 | if convtype == 1 or convtype == 2: 147 | if weight is not None: 148 | if do_grad[0]: 149 | weight_detached = weight.detach() 150 | di = conv_cpp.cudnn_convolution_backward_input(input_shape, grad_output, weight_detached, padding, stride, dilation, groups, algo) 151 | 152 | if do_grad[2]: 153 | return di, dw, db 154 | return di, dw 155 | 156 | @staticmethod 157 | def n_fwd_algos(): 158 | return conv_cpp.n_fwd_algos() 159 | 160 | @staticmethod 161 | def n_bwd_ip_algos(): 162 | return conv_cpp.n_bwd_ip_algos() 163 | 164 | @staticmethod 165 | def n_bwd_wt_algos(): 166 | return conv_cpp.n_bwd_wt_algos() 167 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MONeT: Memory Optimization for Deep Networks 2 | 3 | Implemented over PyTorch, MONeT schedules allow training deep networks on a constrained memory budget with minimal computational overhead. MONeT jointly determines checkpointing as well as operator implementations, reducing GPU memory by as much as 3x with a compute overhead of 9-16%. 4 | 5 |
6 | 7 |
8 | 9 | > **Memory Optimization for Deep Networks**
10 | > Aashaka Shah, Chao-Yuan Wu, Jayashree Mohan, Vijay Chidambaram, Philipp Krähenbühl
11 | > In ICLR 2021 [[paper]](https://openreview.net/pdf?id=bnY0jm4l59) 12 | 13 | 14 | ## Installation 15 | MONeT has been tested with PyTorch 1.5.1, torchvision 0.6.1, and cudatoolkit 10.1. Create a conda environment with python 3.7 or greater. Inside the environment, install the following packages: `cvxpy`, `gurobi`, `pandas`, `ninja-build`, `coinor-cbc`, `coinor-libcbc-dev`, `cylp`. 16 | 17 | [install.sh](install.sh) provides the installation script. 18 | 19 | Clone this repo and install the package. Ensure that the conda environment is activated. 20 | ``` 21 | git clone --recursive https://github.com/utsaslab/MONeT 22 | cd MONeT 23 | pip install -e . 24 | ``` 25 | 26 | ## Getting Started 27 | 28 | ### MONeT usage 29 | MONeT has been tested for single-GPU training and single-machine multi-GPU Distributed Data Parallel training. To get started with MONeT using solutions in the schedule zoo, add the following imports to your code: 30 | 31 | ``` 32 | from monet.cvxpy_solver import Solution 33 | from monet.monet_wrapper import MONeTWrapper 34 | ``` 35 | 36 | Wrap your model using a MONeTWrapper 37 | ``` 38 | monet_model = MONeTWrapper(model, solution_file, input_shape) 39 | ``` 40 | 41 | Use the model like you normally would 42 | ``` 43 | output = monet_model(input) # Forward pass 44 | output.sum().backward() # Backward pass 45 | ``` 46 | 47 | A working version of this code can be found at [examples/training.py](examples/training.py). 48 | 49 | For Distributed Data Parallel training, `monet_model` can be wrapped by `torch.nn.parallel.DistributedDataParallel` like any other model. 50 | A working distributed training code can be found at [examples/dist_training.py](examples/dist_training.py). 51 | 52 | The [examples/imagenet.py](examples/imagenet.py) has been modified to use MONeT schedules for ImageNet training. 53 | ``` 54 | python imagenet.py DATA_DIR -a [arch] --gpu 0 \ 55 | --epochs [num_epochs] \ 56 | --batch-size [batch_size] \ 57 | --solution_file [path to solution file] 58 | ``` 59 | 60 | At higher batch sizes, it is possible that the PyTorch memory allocator outputs an Out-Of-Memory error even if the schedule executed should run without any issues. This is because of the caching-nature of the memory allocator. Please create a pool of expected memory usage before allocating any tensors for the training using the following code snippet: 61 | 62 | ``` 63 | pool = torch.zeros(expected_memory/4).cuda() 64 | del pool 65 | ``` 66 | 67 | ## Schedule zoo 68 | We have already created some schedules which can be used right off the bat. 69 | Simply install MONeT, modify your training similar to [examples/imagenet.py](examples/imagenet.py), and use the memory efficient schedules for training! 70 | The schedule zoo is hosted in the `data` directory. 71 | You can use the results below to pick the right schedule according to your requirements. 72 | 73 | A solution [solution_resnet50_184_inplace_conv_multiway_newnode_10.00.pkl](https://github.com/aashaka/monet-schedules/blob/master/monet_r50_184_24hr/solution_resnet50_184_inplace_conv_multiway_newnode_10.00.pkl) uses 10 GB memory for training ResNet-50 with a batch size of 184, and according to the results, has a 3.22% overhead over the original PyTorch implementation which uses 15.06 GB memory. 74 | 75 | ## Results 76 | 77 | | ResNet-50 (184) | Memory (GB) | Compute Overhead (%) | 78 | |-----------|-------------|----------------------| 79 | | PyTorch | 15.06 | 0 | 80 | | MONeT | 10.01 | 3.22% | 81 | | MONeT | 9.01 | 4.68% | 82 | | MONeT | 8.01 | 5.56% | 83 | | MONeT | 6.99 | 7.28% | 84 | | MONeT | 6.00 | 9.31% | 85 | | MONeT | 4.99 | 11.95% | 86 | 87 | | GoogleNet (320) | Memory (GB) | Compute Overhead (%) | 88 | |-----------|-------------|----------------------| 89 | | PyTorch | 14.93 | 0 | 90 | | MONeT | 9.98 | 7.13% | 91 | | MONeT | 8.99 | 7.87% | 92 | | MONeT | 8.01 | 8.44% | 93 | | MONeT | 7.02 | 9.71% | 94 | | MONeT | 6.01 | 12.14% | 95 | | MONeT | 4.99 | 15.77% | 96 | 97 | | UNet (11) | Memory (GB) | Compute Overhead (%) | 98 | |---------|-------------|----------------------| 99 | | PyTorch | 14.32 | 0 | 100 | | MONeT | 10.01 | -4.10% | 101 | | MONeT | 9.01 | -2.07% | 102 | | MONeT | 8.02 | -0.09% | 103 | | MONeT | 7.00 | 1.39% | 104 | | MONeT | 6.01 | 4.95% | 105 | | MONeT | 5.01 | 11.51% | 106 | 107 | | Mobilenet (272) | Memory (GB) | Compute Overhead (%) | 108 | |-----------|-------------|----------------------| 109 | | PyTorch | 14.46 | 0 | 110 | | MONeT | 10.02 | 2.40% | 111 | | MONeT | 9.01 | 3.10% | 112 | | MONeT | 8.02 | 4.77% | 113 | | MONeT | 7.01 | 5.53% | 114 | | MONeT | 6.01 | 7.55% | 115 | | MONeT | 5.01 | 8.72% | 116 | 117 | | VGG-16 (176) | Memory (GB) | Compute Overhead (%) | 118 | |---------|-------------|----------------------| 119 | | PyTorch | 14.12 | 0 | 120 | | MONeT | 9.71 | -5.30% | 121 | | MONeT | 8.66 | -4.64% | 122 | | MONeT | 7.88 | -2.18% | 123 | | MONeT | 6.82 | 1.99% | 124 | | MONeT | 5.90 | 5.44% | 125 | | MONeT | 5.51 | 9.11% | 126 | 127 | 128 | ## Advanced MONeT usage 129 | Obtain the Gurobi academic license from the Gurobi [website](https://www.gurobi.com/downloads/end-user-license-agreement-academic/). Login with a .edu email to get the free license. 130 | 131 | 1. To create a MONeT solution: 132 | ``` 133 | python cvxpy_solver.py MODEL BATCH_SIZE BUDGET MODE "GUROBI" --time_limit TIME_LIMIT 134 | ``` 135 | 136 | MODEL format: `"torchvision.models.()"`. For UNeT, the format is `"unet"`.
137 | BUDGET is the memory budget in GB
138 | MODE is "inplace_conv_multiway_newnode" for complete MONeT
139 | TIME_LIMIT is the solver time limit in seconds
140 | The flag `--ablation` can be added to disable checkpointing when creating a solution. 141 | 142 | 2. To profile a MONeT schedule given a solution: 143 | ``` 144 | python schedule.py MODEL BATCH_SIZE BUDGET MODE "GUROBI" \ 145 | --run_bs --solution_file SOLUTION_FILE 146 | ``` 147 | The flag `--run_bs` can be replaced by `--check_runtime` to check the runtime of the schedule or `--check_diff` to check the gradients of MONeT against original PyTorch. 148 | 149 | 150 | Other modes may be used for experimenting with MONeT: 151 | - `inplace_` prefix enables operator optimization 152 | - `conv_normal` selects conv-optimization 153 | - `multiway` selects output-activated optimization 154 | - `newnode` selects intermediate-activate optimization 155 | 156 | Refer the paper for details about the optimizations. 157 | 158 | ## Citation 159 | If you use MONeT in your work, please consider citing us as 160 | 161 | ``` 162 | @misc{shah2020memory, 163 | title={Memory Optimization for Deep Networks}, 164 | author={Aashaka Shah and Chao-Yuan Wu and Jayashree Mohan and Vijay Chidambaram and Philipp Krähenbühl}, 165 | year={2020}, 166 | eprint={2010.14501}, 167 | archivePrefix={arXiv}, 168 | primaryClass={cs.LG} 169 | } 170 | ``` 171 | 172 | ## Acknowledgements 173 | The code for UNeT is taken from [Pytorch-UNet](https://github.com/milesial/Pytorch-UNet) by [milesial](https://github.com/milesial). Distributed Data Parallel training example code is borrowed from the [distributed tutorial](https://github.com/yangkky/distributed_tutorial) by [yangkky](https://github.com/yangkky). 174 | -------------------------------------------------------------------------------- /gist/gist_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import lru_cache 3 | 4 | 5 | class Node: 6 | def __init__(self, shape): 7 | self.shape = shape 8 | 9 | 10 | class Input(Node): 11 | def __repr__(self): 12 | return '' % str(list(self.shape)) 13 | 14 | 15 | class Param(Node): 16 | def __repr__(self): 17 | return '' % str(list(self.shape)) 18 | 19 | 20 | class ComputeNode(Node): 21 | class Arg: 22 | pass 23 | 24 | class V(Arg): 25 | def __init__(self, v): 26 | self.value = v 27 | 28 | def __repr__(self): 29 | return '' % str(self.value) 30 | 31 | class D(Arg): 32 | def __init__(self, index, requires_grad=False): 33 | self.index = index 34 | self.requires_grad = requires_grad 35 | 36 | def __repr__(self): 37 | return '' % (self.index, self.requires_grad) 38 | 39 | def __init__(self, shape, nodeid, op, args, has_backward, is_depthwise=False, compress_conv=-1): 40 | super().__init__(shape) 41 | self._op = op 42 | self._args = args 43 | self.id = nodeid 44 | self._has_backward = has_backward 45 | self._is_depthwise = is_depthwise 46 | self.compress_conv = compress_conv 47 | 48 | @property 49 | @lru_cache(maxsize=512) 50 | def op(self): 51 | return self._op 52 | 53 | def clear(self): 54 | ComputeNode.op.fget.cache_clear() 55 | 56 | @property 57 | def args(self): 58 | return self._args 59 | 60 | @property 61 | def has_backward(self): 62 | return self._has_backward 63 | 64 | @property 65 | @lru_cache(maxsize=512) 66 | def is_depthwise(self): 67 | return self._is_depthwise 68 | 69 | @property 70 | @lru_cache(maxsize=128) 71 | def dependencies(self): 72 | return [(a.index, a.requires_grad) for a in self._args if isinstance(a, self.D)] 73 | 74 | def __repr__(self): 75 | return '' % (str(self._op), str(list(self.shape))) 76 | 77 | 78 | class Graph: 79 | def __init__(self): 80 | self._nodes = [] 81 | self._outputs = [] 82 | 83 | def _add_node(self, node): 84 | self._nodes.append(node) 85 | return len(self._nodes)-1 86 | 87 | def _add_input(self, shape): 88 | return self._add_node(Input(shape)) 89 | 90 | def _add_param(self, shape): 91 | return self._add_node(Param(shape)) 92 | 93 | def _add_op(self, shape, op, args, has_backward=False, is_depthwise=False, compress_conv=-1): 94 | nodeid = len(self._nodes) 95 | return self._add_node(ComputeNode(shape, nodeid, op, args, has_backward, is_depthwise, compress_conv)) 96 | 97 | def _add_output(self, output_id): 98 | self._outputs.append(output_id) 99 | 100 | @property 101 | def nodes(self): 102 | return self._nodes 103 | 104 | @classmethod 105 | def create(cls, model, input_shape=(3, 224, 224)): 106 | # create a graph of the forward pass 107 | # JIT trace the model 108 | args = (torch.ones((23,) + input_shape),) 109 | graph, torch_out = torch.jit._get_trace_graph(model, args, _force_outplace=False, _return_inputs_states=False) 110 | torch._C._jit_pass_constant_propagation(graph) 111 | torch._C._jit_pass_inline(graph) 112 | torch._C._jit_pass_dce(graph) 113 | torch._C._jit_pass_lint(graph) 114 | params = torch.jit._unique_state_dict(model) 115 | 116 | assert len(list(graph.inputs())) == len(args) + len(params) 117 | node_id = {} 118 | r = cls() 119 | arg_and_param_shape = [list(a.shape) for a in args] + [list(p.shape) for p in params.values()] 120 | for k, i in enumerate(graph.inputs()): 121 | if k < len(args): 122 | node_id[i.unique()] = r._add_input([-1]+arg_and_param_shape[k][1:]) 123 | else: 124 | node_id[i.unique()] = r._add_param(arg_and_param_shape[k]) 125 | 126 | const = {} 127 | 128 | # Track connected nodes in the graph 129 | track = set() 130 | track.add("input.1") 131 | for node in graph.nodes(): 132 | if node.kind()!="aten::size": 133 | for ip in node.inputs(): 134 | if ip.debugName() in track or "input" in ip.debugName(): 135 | track.add(node.output().debugName()) 136 | if "input" in node.output().debugName(): 137 | track.add(node.output().debugName()) 138 | 139 | list_contents = {} 140 | for n in graph.nodes(): 141 | compress_conv = -1 # Compress storage of conv input because it is a ReLU output 142 | is_gist_mp = -1 # Compress max_pool indices if its input is a ReLU output 143 | assert n.kind() != 'prim::GetAttr' 144 | if n.kind() == 'prim::Constant': 145 | const[n.output().unique()] = n['value'] if n.hasAttribute('value') else None 146 | elif len(n.kind()) > 6 and n.kind()[:6] == 'aten::': 147 | args = [] 148 | for i in n.inputs(): 149 | iu = i.unique() 150 | if iu in list_contents: 151 | iu_list = list_contents[iu] 152 | else: 153 | iu_list = [iu] 154 | for iu in iu_list: 155 | if iu in const: 156 | args.append(ComputeNode.V(const[iu])) 157 | elif iu in node_id: 158 | if i.debugName() not in track and (not isinstance(r._nodes[node_id[iu]], Input)) and (not isinstance(r._nodes[node_id[iu]], Param)): # Doing this for addmm and transpose 159 | for ii in i.node().inputs(): 160 | iiu = ii.unique() 161 | assert (isinstance(r._nodes[node_id[iiu]], Input) or isinstance(r._nodes[node_id[iiu]], Param)) == True 162 | args.append(ComputeNode.D(node_id[iiu], ii.requires_grad())) 163 | else: 164 | if n.kind() == "aten::_convolution" and isinstance(r._nodes[node_id[iu]], ComputeNode) and r._nodes[node_id[iu]].op == 'aten::relu_': 165 | compress_conv = node_id[iu] 166 | r._nodes[node_id[iu]]._op = 'aten::nosave_relu_' 167 | r._nodes[node_id[iu]].clear() 168 | elif n.kind() == "aten::_convolution" and isinstance(r._nodes[node_id[iu]], ComputeNode) and r._nodes[node_id[iu]].op == 'aten::relu': 169 | compress_conv = node_id[iu] 170 | r._nodes[node_id[iu]]._op = 'aten::nosave_relu' 171 | r._nodes[node_id[iu]].clear() 172 | elif n.kind() == "aten::max_pool2d" and isinstance(r._nodes[node_id[iu]], ComputeNode) and r._nodes[node_id[iu]].op == 'aten::relu_': 173 | r._nodes[node_id[iu]]._op = 'aten::savesign_relu_' 174 | is_gist_mp = 1 175 | r._nodes[node_id[iu]].clear() 176 | elif n.kind() == "aten::max_pool2d" and isinstance(r._nodes[node_id[iu]], ComputeNode) and r._nodes[node_id[iu]].op == 'aten::relu': 177 | r._nodes[node_id[iu]]._op = 'aten::savesign_relu' 178 | is_gist_mp = 1 179 | r._nodes[node_id[iu]].clear() 180 | args.append(ComputeNode.D(node_id[iu], i.requires_grad())) 181 | else: 182 | raise ValueError('Nodes %s disconnected' % repr(i)) 183 | has_backward = False 184 | if n.output().debugName() in track: 185 | has_backward = True 186 | # Identify depthwise conv 187 | is_depthwise = False 188 | if n.kind() == "aten::_convolution": 189 | assert isinstance(args[8], ComputeNode.V) 190 | if args[8].value > 1 and args[8].value == r.nodes[args[0].index].shape[1]: 191 | is_depthwise = True 192 | # Add node to graph 193 | kind = n.kind() 194 | if compress_conv != -1 or is_gist_mp != -1: 195 | kind = kind[:6] + 'gist_' + kind[6:] 196 | node_id[n.output().unique()] = r._add_op([s if s != 23 else -1 for s in n.output().type().sizes()], 197 | kind, args, has_backward, is_depthwise, compress_conv) 198 | 199 | elif n.kind() in ['prim::ListConstruct', 'prim::TupleConstruct']: 200 | list_contents[n.output().unique()] = [i.unique() for i in n.inputs()] 201 | else: 202 | print('Unknown OP', n.kind(), n) 203 | # Identify outputs 204 | for op in graph.outputs(): 205 | if op.node().kind()[:6] == 'aten::': 206 | r._add_output(node_id[op.unique()]) 207 | elif op.node().kind() == 'prim::TupleConstruct': 208 | for i in op.node().inputs(): 209 | r._add_output(node_id[i.unique()]) 210 | return r 211 | -------------------------------------------------------------------------------- /monet/lm_ops/maxpool.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit 11 | #define BLOCK_STRIDE 2 // increasing block_stride to lower # of blocks launched 12 | 13 | static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) { 14 | return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1; 15 | } 16 | 17 | static __device__ inline int origp_start(int size, int pad, int kernel, int dilation, int stride) { 18 | return (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1; 19 | } 20 | 21 | static __device__ inline int p_end(int size, int pad, int pooled_size, int stride) { 22 | return min((size + pad) / stride + 1, pooled_size); 23 | } 24 | 25 | // kernels borrowed from Caffe 26 | __global__ void max_pool_forward_nchw(const int nthreads, const float* bottom_data, 27 | const int num, const int channels, const int height, 28 | const int width, const int pooled_height, const int pooled_width, 29 | const int kernel_h, const int kernel_w, const int stride_h, 30 | const int stride_w, const int pad_h, const int pad_w, 31 | const int dilation_h, const int dilation_w, float* top_data, 32 | uint8_t* top_mask) { 33 | CUDA_KERNEL_LOOP(index, nthreads) { 34 | int pw = index % pooled_width; 35 | int ph = (index / pooled_width) % pooled_height; 36 | int c = (index / pooled_width / pooled_height) % channels; 37 | int n = index / pooled_width / pooled_height / channels; 38 | int hstart = ph * stride_h - pad_h; 39 | int wstart = pw * stride_w - pad_w; 40 | int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height); 41 | int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width); 42 | uint8_t count_h = 0; 43 | uint8_t count_w = 0; 44 | while(hstart < 0) { 45 | hstart += dilation_h; 46 | count_h++; 47 | } 48 | while(wstart < 0) { 49 | wstart += dilation_w; 50 | count_w++; 51 | } 52 | uint8_t ch = count_h; 53 | uint8_t cw = count_w; 54 | float maxval = at::numeric_limits::lower_bound(); // -Infinity 55 | int maxidx = hstart * width + wstart; 56 | bottom_data += (n * channels + c) * height * width; 57 | for (int h = hstart; h < hend; h += dilation_h) { 58 | for (int w = wstart; w < wend; w += dilation_w) { 59 | float val = bottom_data[h * width + w]; 60 | if ((ScalarConvert::to(val) > maxval) || THCNumerics::isnan(val)) { 61 | maxidx = ch*(uint8_t)kernel_w + cw; 62 | maxval = ScalarConvert::to(val); 63 | } 64 | cw += 1; 65 | } 66 | cw = count_w; 67 | ch += 1; 68 | } 69 | top_data[index] = ScalarConvert::to(maxval); 70 | top_mask[index] = maxidx; 71 | } 72 | } 73 | 74 | static const int BLOCK_THREADS = 256; 75 | 76 | #if defined (__HIP_PLATFORM_HCC__) 77 | C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 4) 78 | #else 79 | C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 8) 80 | #endif 81 | __global__ void max_pool_backward_nchw(const int nthreads, const float* top_diff, 82 | const uint8_t* top_mask, const int num, const int channels, 83 | const int height, const int width, const int pooled_height, 84 | const int pooled_width, const int kernel_h, const int kernel_w, 85 | const int stride_h, const int stride_w, const int pad_h, const int pad_w, 86 | const int dilation_h, const int dilation_w, 87 | float* bottom_diff) { 88 | // printf("Thread: %d, %d, %d, Block: %d, %d, %d\n", threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z); 89 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (height*width); index += blockDim.x * gridDim.x) { 90 | int h = index/width; 91 | int w = index - h * width; 92 | int phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h); 93 | int origphstart = origp_start(h, pad_h, kernel_h, dilation_h, stride_h); 94 | int phend = p_end(h, pad_h, pooled_height, stride_h); 95 | int pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w); 96 | int origpwstart = origp_start(w, pad_w, kernel_w, dilation_w, stride_w); 97 | int pwend = p_end(w, pad_w, pooled_width, stride_w); 98 | for (int n = blockIdx.y; n < num; n += gridDim.y) 99 | for (int c = blockIdx.z; c < channels; c+= gridDim.z) { 100 | 101 | float gradient = 0; 102 | int offset = (n * channels + c) * pooled_height * pooled_width; 103 | // top_diff += offset; 104 | // top_mask += offset; 105 | uint8_t this_element_idx; 106 | 107 | if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) { 108 | for (int ph = phstart; ph < phend; ++ph) { 109 | for (int pw = pwstart; pw < pwend; ++pw) { 110 | // This works for very few cases 111 | this_element_idx = (uint8_t)(((h-ph*stride_h+pad_h)/dilation_h)*kernel_w + (w-pw*stride_w+pad_w)/dilation_w); 112 | // printf("%d, %d, %d, %d, %d, %d \n", h,w,this_element_idx,ph,pw,ph * pooled_width + pw); 113 | if (top_mask[offset + ph * pooled_width + pw] == this_element_idx) { 114 | // printf("%d, %d, %d, %d, %d, %d, %d\n", h,w,ph,pw, (int32_t)this_element_idx, top_mask[ph * pooled_width + pw], ph * pooled_width + pw); 115 | gradient += ScalarConvert::to(top_diff[offset + ph * pooled_width + pw]); 116 | } 117 | } 118 | } 119 | } else { 120 | this_element_idx = (uint8_t)(((h - phstart * stride_h + pad_h)/dilation_h)*kernel_w + (w - pwstart*stride_w + pad_w)/dilation_w); 121 | // printf("%d, %d, %d, %d, %d, %d \n", h,w,this_element_idx,phstart,pwstart,phstart * pooled_width + pwstart); 122 | if (top_mask[offset + phstart * pooled_width + pwstart] == this_element_idx) { 123 | // printf("%d, %d, %d, %d, %d, %d, %d\n", h,w,phstart,pwstart, (int32_t)this_element_idx, top_mask[phstart * pooled_width + pwstart], phstart * pooled_width + pwstart); 124 | gradient += ScalarConvert::to(top_diff[offset+phstart * pooled_width + pwstart]); 125 | } 126 | } 127 | //printf("%d\n",(n*channels+c)*height*width+index); 128 | bottom_diff[(n*channels+c)*height*width+index] = ScalarConvert::to(gradient); 129 | } 130 | } 131 | } 132 | 133 | #define MAX_THREADS_PER_BLOCK 1024 134 | 135 | void max_pool_forward_nchw_cuda(const int nthreads, const float* bottom_data, 136 | const int num, const int channels, const int height, 137 | const int width, const int pooled_height, const int pooled_width, 138 | const int kernel_h, const int kernel_w, const int stride_h, 139 | const int stride_w, const int pad_h, const int pad_w, 140 | const int dilation_h, const int dilation_w, float* top_data, 141 | uint8_t* top_mask) { 142 | const int num_threads = std::min(MAX_THREADS_PER_BLOCK, 143 | BLOCK_THREADS); 144 | // std::cout<>>( 147 | nthreads, bottom_data, 148 | num, channels, height, 149 | width, pooled_height, pooled_width, 150 | kernel_h, kernel_w, stride_h, 151 | stride_w, pad_h, pad_w, 152 | dilation_h, dilation_w, top_data, 153 | top_mask); 154 | } 155 | 156 | void max_pool_backward_nchw_cuda(const int nthreads, const float* top_diff, 157 | const uint8_t* top_mask, const int num, const int channels, 158 | const int height, const int width, const int pooled_height, 159 | const int pooled_width, const int kernel_h, const int kernel_w, 160 | const int stride_h, const int stride_w, const int pad_h, const int pad_w, 161 | const int dilation_h, const int dilation_w, 162 | float* bottom_diff) { 163 | int imgcount = width * height; 164 | dim3 grid; 165 | const int blocks = (imgcount + BLOCK_THREADS - 1) / BLOCK_THREADS; 166 | grid.x = blocks; 167 | grid.y = num; 168 | uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; 169 | if (maxGridY < grid.y) grid.y = maxGridY; 170 | grid.z = channels; 171 | uint64_t maxGridZ = at::cuda::getCurrentDeviceProperties()->maxGridSize[2]; 172 | if (maxGridZ < grid.z) grid.z = maxGridZ; 173 | 174 | // printf("grid.x %lld, grid.y %lld, grid.z %lld, maxY %lld, maxZ %lld\n \ 175 | // nthreads %lld, num %lld, blocks %d\n \ 176 | // c %lld, h %lld, w %lld, ph %lld, pw %lld,\n \ 177 | // kh %lld, kw %lld, sh %lld, sw %lld, ph %lld, pw %lld, dh %lld, dw %lld\n", grid.x, grid.y, grid.z, maxGridY, maxGridZ, 178 | // nthreads, num, blocks, 179 | // channels, height, width, pooled_height, pooled_width, 180 | // kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,dilation_h, dilation_w); 181 | max_pool_backward_nchw 182 | <<>>( 183 | nthreads, 184 | top_diff, 185 | top_mask, 186 | num, 187 | channels, height, width, pooled_height, pooled_width, 188 | kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, 189 | dilation_h, dilation_w, bottom_diff); 190 | } -------------------------------------------------------------------------------- /monet/lm_ops/lrbn.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // #include 5 | // #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | // #include 11 | #include 12 | #include 13 | // #include "lrbn.cuh" 14 | 15 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 16 | #define CHECK_INPUT(x) CHECK_CONTIGUOUS(x) 17 | 18 | using namespace at; 19 | using namespace at::native; 20 | 21 | std::tuple do_forward( 22 | const torch::Tensor& input, 23 | const torch::Tensor& running_mean, 24 | const torch::Tensor& running_var, 25 | torch::Tensor& weight, 26 | torch::Tensor& bias, 27 | bool training, 28 | torch::optional momentum, 29 | double eps) { 30 | torch::NoGradGuard no_grad_guard; 31 | 32 | return torch::cudnn_batch_norm( 33 | input, 34 | weight, 35 | bias, 36 | running_mean, 37 | running_var, 38 | training, 39 | momentum.value(), 40 | eps); 41 | } 42 | 43 | std::tuple do_cudnn_backward( 44 | const torch::Tensor& input, const torch::Tensor& grad_output, 45 | const torch::Tensor& weight, const torch::Tensor& running_mean, 46 | const torch::Tensor& running_var, const torch::Tensor& save_mean, 47 | const torch::Tensor& save_var, double epsilon, 48 | const torch::Tensor& reservedSpace) { 49 | 50 | // switch (input.suggest_memory_format()) { 51 | // case torch::MemoryFormat::Preserve: 52 | // std::cout << "Preserve\n"; 53 | // break; 54 | // case torch::MemoryFormat::Contiguous: 55 | // std::cout << "Contiguous\n"; 56 | // break; 57 | // case torch::MemoryFormat::ChannelsLast: 58 | // std::cout << "ChannelsLast\n"; 59 | // break; 60 | // default: 61 | // std::cout<<"Unknown memory format\n"; 62 | // break; 63 | // } 64 | 65 | auto& ctx = torch::globalContext(); 66 | return torch::cudnn_batch_norm_backward(input, 67 | grad_output.contiguous(input.suggest_memory_format()), weight, 68 | running_mean, running_var, save_mean, save_var, 69 | epsilon, reservedSpace); 70 | } 71 | 72 | 73 | std::tuple do_native_backward( 74 | const torch::Tensor& grad_out, const torch::Tensor& self, const torch::Tensor& weight, 75 | const torch::Tensor& running_mean, const torch::Tensor& running_var, 76 | const torch::Tensor& save_mean, const torch::Tensor& save_invstd, 77 | bool train, double epsilon, std::array grad_input_mask) { 78 | 79 | return torch::native_batch_norm_backward(grad_out, self, weight, 80 | running_mean, running_var, save_mean, save_invstd, 81 | train, epsilon, grad_input_mask); 82 | } 83 | 84 | template class PtrTraits = torch::DefaultPtrTraits, typename index_t = int64_t> 85 | static torch::GenericPackedTensorAccessor packed_accessor_or_dummy(const Tensor& t) { 86 | if (! t.defined()) { 87 | const std::vector zeros(dim); 88 | return torch::GenericPackedTensorAccessor(nullptr, zeros.data(), zeros.data()); 89 | } 90 | return t.generic_packed_accessor(); 91 | } 92 | 93 | template 94 | void batch_norm_bwd1( 95 | const torch::GenericPackedTensorAccessor input, 96 | const torch::GenericPackedTensorAccessor grad_output, 97 | torch::GenericPackedTensorAccessor grad_input, 98 | torch::GenericPackedTensorAccessor grad_weight, 99 | torch::GenericPackedTensorAccessor grad_bias, 100 | const torch::GenericPackedTensorAccessor weight, 101 | const torch::GenericPackedTensorAccessor bias, 102 | const torch::GenericPackedTensorAccessor running_mean, 103 | const torch::GenericPackedTensorAccessor running_var, 104 | const torch::GenericPackedTensorAccessor save_mean, 105 | const torch::GenericPackedTensorAccessor save_invstd, 106 | bool train); 107 | 108 | template 109 | std::tuple batch_norm_backward_cuda_template1(const Tensor& grad_out_, const Tensor& input_, const Tensor& weight_, const Tensor& bias_, 110 | const Tensor& running_mean_, const Tensor& running_var_, const Tensor& save_mean_, const Tensor& save_invstd_, 111 | bool train, double epsilon, std::array grad_input_mask) { 112 | 113 | using accscalar_t = at::acc_type; 114 | Tensor grad_input_; 115 | Tensor grad_input_reshaped; 116 | Tensor grad_weight_; 117 | Tensor grad_bias_; 118 | auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); 119 | auto grad_output_reshaped = grad_out_.reshape(input_reshaped.sizes()); 120 | 121 | if (grad_input_mask[0]) { 122 | grad_input_ = at::empty_like(input_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 123 | grad_input_reshaped = grad_input_.view(input_reshaped.sizes()); 124 | } 125 | if (grad_input_mask[1]) { 126 | grad_weight_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 127 | } 128 | if (grad_input_mask[2]) { 129 | grad_bias_ = at::empty_like(weight_, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 130 | } 131 | 132 | auto input = input_reshaped.generic_packed_accessor(); 133 | auto grad_output = grad_output_reshaped.generic_packed_accessor(); 134 | auto grad_input = packed_accessor_or_dummy(grad_input_reshaped); 135 | auto weight = packed_accessor_or_dummy(weight_); 136 | auto bias = packed_accessor_or_dummy(bias_); 137 | auto grad_weight = packed_accessor_or_dummy(grad_weight_); 138 | auto grad_bias = packed_accessor_or_dummy(grad_bias_); 139 | auto running_mean = packed_accessor_or_dummy(running_mean_); 140 | auto running_var = packed_accessor_or_dummy(running_var_); 141 | auto save_mean = packed_accessor_or_dummy(save_mean_); 142 | auto save_invstd = packed_accessor_or_dummy(save_invstd_); 143 | 144 | // bnback::batch_norm_bwd1(input, grad_output, grad_input, grad_weight, grad_bias, weight, bias, running_mean, running_var, 145 | // save_mean, save_invstd, train); 146 | batch_norm_bwd1(input, grad_output, grad_input, grad_weight, grad_bias, weight, bias, running_mean, running_var, 147 | save_mean, save_invstd, train); 148 | 149 | AT_CUDA_CHECK(cudaGetLastError()); 150 | 151 | return std::make_tuple(grad_input_, grad_weight_, grad_bias_); 152 | } 153 | 154 | std::tuple output_activated_bn_backward(const Tensor& grad_out, const Tensor& self, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, 155 | const Tensor& save_mean, const Tensor& save_invstd, bool train, double epsilon, std::array grad_input_mask) { 156 | return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "batch_norm_backward_cuda", [&] { 157 | AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "batch_norm_backward_cuda", [&] { 158 | auto mean_st = running_mean.dtype(); 159 | auto var_st = running_var.dtype(); 160 | TORCH_CHECK(mean_st == var_st, "running_mean and running_var need to have the same data types"); 161 | // bool is_half_float = std::is_same::value && mean_st == at::kFloat; 162 | // bool is_bfloat16_float = std::is_same::value && mean_st == at::kFloat; 163 | if (at::cuda::detail::canUse32BitIndexMath(self)) { 164 | // // return batch_norm_backward_cuda_template1(grad_out, self, weight, bias, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); 165 | return batch_norm_backward_cuda_template1(grad_out, self, weight, bias, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); 166 | } else { 167 | // return batch_norm_backward_cuda_template1(grad_out, self, weight, bias, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); 168 | return batch_norm_backward_cuda_template1(grad_out, self, weight, bias, running_mean, running_var, save_mean, save_invstd, train, epsilon, grad_input_mask); 169 | } 170 | }); 171 | }); 172 | } 173 | 174 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 175 | m.def("forward", &do_forward, "BN forward"); 176 | m.def("cudnn_backward", &do_cudnn_backward, "BN backward"); 177 | m.def("native_backward", &do_native_backward, "BN backward"); 178 | m.def("output_activated_bn_backward", &output_activated_bn_backward, "Output activated BN backward"); 179 | } 180 | -------------------------------------------------------------------------------- /monet/lm_ops/funcs.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .pack import * 3 | import torch 4 | 5 | lrfuncs_cpp = load(name="lrfuncs_cpp", sources=[this_dir/"lrfuncs.cpp"], extra_cflags=['-std=c++17']) 6 | 7 | @implements(['aten::adaptive_avg_pool2d'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 8 | class AdaptiveAvgPool2D(OP): 9 | params = None 10 | 11 | def forward(self, x, os): 12 | with torch.no_grad(): 13 | self.params = x.shape, x.device, x.requires_grad 14 | return torch._C._nn.adaptive_avg_pool2d(x, os) 15 | 16 | def backward(self, grad_output,stored, nodel=False): 17 | assert len(stored) == 0 18 | with torch.no_grad(): 19 | shape, device, grad = self.params 20 | ip = torch.zeros(shape, device=device, requires_grad=grad) 21 | return lrfuncs_cpp.adaptive_avg_pool_backward(grad_output,ip) 22 | 23 | @implements(['aten::add'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 24 | class Add(OP): 25 | backward_storage = None 26 | params = None 27 | def forward(self, a, b, s): 28 | with torch.no_grad(): 29 | self.params = s 30 | return torch.add(a, b, alpha=s) 31 | 32 | def backward(self, grad_output, stored, nodel=False): 33 | assert len(stored) == 0 34 | with torch.no_grad(): 35 | if not nodel: 36 | del stored 37 | s = self.params 38 | return grad_output, grad_output*s 39 | 40 | @implements(['aten::add_'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 41 | class AddIP(OP): 42 | backward_storage = None 43 | params = None 44 | inplace = False 45 | 46 | def forward(self, a, b, s): 47 | with torch.no_grad(): 48 | self.params = s 49 | if self.inplace: 50 | return a.add_(b, alpha=s) 51 | return torch.add(a, b, alpha=s) 52 | 53 | def backward(self, grad_output, stored, nodel=False): 54 | assert len(stored) == 0 55 | with torch.no_grad(): 56 | if not nodel: 57 | del stored 58 | s = self.params 59 | return grad_output, grad_output*s 60 | 61 | @implements(['aten::embedding'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 62 | class Embedding(OP): 63 | params = None 64 | backward_storage = InputStorage(1) 65 | 66 | def forward(self, weight, indices, padding_idx, scale_freq, is_sparse): 67 | with torch.no_grad(): 68 | assert not is_sparse, "Will not handle sparse Embedding" 69 | out = lrfuncs_cpp.embedding(weight, indices, padding_idx, scale_freq, is_sparse) 70 | self.params = weight.shape[0], padding_idx, scale_freq#, is_sparse 71 | return out 72 | 73 | def backward(self, grad_output, stores, nodel=False): 74 | with torch.no_grad(): 75 | indices = stores[0] 76 | if not nodel: 77 | del stores[0] 78 | num_weights, padding_idx, scale_freq = self.params 79 | return lrfuncs_cpp.embedding_backward(grad_output, indices, num_weights, padding_idx, scale_freq) 80 | 81 | @implements(['aten::t_matmul'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 82 | class LowMemTMatMul(OP): 83 | params = None 84 | backward_storage = InputStorage(0,1) 85 | 86 | def forward(self, input1, input2): 87 | with torch.no_grad(): 88 | input2_t = input2.t() 89 | out = torch.matmul(input1, input2_t) 90 | out_detached = out.detach() 91 | grad_name = out.grad_fn.next_functions[0][0].name() 92 | assert grad_name == "MmBackward" 93 | save_size2 = input2_t.size() 94 | save_stride2 = input2_t.stride() 95 | 96 | assert input1.dim() == 3 and input2.dim()==2 97 | input1_contiguous = input1.contiguous().view(-1, input1.shape[2]) 98 | save_size1 = input1_contiguous.size() 99 | save_stride1 = input1_contiguous.stride() 100 | 101 | self.params = grad_name, save_size1, save_size2, save_stride1, save_stride2, (input1.requires_grad, input2_t.requires_grad) 102 | del out 103 | return out_detached 104 | 105 | def backward(self, grad_output, stores, nodel=False): 106 | with torch.no_grad(): 107 | input1 = stores[0] 108 | input2 = stores[1] 109 | if not nodel: 110 | del stores[1] 111 | del stores[0] 112 | grad_name, size1, size2, stride1, stride2, do_grad = self.params 113 | dw, di = None, None 114 | if do_grad[1]: 115 | dw = lrfuncs_cpp.mm_mat2_backward(grad_output.view(-1,grad_output.shape[-1]), input1.view(size1), size2, stride2).t() 116 | del input1 117 | if do_grad[0]: 118 | di = lrfuncs_cpp.mm_mat1_backward(grad_output.view(-1,grad_output.shape[-1]), input2.t(), size1, stride1) 119 | return di, dw 120 | 121 | @implements(['aten::matmul'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 122 | class LowMemMatMul(OP): 123 | params = None 124 | backward_storage = InputStorage(0,1) 125 | 126 | def forward(self, input1, input2): 127 | out = torch.matmul(input1, input2) 128 | out_detached = out.detach() 129 | grad_name = out.grad_fn.next_functions[0][0].name() 130 | 131 | save_size1 = input1.size() 132 | save_size2 = input2.size() 133 | save_stride1 = input1.stride() 134 | save_stride2 = input2.stride() 135 | 136 | if input1.dim() == 3 and input2.dim()==2: 137 | input1_contiguous = input1.contiguous().view(-1, input1.shape[2]) 138 | save_size1 = input1_contiguous.size() 139 | save_stride1 = input1_contiguous.stride() 140 | elif input1.dim() == 4 and input2.dim() == 4: 141 | save_size1 = (input1.shape[0]*input1.shape[1],input1.shape[2], input1.shape[3]) 142 | save_size2 = (input2.shape[0]*input2.shape[1],input2.shape[2], input2.shape[3]) 143 | self.params = grad_name, save_size1, save_size2, save_stride1, save_stride2, (input1.requires_grad, input2.requires_grad) 144 | del out 145 | return out_detached 146 | 147 | def backward(self, grad_output, stores, nodel=False): 148 | with torch.no_grad(): 149 | input1 = stores[0] 150 | input2 = stores[1] 151 | if not nodel: 152 | del stores[1] 153 | del stores[0] 154 | grad_name, size1, size2, stride1, stride2, do_grad = self.params 155 | dw, di = None, None 156 | if do_grad[1]: 157 | if grad_name == "MmBackward": 158 | dw = lrfuncs_cpp.mm_mat2_backward(grad_output.view(-1,grad_output.shape[-1]), input1.view(size1), size2, stride2) 159 | elif grad_name == "BmmBackward": 160 | dw = input1.view(size1).transpose(1, 2).bmm(grad_output.view(-1, grad_output.shape[2], grad_output.shape[3])) 161 | else: 162 | raise RuntimeError("Not implemented %s" % grad_name) 163 | del input1 164 | if do_grad[0]: 165 | if grad_name == "MmBackward": 166 | di = lrfuncs_cpp.mm_mat1_backward(grad_output.view(-1,grad_output.shape[-1]), input2.view(size2), size1, stride1) 167 | elif grad_name == "BmmBackward": 168 | di = (grad_output.view(-1, grad_output.shape[2], grad_output.shape[3])).bmm(input2.view(size2).transpose(1, 2)) 169 | else: 170 | raise RuntimeError("Not implemented %s" % grad_name) 171 | return di, dw 172 | 173 | @implements(['aten::slice'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 174 | class Slice(OP): 175 | params = None 176 | backward_storage = None 177 | 178 | def forward(self, input_, dim, start, end, step): 179 | with torch.no_grad(): 180 | self.params = input_.shape, dim, start, end, step 181 | return lrfuncs_cpp.slice(input_, dim, start, end, step) 182 | 183 | def backward(self, grad_output, stores, nodel=False): 184 | with torch.no_grad(): 185 | input_shape, dim, start, end, step = self.params 186 | return lrfuncs_cpp.slice_backward(grad_output, input_shape, dim, start, end, step) 187 | 188 | @implements(['aten::select'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 189 | class Select(OP): 190 | params = None 191 | backward_storage = None 192 | 193 | def forward(self, input_, dim, index): 194 | with torch.no_grad(): 195 | self.params = input_.shape, dim, index 196 | return lrfuncs_cpp.select(input_, dim, index) 197 | 198 | def backward(self, grad_output, stores, nodel=False): 199 | with torch.no_grad(): 200 | input_shape, dim, index = self.params 201 | return lrfuncs_cpp.select_backward(grad_output, input_shape, dim, index) 202 | 203 | @implements(['aten::tanh'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 204 | class Tanh(OP): 205 | params = None 206 | backward_storage = OutputStorage() 207 | 208 | def forward(self, input_): 209 | with torch.no_grad(): 210 | return torch.tanh(input_) 211 | 212 | def backward(self, grad_output, stores, nodel=False): 213 | with torch.no_grad(): 214 | output = stores[0] 215 | if not nodel: 216 | del stores[0] 217 | return lrfuncs_cpp.tanh_backward(grad_output, output) 218 | 219 | @implements(['aten::gelu'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 220 | class GeLU(OP): 221 | params = None 222 | backward_storage = InputStorage(0) 223 | 224 | def forward(self, input_): 225 | with torch.no_grad(): 226 | return torch.nn.functional.gelu(input_) 227 | 228 | def backward(self, grad_output, stores, nodel=False): 229 | with torch.no_grad(): 230 | input_ = stores[0] 231 | if not nodel: 232 | del stores[0] 233 | return lrfuncs_cpp.gelu_backward(grad_output, input_) 234 | 235 | # rsub : output = other - input_ * alpha. 236 | # Till now, we have observed other and alpha both to be Scalars. 237 | # cannot be native OP, because grad_fn stores input_ unnecessarily 238 | @implements(['aten::rsub'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 239 | class RSub(OP): 240 | params = None 241 | backward_storage = None 242 | 243 | def forward(self, input_, other, alpha): 244 | with torch.no_grad(): 245 | assert not torch.is_tensor(other) 246 | self.params = alpha 247 | return lrfuncs_cpp.rsub_const(input_, other, alpha) 248 | 249 | def backward(self, grad_output, stores, nodel=False): 250 | with torch.no_grad(): 251 | alpha = self.params 252 | minus_alpha = -alpha 253 | return grad_output*minus_alpha 254 | 255 | # @implements(['aten::rsub'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 256 | # class RSub(OP): 257 | # params = None 258 | # backward_storage = None 259 | 260 | # def forward(self, input_, other, alpha): 261 | # with torch.no_grad(): 262 | # assert not torch.is_tensor(other) 263 | # self.params = alpha 264 | # return lrfuncs_cpp.rsub_const(input_, other, alpha) 265 | 266 | # def backward(self, grad_output, stores, nodel=False): 267 | # with torch.no_grad(): 268 | # alpha = self.params 269 | # minus_alpha = -alpha 270 | # return grad_output*minus_alpha 271 | 272 | # @implements(['aten::zeros'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 273 | # class Zeros(OP): 274 | # params = None 275 | # backward_storage = None 276 | 277 | # def forward(self, input_size, dtype, layout, device, pin): 278 | # with torch.no_grad(): 279 | # import config 280 | # device = config.device 281 | # return lrfuncs_cpp.zeros(input_size, dtype, layout, device, pin) 282 | 283 | # def backward(self, grad_output, stores, nodel=False): 284 | # raise NotImplementedError 285 | 286 | @implements(['aten::to'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 287 | class To(OP): 288 | params = None 289 | backward_storage = None 290 | 291 | def forward(self, input_, dtype, b1, b2, format): 292 | with torch.no_grad(): 293 | assert format is None 294 | return lrfuncs_cpp.tofwd(input_, dtype, b1, b2) 295 | 296 | def backward(self, grad_output, stores, nodel=False): 297 | raise NotImplementedError 298 | 299 | @implements(['aten::upsample_nearest3d'], ['normal', 'multiway', 'newnode', 'multiway_newnode', 'conv_multiway_newnode', 'conv_normal', 'gist']) 300 | class UpsampleNearest3D(OP): 301 | params = None 302 | 303 | def forward(self, x, size1, size2, size3, scaled, scaleh, scalew): 304 | with torch.no_grad(): 305 | size = [size1, size2, size3] 306 | out = torch._C._nn.upsample_nearest3d(x, size, scaled, scaleh, scalew) 307 | self.params = x.shape, size, scaled, scaleh, scalew 308 | return out 309 | 310 | def backward(self, grad_output, stored, nodel=False): 311 | assert len(stored) == 0 312 | with torch.no_grad(): 313 | ipshape, size, scaled, scaleh, scalew = self.params 314 | return lrfuncs_cpp.lr_upsample_nearest_3d_backward(grad_output, size, ipshape, scaled, scaleh, scalew) -------------------------------------------------------------------------------- /monet/lm_ops/greedyconv.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .pack import * 3 | import torch 4 | import config 5 | from time import time 6 | 7 | conv_cpp = load(name="conv_cpp", sources=[this_dir/"conv.cpp"], extra_cflags=['-std=c++17'], extra_include_paths=[str(this_dir)], with_cuda=True) 8 | 9 | 10 | @implements(['aten::_convolution'], ['none']) 11 | class PytorchConvolution(OP): 12 | backward_storage = InputStorage(0, 1) 13 | 14 | def forward(self, input_, weight, bias, stride, padding, dilation, transposed, output_padding, groups, *args, **kwargs): 15 | assert not transposed 16 | self.params = stride, padding, dilation, groups, input_.shape, weight.shape 17 | self.do_grad = [input_.requires_grad, weight.requires_grad, bias is not None and bias.requires_grad] 18 | with torch.no_grad(): 19 | output = conv_cpp.forward_normal(input_, weight, stride, padding, dilation, groups) 20 | return output + bias.view((1, -1, 1, 1)) 21 | 22 | def backward(self, grad_output, stored): 23 | input_, weight = stored 24 | stride, padding, dilation, groups, input_shape, weight_shape = self.params 25 | di = dw = db = None 26 | if input_ is not None: 27 | if self.do_grad[1]: 28 | dw = conv_cpp.backward_weight_normal(weight_shape, grad_output, input_, stride, padding, dilation, groups) 29 | 30 | if weight is not None: 31 | if self.do_grad[0]: 32 | di = conv_cpp.backward_input_normal(input_shape, grad_output, weight, stride, padding, dilation, groups) 33 | 34 | if self.do_grad[2]: 35 | db = grad_output.sum([0, 2, 3]) 36 | 37 | return di, dw, db 38 | 39 | 40 | 41 | @implements(['aten::_convolution'], ['normal', 'multiway_newnode', 'multiway', 'newnode']) 42 | class GreedyAlgoConvolution(OP): 43 | backward_storage = InputStorage(0, 1) 44 | params = None 45 | algorithm = -1 46 | fwd_algo = -1 47 | bwd_ip_algo = -1 48 | bwd_wt_algo = -1 49 | is_depthwise = False 50 | 51 | def forward(self, input_, weight, bias, stride, padding, dilation, transposed, output_padding, groups, *args, **kwargs): 52 | with torch.no_grad(): 53 | assert not transposed 54 | self.params = stride, padding, dilation, groups, input_.shape, weight.shape, [input_.requires_grad, weight.requires_grad, bias is not None and bias.requires_grad] 55 | algorithm = -1 56 | if self.is_depthwise: 57 | if bias == None: 58 | return conv_cpp.convolution_main(input_, weight, torch.tensor(1), stride, padding, dilation, transposed, output_padding, groups) 59 | return conv_cpp.convolution_main(input_, weight, bias, stride, padding, dilation, transposed, output_padding, groups) 60 | else: 61 | input_ = input_.detach() 62 | weight = weight.detach() 63 | if self.fwd_algo == -1: 64 | tfinal = -1 65 | torch.cuda.empty_cache() 66 | for alg in range(self.n_fwd_algos()): 67 | try: 68 | torch.cuda.reset_max_memory_allocated() 69 | torch.cuda.synchronize() 70 | t1 = time() 71 | for it in range(10): 72 | output = conv_cpp.cudnn_convolution(input_, weight, padding, stride, dilation, groups, alg) 73 | del output 74 | torch.cuda.synchronize() 75 | t2 = time() - t1 76 | if (t2 < tfinal or tfinal == -1) and torch.cuda.max_memory_allocated()/1024/1024/1024 < config.budget*1.01: 77 | self.fwd_algo = alg 78 | tfinal = t2 79 | torch.cuda.empty_cache() 80 | torch.cuda.reset_max_memory_allocated() 81 | except Exception as e: 82 | torch.cuda.empty_cache() 83 | torch.cuda.reset_max_memory_allocated() 84 | assert self.fwd_algo != -1 85 | try: 86 | output = conv_cpp.cudnn_convolution(input_, weight, padding, stride, dilation, groups, self.fwd_algo) 87 | except Exception as e: 88 | output = conv_cpp.cudnn_convolution(input_, weight, padding, stride, dilation, groups, -1) 89 | self.fwd_algo = 4 90 | if bias is not None: 91 | output[:] += bias.view((1, -1, 1, 1)) 92 | return output 93 | 94 | def backward(self, grad_output, stored, nodel=False): 95 | with torch.no_grad(): 96 | input_, weight = stored 97 | stride, padding, dilation, groups, input_shape, weight_shape, do_grad = self.params 98 | algorithm = self.algorithm 99 | algo = -1 100 | convtype = algorithm // 10 101 | # Delete the stored inputs 102 | if not nodel: 103 | del stored[1] 104 | del stored[0] 105 | di = dw = db = None 106 | 107 | if self.is_depthwise: 108 | di, dw = conv_cpp.backward_depthwise(grad_output, input_, weight, (weight.shape[2], weight.shape[3]), stride, padding, dilation, do_grad[:2]) 109 | else: 110 | if convtype == 0 or convtype == 2: 111 | if input_ is not None: 112 | if do_grad[1]: 113 | input_detached = input_.detach() 114 | if self.bwd_wt_algo == -1: 115 | twtfinal = -1 116 | torch.cuda.empty_cache() 117 | for alg in range(self.n_bwd_wt_algos()): 118 | try: 119 | torch.cuda.reset_max_memory_allocated() 120 | torch.cuda.synchronize() 121 | t1 = time() 122 | for it in range(10): 123 | dw = conv_cpp.cudnn_convolution_backward_weight(weight_shape, grad_output, input_detached, padding, stride, dilation, groups, alg) 124 | torch.cuda.synchronize() 125 | t2 = time() - t1 126 | if (t2 < twtfinal or twtfinal == -1) and torch.cuda.max_memory_allocated()/1024/1024/1024 < config.budget*1.01: 127 | self.bwd_wt_algo = alg 128 | twtfinal = t2 129 | del dw 130 | torch.cuda.empty_cache() 131 | torch.cuda.reset_max_memory_allocated() 132 | except Exception as e: 133 | torch.cuda.empty_cache() 134 | torch.cuda.reset_max_memory_allocated() 135 | assert self.bwd_wt_algo != -1 136 | try: 137 | dw = conv_cpp.cudnn_convolution_backward_weight(weight_shape, grad_output, input_detached, padding, stride, dilation, groups, self.bwd_wt_algo) 138 | except Exception as e: 139 | dw = conv_cpp.cudnn_convolution_backward_weight(weight_shape, grad_output, input_detached, padding, stride, dilation, groups, -1) 140 | self.bwd_wt_algo = 1 141 | 142 | if do_grad[2]: 143 | db = grad_output.sum([0, 2, 3]) 144 | del input_, input_detached 145 | 146 | if convtype == 1 or convtype == 2: 147 | if weight is not None: 148 | if do_grad[0]: 149 | weight_detached = weight.detach() 150 | if self.bwd_ip_algo == -1: 151 | tipfinal = -1 152 | torch.cuda.empty_cache() 153 | for alg in range(self.n_bwd_ip_algos()): 154 | try: 155 | torch.cuda.reset_max_memory_allocated() 156 | torch.cuda.synchronize() 157 | t1 = time() 158 | for it in range(10): 159 | di = conv_cpp.cudnn_convolution_backward_input(input_shape, grad_output, weight_detached, padding, stride, dilation, groups, alg) 160 | del di 161 | torch.cuda.synchronize() 162 | t2 = time() - t1 163 | if (t2 < tipfinal or tipfinal == -1) and torch.cuda.max_memory_allocated()/1024/1024/1024 < config.budget*1.01: 164 | self.bwd_ip_algo = alg 165 | tipfinal = t2 166 | torch.cuda.empty_cache() 167 | torch.cuda.reset_max_memory_allocated() 168 | except Exception as e: 169 | # print(e) 170 | torch.cuda.empty_cache() 171 | torch.cuda.reset_max_memory_allocated() 172 | assert self.bwd_ip_algo != -1 173 | weight_detached = weight.detach() 174 | try: 175 | di = conv_cpp.cudnn_convolution_backward_input(input_shape, grad_output, weight_detached, padding, stride, dilation, groups, self.bwd_ip_algo) 176 | except Exception as e: 177 | di = conv_cpp.cudnn_convolution_backward_input(input_shape, grad_output, weight_detached, padding, stride, dilation, groups, -1) 178 | self.bwd_ip_algo = 1 179 | 180 | if do_grad[2]: 181 | return di, dw, db 182 | return di, dw 183 | 184 | @staticmethod 185 | def n_fwd_algos(): 186 | return conv_cpp.n_fwd_algos() 187 | 188 | @staticmethod 189 | def n_bwd_ip_algos(): 190 | return conv_cpp.n_bwd_ip_algos() 191 | 192 | @staticmethod 193 | def n_bwd_wt_algos(): 194 | return conv_cpp.n_bwd_wt_algos() 195 | 196 | @implements(['aten::_convolution'], ['conv_multiway_newnode', 'conv_normal']) 197 | class SpecificAlgoConvolution(OP): 198 | backward_storage = InputStorage(0, 1) 199 | params = None 200 | algorithm = -1 201 | is_depthwise = False 202 | 203 | def forward(self, input_, weight, bias, stride, padding, dilation, transposed, output_padding, groups, *args, **kwargs): 204 | with torch.no_grad(): 205 | assert not transposed 206 | self.params = stride, padding, dilation, groups, input_.shape, weight.shape, [input_.requires_grad, weight.requires_grad, bias is not None and bias.requires_grad] 207 | if self.is_depthwise: 208 | if bias == None: 209 | return conv_cpp.convolution_main(input_, weight, torch.tensor(1), stride, padding, dilation, transposed, output_padding, groups) 210 | return conv_cpp.convolution_main(input_, weight, bias, stride, padding, dilation, transposed, output_padding, groups) 211 | else: 212 | algorithm = self.algorithm 213 | input_ = input_.detach() 214 | weight = weight.detach() 215 | try: 216 | output = conv_cpp.cudnn_convolution(input_, weight, padding, stride, dilation, groups, algorithm) 217 | except Exception as e: 218 | output = conv_cpp.cudnn_convolution(input_, weight, padding, stride, dilation, groups, -1) 219 | self.algorithm = 4 220 | if bias is not None: 221 | output[:] += bias.view((1, -1, 1, 1)) 222 | return output 223 | 224 | def backward(self, grad_output, stored, nodel=False): 225 | with torch.no_grad(): 226 | input_, weight = stored 227 | stride, padding, dilation, groups, input_shape, weight_shape, do_grad = self.params 228 | algorithm = self.algorithm 229 | algo = algorithm % 10 230 | convtype = algorithm // 10 231 | # Delete the stored inputs 232 | if not nodel: 233 | del stored[1] 234 | del stored[0] 235 | di = dw = db = None 236 | 237 | if self.is_depthwise: 238 | di, dw = conv_cpp.backward_depthwise(grad_output, input_, weight, (weight.shape[2], weight.shape[3]), stride, padding, dilation, do_grad[:2]) 239 | else: 240 | if convtype == 0 or convtype == 2: 241 | if input_ is not None: 242 | if do_grad[1]: 243 | input_detached = input_.detach() 244 | try: 245 | dw = conv_cpp.cudnn_convolution_backward_weight(weight_shape, grad_output, input_detached, padding, stride, dilation, groups, algo) 246 | except Exception as e: 247 | dw = conv_cpp.cudnn_convolution_backward_weight(weight_shape, grad_output, input_detached, padding, stride, dilation, groups, -1) 248 | self.algorithm = convtype * 10 + 1 249 | if do_grad[2]: 250 | db = grad_output.sum([0, 2, 3]) 251 | del input_, input_detached 252 | 253 | if convtype == 1 or convtype == 2: 254 | if weight is not None: 255 | if do_grad[0]: 256 | weight_detached = weight.detach() 257 | try: 258 | di = conv_cpp.cudnn_convolution_backward_input(input_shape, grad_output, weight_detached, padding, stride, dilation, groups, algo) 259 | except Exception as e: 260 | di = conv_cpp.cudnn_convolution_backward_input(input_shape, grad_output, weight_detached, padding, stride, dilation, groups, -1) 261 | self.algorithm = convtype * 10 + 1 262 | 263 | if do_grad[2]: 264 | return di, dw, db 265 | return di, dw 266 | 267 | @staticmethod 268 | def n_fwd_algos(): 269 | return conv_cpp.n_fwd_algos() 270 | 271 | @staticmethod 272 | def n_bwd_ip_algos(): 273 | return conv_cpp.n_bwd_ip_algos() 274 | 275 | @staticmethod 276 | def n_bwd_wt_algos(): 277 | return conv_cpp.n_bwd_wt_algos() 278 | -------------------------------------------------------------------------------- /gist/gist_solver_info.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from functools import lru_cache 3 | import torch 4 | import monet.lm_ops as lm_ops 5 | from gist.gist_graph import * 6 | import numpy as np 7 | import sys 8 | 9 | MEMORY_MULTIPLIER = 4 10 | # TODO: Remove this and instead use dtype 11 | 12 | class SolverNode: 13 | def __init__(self): 14 | self.input_shapes = [] 15 | self.output_shapes = [] 16 | self.local_tensors = set() # Includes direct deps + deps due to branching which are kept in memory while recomputing this node 17 | self.local_memory = -1 # Memory of local tensors 18 | self.mem = -1 # Memory of output of node 19 | self.workspace_mem = [] # Extra space over its input and output reqd to compute a node 20 | self.fixed_mem = -1 # Fixed memory of the input + parameters + increasing memory of param grad 21 | self.workspace_compute = [] 22 | self.last_used = -1 # Which node uses this node last 23 | self.recompute_workspace_mem = [] 24 | self.recompute_workspace_compute = [] 25 | self.inplace_workspace_mem = [] 26 | self.inplace_workspace_compute = [] 27 | 28 | class BwdNode(SolverNode): 29 | # args.index for bwd node is wrt to solver_info nodes 30 | def __init__(self, fwd_node, fwd_id, num_bwd, bwd_op="ip_grad"): 31 | super().__init__() 32 | self.fwd_node = fwd_node 33 | self.fwd_id = fwd_id 34 | self.bwd_op = bwd_op 35 | self.stored = [] # Stored parameter gradients 36 | self.num_bwd = num_bwd 37 | self.dep_list_fwd = [[] for i in range(self.num_bwd)] 38 | self.dep_list_bwd = [[] for i in range(self.num_bwd)] 39 | self.args = [[] for i in range(self.num_bwd)] 40 | self.has_intermediates = False 41 | 42 | def __repr__(self): 43 | if self.fwd_node.op == "aten::_convolution" or self.fwd_node.op == "aten::addmm": 44 | return '' % (self.fwd_node, self.bwd_op) 45 | else: 46 | return '' % self.fwd_node.op 47 | 48 | def make_args(self): 49 | self.op = self.fwd_node.op + "_" + self.bwd_op 50 | for p in range(self.num_bwd): 51 | self.args[p] = self.dep_list_fwd[p] + self.dep_list_bwd[p] 52 | 53 | class FwdNode(SolverNode): 54 | def __init__(self, node): 55 | super().__init__() 56 | self.gnode = node 57 | self.op = node.op 58 | self.inbound_nodes = [] 59 | self.args = [[]] 60 | self.has_intermediates = False 61 | self.intermediates = [] # Tuple of op_number, solver intermediate node 62 | 63 | # Forward node will have only one set of dependencies 64 | def make_args(self): 65 | self.args = [self.inbound_nodes] 66 | 67 | def __repr__(self): 68 | return '%s' % self.gnode 69 | 70 | class IntNode(FwdNode): 71 | def __init__(self, node, solver_parent_id, op_idx): 72 | super().__init__(node) 73 | self.solver_parent_id = solver_parent_id 74 | self.op_idx = op_idx 75 | 76 | def __repr__(self): 77 | return 'Int::%s' % self.gnode 78 | 79 | class SolverInfo(): 80 | def __init__(self, bs, model_name, mode): 81 | self.nodes = {} 82 | self.newnodes = [] 83 | self.loss = 0 84 | self.edge_list = [] 85 | self.fwd_to_bwd = {} 86 | self.bwd_to_fwd = {} 87 | self.solver_to_graph = [] 88 | self.graph_to_solver = {} 89 | 90 | self.bs = bs 91 | self.model_name = model_name 92 | self.mode = mode 93 | self.data_path = "%s_%d_%s" % (self.model_name, self.bs, self.mode) 94 | 95 | self.select_conv_algo = "conv" in self.mode 96 | self.do_inplace = "inplace" in self.mode 97 | self.compute_newnode = "newnode" in self.mode 98 | 99 | self.mode = self.mode.replace("inplace_", "") # Remove inplace from the mode after setting do_inplace 100 | 101 | 102 | def extract(self, graph: Graph, *args): 103 | # Create solver graph of forward and backward pass annotated with memory and compute info 104 | for i, n in enumerate(graph.nodes): 105 | # Create solver forward nodes 106 | if isinstance(n, ComputeNode) and n.has_backward: 107 | self.solver_to_graph.append(i) 108 | self.graph_to_solver[i] = len(self.solver_to_graph)-1 109 | snodeid = len(self.nodes) 110 | self.nodes[snodeid] = FwdNode(n) 111 | inbound = [] 112 | for (dep, grad) in n.dependencies: 113 | if isinstance(graph.nodes[dep], ComputeNode): 114 | assert grad == True 115 | inbound.append(self.graph_to_solver[dep]) 116 | self.nodes[snodeid].inbound_nodes = inbound 117 | 118 | for idx, op in enumerate(lm_ops.list_ops(self.mode, n.op)): 119 | storage = op.backward_storage 120 | if not isinstance(storage, list): 121 | storage = [storage] 122 | for storetype in storage: 123 | if isinstance(storetype, lm_ops.IntermediateStorage): 124 | # Create intermediate node 125 | newnode_op = "int::" + op.__name__ 126 | ni = ComputeNode([storetype.size(n.shape)], -1, newnode_op, [], False) 127 | self.solver_to_graph.append(-1) 128 | sintnodeid = len(self.nodes) 129 | self.nodes[sintnodeid] = IntNode(ni, snodeid, idx) 130 | inbound = [] 131 | inbound.append(snodeid) 132 | self.nodes[sintnodeid].inbound_nodes = inbound 133 | self.newnodes.append(sintnodeid) 134 | self.nodes[snodeid].has_intermediates = True 135 | self.nodes[snodeid].intermediates.append((idx,sintnodeid)) 136 | 137 | self.loss = len(self.nodes) 138 | nloss = ComputeNode(self.nodes[self.loss-1].gnode.shape, -2, "loss::loss", [], False) 139 | self.nodes[self.loss] = FwdNode(nloss) 140 | self.nodes[self.loss].inbound_nodes = [self.loss-1] 141 | self.extract_deps(graph) 142 | self.get_mem() 143 | self.get_local_memory() 144 | self.get_fixed_memory(graph) 145 | self.size = len(self.nodes) 146 | 147 | 148 | 149 | def extract_deps(self, graph): 150 | # Create solver backward nodes 151 | for i in sorted(self.nodes, reverse=True): 152 | if i > self.loss: 153 | continue 154 | if self.nodes[i].gnode.has_backward: 155 | num_bwd = len(lm_ops.list_ops(self.mode, self.nodes[i].op)) 156 | nbwd = BwdNode(self.nodes[i].gnode, i, num_bwd) 157 | self.fwd_to_bwd[i] = len(self.nodes) 158 | self.bwd_to_fwd[len(self.nodes)] = i 159 | self.nodes[len(self.nodes)] = nbwd 160 | self.fwd_to_bwd[self.loss] = self.loss 161 | 162 | # get deps for backward 163 | for i in self.nodes: 164 | if isinstance(self.nodes[i], BwdNode): 165 | continue 166 | if self.nodes[i].gnode.has_backward: 167 | for ni in self.nodes[i].inbound_nodes: 168 | num_bwd_prev = len(lm_ops.list_ops(self.mode, self.nodes[ni].op)) 169 | for p in range(num_bwd_prev): 170 | self.nodes[self.fwd_to_bwd[ni]].dep_list_bwd[p].append(self.fwd_to_bwd[i]) 171 | 172 | ops = lm_ops.list_ops(self.mode, self.nodes[i].op) 173 | num_bwd = len(ops) 174 | 175 | for p in range(num_bwd): 176 | deps = ops[p]().backward_storage 177 | l = [] 178 | if not isinstance(deps, list): 179 | deps = [deps] 180 | for dep in deps: 181 | if isinstance(dep, lm_ops.InputStorage): 182 | for inids in dep.ids: 183 | arg_in = self.nodes[i].gnode.args[inids] 184 | if isinstance(arg_in, ComputeNode.D): 185 | innode = graph.nodes[arg_in.index] 186 | if isinstance(innode, ComputeNode): 187 | l.append(self.graph_to_solver[arg_in.index]) 188 | if isinstance(dep, lm_ops.OutputStorage): 189 | l.append(i) 190 | if isinstance(dep, lm_ops.IntermediateStorage): 191 | added_int = False 192 | assert self.nodes[i].has_intermediates 193 | for p_option, nint_id in self.nodes[i].intermediates: 194 | nint = self.nodes[nint_id] 195 | assert isinstance(nint, IntNode) 196 | if p_option == p: 197 | l.append(nint_id) 198 | added_int = True 199 | assert added_int == True 200 | self.nodes[self.fwd_to_bwd[i]].dep_list_fwd[p] = l 201 | 202 | if i == self.loss - 1: # inject loss node assuming we are at output node 203 | for p in range(num_bwd): 204 | self.nodes[self.fwd_to_bwd[i]].dep_list_fwd[p].append(self.loss) 205 | 206 | for i in range(self.loss+1): 207 | self.nodes[i].make_args() 208 | for i in range(self.loss+1, len(self.nodes)): 209 | self.nodes[i].make_args() 210 | fwd_node = self.nodes[i].fwd_node 211 | stored = [] 212 | output_shapes = [] 213 | # Need to have a list because of nodes like add and cat 214 | # NOTE for both aten::add and cat, we consider output of bwd as both inputs of fwd 215 | for (dep, rgrad) in (fwd_node.dependencies): 216 | nin = graph.nodes[dep] 217 | if isinstance(nin, Param) and rgrad == True: 218 | stored.append(nin) # Params have gradients which will be stored in backward 219 | elif isinstance(nin, Param): 220 | pass 221 | elif isinstance(nin, ComputeNode) and rgrad == True: 222 | output_shapes.append(list(nin.shape)) 223 | elif isinstance(nin, Input): 224 | if rgrad: 225 | output_shapes.append(list(nin.shape)) 226 | else: 227 | sys.exit("Unknown node encountered ") 228 | 229 | self.nodes[i].output_shapes = output_shapes 230 | self.nodes[i].stored = stored 231 | 232 | # Create edge_list 233 | self.edge_list = [] 234 | for v in self.nodes: 235 | for k, vdeps in enumerate(self.nodes[v].args): 236 | for u in vdeps: 237 | edge = (u, v, k) 238 | self.edge_list.append(edge) 239 | 240 | self.last_use_bwd = defaultdict(dict) 241 | for i in self.nodes: 242 | if isinstance(self.nodes[i], BwdNode): 243 | assert len(self.nodes[i].dep_list_fwd) == 1 244 | for j in self.nodes[i].dep_list_fwd[0]: 245 | if isinstance(self.nodes[j], IntNode): 246 | pj = self.nodes[j].solver_parent_id 247 | self.last_use_bwd[self.solver_to_graph[pj]]["int"] = self.solver_to_graph[self.bwd_to_fwd[i]] 248 | else: 249 | if j == self.loss: 250 | sj = graph._outputs[0] 251 | else: 252 | sj = self.solver_to_graph[j] 253 | b_graphi = self.solver_to_graph[self.bwd_to_fwd[i]] 254 | self.last_use_bwd[sj]["ip"] = b_graphi 255 | 256 | def get_mem(self): 257 | for i in self.nodes: 258 | if isinstance(self.nodes[i], BwdNode): 259 | self.nodes[i].mem = sum([ np.prod(oshape)* MEMORY_MULTIPLIER * -1 * self.bs if np.prod(oshape)<0 else 260 | np.prod(oshape)* MEMORY_MULTIPLIER for oshape in self.nodes[i].output_shapes]) 261 | else: 262 | oshape = list(self.nodes[i].gnode.shape) 263 | self.nodes[i].mem = np.prod(oshape) * MEMORY_MULTIPLIER * -1 * self.bs if np.prod(oshape)<0 else np.prod(oshape)* MEMORY_MULTIPLIER 264 | if "int::" in self.nodes[i].op: 265 | self.nodes[i].mem = int( (self.nodes[i].mem / MEMORY_MULTIPLIER )) # The size of int node already includes memory multipliers according to dtype 266 | 267 | 268 | def get_local_memory(self): 269 | present = -1 270 | 271 | for v in self.nodes: 272 | # fwd-fwd and bwd-bwd dependencies remain same in all paths 273 | for u in self.nodes[v].args[0]: 274 | # dont add forward nodes to backward working set 275 | if v>self.loss and u= self.loss: 297 | break 298 | for a in self.nodes[i].gnode.args: 299 | if not isinstance(a, ComputeNode.V) and (isinstance(graph.nodes[a.index], Param) or isinstance(graph.nodes[a.index], Input)): 300 | fmem = np.prod(list(graph.nodes[a.index].shape)) * MEMORY_MULTIPLIER 301 | if fmem < 0: 302 | fmem = fmem * -1 * self.bs 303 | fixed_mem_fwd = fixed_mem_fwd + fmem 304 | self.nodes[0].fixed_mem = fixed_mem_fwd 305 | 306 | for i in self.nodes: 307 | if i == 0: 308 | continue 309 | self.nodes[i].fixed_mem = self.nodes[i-1].fixed_mem 310 | if i <= self.loss: # The fixed mem will be counting always 311 | continue 312 | for param_stored in self.nodes[i].stored: 313 | fmem = np.prod(list(param_stored.shape)) * MEMORY_MULTIPLIER 314 | if fmem < 0: 315 | fmem = fmem * -1 * self.bs 316 | self.nodes[i].fixed_mem = self.nodes[i].fixed_mem + fmem 317 | 318 | # Below is extra info for checkmate solver 319 | @property 320 | @lru_cache(maxsize=None) 321 | def successor_dict(self): 322 | sucs = defaultdict(list) 323 | for eidx, (u, v, p) in enumerate(self.edge_list): 324 | assert p == 0 325 | sucs[u].append((eidx, v)) 326 | return sucs 327 | 328 | def successors(self, node): 329 | return {u for (_, u) in self.successor_dict[node]} 330 | 331 | @property 332 | @lru_cache(maxsize=None) 333 | def predecessor_dict(self): 334 | preds = defaultdict(list) 335 | for eidx, (u, v, p) in enumerate(self.edge_list): 336 | assert p == 0 337 | preds[v].append((eidx, u)) 338 | return preds 339 | 340 | def predecessors_indexed(self, node): 341 | return self.predecessor_dict[node] 342 | -------------------------------------------------------------------------------- /gist/gist_schedule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from collections import namedtuple 4 | import monet.lm_ops as lm_ops 5 | import numpy as np 6 | 7 | from gist.gist_graph import * 8 | from monet.cvxpy_solver import * 9 | from gist.gist_solver_info import * 10 | from models.unet import UNet 11 | 12 | ScheduleType = namedtuple('ScheduleType', 'recompute store_output store_intermediate') 13 | 14 | class Schedule(Graph): 15 | def __init__(self, graph: Graph, info: SolverInfo): 16 | self.si = info 17 | self._nodes = graph.nodes 18 | self._outputs = graph._outputs 19 | self._op = [] 20 | self._last_op_params = [None] * len(self.nodes) 21 | self._fwd_schedule = [] 22 | self.real_mem = [[-1 for k in range(self.si.loss + 1)] for t in range(self.si.size - self.si.loss) ] 23 | 24 | # Stored tensors 25 | self._stored = [None] * len(self.nodes) 26 | self._stored_intermediate = [None] * len(self.nodes) 27 | self.deltensors = defaultdict(list) 28 | for j in self.si.nodes: 29 | if j= solver_min: 246 | if self.si.last_use_bwd[i]["ip"] == k: 247 | tensors[i] = None 248 | idx = sidx + 1 249 | elif isinstance(storage, lm_ops.OutputStorage): 250 | i = required_storage[idx] 251 | stored.append(tensors[i].float()) 252 | if self.si.last_use_bwd[i]["ip"] == k: 253 | tensors[i] = None 254 | idx = idx + 1 255 | elif isinstance(storage, lm_ops.IntermediateStorage): 256 | assert self._stored_intermediate[k] is not None, "Intermediate output not computed for node %d" % k 257 | stored.append(self._stored_intermediate[k]) 258 | self._stored_intermediate[k] = None 259 | 260 | # Call backward 261 | assert bw_tensors[k] is not None, "Backward input not computed for node %d (graph), %d (solver), %s (node)" % (k, self.si.graph_to_solver[k], self.si.nodes[self.si.graph_to_solver[k]]) 262 | bw_outs = grad_nd.backward(bw_tensors[k], stored) 263 | 264 | if not isinstance(bw_outs, (list, tuple)): 265 | bw_outs = (bw_outs,) 266 | 267 | assert len(bw_outs) == len(n.dependencies), \ 268 | "Require the same number of grad outputs as forward inputs" \ 269 | " %s (%d) , %s (%d) %s" % ( 270 | repr(bw_outs), len(bw_outs), 271 | repr(n.dependencies), len(n.dependencies), n) 272 | 273 | # Free the backward tensor 274 | bw_tensors[k] = None 275 | 276 | # Accumulate the backward gradient 277 | for (i, r), o in zip(n.dependencies, bw_outs): 278 | if r: 279 | if o is not None: 280 | if bw_tensors[i] is None: 281 | bw_tensors[i] = o 282 | else: 283 | bw_tensors[i] += o 284 | del grad_nd 285 | 286 | elif self._args[k].requires_grad: 287 | self._args[k].backward(bw_tensors[k]) 288 | bw_tensors[k] = None 289 | 290 | # Clear params for BN 291 | for k, n in enumerate(self._nodes): 292 | if k in self.computeInstance and n.op == "aten::batch_norm": 293 | self._op[0][k].params = None 294 | 295 | if __name__ == '__main__': 296 | import argparse 297 | import torchvision 298 | from time import time 299 | from pathlib import Path 300 | 301 | parser = argparse.ArgumentParser() 302 | parser.add_argument('model') 303 | parser.add_argument('bs') 304 | parser.add_argument('budget') 305 | parser.add_argument( 306 | "--check_runtime", action="store_true", 307 | help="Compute the runtime difference between gist and normal model.") 308 | args = parser.parse_args() 309 | 310 | budget = float(args.budget) 311 | bs = int(args.bs) 312 | model_name = args.model.split(".")[-1][:-2] 313 | mode = "gist" 314 | print("Batch size ", bs) 315 | print("Model", model_name) 316 | print("Mode", mode) 317 | 318 | # Initialize pool of budget 319 | pool_shape = ( int(budget * 256 * 1024 * 1024) >> 3 ) << 3 320 | t = torch.zeros(pool_shape).cuda() 321 | del t 322 | torch.cuda.reset_max_memory_allocated() 323 | 324 | if args.model == 'unet': 325 | height, width = 416, 608 326 | model = UNet(n_channels=3, n_classes=1, height=height, width=width) 327 | else: 328 | height, width = 224, 224 329 | model = eval(args.model, {'torch': torch, 'torchvision': torchvision}) 330 | 331 | if 'mobilenet_v2' in args.model: 332 | model = torch.nn.Sequential( 333 | model.features, 334 | torch.nn.AdaptiveAvgPool2d((1, 1)), torch.nn.Flatten(start_dim=1), 335 | model.classifier[0], model.classifier[1]) 336 | 337 | graph = Graph.create(model, input_shape=(3, height, width)) 338 | model.cuda() 339 | 340 | input_ = torch.randn((bs, 3, height, width)).cuda() 341 | solver_info = SolverInfo(bs=bs, model_name=model_name, mode=mode) 342 | solver_info.extract(graph, input_, *list(model.state_dict(keep_vars=True).values())) 343 | schedule = Schedule(graph, solver_info) 344 | schedule.init_schedule(mode) 345 | torch.cuda.synchronize() 346 | torch.cuda.reset_max_memory_allocated() 347 | 348 | if args.check_runtime: 349 | start_event_monet = torch.cuda.Event(enable_timing=True) 350 | end_event_monet = torch.cuda.Event(enable_timing=True) 351 | for iterid in range(120): 352 | if iterid == 100: 353 | torch.cuda.reset_max_memory_allocated() 354 | start_event_monet.record() 355 | x1 = schedule.forward(input_, *list(model.state_dict(keep_vars=True).values())) 356 | schedule.backward(-torch.ones_like(x1)) 357 | for v in model.parameters(): 358 | v.grad = None 359 | end_event_monet.record() 360 | torch.cuda.synchronize() 361 | del x1 362 | monet_maxmem = torch.cuda.max_memory_allocated() / 2**20 363 | 364 | print("monet: %f ms avg, %8.2f MB" % (start_event_monet.elapsed_time(end_event_monet)/20, monet_maxmem)) 365 | else: 366 | print("simple fwd bwd") 367 | x1 = schedule.forward(input_, *list(model.state_dict(keep_vars=True).values())) 368 | schedule.backward(-torch.ones_like(x1)) 369 | torch.cuda.synchronize() 370 | print("Max mem: %.3f MB" % (torch.cuda.max_memory_allocated()/1024/1024)) 371 | print("Done") 372 | --------------------------------------------------------------------------------