├── .gitignore ├── README.md ├── build_notes.txt ├── environment.yml ├── experiments ├── __init__.py ├── experiments.txt ├── loss_exp.py ├── my_loss_exp.py ├── speedup_exp.py ├── split_size_exp.py ├── throughput_exp.py └── utils.py ├── misc └── networkx_metis__init__.py ├── parse_declarations.py ├── partition_torchvision_networks.py ├── pytorch_Gpipe ├── METIS │ ├── METIS_graph_partition.py │ ├── METIS_manual.pdf │ ├── __init__.py │ ├── libmetis.dll │ └── libmetis.so ├── __init__.py ├── delayedNorm.py ├── model_partitioning │ ├── __init__.py │ ├── module_generation │ │ ├── __init__.py │ │ ├── constructor.py │ │ ├── declarations.pyi │ │ ├── forward.py │ │ ├── generate.py │ │ └── misc.py │ ├── partition_graph.py │ └── process_partition.py ├── model_profiling │ ├── __init__.py │ ├── control_flow_graph.py │ └── network_profiler.py ├── pipeline.py └── utils.py ├── sample_models ├── AlexNet.py ├── DenseNet.py ├── GoogleNet.py ├── Inception.py ├── LeNet.py ├── ResNet.py ├── SqueezeNet.py ├── VGG.py ├── WideResNet.py ├── __init__.py ├── amoebaNet │ ├── __init__.py │ ├── genotype.py │ └── utils.py └── torchGpipe_ref_models │ ├── __init__.py │ ├── amoebanet_d │ ├── __init__.py │ ├── genotype.py │ └── surgery.py │ ├── flatten_sequential.py │ └── resnet │ ├── __init__.py │ └── bottleneck.py └── tests ├── __init__.py ├── test_delayed_norm.py ├── test_metis_lib_bindings.py └── test_networkx.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | .ipynb_checkpoints/* 3 | modelParallel/__pycache__/* 4 | data/* 5 | models/__pycache__/* 6 | */__pycache__/* 7 | __pycache__/* 8 | ideas/* 9 | .idea/ 10 | .pytest_cache/** 11 | playground/** 12 | model_partition/graph/__pycache__/ 13 | *.pyc 14 | main.py 15 | experiment/stanford_car_dataset_images_in_224x224/* 16 | experiments/stanford-car-dataset-by-classes-folder-224/* 17 | playground.py 18 | trace.txt 19 | inc.pdf 20 | *.pdf 21 | generated_modules.py 22 | *.log 23 | transformer_example/** 24 | transformer_trace.txt 25 | generated* 26 | *Bug* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Pytorch implementation of Pipeline Model Parallelism as described in Google AI article https://arxiv.org/pdf/1811.06965.pdf 2 | 3 | this is still a work in progress and not yet functional 4 | -------------------------------------------------------------------------------- /build_notes.txt: -------------------------------------------------------------------------------- 1 | prerequisites: 2 | 1. make sure you have the packages specified in environment.yml (graphviz and python-graphviz are optional for graph visualization) 3 | 2. clone the repository 4 | 3. if you wish to use METIS library(recommendend) in order to decide how to divide the model follow the next section 5 | 6 | METIS build instructions: 7 | requires CMake to build !!!!!!!!!!! 8 | source code can be downloaded from http://glaros.dtc.umn.edu/gkhome/fetch/sw/metis/metis-5.1.0.tar.gz 9 | 10 | for linux(tested): 11 | in the file include/metis.h 12 | 1.in line 33 set IDXTYPEWIDTH to 64 13 | 2.in line 42 set REALTYPEWIDTH to 64 14 | 15 | 1. make config shared=1 16 | 2. make install 17 | 3. copy build/libmetis/libmetis.so to pytorch_Gpipe/METIS 18 | 4. run tests/test_metis.py if passes then you have successfully integrated metis into this project 19 | 20 | for windows(tested): 21 | 1.in the file include/metis.h 22 | a. in line 33 set IDXTYPEWIDTH to 64 23 | b. in line 42 set REALTYPEWIDTH to 64 24 | 25 | 1. in the top level CMakeLists.txt: 26 | a. change line 4 set(GKLIB_PATH "GKlib" CACHE PATH "path to GKlib") to set(GKLIB_PATH "${CMAKE_SOURCE_DIR}/GKlib" CACHE PATH "path to GKlib") 27 | b. change line 6 set(SHARED FALSE CACHE BOOL "build a shared library") to set(SHARED TRUE CACHE BOOL "build a shared library") 28 | 29 | 2. in the file GKlib/gk_arch.h 30 | a. comment line 42 #include sys/resource.h 31 | 32 | 3. in the file GKlib/getopt.h 33 | a. comment all extern declarations lines 53-57: 34 | extern int gk_getopt(int __argc, char **__argv, char *__shortopts); 35 | extern int gk_getopt_long(int __argc, char **__argv, char *__shortopts,struct gk_option *__longopts, int *__longind); 36 | extern int gk_getopt_long_only (int __argc, char **__argv,char *__shortopts, struct gk_option *__longopts, int *__longind); 37 | 38 | 4. build the libmetis.dll target 39 | 5. copy build/libmetis/libmetis.dll to pytorch_Gpipe/METIS 40 | 6. run tests/test_metis.py if passes then you have successfully integrated metis into this project -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: Gpipe 2 | channels: 3 | - pytorch 4 | - defaults 5 | - conda-forge 6 | dependencies: 7 | - python=3.7 8 | - pytorch=1.2.0 9 | - torchvision==0.4.0 10 | - graphviz 11 | - python-graphviz 12 | - networkx 13 | - networkx-metis 14 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /experiments/experiments.txt: -------------------------------------------------------------------------------- 1 | #TODO experiments 2 | 3 | 4 | performence: 5 | 1. compare memory consumption between the pipeline and dataPatallel/singleGpu/naiveModelParallel 6 | 2. for a given network partition and a given batch size find optimal microbatch size 7 | 3. for a given network partition find optimal batch size (microbatch size is batch/nparts) 8 | 9 | effects of partition on performence: 10 | 1. for a given network find effects of depth and basic blocks on performence 11 | 2. for a given network find effects of number of partition 12 | 13 | reproduce paper results: 14 | 1. resnet 101 15 | 2. AmoebaNet-D https://github.com/tensorflow/tpu/tree/e5c126d66aa3d25e0cb066bdf7fc46f98fe59901/models/experimental/amoeba_net 16 | 17 | -------------------------------------------------------------------------------- /experiments/loss_exp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import DataLoader 4 | import torch.nn as nn 5 | import torchvision.transforms as transforms 6 | from sample_models import * 7 | import torchvision 8 | import matplotlib.pyplot as plt 9 | import torch.optim as optim 10 | from .utils import ExpParser 11 | 12 | from sample_models.AlexNet import alexnet 13 | from pytorch_Gpipe import pipe_model 14 | 15 | 16 | def loss_exp(model_class, num_devices: int, batch_size: int, model_params: dict, pipeline_params: dict, 17 | dataset: torch.utils.data.Dataset, train_ratio: float = 0.8, num_epochs=500): 18 | device_single = 'cuda:0' if torch.cuda.is_available() else 'cpu' 19 | 20 | pipeline_params['devices'] = list(range(num_devices)) 21 | pipeline_params['sample_batch'] = torch.randn(batch_size, *dataset[0].shape) 22 | 23 | train_amount = int(train_ratio * len(dataset)) 24 | test_amount = len(dataset) - train_amount 25 | train_set, test_set = torch.utils.data.random_split(dataset, (train_amount, test_amount)) 26 | 27 | # the models to compare 28 | model_single = model_class(**model_params).to(device_single) 29 | 30 | model_pipe = pipe_model(model_class(**model_params), **pipeline_params) 31 | model_pipe.zero_grad() 32 | 33 | model_dp = nn.DataParallel(model_class(**model_params), device_ids=pipeline_params['devices']) 34 | 35 | # train the models 36 | print(f"Training model on {device_single}") 37 | stats_single_train, stats_single_test = train_with_stats_saved(model_single, train_set, test_set, num_epochs, 38 | batch_size, device_single) 39 | 40 | print("Training model using G-pipe") 41 | stats_pipe_train, stats_pipe_test = train_with_stats_saved(model_pipe, train_set, test_set, num_epochs, batch_size) 42 | 43 | print("Training model using data parallel") 44 | stats_dp_train, stats_dp_test = train_with_stats_saved(model_dp, train_set, test_set, num_epochs, batch_size) 45 | 46 | # plot results 47 | plot_loss(stats_single_train, stats_pipe_train, stats_dp_train) 48 | plot_loss(stats_single_test, stats_pipe_test, stats_dp_test, False) 49 | 50 | # print(stats_single) 51 | # print(stats_pipe) 52 | # print(stats_dp) 53 | 54 | accuracy_single = stats_single_test[1][-1] 55 | accuracy_pipe = stats_pipe_test[1][-1] 56 | accuracy_dp = stats_dp_test[1][-1] 57 | 58 | print(f"Test accuracy on single device: {accuracy_single}") 59 | print(f"Test accuracy on model parallel: {accuracy_pipe}") 60 | print(f"Test accuracy on data parallel: {accuracy_dp}") 61 | 62 | # asserts that pipeline's accuracy is within 5% difference from regular model 63 | assert accuracy_single * 0.95 < accuracy_pipe < accuracy_single * 1.05 64 | 65 | 66 | def train_with_stats_saved(model, train_set, test_set, num_epochs, batch_size, input_device='cpu'): 67 | # using fixed seed to ensure same train-set loader 68 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2) 69 | 70 | criterion = nn.CrossEntropyLoss() 71 | optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 72 | 73 | train_losses, train_accuracies = [], [] 74 | test_stats = [] 75 | 76 | for epoch in range(num_epochs): 77 | epoch_losses = [] 78 | epoch_accuracies = [] 79 | 80 | if epoch % 5 == 0: 81 | print(f"epoch number {epoch}:") 82 | 83 | for i, (inputs, labels) in enumerate(train_loader): 84 | optimizer.zero_grad() 85 | 86 | outputs = model(inputs.to(input_device)) 87 | labels = labels.to(outputs.device) 88 | 89 | loss = criterion(outputs, labels) 90 | loss.backward() 91 | optimizer.step() 92 | 93 | # saving statistics every 25 batches 94 | epoch_losses.append(loss.item()) 95 | epoch_accuracies.append(get_accuracy(outputs, labels)) 96 | 97 | # average loss so far 98 | train_losses.append(np.mean(epoch_losses)) 99 | 100 | # average accuracy so far 101 | train_accuracies.append(np.mean(epoch_accuracies)) 102 | 103 | test_stats.append(test(model, test_set, batch_size, input_device)) 104 | 105 | return (train_losses, train_accuracies), tuple(zip(*test_stats)) 106 | 107 | 108 | def test(model, test_set, batch_size, input_device='cpu'): 109 | # always using same test-set loader 110 | test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2) 111 | 112 | criterion = nn.CrossEntropyLoss() 113 | 114 | losses, accuracies = [], [] 115 | 116 | with torch.no_grad(): 117 | for batch in test_loader: 118 | inputs, labels = batch 119 | outputs = model(inputs.to(input_device)) 120 | labels = labels.to(outputs.device) 121 | 122 | loss = criterion(outputs, labels) 123 | 124 | losses.append(loss.item()) 125 | accuracies.append(get_accuracy(outputs, labels)) 126 | 127 | return np.mean(losses), np.mean(accuracies) 128 | 129 | 130 | def plot_loss(stats_single, stats_pipe, stats_dp, training=True): 131 | loss_s, acc_s = stats_single 132 | loss_p, acc_p = stats_pipe 133 | loss_dp, acc_dp = stats_dp 134 | 135 | x = range(0, 25 * len(loss_s), 25) 136 | 137 | plt.subplot(1, 2, 1) 138 | 139 | plt.plot(x, loss_s, label='single device') 140 | plt.plot(x, loss_p, label='model parallel') 141 | plt.plot(x, loss_dp, label='data parallel') 142 | 143 | plt.xlabel('epoch') 144 | plt.ylabel('loss') 145 | plt.title(f'{"train set" if training else "test set"} Cross-Entropy Loss Comparison') 146 | plt.legend() 147 | 148 | plt.subplot(1, 2, 2) 149 | 150 | plt.plot(x, acc_s, label='single device') 151 | plt.plot(x, acc_p, label='model parallel') 152 | plt.plot(x, acc_dp, label='data parallel') 153 | 154 | plt.xlabel('epoch') 155 | plt.ylabel('accuracy') 156 | plt.title(f'{"train set" if training else "test set"} Accuracy') 157 | plt.legend() 158 | 159 | plt.show() 160 | 161 | 162 | def get_accuracy(outputs, labels): 163 | _, predictions = torch.max(outputs.data, 1) 164 | 165 | return (predictions == labels).sum().item() / labels.size(0) 166 | 167 | 168 | if __name__ == '__main__': 169 | parser = ExpParser(description='Run the speedup experiment.') 170 | args = parser.parse_args() 171 | loss_exp(**args) -------------------------------------------------------------------------------- /experiments/my_loss_exp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import time 4 | import numpy as np 5 | # add our code to the path so we could import it 6 | import sys 7 | import os 8 | sys.path.append(os.path.join(os.getcwd(), '..')) 9 | 10 | from pytorch_Gpipe import pipe_model 11 | import torch.nn.functional as F 12 | from torch.optim import SGD 13 | from torch.utils.data import DataLoader 14 | import torchvision 15 | from torchvision import transforms as transforms 16 | import argparse 17 | from sample_models import alexnet, resnet101, vgg19_bn, squeezenet1_1, densenet201, \ 18 | WideResNet, AmoebaNet_D as amoebanet, amoebanetd as torchgpipe_amoebanet, torchgpipe_resnet101 19 | import platform 20 | 21 | MODELS = { 22 | "alexnet": alexnet, 23 | "resnet101": resnet101, 24 | "vgg19_bn": vgg19_bn, 25 | "squeezenet1_1": squeezenet1_1, 26 | "densenet201": densenet201, 27 | "WideResNet": WideResNet, 28 | "amoebanet": amoebanet, 29 | "torchgpipe_amoebanet": torchgpipe_amoebanet, 30 | "torchgpipe_resnet101": torchgpipe_resnet101 31 | } 32 | 33 | # setups 34 | 35 | 36 | def single_gpu(model_class: nn.Module, devices, *model_args, **model_kwargs): 37 | model = model_class(*model_args, **model_kwargs).to(devices[0]) 38 | used_devices = [devices[0]] 39 | 40 | return model, used_devices 41 | 42 | 43 | def pipeLine(model_class: nn.Module, devices, microbatch_size, pipe_sample, model_args, model_kwargs, pipeline_args, 44 | pipeline_kwargs): 45 | net = model_class(*model_args, **model_kwargs).to(devices[0]) 46 | net(pipe_sample) 47 | torch.cuda.synchronize() 48 | 49 | piped = pipe_model(net, microbatch_size, pipe_sample, *pipeline_args, 50 | devices=devices, **pipeline_kwargs, return_graph=True) 51 | return piped, piped.module_devices 52 | 53 | 54 | def dataParallel(model_class: nn.Module, devices, *model_args, **model_kwargs): 55 | return nn.DataParallel(model_class(*model_args, **model_kwargs), devices).to(devices[0]), devices 56 | 57 | 58 | SETUPS = { 59 | "single_gpu": single_gpu, 60 | "pipeLine": pipeLine, 61 | "dataParallel": dataParallel 62 | } 63 | 64 | # the experiment itself 65 | img_size = (224, 224) 66 | 67 | 68 | def create_dataloaders(batch_train, batch_test): 69 | """ 70 | Assumes the following folder structures: 71 | cats_dogs/cat/ 72 | cats_dogs/cat/ 73 | 74 | ... 75 | 76 | cats_dogs/dog/ 77 | cats_dogs/dog/ 78 | """ 79 | dataset_dir = "cats_dogs" 80 | tfms = transforms.Compose([ 81 | transforms.Resize(img_size), 82 | transforms.RandomHorizontalFlip(), 83 | transforms.RandomRotation(15), 84 | transforms.ToTensor(), 85 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 86 | ]) 87 | 88 | dataset = torchvision.datasets.ImageFolder(root=dataset_dir, transform=tfms) 89 | train_set, test_set = torch.utils.data.random_split(dataset, [20000, 5000]) 90 | 91 | train_loader = DataLoader(train_set, batch_size=batch_train, shuffle=True, num_workers=2) 92 | test_loader = DataLoader(test_set, batch_size=batch_test, shuffle=False, num_workers=2) 93 | 94 | return train_loader, test_loader 95 | 96 | 97 | def loss_exp(config): 98 | assert torch.cuda.is_available(), "gpus are required" 99 | devices = [torch.device('cuda', i) 100 | for i in range(torch.cuda.device_count())] 101 | 102 | setup = config['setup'] 103 | model_class = config['model'] 104 | model_args = config['model_args'] 105 | model_kwargs = config['model_kwargs'] 106 | model_kwargs['num_classes'] = 2 107 | batch_size = config['batch_size'] 108 | pipeLine_kwargs = config['pipeLine_kwargs'] 109 | pipeLine_args = config['pipeLine_args'] 110 | epochs = config['epochs'] 111 | batch_shape = config['batch_shape'] 112 | microbatch_size = config['microbatch_size'] 113 | profile_sample_size = config['profile_sample_size'] 114 | if microbatch_size is None: 115 | microbatch_size = batch_size // len(devices) 116 | 117 | if setup is pipeLine: 118 | assert len(devices) > 1, "automatic partitioning does not work for 1 gpu" 119 | pipe_sample = torch.randn( 120 | (profile_sample_size,) + batch_shape[1:]).to(devices[0]) 121 | model, used_devices = pipeLine(model_class, devices, microbatch_size, pipe_sample, model_args, 122 | model_kwargs, pipeLine_args, pipeLine_kwargs) 123 | else: 124 | model, used_devices = setup( 125 | model_class, devices, *model_args, **model_kwargs) 126 | 127 | used_devices = list(used_devices) 128 | in_device = used_devices[0] 129 | out_device = used_devices[-1] 130 | batch_size = batch_shape[0] 131 | optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9) 132 | 133 | train_loader, test_loader = create_dataloaders(batch_size, batch_size) 134 | 135 | throughputs = [] 136 | elapsed_times = [] 137 | memory_consumptions = {device: [] for device in used_devices} 138 | losses = [] 139 | accuracies = [] 140 | 141 | # run one epoch and gather statistics 142 | def run_epoch(epoch): 143 | torch.cuda.synchronize(in_device) 144 | tick = time.time() 145 | model.train() 146 | 147 | data_trained = 0 148 | steps = len(train_loader) 149 | for i, data in enumerate(train_loader): 150 | batch, target = data 151 | data_trained += batch.shape[0] 152 | 153 | output = model(batch.to(in_device)) 154 | gt = target.to(output.device) 155 | loss = F.cross_entropy(output, gt) 156 | loss.backward() 157 | 158 | optimizer.step() 159 | optimizer.zero_grad() 160 | 161 | # estimate statistics after each batch 162 | percent = i / steps * 100 163 | throughput = data_trained / (time.time() - tick) 164 | 165 | # print every 100 steps 166 | if i % 100 == 0 or i == steps - 1: 167 | print( 168 | f"{epoch+1}/{epochs} epochs ({percent:.2f}%%) | {throughput:.2f} samples/sec (estimated)") 169 | 170 | torch.cuda.synchronize(in_device) 171 | tock = time.time() 172 | 173 | # calculate exact statistics after epoch is finished 174 | test_loss, test_accuracy = test(model, test_loader, used_devices[0]) 175 | losses.append(test_loss) 176 | accuracies.append(test_accuracy) 177 | 178 | elapsed_time = tock - tick 179 | throughput = batch_size * steps / elapsed_time 180 | print( 181 | f"{epoch+1}/{epochs} epochs | {throughput:.2f} samples/sec, {elapsed_time:.2f} sec/epoch, test_accuracy {test_accuracy:.2f}, test loss {test_loss:.2f}") 182 | 183 | return throughput, elapsed_time 184 | 185 | exp = setup.__name__ 186 | title = f'loss experiment\n config: {exp}\n used_gpus: {len(used_devices)}\n epochs: {epochs}\n' 187 | print(title) 188 | 189 | gpus = [torch.cuda.get_device_name(device) for device in used_devices] 190 | print('system information\n python: %s, torch: %s, cudnn: %s, cuda: %s, \ngpus: %s' % ( 191 | platform.python_version(), 192 | torch.__version__, 193 | torch.backends.cudnn.version(), 194 | torch.version.cuda, 195 | gpus)) 196 | print("\n") 197 | for epoch in range(epochs): 198 | for d in used_devices: 199 | torch.cuda.reset_max_memory_allocated(device=d) 200 | 201 | throughput, elapsed_time = run_epoch(epoch) 202 | # first epoch is used as a warmup 203 | if epoch < 1: 204 | continue 205 | 206 | for d in used_devices: 207 | memory_consumptions[d].append( 208 | torch.cuda.max_memory_allocated(device=d)) 209 | 210 | throughputs.append(throughput) 211 | elapsed_times.append(elapsed_time) 212 | 213 | print("\n") 214 | n = len(throughputs) 215 | throughput = sum(throughputs) / n 216 | elapsed_time = sum(elapsed_times) / n 217 | 218 | avg_mem_per_epoch = {device: sum( 219 | memory_consumptions[device]) / n for device in used_devices} 220 | # Just use 'min' instead of 'max' for minimum. 221 | maximum = max(avg_mem_per_epoch, key=avg_mem_per_epoch.get) 222 | 223 | print( 224 | f'{title} {throughput:.2f} samples/sec\n{elapsed_time:.2f} sec/epoch (average)\n max memory consumption on device: {maximum} {(avg_mem_per_epoch[maximum]/1e6):.2f} MB') 225 | print(f"final loss {losses[-1]:.2f} final accuracy {accuracies[-1]:.2f}") 226 | 227 | 228 | def test(model, test_loader, input_device): 229 | criterion = F.cross_entropy 230 | 231 | model.train(mode=False) 232 | losses, accuracies = [], [] 233 | 234 | with torch.no_grad(): 235 | for batch in test_loader: 236 | inputs, labels = batch 237 | outputs = model(inputs.to(input_device)) 238 | labels = labels.to(outputs.device) 239 | 240 | loss = criterion(outputs, labels) 241 | 242 | losses.append(loss.item()) 243 | accuracies.append(get_accuracy(outputs, labels)) 244 | 245 | return np.mean(losses), np.mean(accuracies) 246 | 247 | 248 | def get_accuracy(outputs, labels): 249 | _, predictions = torch.max(outputs.data, 1) 250 | 251 | return (predictions == labels).sum().item() / labels.size(0) 252 | 253 | 254 | class StoreDict(argparse.Action): 255 | def __call__(self, parser, namespace, values, option_string=None): 256 | kv = {} 257 | if not isinstance(values, (list,)): 258 | values = (values,) 259 | for value in values: 260 | n, v = value.split('=') 261 | kv[n] = v 262 | setattr(namespace, self.dest, kv) 263 | 264 | 265 | class ExpParser(argparse.ArgumentParser): 266 | def __init__(self, *args, **kwargs): 267 | super().__init__(*args, **kwargs) 268 | 269 | self.add_argument('--setup', '-s', help='The way to run the model.', 270 | choices=SETUPS.keys(), 271 | required=True, dest='setup') 272 | self.add_argument('--model', '-m', help='The model we want to run the experiment on.', choices=MODELS.keys(), 273 | required=True, dest='model') 274 | self.add_argument('--model_args', '-ma', help='additional args passed to the pipeline', nargs='*', 275 | default=[]) 276 | self.add_argument('--model_kwargs', '-mkw', help='additional kwargs passed to the model', nargs='*', action=StoreDict, 277 | default={}) 278 | self.add_argument('--batch_size', '-b', help='batch size used in the experiment defaults to 64', 279 | type=int, dest='batch_size', default=64) 280 | self.add_argument('--pipeLine_args', '-pa', help='additional args passed to the pipeline', nargs='*', 281 | default=[]) 282 | self.add_argument('--pipeLine_kwargs', '-pkw', help='additional kwargs passed to the pipeline', nargs='*', 283 | action=StoreDict, default={}) 284 | self.add_argument('--epochs', '-e', help="the number of training epochs used,\nthe first epoch is a warmup and will not effect results", 285 | type=int, default=10) 286 | self.add_argument('--microbatch_size', '-mb', help="micro batch size of the pipeline\nif not given defaults to batch_size/num_devices", 287 | type=int, default=None) 288 | self.add_argument('--profile_sample_size', '-ps', help='size of batch used to partition the model if testing the pipeline', 289 | type=int, default=16) 290 | 291 | def parse_args(self, *args, **kwargs): 292 | res = vars(super().parse_args(*args, **kwargs)) 293 | res['model'] = MODELS[res['model']] 294 | res['setup'] = SETUPS[res['setup']] 295 | net = res['model'] 296 | 297 | if net.__name__.find("inception") != -1: 298 | sample_shape = (3, 299, 299) 299 | elif net.__name__.find("GoogLeNet") != -1: 300 | sample_shape = (3, 32, 32) 301 | elif net.__name__.find("LeNet") != -1: 302 | sample_shape = (3, 32, 32) 303 | else: 304 | sample_shape = (3, 224, 224) 305 | 306 | res['batch_shape'] = (res['batch_size'],) + sample_shape 307 | 308 | return res 309 | 310 | 311 | if __name__ == "__main__": 312 | # TODO all var/kw args are treated as strings 313 | # throuput_exp.py - s single_gpu - m amoebanet - -model_args 10 - -model_kwargs kw0 = 1 kw1 = hello 314 | 315 | parser = ExpParser() 316 | args = parser.parse_args() 317 | 318 | loss_exp(args) 319 | -------------------------------------------------------------------------------- /experiments/speedup_exp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from typing import Tuple 5 | import timeit 6 | from sample_models import * 7 | from .utils import * 8 | 9 | plt.switch_backend('Agg') 10 | 11 | 12 | def exp_model_time(run_type, model_class, num_classes, batch_shape: Tuple[int, ...], num_repeats, num_warmups, 13 | model_params: dict, tests_config: dict, pipeline_params: dict = None, num_devices=None): 14 | 15 | tests_config['num_classes'] = num_classes 16 | tests_config['batch_shape'] = batch_shape 17 | 18 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 19 | 20 | if run_type in ['S', 'Single']: 21 | tests_config['model'] = model_class(**model_params).to(device) 22 | 23 | print('Single GPU:') 24 | 25 | elif run_type in ['P', 'Pipeline-Parallel']: 26 | pipeline_params['devices'] = list(range(num_devices)) 27 | model = model_class(**model_params).to(device) 28 | tests_config['model'] = create_pipeline(model, batch_shape, **pipeline_params) 29 | 30 | print('Pipeline-Parallel:') 31 | 32 | elif run_type in ['D', 'Data-Parallel']: 33 | devices_ids = list(range(num_devices)) 34 | model = model_class(**model_params).to(device) 35 | tests_config['model'] = nn.DataParallel(model, device_ids=devices_ids).to(device) 36 | 37 | print('Data-Parallel:') 38 | 39 | else: 40 | raise ValueError('Not a valid run type') 41 | 42 | run_times, mem_uses = track_train(num_repeats + num_warmups, **tests_config) 43 | run_times, mem_uses = run_times[num_warmups:], mem_uses[num_warmups:] 44 | rt_mean, rt_std = np.mean(run_times), np.std(run_times) 45 | max_mem = np.mean(mem_uses) 46 | 47 | print(f'\trun time mean - {rt_mean}') 48 | print(f'\trun time std - {rt_std}') 49 | print(f'\tmax memory usage - {max_mem}') 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = ExpParser(uses_dataset=False, description='Run the speedup experiment.') 54 | args = parser.parse_args() 55 | exp_model_time(**args) 56 | -------------------------------------------------------------------------------- /experiments/split_size_exp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from sample_models import * 5 | from typing import Tuple 6 | import timeit 7 | from .utils import * 8 | 9 | 10 | def exp_split_size(model_class, num_devices: int, num_classes: int, batch_shape: Tuple[int, ...], model_params: dict, 11 | tests_config: dict, pipeline_params: dict): 12 | num_repeat = 10 13 | 14 | tests_config['num_classes'] = num_classes 15 | tests_config['batch_shape'] = batch_shape 16 | 17 | pipeline_params['devices'] = list(range(num_devices)) 18 | 19 | stmt = call_func_stmt(train, 'model', **tests_config) 20 | 21 | model_init_stmt = call_func_stmt(model_class, **model_params) 22 | 23 | means = [] 24 | stds = [] 25 | split_sizes = [1, 3, 5, 8, 10, 12, 20, 40, 60] 26 | 27 | for split_size in split_sizes: 28 | pipeline_params['microbatch_size'] = split_size 29 | setup = f"model = {call_func_stmt(create_pipeline, model_init_stmt, batch_shape, **pipeline_params)}" 30 | pp_run_times = timeit.repeat( 31 | stmt, setup, number=1, repeat=num_repeat, globals=globals()) 32 | means.append(np.mean(pp_run_times)) 33 | stds.append(np.std(pp_run_times)) 34 | print( 35 | f'Split size {split_size} has a mean execution time of {means[-1]} with standard deviation of {stds[-1]}') 36 | 37 | fig, ax = plt.subplots() 38 | ax.plot(split_sizes, means) 39 | ax.errorbar(split_sizes, means, yerr=stds, ecolor='red', fmt='ro') 40 | ax.set_ylabel('ResNet50 Execution Time (Second)') 41 | ax.set_xlabel('Pipeline Split Size') 42 | ax.set_xticks(split_sizes) 43 | ax.yaxis.grid(True) 44 | plt.tight_layout() 45 | plt.savefig("split_size_tradeoff.png") 46 | plt.close(fig) 47 | 48 | 49 | if __name__ == '__main__': 50 | parser = ExpParser(uses_dataset=False, description='Run the speedup experiment.') 51 | args = parser.parse_args() 52 | exp_split_size(**args) 53 | -------------------------------------------------------------------------------- /experiments/throughput_exp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import time 4 | 5 | # add our code to the path so we could import it 6 | import sys 7 | import os 8 | sys.path.append(os.path.join(os.getcwd(), '..')) 9 | 10 | from pytorch_Gpipe import pipe_model 11 | import torch.nn.functional as F 12 | from torch.optim import SGD 13 | import argparse 14 | from sample_models import alexnet, resnet101, vgg19_bn, squeezenet1_1, inception_v3, densenet201, GoogLeNet, LeNet, \ 15 | WideResNet, AmoebaNet_D as amoebanet, amoebanetd as torchgpipe_amoebanet, torchgpipe_resnet101 16 | import platform 17 | 18 | MODELS = { 19 | "alexnet": alexnet, 20 | "resnet101": resnet101, 21 | "vgg19_bn": vgg19_bn, 22 | "squeezenet1_1": squeezenet1_1, 23 | "inception_v3": inception_v3, 24 | "densenet201": densenet201, 25 | "GoogLeNet": GoogLeNet, 26 | "LeNet": LeNet, 27 | "WideResNet": WideResNet, 28 | "amoebanet": amoebanet, 29 | "torchgpipe_amoebanet": torchgpipe_amoebanet, 30 | "torchgpipe_resnet101": torchgpipe_resnet101 31 | } 32 | 33 | # setups 34 | 35 | 36 | def single_gpu(model_class: nn.Module, devices, *model_args, **model_kwargs): 37 | return model_class(*model_args, **model_kwargs).to(devices[0]), [devices[0]] 38 | 39 | 40 | def pipeLine(model_class: nn.Module, devices, pipe_sample, model_args, model_kwargs, pipeline_args, pipeline_kwargs): 41 | net = model_class(*model_args, **model_kwargs).to(devices[0]) 42 | net(pipe_sample) 43 | torch.cuda.synchronize() 44 | 45 | piped = pipe_model(net, pipe_sample.shape[0], pipe_sample, *pipeline_args, 46 | devices=devices, **pipeline_kwargs) 47 | # TODO assumes first is input_device and last is output device 48 | return piped, list(piped.module_devices) 49 | 50 | 51 | def dataParallel(model_class: nn.Module, devices, *model_args, **model_kwargs): 52 | return nn.DataParallel(model_class(*model_args, **model_kwargs), devices).to(devices[0]), devices 53 | 54 | 55 | SETUPS = { 56 | "single_gpu": single_gpu, 57 | "pipeLine": pipeLine, 58 | "dataParallel": dataParallel 59 | } 60 | 61 | # the experiment itself 62 | 63 | 64 | def throughput_exp(config): 65 | assert torch.cuda.is_available(), "gpus are required" 66 | devices = [torch.device('cuda', i) 67 | for i in range(torch.cuda.device_count())] 68 | 69 | setup = config['setup'] 70 | model_class = config['model'] 71 | model_args = config['model_args'] 72 | model_kwargs = config['model_kwargs'] 73 | batch_size = config['batch_size'] 74 | pipeLine_kwargs = config['pipeLine_kwargs'] 75 | pipeLine_args = config['pipeLine_args'] 76 | epochs = config['epochs'] 77 | batch_shape = config['batch_shape'] 78 | microbatch_size = config['microbatch_size'] 79 | profile_sample_size = config['profile_sample_size'] 80 | if microbatch_size is None: 81 | microbatch_size = batch_size // len(devices) 82 | 83 | if setup is pipeLine: 84 | assert len(devices) > 1, "automatic partitioning does not work for 1 gpu" 85 | pipe_sample = torch.randn( 86 | (profile_sample_size,) + batch_shape[1:]).to(devices[0]) 87 | model, used_devices = pipeLine(model_class, devices, pipe_sample, model_args, 88 | model_kwargs, pipeLine_args, pipeLine_kwargs) 89 | else: 90 | model, used_devices = setup( 91 | model_class, devices, *model_args, **model_kwargs) 92 | 93 | in_device = used_devices[0] 94 | out_device = used_devices[-1] 95 | batch_size = batch_shape[0] 96 | optimizer = SGD(model.parameters(), lr=0.1) 97 | # This experiment cares about only training performance, rather than 98 | # accuracy. To eliminate any overhead due to data loading, we use fake data 99 | batch = torch.randn(batch_shape, device=in_device) 100 | target = torch.randint(10, (batch_size,)) 101 | 102 | throughputs = [] 103 | elapsed_times = [] 104 | memory_consumptions = {device: [] for device in used_devices} 105 | 106 | # run one epoch and gather statistics 107 | def run_epoch(epoch): 108 | torch.cuda.synchronize(in_device) 109 | tick = time.time() 110 | 111 | data_trained = 0 112 | steps = 50000 // batch_size 113 | for i in range(steps): 114 | data_trained += batch_size 115 | 116 | output = model(batch) 117 | gt = target.to(output.device) 118 | loss = F.cross_entropy(output, gt) 119 | loss.backward() 120 | 121 | optimizer.step() 122 | optimizer.zero_grad() 123 | 124 | # estimate statistics after each batch 125 | percent = i / steps * 100 126 | throughput = data_trained / (time.time() - tick) 127 | 128 | # print every 100 steps 129 | if i % 100 == 0 or i == steps - 1: 130 | print( 131 | f"{epoch+1}/{epochs} epochs ({percent:.2f}%%) | {throughput:.2f} samples/sec (estimated)") 132 | 133 | torch.cuda.synchronize(in_device) 134 | tock = time.time() 135 | 136 | # calculate exact statistics after epoch is finished 137 | 138 | elapsed_time = tock - tick 139 | throughput = batch_size * steps / elapsed_time 140 | print( 141 | f"{epoch+1}/{epochs} epochs | {throughput:.2f} samples/sec, {elapsed_time:.2f} sec/epoch") 142 | 143 | return throughput, elapsed_time 144 | 145 | exp = setup.__name__ 146 | title = f'throughput experiment\n config: {exp}\n used_gpus: {len(used_devices)}\n epochs: {epochs}\n' 147 | print(title) 148 | 149 | gpus = [torch.cuda.get_device_name(device) for device in used_devices] 150 | print('system information\n python: %s, torch: %s, cudnn: %s, cuda: %s, \ngpus: %s' % ( 151 | platform.python_version(), 152 | torch.__version__, 153 | torch.backends.cudnn.version(), 154 | torch.version.cuda, 155 | gpus)) 156 | print("\n") 157 | for epoch in range(epochs): 158 | for d in used_devices: 159 | torch.cuda.reset_max_memory_allocated(device=d) 160 | 161 | throughput, elapsed_time = run_epoch(epoch) 162 | # first epoch is used as a warmup 163 | if epoch < 1: 164 | continue 165 | 166 | for d in used_devices: 167 | memory_consumptions[d].append( 168 | torch.cuda.max_memory_allocated(device=d)) 169 | 170 | throughputs.append(throughput) 171 | elapsed_times.append(elapsed_time) 172 | 173 | print("\n") 174 | n = len(throughputs) 175 | throughput = sum(throughputs) / n 176 | elapsed_time = sum(elapsed_times) / n 177 | 178 | avg_mem_per_epoch = {device: sum( 179 | memory_consumptions[device]) / n for device in used_devices} 180 | # Just use 'min' instead of 'max' for minimum. 181 | maximum = max(avg_mem_per_epoch, key=avg_mem_per_epoch.get) 182 | 183 | print( 184 | f'{title} {throughput:.2f} samples/sec\n{elapsed_time:.2f} sec/epoch (average)\n max memory consumption on device: {maximum} {(avg_mem_per_epoch[maximum]/1e6):.2f} MB') 185 | 186 | 187 | class StoreDict(argparse.Action): 188 | def __call__(self, parser, namespace, values, option_string=None): 189 | kv = {} 190 | if not isinstance(values, (list,)): 191 | values = (values,) 192 | for value in values: 193 | n, v = value.split('=') 194 | kv[n] = v 195 | setattr(namespace, self.dest, kv) 196 | 197 | 198 | class ExpParser(argparse.ArgumentParser): 199 | def __init__(self, *args, **kwargs): 200 | super().__init__(*args, **kwargs) 201 | 202 | self.add_argument('--setup', '-s', help='The way to run the model.', 203 | choices=SETUPS.keys(), 204 | required=True, dest='setup') 205 | self.add_argument('--model', '-m', help='The model we want to run the experiment on.', choices=MODELS.keys(), 206 | required=True, dest='model') 207 | self.add_argument('--model_args', '-ma', help='additional args passed to the pipeline', nargs='*', 208 | default=[]) 209 | self.add_argument('--model_kwargs', '-mkw', help='additional kwargs passed to the model', nargs='*', action=StoreDict, 210 | default={}) 211 | self.add_argument('--batch_size', '-b', help='batch size used in the experiment defaults to 64', 212 | type=int, dest='batch_size', default=64) 213 | self.add_argument('--pipeLine_args', '-pa', help='additional args passed to the pipeline', nargs='*', 214 | default=[]) 215 | self.add_argument('--pipeLine_kwargs', '-pkw', help='additional kwargs passed to the pipeline', nargs='*', 216 | action=StoreDict, default={}) 217 | self.add_argument('--epochs', '-e', help="the number of training epochs used,\nthe first epoch is a warmup and will not effect results", 218 | type=int, default=10) 219 | self.add_argument('--microbatch_size', '-mb', help="micro batch size of the pipeline\nif not given defaults to batch_size/num_devices", 220 | type=int, default=None) 221 | self.add_argument('--profile_sample_size', '-ps', help='size of batch used to partition the model if testing the pipeline', 222 | type=int, default=16) 223 | 224 | def parse_args(self, *args, **kwargs): 225 | res = vars(super().parse_args(*args, **kwargs)) 226 | res['model'] = MODELS[res['model']] 227 | res['setup'] = SETUPS[res['setup']] 228 | net = res['model'] 229 | 230 | if net.__name__.find("inception") != -1: 231 | sample_shape = (3, 299, 299) 232 | elif net.__name__.find("GoogLeNet") != -1: 233 | sample_shape = (3, 32, 32) 234 | elif net.__name__.find("LeNet") != -1: 235 | sample_shape = (3, 32, 32) 236 | else: 237 | sample_shape = (3, 224, 224) 238 | 239 | res['batch_shape'] = (res['batch_size'],) + sample_shape 240 | 241 | return res 242 | 243 | 244 | if __name__ == "__main__": 245 | # TODO all var/kw args are treated as strings the following are all strings 246 | # throuput_exp.py - s single_gpu - m amoebanet - -model_args 10 - -model_kwargs kw0 = 1 kw1 = hello 247 | 248 | # run as follows from the experiments folder for eg. 249 | # will check throuput, execution time and memory consumption of amoebanet 250 | # accross 10 training epochs each with 50000//batch_size batches 251 | # the first epoch is a warmpu and will not affect results 252 | # with batch size 64 253 | # the partition will be done using a batch size of 32 254 | # python throuput_exp.py --epochs 10 --setup pipeLine - m amoebanet --microbatch_size 64 --profile_sample_size 32 255 | 256 | parser = ExpParser() 257 | args = parser.parse_args() 258 | 259 | throughput_exp(args) 260 | -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | import torch 3 | import torch.nn as nn 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from pytorch_Gpipe import pipe_model 7 | import argparse 8 | import sample_models 9 | import torchvision 10 | import sys 11 | 12 | 13 | def kwargs_string(*pos_strings, **kwargs): 14 | return ', '.join(list(pos_strings) + [f'{key}={val}' for key, val in kwargs.items()]) 15 | 16 | 17 | def reset_mex_memory_allocated(): 18 | for i in range(torch.cuda.device_count()): 19 | torch.cuda.reset_max_memory_allocated(i) 20 | 21 | 22 | def get_max_memory_allocated(): 23 | max_mem = -1 24 | 25 | for i in range(torch.cuda.device_count()): 26 | mem_alloc = torch.cuda.max_memory_allocated(i) 27 | max_mem = max(max_mem, mem_alloc) 28 | 29 | return max_mem 30 | 31 | 32 | def call_func_stmt(func, *params, **kwargs): 33 | if isinstance(func, str): 34 | func_name = func 35 | else: 36 | func_name = func.__name__ 37 | 38 | params_str = [str(param) for param in params] 39 | 40 | return f'{func_name}({kwargs_string(*params_str, **kwargs)})' 41 | 42 | 43 | def track_train(num_repeats, model, num_classes, num_batches, batch_shape): 44 | num_batches = int(num_batches) 45 | 46 | run_times = [] 47 | mem_uses = [] 48 | for _ in range(num_repeats): 49 | reset_mex_memory_allocated() 50 | start = torch.cuda.Event(enable_timing=True) 51 | end = torch.cuda.Event(enable_timing=True) 52 | 53 | start.record() 54 | train(model, num_classes, num_batches, batch_shape) 55 | end.record() 56 | 57 | torch.cuda.synchronize() 58 | run_times.append(start.elapsed_time(end)) 59 | mem_uses.append(get_max_memory_allocated()) 60 | print('.', end='') 61 | 62 | return run_times, mem_uses 63 | 64 | 65 | def train(model, num_classes, num_batches, batch_shape): 66 | model.train(True) 67 | loss_fn = nn.MSELoss() 68 | optimizer = optim.SGD(model.parameters(), lr=0.001) 69 | 70 | batch_size = batch_shape[0] 71 | 72 | one_hot_indices = torch.LongTensor(batch_size).random_( 73 | 0, num_classes).view(batch_size, 1) 74 | 75 | dev = 'cuda:0' if torch.cuda.is_available() else 'cpu' 76 | 77 | for b in range(num_batches): 78 | # generate random inputs and labels 79 | inputs = torch.randn(*batch_shape) 80 | labels = torch.zeros(batch_size, num_classes).scatter_( 81 | 1, one_hot_indices, 1) 82 | 83 | # run forward pass 84 | optimizer.zero_grad() 85 | outputs = model(inputs.to(dev)) 86 | 87 | # run backward pass 88 | labels = labels.to(outputs.device) 89 | loss_fn(outputs, labels).backward() 90 | optimizer.step() 91 | 92 | 93 | def plot(means, stds, labels, fig_name, fig_label): 94 | fig, ax = plt.subplots() 95 | ax.bar(np.arange(len(means)), means, yerr=stds, 96 | align='center', alpha=0.5, ecolor='red', capsize=10, width=0.6) 97 | ax.set_ylabel(fig_label) 98 | ax.set_xticks(np.arange(len(means))) 99 | ax.set_xticklabels(labels) 100 | ax.yaxis.grid(True) 101 | plt.tight_layout() 102 | plt.savefig(fig_name) 103 | plt.close(fig) 104 | 105 | 106 | def create_pipeline(model, batch_shape, microbatch_size, **kwargs): 107 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 108 | microbatch_size = int(microbatch_size) 109 | return pipe_model(model.to(device), microbatch_size, torch.randn(*batch_shape, device=device), **kwargs) 110 | 111 | 112 | class StoreDict(argparse.Action): 113 | def __call__(self, parser, namespace, values, option_string=None): 114 | kv = {} 115 | if not isinstance(values, (list,)): 116 | values = (values,) 117 | for value in values: 118 | n, v = value.split('=') 119 | kv[n] = v 120 | setattr(namespace, self.dest, kv) 121 | 122 | 123 | class ExpParser(argparse.ArgumentParser): 124 | def __init__(self, *args, uses_dataset=True, **kwargs): 125 | super().__init__(*args, **kwargs) 126 | self.uses_dataset = uses_dataset 127 | 128 | models = [ 129 | 'AlexNet', 'alexnet', 'DenseNet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'GoogLeNet', 130 | 'Inception3', 'inception_v3', 'LeNet', 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 131 | 'resnet152', 'SqueezeNet', 'squeezenet1_0', 'squeezenet1_1', 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 132 | 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'WideResNet', 'AmoebaNet_D', 'amoebanetd', 133 | 'resnet101', 'torchgpipe_resnet101' 134 | ] 135 | 136 | self.add_argument('--run_type', '-r', help='The way to run the model.', 137 | choices=['Single', 'Data-Parallel', 'Pipeline-Parallel', 'S', 'D', 'P'], 138 | required=True, dest='run_type') 139 | self.add_argument('--model', '-m', help='The model we want to run the experiment on.', choices=models, 140 | required=True, dest='model_class') 141 | self.add_argument('--classes', '-c', help='The number of classes in the prediction problem.', type=int, 142 | required=True, dest='num_classes') 143 | self.add_argument('--repeats', '-n', help='amount of times to repeat the experiments.', type=int, 144 | default=10, dest='num_repeats') 145 | self.add_argument('--warmups', '-w', help='amount of times to run the experiments before tracking results.', 146 | type=int, default=1, dest='num_warmups') 147 | self.add_argument('--model_params', help='The parameters for the model', nargs='*', action=StoreDict, 148 | default={}) 149 | self.add_argument('--devices', '-d', help='The number of devices to use in the experiment.', type=int, 150 | dest='num_devices') 151 | self.add_argument('--pipeline_params', help='Parameters for the pipeline itself other then devices', nargs='*', 152 | action=StoreDict, default={}) 153 | 154 | if uses_dataset: 155 | self.add_argument('--dataset', '-s', choices=list(torchvision.datasets.__all__), required=True) 156 | self.add_argument('--ds_root', '-r', type=str, required=True) 157 | else: 158 | self.add_argument('--batch_shape', '-s', help='The shape of one batch.', nargs='*', type=int, required=True) 159 | self.add_argument('--tests_config', help='Any other config kwargs for the test', nargs='*', 160 | action=StoreDict, default={}) 161 | 162 | def parse_args(self, *args, **kwargs): 163 | res = vars(super().parse_args(*args, **kwargs)) 164 | 165 | res['model_params']['num_classes'] = res['num_classes'] 166 | 167 | res['model_class'] = getattr(sys.modules['sample_models'], res['model_class']) 168 | 169 | if self.uses_dataset: 170 | ds_class = getattr(sys.modules['torchvision.datasets'], res['dataset']) 171 | res['dataset'] = ds_class(res['ds_root']) 172 | 173 | return res 174 | -------------------------------------------------------------------------------- /misc/networkx_metis__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (C) 2015 ysitu 4 | # All rights reserved 5 | 6 | """ 7 | Wrappers of METIS graph partitioning functions. 8 | """ 9 | 10 | import contextlib 11 | import decorator 12 | import itertools 13 | import sys 14 | 15 | import networkx as nx 16 | import six 17 | 18 | from nxmetis import enums 19 | from nxmetis import exceptions 20 | from nxmetis import metis 21 | from nxmetis import types 22 | 23 | __all__ = ['node_nested_dissection', 'partition', 'vertex_separator', 24 | 'MetisOptions'] 25 | 26 | MetisOptions = types.MetisOptions 27 | 28 | 29 | @contextlib.contextmanager 30 | def _zero_numbering(options): 31 | """Temporarily force zero-based numbering.""" 32 | if options: 33 | numbering = options.numbering 34 | options.numbering = enums.MetisNumbering.zero 35 | try: 36 | yield 37 | finally: 38 | if options: 39 | options.numbering = numbering 40 | 41 | 42 | def _convert_graph(G): 43 | """Convert a graph to the numbered adjacency list structure expected by 44 | METIS. 45 | """ 46 | index = dict(zip(G, list(range(len(G))))) 47 | xadj = [0] 48 | adjncy = [] 49 | for u in G: 50 | adjncy.extend(index[v] for v in G[u]) 51 | xadj.append(len(adjncy)) 52 | return xadj, adjncy 53 | 54 | 55 | def _convert_exceptions(convert_type, catch_types=None): 56 | """Decorator to convert types of exceptions 57 | 58 | Parameters 59 | ---------- 60 | convert_type : subclass of Exception 61 | Target type to convert to. 62 | 63 | catch_types : tuple of subclasses of Exception, optional 64 | Source types whose instances are to be caught and converted. If None, 65 | all instances of Exception will be caught and converted. 66 | 67 | Returns 68 | ------- 69 | _convert_exceptions : function 70 | Function that performs exception type conversion. 71 | 72 | Example 73 | ------- 74 | Decorate functions like this:: 75 | 76 | @_convert_exceptions(nx.NetworkXError, (ValueError,)) 77 | def function(): 78 | pass 79 | """ 80 | @decorator.decorator 81 | def _convert_exceptions(func, *args, **kwargs): 82 | try: 83 | return func(*args, **kwargs) 84 | except catch_types as e: 85 | exc = e 86 | except Exception as e: 87 | if catch_types is not None: 88 | raise 89 | exc = sys.exc_info() 90 | six.reraise(convert_type, convert_type(exc[1]), exc[2]) 91 | return _convert_exceptions 92 | 93 | 94 | @nx.utils.not_implemented_for('directed') 95 | @nx.utils.not_implemented_for('multigraph') 96 | @_convert_exceptions( 97 | nx.NetworkXError, (ValueError, TypeError, exceptions.MetisError)) 98 | def node_nested_dissection(G, weight='weight', options=None): 99 | """Compute a node ordering of a graph that reduces fill when the Laplacian 100 | matrix of the graph is LU factorized. The algorithm aims to minimize the 101 | sum of weights of vertices in separators computed in the process. 102 | 103 | Parameters 104 | ---------- 105 | G : NetworkX graph 106 | A graph. 107 | 108 | weight : object, optional 109 | The data key used to determine the weight of each node. If None, each 110 | node has unit weight. Default value: 'weight'. 111 | 112 | options : MetisOptions, optional 113 | METIS options. If None, the default options are used. Default value: 114 | None. 115 | 116 | Returns 117 | ------- 118 | perm : list of nodes 119 | The node ordering. 120 | 121 | Raises 122 | ------ 123 | NetworkXError 124 | If the parameters cannot be converted to valid METIS input format, or 125 | METIS returns an error status. 126 | """ 127 | if len(G) == 0: 128 | return [] 129 | 130 | vwgt = [G.nodes[u].get(weight, 1) for u in G] 131 | if all(w == 1 for w in vwgt): 132 | vwgt = None 133 | 134 | xadj, adjncy = _convert_graph(G) 135 | 136 | with _zero_numbering(options): 137 | perm = metis.node_nd(xadj, adjncy, vwgt, options)[0] 138 | 139 | nodes = list(G) 140 | perm = [nodes[i] for i in perm] 141 | 142 | return perm 143 | 144 | 145 | @nx.utils.not_implemented_for('directed') 146 | @nx.utils.not_implemented_for('multigraph') 147 | @_convert_exceptions( 148 | nx.NetworkXError, (ValueError, TypeError, exceptions.MetisError)) 149 | def partition(G, nparts, node_weight='weight', node_size='size', 150 | edge_weight='weight', tpwgts=None, ubvec=None, options=None, 151 | recursive=False): 152 | """Partition a graph using multilevel recursive bisection or multilevel 153 | multiway partitioning. 154 | 155 | Parameters 156 | ---------- 157 | G : NetworkX graph 158 | An undirected graph. 159 | 160 | nparts : int 161 | Number of parts to partition the graph. It should be at least 2. 162 | 163 | node_weight : object, optional 164 | The data key used to determine the weight of each node. If None, each 165 | node has unit weight. Default value: 'weight'. 166 | 167 | node_size : object, optional 168 | The data key used to determine the size of each node when computing the 169 | total communication volumne. If None, each node has unit size. Default 170 | value: 'size' 171 | 172 | edge_weight : object, optional 173 | The data key used to determine the weight of each edge. If None, each 174 | edge has unit weight. Default value: 'weight'. 175 | 176 | tpwgts : list of lists of floats, optional 177 | The target weights of the partitions and the constraints. The target 178 | weight of the `i`-th partition and the `j`-th constraint is given by 179 | ``tpwgts[i][j]`` (the numbering for both partitions and constraints 180 | starts from zero). For each constraint the sum of the ``tpwgts[][]`` 181 | entries must be 1.0 (i.e., `\sum_i \\text{tpwgts}[i][j] = 1.0`). If 182 | None, the graph is equally divided among the partitions. Default value: 183 | None. 184 | 185 | ubvec : list of floats, optional 186 | The allowed load imbalance tolerance for each constraint. For the 187 | `i`-th and the `j`-th constraint, the allowed weight is the 188 | ``ubvec[j] * tpwgts[i][j]`` fraction of the `j`-th constraint's total 189 | weight. The load imbalances must be greater than 1.0. If None, the load 190 | imbalance tolerance is 1.001 if there is exactly one constraint or 1.01 191 | if there are more. Default value: None. 192 | 193 | options : MetisOptions, optional. 194 | METIS options. If None, the default options are used. Default value: 195 | None. 196 | 197 | recursive : bool, optional 198 | If True, multilevel recursive bisection is used. Otherwise, multileve 199 | multilevel multiway partitioning is used. Default value: False. 200 | 201 | Returns 202 | ------- 203 | objval : int 204 | The edge-cut or the total communication volume of the partitioning 205 | solution. The value returned depends on the partitioining's objective 206 | function. 207 | 208 | parts : lists of nodes 209 | The partitioning. 210 | 211 | Raises 212 | ------ 213 | NetworkXNotImplemented 214 | If the graph is directed or is a multigraph. 215 | 216 | NetworkXError 217 | If the parameters cannot be converted to valid METIS input format, or 218 | METIS returns an error status. 219 | """ 220 | if nparts < 1: 221 | raise nx.NetworkXError('nparts is less than one.') 222 | if nparts == 1: 223 | return 0, [list(G)] 224 | 225 | if len(G) == 0: 226 | return 0, [[] for i in range(nparts)] 227 | 228 | xadj, adjncy = _convert_graph(G) 229 | 230 | vwgt = [G.nodes[u].get(node_weight, 1) for u in G] 231 | if all(w == 1 for w in vwgt): 232 | vwgt = None 233 | 234 | vsize = [G.nodes[u].get(node_size, 1) for u in G] 235 | if all(w == 1 for w in vsize): 236 | vsize = None 237 | 238 | adjwgt = [G[u][v].get(edge_weight, 1) for u in G for v in G[u]] 239 | if all(w == 1 for w in adjwgt): 240 | adjwgt = None 241 | 242 | if tpwgts is not None: 243 | if len(tpwgts) != nparts: 244 | raise nx.NetworkXError('length of tpwgts is not equal to nparts.') 245 | ncon = len(tpwgts[0]) 246 | if any(len(tpwgts[j]) != ncon for j in range(1, nparts)): 247 | raise nx.NetworkXError( 248 | 'lists in tpwgts are not of the same length.') 249 | if ubvec is not None and len(ubvec) != ncon: 250 | raise nx.NetworkXError( 251 | 'ubvec is not of the same length as tpwgts.') 252 | tpwgts = list(itertools.chain.from_iterable(tpwgts)) 253 | 254 | with _zero_numbering(options): 255 | objval, part = metis.part_graph(xadj, adjncy, nparts, vwgt, vsize, 256 | adjwgt, tpwgts, ubvec, options, 257 | recursive) 258 | 259 | parts = [[] for i in range(nparts)] 260 | for u, i in zip(G, part): 261 | parts[i].append(u) 262 | 263 | return objval, parts 264 | 265 | 266 | @nx.utils.not_implemented_for('directed') 267 | @nx.utils.not_implemented_for('multigraph') 268 | @_convert_exceptions( 269 | nx.NetworkXError, (ValueError, TypeError, exceptions.MetisError)) 270 | def vertex_separator(G, weight='weight', options=None): 271 | """Compute a vertex separator that bisects a graph. The algorithm aims to 272 | minimize the sum of weights of vertices in the separator. 273 | 274 | Parameters 275 | ---------- 276 | G : NetworkX graph 277 | A graph. 278 | 279 | weight : object, optional 280 | The data key used to determine the weight of each node. If None, each 281 | node has unit weight. Default value: 'weight'. 282 | 283 | options : MetisOptions, optional 284 | METIS options. If None, the default options are used. Default value: 285 | None. 286 | 287 | Returns 288 | ------- 289 | sep, part1, part2 : lists of nodes 290 | The separator and the two parts of the bisection represented as lists. 291 | 292 | Raises 293 | ------ 294 | NetworkXError 295 | If the parameters cannot be converted to valid METIS input format, or 296 | METIS returns an error status. 297 | """ 298 | if len(G) == 0: 299 | return [], [], [] 300 | 301 | vwgt = [G.nodes[u].get(weight, 1) for u in G] 302 | if all(w == 1 for w in vwgt): 303 | vwgt = None 304 | 305 | xadj, adjncy = _convert_graph(G) 306 | 307 | with _zero_numbering(options): 308 | part = metis.compute_vertex_separator(xadj, adjncy, vwgt, options)[1] 309 | 310 | groups = [[], [], []] 311 | for u, i in zip(G, part): 312 | groups[i].append(u) 313 | 314 | return groups[2], groups[0], groups[1] 315 | -------------------------------------------------------------------------------- /parse_declarations.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import torch 3 | from torch import Tensor 4 | import numpy as np 5 | import builtins 6 | 7 | ts = set() 8 | 9 | 10 | class Type(): 11 | def __call__(self, other: type): 12 | raise NotImplementedError() 13 | 14 | def __str__(self): 15 | raise NotImplementedError() 16 | 17 | 18 | class UnionType(Type): 19 | def __init__(self, *types: Type): 20 | self.types = types 21 | 22 | def __call__(self, other): 23 | return any(t(other) for t in self.types) 24 | 25 | def __str__(self): 26 | return f"Union[{','.join([str(t) for t in self.types])}]" 27 | 28 | 29 | class BasicType(Type): 30 | def __init__(self, t: type): 31 | self.type = t 32 | 33 | def __call__(self, other): 34 | return isinstance(other, self.type) 35 | 36 | def __str__(self): 37 | return self.type.__name__ 38 | 39 | 40 | class OptionalType(Type): 41 | def __init__(self, _type: Type): 42 | self.type = _type 43 | 44 | def __call__(self, other): 45 | return (other is None) or self.type(other) 46 | 47 | def __str__(self): 48 | return f"Optional[{self.type}]" 49 | 50 | 51 | class ListType(Type): 52 | def __init__(self, t): 53 | self.type = t 54 | 55 | def __call__(self, other): 56 | return(isinstance(other, list) and all(self.type(e) for e in other)) 57 | 58 | def __str__(self): 59 | return f"List[{self.type}]" 60 | 61 | 62 | class TupleType(Type): 63 | def __init__(self, types, homogeneous=True): 64 | if homogeneous: 65 | assert isinstance(types, Type) 66 | self.types = types 67 | self.homogeneous = homogeneous 68 | 69 | def __call__(self, other): 70 | if not isinstance(other, tuple): 71 | return False 72 | 73 | if self.homogeneous: 74 | return all(self.types(t) for t in other) 75 | 76 | if len(self.types) != len(other): 77 | return False 78 | 79 | return all(e(a) for e, a in zip(self.types, other)) 80 | 81 | def __str__(self): 82 | if self.homogeneous: 83 | return f"Tuple[{self.types},...]" 84 | else: 85 | return f"Tuple[{', '.join([str(t) for t in self.types])}]" 86 | 87 | 88 | class AnyType(Type): 89 | def __call__(self, other): 90 | return True 91 | 92 | def __str__(self): 93 | return "Any" 94 | 95 | 96 | def getLines(lines): 97 | for line in lines: 98 | line = line.strip().rstrip(" .=").replace(" ", "") 99 | if line in ['', '\n', '\r\n']: 100 | # skip empty lines 101 | continue 102 | elif line.startswith("import") or line.startswith("from"): 103 | # skip imports 104 | continue 105 | elif line.startswith("#"): 106 | # skip comments 107 | continue 108 | else: 109 | yield line 110 | 111 | 112 | def parse_function(line, types): 113 | function_decl = line[3:line.rindex(")") + 1] 114 | func_name = line[3:line.index("(")] 115 | 116 | args = function_decl[function_decl.index("(") + 1:-1] 117 | args = args.strip().split(",") 118 | 119 | i = 0 120 | keyword = False 121 | while i < len(args): 122 | arg = args[i] 123 | 124 | if ('[' in arg)and arg.count('[') > arg.count(']'): 125 | # this is a compound type, merge tokens untill brackets are balanced 126 | to_merge = [arg] 127 | cnt = arg.count('[') - arg.count(']') 128 | i += 1 129 | while i < len(args): 130 | to_merge.append(args[i]) 131 | cnt += (args[i].count('[') - args[i].count(']')) 132 | i += 1 133 | if cnt == 0: 134 | break 135 | i -= 1 136 | arg = ",".join(to_merge) 137 | 138 | if ('(' in arg)and arg.count('(') > arg.count(')'): 139 | # this is a unbalanced tuple, merge tokens untill brackets are balanced 140 | to_merge = [arg] 141 | cnt = arg.count('(') - arg.count(')') 142 | i += 1 143 | while i < len(args): 144 | to_merge.append(args[i]) 145 | cnt += (args[i].count('(') - args[i].count(')')) 146 | i += 1 147 | if cnt == 0: 148 | break 149 | i -= 1 150 | arg = ",".join(to_merge) 151 | 152 | if '*' == arg: 153 | # end of positionals 154 | keyword = True 155 | i += 1 156 | continue 157 | elif keyword or '=' in arg: 158 | # keyword 159 | keyword = True 160 | arg_name = arg.split(":")[0] 161 | if '=' in arg: 162 | arg_type, default_value = arg.split(":")[1].split("=") 163 | else: 164 | arg_type = arg.split(":")[1] 165 | elif '**' in arg: 166 | # **kwargs 167 | keyword = True 168 | arg_name = arg.split("*")[-1] 169 | arg_type = "dict" 170 | elif '*' in arg: 171 | # *args 172 | arg_name = arg.split(":")[0][1:] 173 | arg_type = f"Tuple[{arg.split(':')[1]}, ...]" 174 | keyword = True 175 | else: 176 | # a:cls 177 | arg_name, arg_type = arg.split(":") 178 | 179 | i += 1 180 | arg_type = arg_type.replace(" ", "") 181 | ts.add(arg_type) 182 | return func_name, args 183 | 184 | 185 | def generateTypes(): 186 | types = dict() 187 | # builtins 188 | for s in ['int', 'bool', 'float', 'str', 'slice']: 189 | types[f"_{s}"] = BasicType(getattr(builtins, s)) 190 | types[f"List[_{s}]"] = ListType(types[f"_{s}"]) 191 | types[f"Tuple[_{s},...]"] = TupleType(types[f"_{s}"]) 192 | types[f"Optional[_{s}]"] = OptionalType(types[f"_{s}"]) 193 | 194 | types["List"] = BasicType(list) 195 | types["Tuple"] = BasicType(tuple) 196 | types["Number"] = UnionType(types['_int'], types['_bool'], types['_float']) 197 | types["Any"] = AnyType() 198 | types["_ndarray"] = BasicType(np.ndarray) 199 | types["Callable"] = inspect.isfunction 200 | 201 | # torch types 202 | for s in ['Tensor', 'layout', 'qscheme', 'Generator', "Storage", 'memory_format', "device", "dtype", "Size"]: 203 | types[f'_{s}'] = BasicType(getattr(torch, s)) 204 | types[f"List[_{s}]"] = ListType(types[f"_{s}"]) 205 | types[f"Tuple[_{s},...]"] = TupleType(types[f"_{s}"]) 206 | types[f"Optional[_{s}]"] = OptionalType(types[f"_{s}"]) 207 | 208 | # special cases 209 | types['_size'] = UnionType([types[s] 210 | for s in ['_Size', 'List[_int]', 'Tuple[_int,...]']]) 211 | types['Union[_Tensor,List]'] = UnionType(types['_Tensor'], types['List']) 212 | types['Union[_int,List[_int]]'] = UnionType(types['_int'], 213 | types['List[_int]']) 214 | types['Union[_Tensor,Number]'] = UnionType(types['_Tensor'], 215 | types['Number']) 216 | types['Optional[_size]'] = OptionalType(types['_size']) 217 | types['Union[_int,_size]'] = UnionType(types['_int'], types['_size']) 218 | types['Optional[Union[_device,_str]]'] = OptionalType(UnionType(types['_device'], 219 | types['_str'])) 220 | types['Optional[Union[_str,_dtype]]'] = OptionalType(UnionType(types['_str'], 221 | types['_dtype'])) 222 | types['Union[Tuple[_Tensor,...],List[_Tensor]]'] = UnionType( 223 | types['Tuple[_Tensor,...]'], types['List[_Tensor]']) 224 | types['Optional[Union[Tuple[_Tensor,...],List[_Tensor]]]'] = OptionalType( 225 | types['Union[Tuple[_Tensor,...],List[_Tensor]]']) 226 | 227 | types['Optional[Union[_int,_slice,_Tensor,List,Tuple]]'] =\ 228 | OptionalType( 229 | UnionType([types[s] for s in ['_int', '_slice', '_Tensor', 'List', 'Tuple']])) 230 | 231 | return types 232 | 233 | 234 | def parse(): 235 | decl_file = "pytorch_Gpipe/model_partitioning/module_generation/declarations.pyi" 236 | types = generateTypes() 237 | functions = dict() 238 | is_torch = False 239 | with open(decl_file, "r") as f: 240 | for line in getLines(f.readlines()): 241 | 242 | if line.startswith('class'): 243 | # class declaration 244 | current_class = line[line.index( 245 | "class") + 5:].split(":")[0] 246 | # print(current_class) 247 | elif line.startswith('def'): 248 | # function decl 249 | func, args = parse_function(line, types) 250 | assert hasattr(Tensor, func) or hasattr(torch, func) 251 | elif line.startswith('@overload'): 252 | # function overload 253 | pass 254 | 255 | 256 | # function + args in order => positional/keyword 257 | if __name__ == "__main__": 258 | parse() 259 | -------------------------------------------------------------------------------- /partition_torchvision_networks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pytorch_Gpipe import partition_with_profiler, profileNetwork, distribute_by_memory, distribute_by_time, \ 3 | distribute_using_profiler, pipe_model 4 | import torch 5 | from sample_models import alexnet, resnet101, vgg19_bn, squeezenet1_1, inception_v3, densenet201, GoogLeNet, LeNet, \ 6 | WideResNet 7 | from sample_models import AmoebaNet_D as my_amoeaba, amoebanetd as ref_amoeba, torchgpipe_resnet101 8 | 9 | import torch.nn as nn 10 | from pytorch_Gpipe.utils import model_scopes 11 | import datetime 12 | 13 | 14 | def partition_torchvision(networks=None, nparts=4, depth=100, nruns=4, blocks=None, 15 | save_graph=False, show_scope_diff=False, 16 | dump_graph=False, **model_kwargs): 17 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 18 | if networks != None and not isinstance(networks, (list, tuple)): 19 | networks = [networks] 20 | 21 | if networks is None: 22 | networks = [my_amoeaba, ref_amoeba, alexnet, resnet101, torchgpipe_resnet101, vgg19_bn, squeezenet1_1, 23 | inception_v3, densenet201, GoogLeNet, LeNet, WideResNet] 24 | 25 | if not isinstance(nparts, (list, tuple)): 26 | nparts = [nparts] 27 | 28 | if not isinstance(depth, (list, tuple)): 29 | depth = [depth] 30 | 31 | for i in range(nruns): 32 | for net in networks: 33 | model = net(**model_kwargs).to(device) 34 | print("model built") 35 | basic_blocks = blocks 36 | for p in nparts: 37 | for d in depth: 38 | print(f"current net is {net.__name__}") 39 | if net.__name__.find("inception") != -1: 40 | graph = partition_with_profiler( 41 | model, torch.zeros(16, 3, 299, 299, device=device), nparts=p, max_depth=d, 42 | basic_blocks=basic_blocks) 43 | elif net.__name__.find("GoogLeNet") != -1: 44 | graph = partition_with_profiler( 45 | model, torch.zeros(16, 3, 32, 32, device=device), nparts=p, max_depth=d, 46 | basic_blocks=basic_blocks) 47 | elif net.__name__.find("LeNet") != -1: 48 | graph = partition_with_profiler( 49 | model, torch.zeros(16, 3, 32, 32, device=device), nparts=p, max_depth=d, 50 | basic_blocks=basic_blocks) 51 | elif net.__name__.find("moeba") != -1: 52 | graph = partition_with_profiler( 53 | model, torch.zeros(4, 3, 224, 224, device=device), nparts=p, max_depth=d, basic_blocks=basic_blocks) 54 | else: 55 | graph = partition_with_profiler( 56 | model, torch.zeros(16, 3, 224, 224, device=device), nparts=p, max_depth=d, basic_blocks=basic_blocks) 57 | 58 | time_stemp = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S") 59 | filename = f"{net.__name__}_run{i}_attempted_{p}_partitions_at_depth_{d}_{time_stemp}" 60 | 61 | curr_dir = os.path.dirname(os.path.realpath(__file__)) 62 | out_dir = f"{curr_dir}\\partition_visualization" 63 | 64 | if dump_graph: 65 | graph.serialize(f"{curr_dir}\\graph_dump\\{filename}") 66 | if save_graph: 67 | graph.save(directory=out_dir, file_name=filename, 68 | show_buffs_params=False, show_weights=False) 69 | print(filename) 70 | 71 | if show_scope_diff: 72 | scopes = set(model_scopes(model, depth=d, 73 | basic_blocks=basic_blocks)) 74 | graph_scopes = graph.scopes() 75 | diff = scopes.difference(graph_scopes) 76 | print(f"scope diff {len(diff)}") 77 | for s in diff: 78 | print(s) 79 | print("\n") 80 | 81 | 82 | def distribute_torchvision(networks=None, nparts=4, depth=100, nruns=4, 83 | fake_gpus=False, save_graph=False, show_scope_diff=False, 84 | optimize_pipeline_wrappers=True, dump_graph=False, **model_kwargs): 85 | if not torch.cuda.is_available(): 86 | raise ValueError("CUDA is required") 87 | 88 | device = 'cuda:0' 89 | if networks != None and not isinstance(networks, (list, tuple)): 90 | networks = [networks] 91 | 92 | if networks is None: 93 | networks = [my_amoeaba, ref_amoeba, alexnet, resnet101, torchgpipe_resnet101, vgg19_bn, squeezenet1_1, 94 | inception_v3, densenet201, GoogLeNet, LeNet, WideResNet] 95 | 96 | if not isinstance(nparts, (list, tuple)): 97 | nparts = [nparts] 98 | 99 | if not isinstance(depth, (list, tuple)): 100 | depth = [depth] 101 | 102 | for i in range(nruns): 103 | for net in networks: 104 | model = net(**model_kwargs).to(device) 105 | print("model built") 106 | basic_blocks = None 107 | for p in nparts: 108 | if fake_gpus: 109 | devices = [f'cuda:0' for _ in range(p)] 110 | else: 111 | assert torch.cuda.device_count() == p 112 | devices = [f'cuda:{i}' for i in range(p)] 113 | for d in depth: 114 | print(f"current net is {net.__name__}") 115 | if net.__name__.find("inception") != -1: 116 | model, _, _, graph = distribute_using_profiler( 117 | model, torch.zeros(16, 3, 299, 299, device=device), 118 | optimize_pipeline_wrappers=optimize_pipeline_wrappers, devices=devices, max_depth=d, 119 | basic_blocks=basic_blocks) 120 | elif net.__name__.find("GoogLeNet") != -1: 121 | model, _, _, graph = distribute_using_profiler( 122 | model, torch.zeros(16, 3, 32, 32, device=device), 123 | optimize_pipeline_wrappers=optimize_pipeline_wrappers, devices=devices, max_depth=d, 124 | basic_blocks=basic_blocks) 125 | elif net.__name__.find("LeNet") != -1: 126 | model, _, _, graph = distribute_using_profiler( 127 | model, torch.zeros(16, 3, 32, 32, device=device), 128 | optimize_pipeline_wrappers=optimize_pipeline_wrappers, devices=devices, max_depth=d, 129 | basic_blocks=basic_blocks) 130 | elif net.__name__.find("moeba") != -1: 131 | model, _, _, graph = distribute_using_profiler( 132 | model, torch.zeros(8, 3, 224, 224, device=device), 133 | optimize_pipeline_wrappers=optimize_pipeline_wrappers, devices=devices, max_depth=d, 134 | basic_blocks=basic_blocks) 135 | else: 136 | model, _, _, graph = distribute_using_profiler( 137 | model, torch.zeros(16, 3, 224, 224, device=device), 138 | optimize_pipeline_wrappers=optimize_pipeline_wrappers, devices=devices, max_depth=d, 139 | basic_blocks=basic_blocks) 140 | 141 | time_stemp = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S") 142 | filename = f"{net.__name__}_run{i}_attempted_{p}_partitions_at_depth_{d}_{time_stemp}" 143 | 144 | curr_dir = os.path.dirname(os.path.realpath(__file__)) 145 | out_dir = f"{curr_dir}\\distributed_models" 146 | if dump_graph: 147 | graph.serialize(f"{curr_dir}\\graph_dump\\{filename}") 148 | if save_graph: 149 | graph.save(directory=out_dir, file_name=filename, 150 | show_buffs_params=False, show_weights=False) 151 | 152 | if show_scope_diff: 153 | scopes = set(model_scopes(model, depth=d, 154 | basic_blocks=basic_blocks)) 155 | graph_scopes = graph.scopes() 156 | diff = scopes.difference(graph_scopes) 157 | print(f"scope diff {len(diff)}") 158 | for s in diff: 159 | print(s) 160 | print("\n") 161 | 162 | 163 | if __name__ == "__main__": 164 | # distribute_torchvision(networks=[my_amoeaba, ref_amoeba], nparts=8, 165 | # fake_gpus=True, save_graph=True, nruns=2, num_layers=5) 166 | 167 | from sample_models.Inception import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE 168 | partition_torchvision(networks=inception_v3, save_graph=True, depth=[0, 1, 100], blocks=[ 169 | BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE], nruns=2) 170 | -------------------------------------------------------------------------------- /pytorch_Gpipe/METIS/METIS_manual.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alondj/Pytorch-Gpipe/7c9bf892fe65ffe3efdb8cb00a7052b665d41f75/pytorch_Gpipe/METIS/METIS_manual.pdf -------------------------------------------------------------------------------- /pytorch_Gpipe/METIS/__init__.py: -------------------------------------------------------------------------------- 1 | from .METIS_graph_partition import * 2 | -------------------------------------------------------------------------------- /pytorch_Gpipe/METIS/libmetis.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alondj/Pytorch-Gpipe/7c9bf892fe65ffe3efdb8cb00a7052b665d41f75/pytorch_Gpipe/METIS/libmetis.dll -------------------------------------------------------------------------------- /pytorch_Gpipe/METIS/libmetis.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alondj/Pytorch-Gpipe/7c9bf892fe65ffe3efdb8cb00a7052b665d41f75/pytorch_Gpipe/METIS/libmetis.so -------------------------------------------------------------------------------- /pytorch_Gpipe/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Dict, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .model_partitioning import generatePartitionModules, partition 7 | from .model_profiling import Graph, graph_builder, profileNetwork 8 | from .pipeline import Pipeline 9 | from .utils import Devices, Tensors 10 | 11 | __all__ = ['pipe_model', 'partition_with_profiler', 12 | 'partition', 'Pipeline'] 13 | 14 | 15 | def pipe_model(model: nn.Module, sample_batch: Tensors, kwargs: Optional[Dict] = None, nparts: int = 4, partition_by_memory: bool = False, output_file: str = None, DEBUG=False): 16 | '''attemps to partition a model to given number of parts using our profiler 17 | this will produce a python file with the partition config 18 | 19 | the generated python file exposes a method named {modelClass}Pipeline that creates the pipeline 20 | for this specific model config 21 | 22 | Parameters: 23 | model: 24 | the network we wish to model 25 | sample_batch: 26 | a sample input to use for tracing 27 | nparts: 28 | the number of partitions 29 | partition_by_memory: 30 | whether to partition by memory consumption if False partitions by time defaults to False 31 | output_file: 32 | the file name in which to save the partition config 33 | if not given defualts to generated_{modelClass}{actualNumberOfPartitions} 34 | DEBUG: 35 | whether to generate the debug version of the partition more comments and assertions in the generated file 36 | ''' 37 | def by_time(w): 38 | if hasattr(w, 'forward_time') and hasattr(w, 'backward_time'): 39 | return max(int(100 * (w.forward_time + w.backward_time) / 2), 1) 40 | return 1 41 | 42 | def by_memory(w): 43 | if hasattr(w, 'cuda_memory_forward') and hasattr(w, 'cuda_memory_backward'): 44 | return max(int(100 * (w.cuda_memory_forward + w.cuda_memory_backward) / 2), 1) 45 | return 1 46 | 47 | if partition_by_memory: 48 | w_func = by_memory 49 | else: 50 | w_func = by_time 51 | 52 | graph = partition_with_profiler(model, sample_batch, kwargs=kwargs, nparts=nparts, 53 | weighting_function=w_func) 54 | 55 | generatePartitionModules(graph, model, 56 | output_file=output_file, verbose=DEBUG) 57 | 58 | return graph 59 | 60 | 61 | def partition_with_profiler(model: nn.Module, sample_batch: Tensors, kwargs: Optional[Dict] = None, nparts=4, max_depth=100, basic_blocks: Optional[List[nn.Module]] = None, weighting_function: Optional[Callable[[Any], int]] = None) -> Graph: 62 | ''' 63 | return a graph representing the partitioned model with the weights given by the profiler 64 | this method does not distribute the model accross devices 65 | 66 | Parameters: 67 | model: 68 | the network we wish to model 69 | sample_batch: 70 | a sample input to use for tracing 71 | nparts: 72 | the number of partitions 73 | max_depth: 74 | how far down we go in the model tree determines the detail level of the graph 75 | basic_blocks: 76 | an optional list of modules that if encountered will not be broken down 77 | weighting_function: 78 | an optional function from node weights to non negative integers if not provided a default function will be used 79 | ''' 80 | graph = graph_builder(model, sample_batch, kwargs=kwargs, max_depth=max_depth, 81 | basic_blocks=basic_blocks, use_profiler=True) 82 | 83 | graph = partition(graph, nparts, weighting_function=weighting_function) 84 | 85 | return graph 86 | -------------------------------------------------------------------------------- /pytorch_Gpipe/delayedNorm.py: -------------------------------------------------------------------------------- 1 | """Tracks the running statistics per mini-batch instead of micro-batch.""" 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import Tensor, nn 6 | import torch.nn.functional as F 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | 9 | # taken form torchGpipe repo 10 | 11 | 12 | class DelayedBatchNorm(_BatchNorm): 13 | """A BatchNorm layer tracks multiple micro-batches to update running 14 | statistics per mini-batch. 15 | """ 16 | 17 | def __init__(self, 18 | num_features: int, 19 | eps: float = 1e-5, 20 | momentum: Optional[float] = 0.1, 21 | affine: bool = True, 22 | num_micro_batches: int = 1, 23 | is_recomputing: bool = False 24 | ): 25 | super().__init__(num_features, eps, momentum, affine, track_running_stats=True) 26 | 27 | self.register_buffer('sum', torch.zeros_like(self.running_mean)) 28 | self.register_buffer('sum_squares', torch.zeros_like(self.running_var)) 29 | 30 | self.counter = 0 31 | self.tracked = 0 32 | self.num_micro_batches = num_micro_batches 33 | self.is_recomputing = is_recomputing 34 | 35 | def _check_input_dim(self, x: Tensor): 36 | if x.dim() <= 2: 37 | raise ValueError( 38 | 'expected at least 3D input (got %dD input)' % x.dim()) 39 | 40 | def _track(self, x: Tensor) -> bool: 41 | """Tracks statistics of a micro-batch.""" 42 | # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d. 43 | dim = [0] 44 | dim.extend(range(2, x.dim())) 45 | 46 | with torch.no_grad(): 47 | self.sum += x.sum(dim) 48 | self.sum_squares += (x**2).sum(dim) 49 | 50 | size = x.size().numel() // x.size(1) 51 | self.counter += size 52 | self.tracked += 1 53 | 54 | return (self.tracked == self.num_micro_batches) 55 | 56 | def _commit(self): 57 | """Updates the running statistics of a mini-batch.""" 58 | exponential_average_factor = 0.0 59 | self.num_batches_tracked += 1 60 | if self.momentum is None: # use cumulative moving average 61 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 62 | else: # use exponential moving average 63 | exponential_average_factor = self.momentum 64 | 65 | mean = self.sum / self.counter 66 | var = self.sum_squares / self.counter - mean**2 67 | 68 | # Calculate the exponential moving average here. 69 | m = exponential_average_factor 70 | 71 | self.running_mean *= 1 - m 72 | self.running_mean += mean * m 73 | 74 | self.running_var *= 1 - m 75 | self.running_var += var * m 76 | 77 | self.sum.zero_() 78 | self.sum_squares.zero_() 79 | self.counter = 0 80 | self.tracked = 0 81 | 82 | def forward(self, x: Tensor) -> Tensor: # type: ignore 83 | if not self.training: 84 | # Don't train parameters on the evaluation mode. 85 | return F.batch_norm( 86 | x, 87 | running_mean=self.running_mean, 88 | running_var=self.running_var, 89 | weight=self.weight, 90 | bias=self.bias, 91 | training=False, 92 | momentum=0.0, 93 | eps=self.eps, 94 | ) 95 | 96 | if not self.is_recomputing: 97 | # Track a micro-batch on the training mode 98 | # but not under a recomputation. 99 | tracked_enough = self._track(x) 100 | 101 | # Update the running statistics for a mini-batch 102 | # if it has tracked enough micro-batches. 103 | if tracked_enough: 104 | self._commit() 105 | 106 | # Normalize a micro-batch and train the parameters. 107 | return F.batch_norm( 108 | x, 109 | running_mean=None, 110 | running_var=None, 111 | weight=self.weight, 112 | bias=self.bias, 113 | training=True, 114 | momentum=0.0, 115 | eps=self.eps, 116 | ) 117 | 118 | @classmethod 119 | def convertBatchNorm(cls, module: nn.Module, num_micro_batches: int = 1) -> nn.Module: 120 | """Converts a :class:`nn.BatchNorm` or underlying 121 | :class:`nn.BatchNorm`s into :class:`DelayedBatchNorm`:: 122 | from torchvision.models.resnet import resnet101 123 | from pytorch_Gpipe.delayedNorm import DelayedBatchNorm 124 | model = resnet101() 125 | model = DelayedBatchNorm.convertBatchNorm(model) 126 | """ 127 | if isinstance(module, DelayedBatchNorm) and module.num_micro_batches is num_micro_batches: 128 | return module 129 | 130 | if isinstance(module, _BatchNorm) and module.track_running_stats: 131 | module_output = DelayedBatchNorm(module.num_features, 132 | module.eps, 133 | module.momentum, 134 | module.affine, 135 | num_micro_batches) 136 | if module.affine: 137 | module_output.register_parameter('weight', module.weight) 138 | module_output.register_parameter('bias', module.bias) 139 | module_output.register_buffer('running_mean', module.running_mean) 140 | module_output.register_buffer('running_var', module.running_var) 141 | module_output.register_buffer( 142 | 'num_batches_tracked', module.num_batches_tracked) 143 | 144 | return module_output 145 | 146 | for name, child in module.named_children(): 147 | module.add_module( 148 | name, cls.convertBatchNorm(child, num_micro_batches)) 149 | 150 | return module 151 | -------------------------------------------------------------------------------- /pytorch_Gpipe/model_partitioning/__init__.py: -------------------------------------------------------------------------------- 1 | from .module_generation import generatePartitionModules 2 | from .partition_graph import partiton_graph as partition 3 | from .process_partition import post_process_partition 4 | 5 | __all__ = ["post_process_partition", 6 | "partition", "generatePartitionModules"] 7 | -------------------------------------------------------------------------------- /pytorch_Gpipe/model_partitioning/module_generation/__init__.py: -------------------------------------------------------------------------------- 1 | from .generate import generatePartitionModules 2 | -------------------------------------------------------------------------------- /pytorch_Gpipe/model_partitioning/module_generation/constructor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Dict, Set 2 | import re 3 | from itertools import chain 4 | from torch.nn import Module 5 | tab = ' ' 6 | dtab = tab + tab 7 | 8 | 9 | def generateConstructor(class_name: str, full_names: List[str], layer_classes: Dict[str, Module], 10 | is_param_dict: Dict[str, bool], buff_param_names: Set[str]) -> Tuple[str, Dict[str, str]]: 11 | '''creates the partition constructor and the mapping between layers and field ids 12 | ''' 13 | class_decl = f"class {class_name}(nn.Module):" 14 | init_dec = f"{tab}def __init__(self, layers, buffers, parameters):" 15 | super_init = f'{dtab}super({class_name}, self).__init__()' 16 | layer_names = [f'self.l_{idx}' for idx, _ in enumerate(full_names)] 17 | layers_init = generate__init__layersStatements(layer_names, full_names, 18 | layer_classes) 19 | scope = dict(zip(full_names, layer_names)) 20 | 21 | params, buffs = [], [] 22 | for k, v in is_param_dict.items(): 23 | if k not in buff_param_names: 24 | continue 25 | elif v: 26 | params.append(k) 27 | else: 28 | buffs.append(k) 29 | 30 | tensor_init, tensor_ids = generate__init__BuffParamStatements(buffs, 31 | params) 32 | lookup = generateLookup(scope, tensor_ids) 33 | scope.update(tensor_ids) 34 | 35 | device_id = re.search(r'\d+$', class_name).group() 36 | 37 | # we initialize it to expected device if DEBUG then the pipeline will set it to cpu device 38 | device = f"{dtab}self.device = torch.device('cuda:{device_id}')" 39 | 40 | return '\n'.join([class_decl, init_dec, super_init, layers_init, tensor_init, device, lookup]) + '\n', scope 41 | 42 | 43 | def generate__init__layersStatements(layer_names: List[str], full_names: List[str], layer_classes: Dict[str, Module]) -> str: 44 | ''' generates partition field initialization statements\n 45 | and comments to describe which scope is allocated to which field 46 | ''' 47 | statements = [f'{dtab}# initializing partition layers', 48 | generate__init__assertGuards(len(layer_names))] 49 | 50 | for field, full_name in zip(layer_names, full_names): 51 | statements.extend([f"# {full_name}", 52 | f"assert '{full_name}' in layers, 'layer {full_name} was expected but not given'", 53 | f"{field} = layers['{full_name}']"]) 54 | class_name = layer_classes[full_name].__name__ 55 | error_msg = f"f'layers[{full_name}] is expected to be of type {class_name} but was of type {{type({field})}}'" 56 | statements.append( 57 | f"assert isinstance({field},{class_name}) ,{error_msg}") 58 | return f'\n{dtab}'.join(statements) 59 | 60 | 61 | def generate__init__assertGuards(nlayers: int) -> str: 62 | ''' generate assert guards ensuring we recieve the necessary amount of layers\n 63 | in the *layers vararg argument of the constructor 64 | ''' 65 | assert_statements = f"assert isinstance(layers,dict), f'expected layers to be of type dict but got type{{type(layers)}}'\n" 66 | assert_statements += f"{dtab}assert(len(layers) == {nlayers})\n" 67 | assert_statements += f"{dtab}assert(all(isinstance(k, str) for k in layers.keys())), 'string keys are expected'\n" 68 | assert_statements += f"{dtab}assert(all(isinstance(v, nn.Module) for v in layers.values())), 'Module values are expected'" 69 | return assert_statements 70 | 71 | 72 | def generate__init__BuffParamStatements(buffers: List[str], parameters: List[str]) -> str: 73 | tensor_ids = {} 74 | lines = [f"\n{dtab}# initializing partition buffers", 75 | f"assert isinstance(buffers,dict), f'expected buffers to be of type dict got {{type(buffers)}}'", 76 | f"assert len(buffers) == {len(buffers)}, f'expected buffers to have {len(buffers)} elements but has {{len(buffers)}} elements'", 77 | f"assert all(isinstance(k,str) for k in buffers.keys()), 'string keys are expected'", 78 | f"assert all(isinstance(v,Tensor) for v in buffers.values()), 'Tensor values are expected'"] 79 | for idx, b_name in enumerate(buffers): 80 | lines.extend([f"# {b_name}", 81 | f"assert '{b_name}' in buffers, '{b_name} buffer was expected but not given'", 82 | f"self.register_buffer('b_{idx}',buffers['{b_name}'])"]) 83 | tensor_ids[b_name] = f'self.b_{idx}' 84 | 85 | lines.extend([f"\n{dtab}# initializing partition parameters", 86 | f"assert isinstance(parameters,dict), f'expected parameters to be of type dict got {{type(parameters)}}'", 87 | f"assert len(parameters) == {len(parameters)}, f'expected parameters to have {len(parameters)} elements but has {{len(parameters)}} elements'", 88 | f"assert all(isinstance(k,str) for k in parameters.keys()), 'string keys are expected'", 89 | f"assert all(isinstance(v,Tensor) for v in parameters.values()), 'Tensor values are expected'"]) 90 | for idx, p_name in enumerate(parameters): 91 | lines.extend([f"# {p_name}", 92 | f"assert '{p_name}' in parameters, '{p_name} parameter was expected but not given'", 93 | f"self.p_{idx} = parameters['{p_name}']"]) 94 | tensor_ids[p_name] = f'self.p_{idx}' 95 | 96 | return f'\n{dtab}'.join(lines), tensor_ids 97 | 98 | 99 | def generateLookup(layers_to_id, tensors_to_id): 100 | # first generate lookup table 101 | {'p_0': 'w', 102 | 'l_1': 'module0.sub1.linear'} 103 | lookup = [] 104 | for scope, id in chain(layers_to_id.items(), tensors_to_id.items()): 105 | # scope: testMod/Linear[linear0] id: l_0 106 | # we will have 2 keys: l_0.weight l_0.bias 107 | # we wish to replace l_0 with linear0 108 | # resulting in keys: linear0.weight linear0.bias 109 | # for eg scope testMod/Mod0[a]/Sub[b] => a.b 110 | fields = re.findall("\[[a-zA-Z0-9_]*\]", scope) 111 | fields = map(lambda s: s[1:-1:], fields) 112 | prefix = '.'.join(fields) 113 | # remove the self. part of the id 114 | lookup.append(f"'{id[5:]}': '{prefix}'") 115 | lookup = f",\n{dtab}{dtab}{dtab}".join(lookup) 116 | return f"{dtab}self.lookup = {{ {lookup}}}" 117 | -------------------------------------------------------------------------------- /pytorch_Gpipe/model_partitioning/module_generation/generate.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import Tensor 4 | from torch.nn import Module 5 | import torch.nn.functional as F 6 | from pytorch_Gpipe.model_profiling.control_flow_graph import Node, NodeTypes, Graph 7 | from pytorch_Gpipe.utils import traverse_model, traverse_params_buffs 8 | import string 9 | from .forward import generateForwardFunction 10 | from .constructor import generateConstructor 11 | from .misc import generateMiscMethods 12 | from typing import List, Tuple, Dict 13 | from pytorch_Gpipe.utils import OrderedSet 14 | from collections import OrderedDict 15 | import inspect 16 | 17 | tab = ' ' 18 | dtab = tab + tab 19 | 20 | 21 | def generatePartitionModules(graph: Graph, model: Module, verbose=False, output_file=None): 22 | layer_classes = {scope: type(layer) for layer, scope, _ 23 | in traverse_model(model, depth=graph.depth)} 24 | is_param_dict = {scope: t.requires_grad for t, 25 | scope in traverse_params_buffs(model)} 26 | 27 | parts = groupByPartition(graph.nodes) 28 | 29 | lines = generateImports(layer_classes) 30 | lines.append(connections(graph)) 31 | ios = [] 32 | # the main code generation loop generating a class decl 33 | # and forward function 34 | partitions_code = [] 35 | ios = dict() 36 | for idx, part in parts: 37 | class_name = f'{graph.model_name}Partition{idx}' 38 | layer_names = [n.scope for n in part if n.type == NodeTypes.LAYER] 39 | buff_param_names = {n.scope for n in part 40 | if n.type == NodeTypes.BUFF_PARAM} 41 | class_decl, scope_to_class_field = generateConstructor(class_name, layer_names, 42 | layer_classes, is_param_dict, 43 | buff_param_names) 44 | misc_functions = generateMiscMethods() 45 | forward_function, io = generateForwardFunction(part, graph.output_scopes, scope_to_class_field, 46 | verbose=verbose) 47 | partitions_code.append(class_decl) 48 | partitions_code.extend(forward_function) 49 | partitions_code.append(misc_functions) 50 | ios[idx] = io 51 | 52 | lines.append(generatePiplineAndGetConfig(graph, parts, model, ios)) 53 | lines += partitions_code 54 | 55 | if output_file is None: 56 | output_file = f'generated_{graph.model_name}{len(parts)}' 57 | 58 | output_file = output_file + '.py' 59 | 60 | with open(output_file, 'w') as f: 61 | f.write('\n'.join(lines)) 62 | 63 | 64 | def groupByPartition(nodes: List[Node]) -> List[Tuple[int, List[Node]]]: 65 | # groups layers and their respective nodes according to their partition 66 | idxs = {n.part for n in nodes} 67 | parts = OrderedDict() 68 | for i in sorted(idxs): 69 | parts[i] = [] 70 | 71 | for n in nodes: 72 | if n.type == NodeTypes.IN: 73 | continue 74 | elif n.type == NodeTypes.BUFF_PARAM: 75 | parts[n.part].append(n) 76 | elif n.type == NodeTypes.LAYER: 77 | parts[n.part].append(n) 78 | elif n.type == NodeTypes.OP: 79 | scope = n.scope 80 | # we handle torch,Tensor and torch.nn.functional nameSpaces 81 | func_name = getFunctionName(scope) 82 | if hasattr(torch, func_name) or hasattr(F, func_name) or hasattr(Tensor, func_name): 83 | parts[n.part].append(n) 84 | elif 'aten::slice' in scope: 85 | parts[n.part].append(n) 86 | else: 87 | assert False, f'could not find nameSpace for {scope}' 88 | elif n.type == NodeTypes.PYTHON_PRIMITIVE: 89 | scope = n.scope 90 | assert 'prim::' in scope, f'primitive does not have prim:: prefix {scope}' 91 | func_name = scope.split('prim::')[1].rstrip(string.digits) 92 | parts[n.part].append(n) 93 | else: 94 | assert n.type == NodeTypes.CONSTANT, f'got type {n.type}' 95 | parts[n.part].append(n) 96 | 97 | return parts.items() 98 | 99 | 100 | def generateImports(layer_classes: Dict[str, Module]) -> List[str]: 101 | '''generates imports to torch torch.nn, torch.nn.functionl as F and torch.Tensor, 102 | and to every layer used and various other small things 103 | ''' 104 | imports = 'import torch\nfrom torch import Tensor\nimport torch.nn as nn\nimport torch.nn.functional as F\n' 105 | imports += 'from itertools import chain\n' 106 | imports += 'from pytorch_Gpipe.utils import layerDict, tensorDict, OrderedSet\n' 107 | imports += 'from pytorch_Gpipe import Pipeline\n' 108 | unique_classes = set(layer_classes.values()) 109 | 110 | for cls in unique_classes: 111 | imports += f'from {inspect.getmodule(cls).__name__} import {cls.__name__}\n' 112 | 113 | disclaimer = '# this is an auto generated file do not edit unless you know what you are doing\n\n' 114 | 115 | return imports.splitlines() + [disclaimer] 116 | 117 | 118 | def getFunctionName(scope: str) -> str: 119 | if 'aten::' in scope: 120 | sep = 'aten::' 121 | else: 122 | assert 'prim::' in scope, f"attempting to find function name but got {scope}" 123 | sep = 'prim::' 124 | 125 | return scope.split(sep)[1].rstrip(string.digits) 126 | 127 | 128 | def createConfig(graph: Graph, partitions: List[List[Node]], model: Module, ios: Dict[int, OrderedSet]): 129 | model_buffers = {scope: t for t, scope in traverse_params_buffs(model) 130 | if not t.requires_grad} 131 | model_parameteres = {scope: t for t, scope in traverse_params_buffs(model) 132 | if t.requires_grad} 133 | model_class = model.__class__.__name__ 134 | # function header 135 | lines = [ 136 | f"def createConfig(model,DEBUG=False,partitions_only=False):", 137 | "layer_dict = layerDict(model)", 138 | "tensor_dict = tensorDict(model)", 139 | f"\n{tab}# now constructing the partitions in order" 140 | ] 141 | 142 | # hard code which layers buffers and parameters belong to each partition 143 | construction_args = [] 144 | for idx, part in partitions: 145 | layer_scopes = [f"'{n.scope}'" 146 | for n in part if n.type == NodeTypes.LAYER] 147 | buffer_scopes = [ 148 | f"'{n.scope}'" for n in part if n.scope in model_buffers] 149 | parameter_scopes = [f"'{n.scope}'" for n in part 150 | if n.scope in model_parameteres] 151 | construction_args.append( 152 | (layer_scopes, buffer_scopes, parameter_scopes)) 153 | 154 | # create partition generation statements 155 | for idx, (layer_scopes, buffer_scopes, parameter_scopes) in zip(sorted(list(ios.keys())), construction_args): 156 | l_scopes = 'layer_scopes = [' + f",\n{dtab}".join(layer_scopes) + ']' 157 | b_scopes = 'buffer_scopes = [' + f",\n{dtab}".join(buffer_scopes) + ']' 158 | p_scopes = 'parameter_scopes = [' + \ 159 | f",\n{dtab}".join(parameter_scopes) + ']' 160 | lines.extend([l_scopes, b_scopes, p_scopes, 161 | f"layers = {{l: layer_dict[l] for l in layer_scopes}}", 162 | f"buffers = {{b: tensor_dict[b] for b in buffer_scopes}}", 163 | f"parameters = {{p: tensor_dict[p] for p in parameter_scopes}}", 164 | f"partition{idx} = {model_class}Partition{idx}(layers,buffers,parameters)\n"]) 165 | 166 | # create and return the partition config 167 | exp = f',\n{dtab}{tab}'.join([f"{k}: {v}" for k, v in ios.items()]) 168 | lines.append( 169 | f"# creating configuration\n{tab}config = {{{exp}\n{dtab}{tab}}}") 170 | 171 | for idx in sorted(list(ios.keys())): 172 | lines.extend([f"device = 'cpu' if DEBUG else torch.device('cuda:{idx}')", 173 | f"partition{idx}.device=device", 174 | f"config[{idx}]['model'] = partition{idx}.to(device)"]) 175 | 176 | input_ids = [f"'input{idx}'" for idx in range(graph.num_inputs)] 177 | lines.extend([f"config['model inputs'] = [{', '.join(input_ids)}]", 178 | f"config['model outputs'] = {list(graph.output_scopes)}", 179 | f"\n{tab}return [config[i]['model'] for i in range({len(ios)})] if partitions_only else config"]) 180 | 181 | return f"\n{tab}".join(lines) + "\n" 182 | 183 | 184 | def generatePiplineAndGetConfig(graph: Graph, partitions: List[List[Node]], model: Module, ios: Dict[int, OrderedSet]): 185 | '''generates function that will perform the actual partition returning a Pipeline object\n 186 | the function will have the partition config hardcoded into it,\n 187 | enabling us to perform the partition process once and use the config multiple times 188 | ''' 189 | model_class = model.__class__.__name__ 190 | config = createConfig(graph, partitions, model, ios) 191 | 192 | lines = [ 193 | f'\ndef {model_class}Pipeline(model:nn.Module,output_device=None,split_dim=0,use_delayedNorm=False,DEBUG=False):'] 194 | 195 | lines.append( 196 | f"return Pipeline(createConfig(model,DEBUG=DEBUG,partitions_only=False),output_device=output_device,split_dim=split_dim,use_delayedNorm=use_delayedNorm)\n\n", 197 | ) 198 | 199 | return f'\n{tab}'.join(lines) + f"\n{config}\n" 200 | 201 | 202 | def connections(graph: Graph): 203 | adj_matrix = [{"inputs": set(), "outputs": set()} 204 | for i in range(graph.num_parts + 2)] 205 | 206 | for node in graph.nodes: 207 | if node.idx < graph.num_inputs: 208 | for n in node.out_nodes: 209 | adj_matrix[n.part + 1]["inputs"].add(node.scope) 210 | adj_matrix[0]["outputs"].add(n.part) 211 | 212 | idx = graph.output_scopes.indexOf(node.scope) 213 | 214 | if idx >= 0: 215 | adj_matrix[graph.num_parts + 1]["inputs"].add(node.part) 216 | adj_matrix[node.part + 1]["outputs"].add(f"output{idx}") 217 | 218 | for n in node.out_nodes: 219 | if n.part != node.part: 220 | adj_matrix[node.part + 1]["outputs"].add(n.part) 221 | adj_matrix[n.part + 1]["inputs"].add(node.part) 222 | 223 | lines = ["# partition adjacency"] 224 | lines.append(f"# model inputs {adj_matrix[0]['outputs']}") 225 | for i, line in enumerate(adj_matrix[1:-1:]): 226 | lines.append(f"# partition {i} {line}") 227 | lines.append( 228 | f"# model outputs {adj_matrix[graph.num_parts + 1]['inputs']}") 229 | return '\n'.join(lines) + '\n' 230 | -------------------------------------------------------------------------------- /pytorch_Gpipe/model_partitioning/module_generation/misc.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Dict, Set 2 | from torch.nn import Module 3 | tab = ' ' 4 | dtab = tab + tab 5 | 6 | 7 | def generateMiscMethods(): 8 | state_dict = generateStateDictFunction() 9 | load_state_dict = generateLoadStateDict() 10 | named_parameters = generateNamedParametersFunction() 11 | named_buffers = generateNamedBuffersFunction() 12 | 13 | return "\n".join([state_dict, load_state_dict, named_parameters, named_buffers]) + "\n\n" 14 | 15 | 16 | def generateStateDictFunction(): 17 | state_dict_function = ["def state_dict(self,device):", 18 | f"# we return the state dict of this part as it should be in the original model", 19 | "state = super().state_dict()", 20 | f"lookup = self.lookup", 21 | "result = dict()", 22 | "for k, v in state.items():", 23 | f"{tab}if k in lookup:", 24 | f"{dtab}result[lookup[k]] = v if device is None else v.to(device)", 25 | f"{tab}else:", 26 | f"{dtab}assert '.' in k", 27 | f"{dtab}split_idx = k.find('.')", 28 | f"{dtab}new_k = lookup[k[:split_idx]] + k[split_idx:]", 29 | f"{dtab}result[new_k] = v if device is None else v.to(device)", 30 | f"return result"] 31 | 32 | return f"{tab}" + f"\n{dtab}".join(state_dict_function) 33 | 34 | 35 | def generateNamedParametersFunction(): 36 | named_parameters_function = ["def named_parameters(self):", 37 | f"# we return the named parameters of this part as it should be in the original model", 38 | "params = super().named_parameters()", 39 | f"lookup = self.lookup", 40 | "for k, v in params:", 41 | f"{tab}if k in lookup:", 42 | f"{dtab}yield (lookup[k],v)", 43 | f"{tab}else:", 44 | f"{dtab}assert '.' in k", 45 | f"{dtab}split_idx = k.find('.')", 46 | f"{dtab}new_k = lookup[k[:split_idx]] + k[split_idx:]", 47 | f"{dtab}yield (new_k, v)"] 48 | return f"\n{tab}" + f"\n{dtab}".join(named_parameters_function) 49 | 50 | 51 | def generateNamedBuffersFunction(): 52 | named_buffers_function = ["def named_buffers(self):", 53 | f"# we return the named buffers of this part as it should be in the original model", 54 | "params = super().named_buffers()", 55 | f"lookup = self.lookup", 56 | "for k, v in params:", 57 | f"{tab}if k in lookup:", 58 | f"{dtab}yield (lookup[k],v)", 59 | f"{tab}else:", 60 | f"{dtab}assert '.' in k", 61 | f"{dtab}split_idx = k.find('.')", 62 | f"{dtab}new_k = lookup[k[:split_idx]] + k[split_idx:]", 63 | f"{dtab}yield (new_k, v)"] 64 | return f"\n{tab}" + f"\n{dtab}".join(named_buffers_function) 65 | 66 | 67 | def generateLoadStateDict(): 68 | func = ['def load_state_dict(self, state):', 69 | 'reverse_lookup = {v: k for k, v in self.lookup.items()}', 70 | 'ts = chain(self.named_parameters(), self.named_buffers())', 71 | 'device = list(ts)[0][1].device', 72 | 'keys = list(self.state_dict(None).keys())', 73 | 'new_state = dict()', 74 | 'for k in keys:', 75 | tab + 'if k in reverse_lookup:', 76 | dtab + 'new_state[reverse_lookup[k]] = state[k].to(device)', 77 | dtab + 'continue', 78 | tab + 'idx = k.rfind(".")', 79 | tab + 'to_replace = k[:idx]', 80 | tab + 'if to_replace in reverse_lookup:', 81 | dtab + 'key = reverse_lookup[to_replace] + k[idx:]', 82 | dtab + 'new_state[key] = state[k].to(device)', 83 | 'super().load_state_dict(new_state, strict=True)'] 84 | 85 | return f"\n{tab}" + f"\n{dtab}".join(func) 86 | -------------------------------------------------------------------------------- /pytorch_Gpipe/model_partitioning/partition_graph.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, Callable, Optional 3 | 4 | from ..model_profiling import Graph 5 | from .process_partition import post_process_partition 6 | import networkx as nx 7 | import nxmetis 8 | 9 | __all__ = ["partiton_graph"] 10 | 11 | 12 | def partition_METIS(graph: Graph, num_partitions: int, weighting_function: Optional[Callable[[Any], int]] = None, **METIS_opts) -> Graph: 13 | ''' 14 | partition the graph using METIS's PartGraphKway and then optimizes it to our needs 15 | 16 | Parameters 17 | ---------- 18 | graph: 19 | the Graph object to partition 20 | num_partitions: 21 | the requested number of partitions 22 | weighting_function: 23 | a weighting function that transforms the graph weights to non negative integers 24 | if not specified a default function will be used 25 | METIS_opts: 26 | additional options to pass to METIS 27 | for eg. for the option METIS_OPTION_SEED pass seed=value 28 | ''' 29 | from ..METIS import METIS_partition 30 | wfunc = weighting_function if weighting_function != None else default_weight_func 31 | 32 | adjlist = graph.adjacency_list() 33 | nodew = graph.get_weights().values() 34 | 35 | assert(len(adjlist) == len(nodew)) 36 | 37 | weights = [wfunc(w) for w in nodew] 38 | 39 | if 'seed' not in METIS_opts: 40 | METIS_opts['seed'] = 0 41 | 42 | if 'contig' not in METIS_opts: 43 | METIS_opts['contig'] = 1 44 | 45 | partition, _ = METIS_partition(adjlist, nparts=num_partitions, algorithm="metis", 46 | nodew=weights, **METIS_opts) 47 | 48 | post_process_partition(graph, partition) 49 | 50 | actual_nparts = len({n.part for n in graph.nodes}) 51 | 52 | if(actual_nparts < num_partitions): 53 | print( 54 | f"expected {num_partitions} partitions but only {actual_nparts} found implicating that the model to partition is too small") 55 | print("consider increasing the depth of graph or disabling the basic blocks option") 56 | return graph 57 | 58 | 59 | def default_weight_func(w): 60 | if hasattr(w, 'forward_time') and hasattr(w, 'backward_time'): 61 | return max(int(100 * (w.forward_time + w.backward_time) / 2), 1) 62 | return 1 63 | 64 | 65 | def partiton_graph(graph: Graph, num_partitions: int, weighting_function: Optional[Callable[[Any], int]] = None, **METIS_opts): 66 | wfunc = weighting_function if weighting_function != None else default_weight_func 67 | 68 | weights = {node.idx: wfunc(node.weight) for node in graph.nodes} 69 | 70 | G = graph.asNetworkx() 71 | nx.set_node_attributes(G, weights, 'weight') 72 | 73 | _, parts = nxmetis.partition(G, num_partitions) 74 | 75 | parts = sorted((idx, n) for n, p in enumerate(parts)for idx in p) 76 | parts = [n for _, n in parts] 77 | 78 | post_process_partition(graph, parts) 79 | 80 | actual_nparts = len({n.part for n in graph.nodes}) 81 | 82 | if(actual_nparts < num_partitions): 83 | print( 84 | f"expected {num_partitions} partitions but only {actual_nparts} found implicating that the model to partition is too small") 85 | print("consider increasing the depth of graph or disabling the basic blocks option") 86 | return graph 87 | -------------------------------------------------------------------------------- /pytorch_Gpipe/model_partitioning/process_partition.py: -------------------------------------------------------------------------------- 1 | from collections import Counter, deque 2 | from typing import Dict, List 3 | 4 | from ..model_profiling import Graph, NodeTypes 5 | 6 | __all__ = ["post_process_partition"] 7 | 8 | 9 | def post_process_partition(graph: Graph, part: List[int]): 10 | ''' 11 | process the partition and optimize it 12 | called as part of partition_graph method 13 | 14 | Parameters: 15 | ---------- 16 | graph: 17 | the Graph object that was partitioned 18 | part: 19 | a list of the nodes partition indices 20 | ''' 21 | 22 | for node, idx in zip(graph.nodes, part): 23 | node.part = idx 24 | 25 | cannonize_partition_indices(graph) 26 | make_partitions_change_only_at_end_of_scope(graph) 27 | # make sure every scc in the graph is not splitted between different parts 28 | scc_partition_correction(graph) 29 | ensure_dag(graph, part) 30 | 31 | cannonize_partition_indices(graph) 32 | # TODO we disabled this optimization 33 | # fix_arithmetic_inputs(graph) 34 | return 35 | 36 | 37 | def fix_arithmetic_inputs(graph: Graph): 38 | while True: 39 | changed = False 40 | for node in graph.nodes: 41 | if node.type is NodeTypes.OP: 42 | for n in node.in_nodes: 43 | if n.part != node.part: 44 | n.part = node.part 45 | changed = True 46 | if not changed: 47 | break 48 | 49 | 50 | def ensure_dag(graph: Graph, node_parts: List[int]): 51 | flag = True 52 | while flag: 53 | flag, prob_edge = not_dag(graph, node_parts) 54 | 55 | if flag: 56 | fix_problem_node(graph, prob_edge) 57 | 58 | 59 | def not_dag(graph: Graph, node_parts): 60 | part_edges = [] 61 | num_parts = len(set(node_parts)) 62 | for node in graph.nodes: 63 | for out_node in node.out_nodes: 64 | if node.part != out_node.part: 65 | part_edge = (node.part, out_node.part) 66 | if part_edge not in part_edges: 67 | part_edges.append(part_edge) 68 | 69 | for num_part1 in range(num_parts): 70 | for num_part2 in range(num_parts): 71 | if (num_part1, num_part2) in part_edges and (num_part2, num_part1) in part_edges and num_part1 < num_part2: 72 | return True, (num_part1, num_part2) 73 | 74 | return False, (-1, -1) 75 | 76 | 77 | def fix_problem_node(graph: Graph, prob_edge: tuple): 78 | first_part, second_part = prob_edge 79 | for node in graph.nodes: 80 | if node.part == second_part: 81 | for o_node in node.out_nodes: 82 | if o_node.part == first_part: 83 | node.part = first_part 84 | 85 | 86 | def scc_partition_correction(graph: Graph): 87 | # create the scc graph 88 | vertices = [v.idx for v in graph.nodes] 89 | edges = {} 90 | for v in graph.nodes: 91 | idx_out_nodes = [h.idx for h in v.out_nodes] 92 | edges.update({v.idx: idx_out_nodes}) 93 | 94 | for scc in strongly_connected_components_iterative(vertices, edges): 95 | # check if the scc is splitted between 2 parts or more 96 | scc_parts = [] 97 | for v in scc: 98 | if graph.nodes[v].part not in scc_parts: 99 | scc_parts.append(graph.nodes[v].part) 100 | if len(scc_parts) >= 2: 101 | break 102 | # if he is splitted: 103 | if len(scc_parts) >= 2: 104 | output_part = -1 105 | # find out what part edges go to from this scc 106 | for v in scc: 107 | for out in graph.nodes[v].out_nodes: 108 | if out.idx not in scc: 109 | output_part = graph.nodes[out.idx].part 110 | break 111 | if output_part != -1: 112 | break 113 | # update the scc part to the part we found 114 | for v in scc: 115 | graph.nodes[v].part = output_part 116 | 117 | 118 | def strongly_connected_components_iterative(vertices: List[int], edges: Dict[int, List[int]]): 119 | identified = set() 120 | stack = [] 121 | index = {} 122 | boundaries = [] 123 | 124 | for v in vertices: 125 | if v not in index: 126 | to_do = [('VISIT', v)] 127 | while to_do: 128 | operation_type, v = to_do.pop() 129 | if operation_type == 'VISIT': 130 | index[v] = len(stack) 131 | stack.append(v) 132 | boundaries.append(index[v]) 133 | to_do.append(('POSTVISIT', v)) 134 | # We reverse to keep the search order identical to that of 135 | # the recursive code; the reversal is not necessary for 136 | # correctness, and can be omitted. 137 | to_do.extend( 138 | reversed([('VISITEDGE', w) for w in edges[v]])) 139 | elif operation_type == 'VISITEDGE': 140 | if v not in index: 141 | to_do.append(('VISIT', v)) 142 | elif v not in identified: 143 | while index[v] < boundaries[-1]: 144 | boundaries.pop() 145 | else: 146 | # operation_type == 'POSTVISIT' 147 | if boundaries[-1] == index[v]: 148 | boundaries.pop() 149 | scc = set(stack[index[v]:]) 150 | del stack[index[v]:] 151 | identified.update(scc) 152 | yield scc 153 | 154 | 155 | def cannonize_partition_indices(graph: Graph): 156 | num_parts = len({n.part for n in graph.nodes}) 157 | num_taken = 0 158 | model_inputs = [node for node in graph.nodes if node.type == NodeTypes.IN] 159 | open_nodes = deque(model_inputs) 160 | closed = set() 161 | cannonical_parts = dict() 162 | 163 | while num_taken < num_parts: 164 | node = open_nodes.popleft() 165 | if node in closed or node in open_nodes: 166 | continue 167 | if node.part not in cannonical_parts: 168 | cannonical_parts[node.part] = num_taken 169 | num_taken += 1 170 | 171 | closed.add(node) 172 | edges = node.out_nodes.union(node.in_nodes) 173 | nodes = edges.difference(closed, set(open_nodes)) 174 | open_nodes.extend(nodes) 175 | 176 | for node in graph.nodes: 177 | node.part = cannonical_parts[node.part] 178 | 179 | graph.num_parts = len(cannonical_parts) 180 | 181 | 182 | def make_partitions_change_only_at_end_of_scope(graph: Graph): 183 | def is_first_in_partition(node): 184 | return any(other.part != node.part for other in node.in_nodes) 185 | 186 | first_nodes_of_partition = filter(is_first_in_partition, graph.nodes) 187 | 188 | for node in first_nodes_of_partition: 189 | scope_depth = node.scope.count('/') - 1 190 | # dont do it too shallow 191 | if scope_depth >= 2: # TODO think about threshold 192 | parent_scope = node.scope.rsplit('/', 1)[0] 193 | 194 | def in_scope(n): 195 | return parent_scope == n.scope.rsplit('/', 1)[0] 196 | 197 | scope_nodes = list(filter(in_scope, graph.nodes)) 198 | parts = [n.part for n in scope_nodes] 199 | part_histogram = Counter(parts) 200 | most_common, num_layers = part_histogram.most_common(1)[0] 201 | if num_layers >= len(parts) // 2: 202 | for other in scope_nodes: 203 | other.part = most_common 204 | -------------------------------------------------------------------------------- /pytorch_Gpipe/model_profiling/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..utils import Tensors, traverse_model, traverse_params_buffs, model_scopes, _count_elements 7 | from .control_flow_graph import Graph, NodeTypes 8 | from .network_profiler import profileNetwork 9 | 10 | __all__ = ['graph_builder', 'profileNetwork'] 11 | 12 | 13 | def graph_builder(model: nn.Module, sample_batch: Tensors, kwargs: Optional[Dict] = None, max_depth: int = 1000, weights: Optional[Dict[str, Any]] = None, basic_blocks: Optional[List[nn.Module]] = None, use_profiler=False) -> Graph: 14 | ''' 15 | returns a graph that models the control flow of the given network by tracing it's forward pass 16 | 17 | Parameters: 18 | model: 19 | the network we wish to model 20 | sample_batch: 21 | a sample input to use for tracing 22 | kwargs: 23 | keyword args to pass to the model 24 | max_depth: 25 | how far down we go in the model tree determines the detail level of the graph 26 | basic_blocks: 27 | an optional list of modules that if encountered will not be broken down 28 | weights: 29 | an optional dictionary from scopes to Node weights 30 | use_profiler: 31 | wether to use weights given by our profiler 32 | this option supersedes the wieghts option defaults to False 33 | ''' 34 | weights = weights if weights != None else {} 35 | if kwargs is None: 36 | kwargs = {} 37 | if not isinstance(sample_batch, tuple): 38 | sample_batch = (sample_batch,) 39 | 40 | if use_profiler: 41 | weights = profileNetwork(model, sample_batch, kwargs=kwargs, max_depth=max_depth, 42 | basic_blocks=basic_blocks) 43 | 44 | buffer_param_names = map(lambda t: t[1], traverse_params_buffs(model)) 45 | buffer_param_names = list(buffer_param_names) 46 | 47 | layerNames = model_scopes(model, depth=max_depth, 48 | basic_blocks=basic_blocks) 49 | layerNames = list(layerNames) 50 | 51 | # trace the model and build a graph 52 | with torch.no_grad(): 53 | trace_graph, _ = torch.jit.get_trace_graph( 54 | model, sample_batch, kwargs) 55 | trace_graph = trace_graph.graph() 56 | 57 | num_inputs = _count_elements(*sample_batch) + len(kwargs) 58 | 59 | graph = Graph(layerNames, num_inputs, buffer_param_names, 60 | trace_graph, weights, basic_blocks, max_depth) 61 | 62 | return graph 63 | -------------------------------------------------------------------------------- /pytorch_Gpipe/model_profiling/network_profiler.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import namedtuple 3 | from typing import Dict, List, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from ..utils import Tensors, _detach_inputs, _get_size, get_device, traverse_model 9 | 10 | __all__ = ['profileNetwork', 'Profile'] 11 | 12 | Profile = namedtuple('Profile', 13 | 'forward_time backward_time cuda_memory_forward cuda_memory_backward layer_size') 14 | 15 | 16 | def profileNetwork(net: nn.Module, sample_batch: Tensors, kwargs: Optional[Dict] = None, basic_blocks: Optional[List[nn.Module]] = None, max_depth=100) -> Dict[str, Profile]: 17 | ''' 18 | profiles a network's computation time(forward/backward) and memory consumption 19 | returns a dictionary from layer_scope to Profile 20 | 21 | Parameters 22 | ---------- 23 | net: 24 | the network we wish to profile a nn.Module 25 | 26 | sample_batch: 27 | a sample batch that will be used to measure executation time of network 28 | can be single/multiple inputs 29 | 30 | kwargs: 31 | keyword args to pass to the profiled model 32 | 33 | basic_blocks: 34 | a tuple of nn.Module classes that the profiler will regard as a cohesive unit 35 | for eg. if basic_blocks = nn.Sequential then the profiler will break it down to its components 36 | 37 | max_depth: 38 | determins how far the profiler will go in the model tree 39 | 40 | 41 | 42 | ''' 43 | if kwargs is None: 44 | kwargs = {} 45 | if not isinstance(sample_batch, tuple): 46 | sample_batch = (sample_batch,) 47 | 48 | # wrap all individula layers for profiling 49 | layers_dict = _wrap_profiled_layers(net, max_depth, basic_blocks) 50 | 51 | # perform 2 symbolic forward backward run first one is warmup as we have seen the first time measurements are higher 52 | _perform_forward_backward_pass(net, *sample_batch, **kwargs) 53 | _perform_forward_backward_pass(net, *sample_batch, **kwargs) 54 | 55 | # gather forward and backward execution times 56 | backward_times = [layer.backward_time 57 | for layer in layers_dict.values()] 58 | forward_times = [layer.forward_time 59 | for layer in layers_dict.values()] 60 | 61 | # gather input and output sizes 62 | layer_input_sizes = [layer.input_size for layer in layers_dict.values()] 63 | layer_output_sizes = [layer.output_size for layer in layers_dict.values()] 64 | 65 | # gather all individual layer sizes 66 | param_sizes = [layer.parameters_size for layer in layers_dict.values()] 67 | buffer_sizes = [layer.buffers_size for layer in layers_dict.values()] 68 | 69 | # gather cuda memory consumption 70 | cuda_memory = [(layer.forward_cuda_mem, layer.backward_cuda_mem) 71 | for layer in layers_dict.values()] 72 | 73 | # prepare profiling results 74 | layers_profile = {name: Profile(forward, backward, *cuda_mem, param_size + buffer_size + in_size + out_size) for name, forward, backward, param_size, buffer_size, in_size, out_size, cuda_mem in zip( 75 | layers_dict.keys(), forward_times, backward_times, param_sizes, buffer_sizes, layer_input_sizes, layer_output_sizes, cuda_memory)} 76 | 77 | _unwrap_layers(net) 78 | 79 | return layers_profile 80 | 81 | 82 | def _perform_forward_backward_pass(net, *sample_batch: Tensors, **kwargs: Dict): 83 | device = get_device(sample_batch) 84 | if device.type == "cuda": 85 | torch.cuda.synchronize(device=device) 86 | out = net(*sample_batch, **kwargs) 87 | torch.cuda.synchronize(device=device) 88 | else: 89 | out = net(*sample_batch, **kwargs) 90 | net.zero_grad() 91 | return out 92 | 93 | 94 | def _wrap_profiled_layers(module: nn.Module, depth, basic_blocks: List[nn.Module]): 95 | layers_dict = {} 96 | 97 | for sub_layer, scope, parent in traverse_model(module, depth, basic_blocks): 98 | name = scope[scope.rfind('[') + 1:-1] 99 | wrapper = Wrapper(sub_layer) 100 | parent.add_module(name, wrapper) 101 | layers_dict[scope] = wrapper 102 | 103 | return layers_dict 104 | 105 | 106 | def _unwrap_layers(module: nn.Module): 107 | for name, sub_module in module.named_children(): 108 | if isinstance(sub_module, Wrapper): 109 | module.add_module(name, sub_module.layer) 110 | else: 111 | _unwrap_layers(sub_module) 112 | 113 | 114 | class Wrapper(nn.Module): 115 | ''' 116 | a module whose purpose is to profile a given layer\n 117 | when the wrapper performs forward propagation it records the following metrics:\n 118 | forward_time: the execution time of a forward pass of the underlying layer in milliseconds\n 119 | backward_time: the execution time of a backward pass of the underlying layer in milliseconds\n 120 | input_size: the input size in MB 121 | output_size: the layer output size in MB 122 | parameters_size: the size of parameters of the layer in MB 123 | buffers_size: the size of buffers of the layer in MB 124 | forward_cuda_mem: the peak CUDA memory usage during the forward pass in MB 125 | backward_cuda_mem: the peak CUDA memory usage during the backward pass in MB 126 | 127 | Parameters 128 | ---------- 129 | sub_module: 130 | a nn.module to be profiled 131 | 132 | ''' 133 | 134 | def __init__(self, sub_module: nn.Module): 135 | super(Wrapper, self).__init__() 136 | self.layer = sub_module 137 | self.forward_time = 0 138 | self.backward_time = 0 139 | self.input_size = 0 140 | self.output_size = 0 141 | self.parameters_size, self.buffers_size = self._layer_size() 142 | self.forward_cuda_mem = 0 143 | self.backward_cuda_mem = 0 144 | 145 | def _layer_size(self): 146 | ''' 147 | return the size of the layer considering parameters and buffers 148 | ''' 149 | parameters_size = buffers_size = 0 150 | for param in self.layer.parameters(): 151 | parameters_size += param.nelement() * param.element_size() 152 | for buffer in self.layer.buffers(): 153 | buffers_size += buffer.nelement() * buffer.element_size() 154 | 155 | return parameters_size, buffers_size 156 | 157 | def forward(self, *inputs: Tensors, **kwargs: Dict): 158 | ''' 159 | perform forward and backward pass of the underlying layer and measure metrics 160 | ''' 161 | # detach inputs from previous history enabling us to measure execution time 162 | # only for this layer 163 | device = get_device(inputs) 164 | detached_inputs = _detach_inputs(inputs) 165 | 166 | self.forward_time, outputs, self.forward_cuda_mem = self._time_op( 167 | self.layer, *detached_inputs, **kwargs) 168 | 169 | # reduce outputs to calculate dummy loss 170 | loss = torch.zeros(1, requires_grad=True, device=device) 171 | for out in outputs: 172 | loss = loss + out.norm() 173 | 174 | # measure backward execution time 175 | self.backward_time, _, self.backward_cuda_mem = self._time_op( 176 | torch.autograd.backward, loss) 177 | 178 | # input and output size 179 | self.input_size = _get_size(inputs) 180 | self.output_size = _get_size(outputs) 181 | 182 | #size in MegaBytes 183 | self.backward_cuda_mem /= 1e6 184 | self.forward_cuda_mem /= 1e6 185 | self.input_size /= 1e6 186 | self.output_size /= 1e6 187 | self.parameters_size /= 1e6 188 | self.buffers_size /= 1e6 189 | 190 | return outputs 191 | 192 | def _time_op(self, func, *inputs: Tensors, **kwargs: Dict): 193 | exec_time = 0 194 | cuda_mem = 0 195 | device = get_device(inputs) 196 | if(device.type == 'cuda'): 197 | torch.cuda.reset_max_memory_allocated(device=device) 198 | base_mem = torch.cuda.max_memory_allocated(device=device) 199 | 200 | # measure execution time 201 | torch.cuda.synchronize(device=device) 202 | start = torch.cuda.Event(enable_timing=True) 203 | end = torch.cuda.Event(enable_timing=True) 204 | start.record() 205 | out = func(*inputs, **kwargs) 206 | end.record() 207 | torch.cuda.synchronize(device=device) 208 | exec_time = (start.elapsed_time(end)) 209 | 210 | # record memory usage 211 | peak_usage = torch.cuda.max_memory_allocated(device=device) 212 | cuda_mem = peak_usage - base_mem 213 | else: 214 | # convert seconds to milliseconds 215 | start = time.time() 216 | out = func(*inputs, **kwargs) 217 | end = time.time() 218 | exec_time = 1000 * (end - start) 219 | 220 | return exec_time, out, cuda_mem 221 | 222 | # just in case those operations are required we pass them to the profiled layer 223 | 224 | def __iter__(self): 225 | return iter(self.layer) 226 | 227 | def __getitem__(self, key): 228 | return self.layer[key] 229 | 230 | def __setitem__(self, key, value): 231 | self.layer[key] = value 232 | 233 | def __delitem__(self, idx): 234 | delattr(self.layer, idx) 235 | 236 | def __len__(self): 237 | return len(self.layer) 238 | 239 | def __contains__(self, key): 240 | return key in self.layer 241 | 242 | def __getattr__(self, name): 243 | try: 244 | return super().__getattr__(name) 245 | except Exception: 246 | return getattr(self.layer, name) 247 | -------------------------------------------------------------------------------- /sample_models/AlexNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['AlexNet', 'alexnet'] 6 | 7 | 8 | model_urls = { 9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 10 | } 11 | 12 | 13 | class AlexNet(nn.Module): 14 | 15 | def __init__(self, num_classes=1000): 16 | super(AlexNet, self).__init__() 17 | self.features = nn.Sequential( 18 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 19 | nn.ReLU(inplace=True), 20 | nn.MaxPool2d(kernel_size=3, stride=2), 21 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=3, stride=2), 24 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.MaxPool2d(kernel_size=3, stride=2), 31 | ) 32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 33 | self.classifier = nn.Sequential( 34 | nn.Dropout(), 35 | nn.Linear(256 * 6 * 6, 4096), 36 | nn.ReLU(inplace=True), 37 | nn.Dropout(), 38 | nn.Linear(4096, 4096), 39 | nn.ReLU(inplace=True), 40 | nn.Linear(4096, num_classes), 41 | ) 42 | 43 | def forward(self, x): 44 | x = self.features(x) 45 | x = self.avgpool(x) 46 | x = x.view(x.size(0), 256 * 6 * 6) 47 | x = self.classifier(x) 48 | return x 49 | 50 | 51 | def alexnet(pretrained=False, **kwargs): 52 | r"""AlexNet model architecture from the 53 | `"One weird trick..." `_ paper. 54 | 55 | Args: 56 | pretrained (bool): If True, returns a model pre-trained on ImageNet 57 | """ 58 | model = AlexNet(**kwargs) 59 | if pretrained: 60 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 61 | return model 62 | -------------------------------------------------------------------------------- /sample_models/DenseNet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | from collections import OrderedDict 6 | 7 | __all__ = ['DenseNet', 'densenet121', 8 | 'densenet169', 'densenet201', 'densenet161'] 9 | 10 | 11 | model_urls = { 12 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 13 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 14 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 15 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 16 | } 17 | 18 | 19 | class _DenseLayer(nn.Sequential): 20 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 21 | super(_DenseLayer, self).__init__() 22 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 23 | self.add_module('relu1', nn.ReLU(inplace=True)), 24 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 25 | growth_rate, kernel_size=1, stride=1, bias=False)), 26 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 27 | self.add_module('relu2', nn.ReLU(inplace=True)), 28 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 29 | kernel_size=3, stride=1, padding=1, bias=False)), 30 | self.drop_rate = drop_rate 31 | if self.drop_rate > 0: 32 | self.add_module("dropout", nn.Dropout(p=self.drop_rate)) 33 | 34 | def forward(self, x): 35 | new_features = super(_DenseLayer, self).forward(x) 36 | return torch.cat([x, new_features], 1) 37 | 38 | 39 | class _DenseBlock(nn.Sequential): 40 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 41 | super(_DenseBlock, self).__init__() 42 | for i in range(num_layers): 43 | layer = _DenseLayer(num_input_features + i * 44 | growth_rate, growth_rate, bn_size, drop_rate) 45 | self.add_module('denselayer%d' % (i + 1), layer) 46 | 47 | 48 | class _Transition(nn.Sequential): 49 | def __init__(self, num_input_features, num_output_features): 50 | super(_Transition, self).__init__() 51 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 52 | self.add_module('relu', nn.ReLU(inplace=True)) 53 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 54 | kernel_size=1, stride=1, bias=False)) 55 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 56 | 57 | 58 | class DenseNet(nn.Module): 59 | r"""Densenet-BC model class, based on 60 | `"Densely Connected Convolutional Networks" `_ 61 | 62 | Args: 63 | growth_rate (int) - how many filters to add each layer (`k` in paper) 64 | block_config (list of 4 ints) - how many layers in each pooling block 65 | num_init_features (int) - the number of filters to learn in the first convolution layer 66 | bn_size (int) - multiplicative factor for number of bottle neck layers 67 | (i.e. bn_size * k features in the bottleneck layer) 68 | drop_rate (float) - dropout rate after each dense layer 69 | num_classes (int) - number of classification classes 70 | """ 71 | 72 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 73 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 74 | 75 | super(DenseNet, self).__init__() 76 | 77 | # First convolution 78 | self.features = nn.Sequential(OrderedDict([ 79 | ('conv0', nn.Conv2d(3, num_init_features, 80 | kernel_size=7, stride=2, padding=3, bias=False)), 81 | ('norm0', nn.BatchNorm2d(num_init_features)), 82 | ('relu0', nn.ReLU(inplace=True)), 83 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 84 | ])) 85 | 86 | # Each denseblock 87 | num_features = num_init_features 88 | for i, num_layers in enumerate(block_config): 89 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 90 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 91 | self.features.add_module('denseblock%d' % (i + 1), block) 92 | num_features = num_features + num_layers * growth_rate 93 | if i != len(block_config) - 1: 94 | trans = _Transition( 95 | num_input_features=num_features, num_output_features=num_features // 2) 96 | self.features.add_module('transition%d' % (i + 1), trans) 97 | num_features = num_features // 2 98 | 99 | # Final batch norm 100 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 101 | 102 | self.relu = nn.ReLU() 103 | self.adaptive_avg_pool2d = nn.AdaptiveAvgPool2d((1, 1)) 104 | 105 | # Linear layer 106 | self.classifier = nn.Linear(num_features, num_classes) 107 | 108 | # Official init from torch repo. 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | nn.init.kaiming_normal_(m.weight) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | nn.init.constant_(m.weight, 1) 114 | nn.init.constant_(m.bias, 0) 115 | elif isinstance(m, nn.Linear): 116 | nn.init.constant_(m.bias, 0) 117 | 118 | def forward(self, x): 119 | features = self.features(x) 120 | out = self.relu(features) 121 | out = self.adaptive_avg_pool2d(out) 122 | out = out.view(out.size(0), -1) 123 | out = self.classifier(out) 124 | return out 125 | 126 | 127 | def densenet121(pretrained=False, **kwargs): 128 | r"""Densenet-121 model from 129 | `"Densely Connected Convolutional Networks" `_ 130 | 131 | Args: 132 | pretrained (bool): If True, returns a model pre-trained on ImageNet 133 | """ 134 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 135 | **kwargs) 136 | if pretrained: 137 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 138 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 139 | # They are also in the checkpoints in model_urls. This pattern is used 140 | # to find such keys. 141 | pattern = re.compile( 142 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 143 | state_dict = model_zoo.load_url(model_urls['densenet121']) 144 | for key in list(state_dict.keys()): 145 | res = pattern.match(key) 146 | if res: 147 | new_key = res.group(1) + res.group(2) 148 | state_dict[new_key] = state_dict[key] 149 | del state_dict[key] 150 | model.load_state_dict(state_dict) 151 | return model 152 | 153 | 154 | def densenet169(pretrained=False, **kwargs): 155 | r"""Densenet-169 model from 156 | `"Densely Connected Convolutional Networks" `_ 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | """ 161 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 162 | **kwargs) 163 | if pretrained: 164 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 165 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 166 | # They are also in the checkpoints in model_urls. This pattern is used 167 | # to find such keys. 168 | pattern = re.compile( 169 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 170 | state_dict = model_zoo.load_url(model_urls['densenet169']) 171 | for key in list(state_dict.keys()): 172 | res = pattern.match(key) 173 | if res: 174 | new_key = res.group(1) + res.group(2) 175 | state_dict[new_key] = state_dict[key] 176 | del state_dict[key] 177 | model.load_state_dict(state_dict) 178 | return model 179 | 180 | 181 | def densenet201(pretrained=False, **kwargs): 182 | r"""Densenet-201 model from 183 | `"Densely Connected Convolutional Networks" `_ 184 | 185 | Args: 186 | pretrained (bool): If True, returns a model pre-trained on ImageNet 187 | """ 188 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 189 | **kwargs) 190 | if pretrained: 191 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 192 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 193 | # They are also in the checkpoints in model_urls. This pattern is used 194 | # to find such keys. 195 | pattern = re.compile( 196 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 197 | state_dict = model_zoo.load_url(model_urls['densenet201']) 198 | for key in list(state_dict.keys()): 199 | res = pattern.match(key) 200 | if res: 201 | new_key = res.group(1) + res.group(2) 202 | state_dict[new_key] = state_dict[key] 203 | del state_dict[key] 204 | model.load_state_dict(state_dict) 205 | return model 206 | 207 | 208 | def densenet161(pretrained=False, **kwargs): 209 | r"""Densenet-161 model from 210 | `"Densely Connected Convolutional Networks" `_ 211 | 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 216 | **kwargs) 217 | if pretrained: 218 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 219 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 220 | # They are also in the checkpoints in model_urls. This pattern is used 221 | # to find such keys. 222 | pattern = re.compile( 223 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 224 | state_dict = model_zoo.load_url(model_urls['densenet161']) 225 | for key in list(state_dict.keys()): 226 | res = pattern.match(key) 227 | if res: 228 | new_key = res.group(1) + res.group(2) 229 | state_dict[new_key] = state_dict[key] 230 | del state_dict[key] 231 | model.load_state_dict(state_dict) 232 | return model 233 | -------------------------------------------------------------------------------- /sample_models/GoogleNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | __all__ = ["GoogLeNet"] 5 | 6 | 7 | class Inception(nn.Module): 8 | def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes): 9 | super(Inception, self).__init__() 10 | # 1x1 conv branch 11 | self.b1 = nn.Sequential( 12 | nn.Conv2d(in_planes, kernel_1_x, kernel_size=1), 13 | nn.BatchNorm2d(kernel_1_x), 14 | nn.ReLU(True), 15 | ) 16 | 17 | # 1x1 conv -> 3x3 conv branch 18 | self.b2 = nn.Sequential( 19 | nn.Conv2d(in_planes, kernel_3_in, kernel_size=1), 20 | nn.BatchNorm2d(kernel_3_in), 21 | nn.ReLU(True), 22 | nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(kernel_3_x), 24 | nn.ReLU(True), 25 | ) 26 | 27 | # 1x1 conv -> 5x5 conv branch 28 | self.b3 = nn.Sequential( 29 | nn.Conv2d(in_planes, kernel_5_in, kernel_size=1), 30 | nn.BatchNorm2d(kernel_5_in), 31 | nn.ReLU(True), 32 | nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(kernel_5_x), 34 | nn.ReLU(True), 35 | nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(kernel_5_x), 37 | nn.ReLU(True), 38 | ) 39 | 40 | # 3x3 pool -> 1x1 conv branch 41 | self.b4 = nn.Sequential( 42 | nn.MaxPool2d(3, stride=1, padding=1), 43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 44 | nn.BatchNorm2d(pool_planes), 45 | nn.ReLU(True), 46 | ) 47 | 48 | def forward(self, x): 49 | y1 = self.b1(x) 50 | y2 = self.b2(x) 51 | y3 = self.b3(x) 52 | y4 = self.b4(x) 53 | return torch.cat([y1, y2, y3, y4], 1) 54 | 55 | 56 | class GoogLeNet(nn.Module): 57 | def __init__(self, num_classes=1000): 58 | super(GoogLeNet, self).__init__() 59 | self.pre_layers = nn.Sequential( 60 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(192), 62 | nn.ReLU(True), 63 | ) 64 | 65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 67 | 68 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 69 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 70 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 71 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 72 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 73 | 74 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 75 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 76 | 77 | self.linear = nn.Linear(1024, num_classes) 78 | 79 | def forward(self, x): 80 | x = self.pre_layers(x) 81 | x = self.a3(x) 82 | x = self.b3(x) 83 | x = F.max_pool2d(x, 3, stride=2, padding=1) 84 | x = self.a4(x) 85 | x = self.b4(x) 86 | x = self.c4(x) 87 | x = self.d4(x) 88 | x = self.e4(x) 89 | x = F.max_pool2d(x, 3, stride=2, padding=1) 90 | x = self.a5(x) 91 | x = self.b5(x) 92 | x = F.avg_pool2d(x, 8, stride=1) 93 | x = x.view(x.size(0), -1) 94 | x = self.linear(x) 95 | return x 96 | -------------------------------------------------------------------------------- /sample_models/LeNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | __all__ = ["LeNet"] 4 | 5 | 6 | class LeNet(nn.Module): 7 | def __init__(self, num_classes=1000): 8 | super(LeNet, self).__init__() 9 | self.conv1 = nn.Conv2d(3, 6, kernel_size=5) 10 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5) 11 | self.fc1 = nn.Linear(16*5*5, 120) 12 | self.fc2 = nn.Linear(120, 84) 13 | self.fc3 = nn.Linear(84, num_classes) 14 | self.relu1 = nn.ReLU() 15 | self.relu2 = nn.ReLU() 16 | self.relu3 = nn.ReLU() 17 | self.relu4 = nn.ReLU() 18 | self.max_pool2d1 = nn.MaxPool2d(2) 19 | self.max_pool2d2 = nn.MaxPool2d(2) 20 | 21 | def forward(self, x): 22 | x = self.relu1(self.conv1(x)) 23 | x = self.max_pool2d1(x) 24 | x = self.relu2(self.conv2(x)) 25 | x = self.max_pool2d2(x) 26 | x = x.view(x.size(0), -1) 27 | x = self.relu3(self.fc1(x)) 28 | x = self.relu4(self.fc2(x)) 29 | x = self.fc3(x) 30 | return x 31 | -------------------------------------------------------------------------------- /sample_models/SqueezeNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] 8 | 9 | 10 | model_urls = { 11 | 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', 12 | 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', 13 | } 14 | 15 | 16 | class Fire(nn.Module): 17 | 18 | def __init__(self, inplanes, squeeze_planes, 19 | expand1x1_planes, expand3x3_planes): 20 | super(Fire, self).__init__() 21 | self.inplanes = inplanes 22 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 23 | self.squeeze_activation = nn.ReLU(inplace=True) 24 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 25 | kernel_size=1) 26 | self.expand1x1_activation = nn.ReLU(inplace=True) 27 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 28 | kernel_size=3, padding=1) 29 | self.expand3x3_activation = nn.ReLU(inplace=True) 30 | 31 | def forward(self, x): 32 | x = self.squeeze_activation(self.squeeze(x)) 33 | return torch.cat([ 34 | self.expand1x1_activation(self.expand1x1(x)), 35 | self.expand3x3_activation(self.expand3x3(x)) 36 | ], 1) 37 | 38 | 39 | class SqueezeNet(nn.Module): 40 | 41 | def __init__(self, version=1.0, num_classes=1000): 42 | super(SqueezeNet, self).__init__() 43 | if version not in [1.0, 1.1]: 44 | raise ValueError("Unsupported SqueezeNet version {version}:" 45 | "1.0 or 1.1 expected".format(version=version)) 46 | self.num_classes = num_classes 47 | if version == 1.0: 48 | self.features = nn.Sequential( 49 | nn.Conv2d(3, 96, kernel_size=7, stride=2), 50 | nn.ReLU(inplace=True), 51 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 52 | Fire(96, 16, 64, 64), 53 | Fire(128, 16, 64, 64), 54 | Fire(128, 32, 128, 128), 55 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 56 | Fire(256, 32, 128, 128), 57 | Fire(256, 48, 192, 192), 58 | Fire(384, 48, 192, 192), 59 | Fire(384, 64, 256, 256), 60 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 61 | Fire(512, 64, 256, 256), 62 | ) 63 | else: 64 | self.features = nn.Sequential( 65 | nn.Conv2d(3, 64, kernel_size=3, stride=2), 66 | nn.ReLU(inplace=True), 67 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 68 | Fire(64, 16, 64, 64), 69 | Fire(128, 16, 64, 64), 70 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 71 | Fire(128, 32, 128, 128), 72 | Fire(256, 32, 128, 128), 73 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 74 | Fire(256, 48, 192, 192), 75 | Fire(384, 48, 192, 192), 76 | Fire(384, 64, 256, 256), 77 | Fire(512, 64, 256, 256), 78 | ) 79 | # Final convolution is initialized differently form the rest 80 | final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) 81 | self.classifier = nn.Sequential( 82 | nn.Dropout(p=0.5), 83 | final_conv, 84 | nn.ReLU(inplace=True), 85 | nn.AdaptiveAvgPool2d((1, 1)) 86 | ) 87 | 88 | for m in self.modules(): 89 | if isinstance(m, nn.Conv2d): 90 | if m is final_conv: 91 | init.normal_(m.weight, mean=0.0, std=0.01) 92 | else: 93 | init.kaiming_uniform_(m.weight) 94 | if m.bias is not None: 95 | init.constant_(m.bias, 0) 96 | 97 | def forward(self, x): 98 | x = self.features(x) 99 | x = self.classifier(x) 100 | return x.view(x.size(0), self.num_classes) 101 | 102 | 103 | def squeezenet1_0(pretrained=False, **kwargs): 104 | r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level 105 | accuracy with 50x fewer parameters and <0.5MB model size" 106 | `_ paper. 107 | 108 | Args: 109 | pretrained (bool): If True, returns a model pre-trained on ImageNet 110 | """ 111 | model = SqueezeNet(version=1.0, **kwargs) 112 | if pretrained: 113 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0'])) 114 | return model 115 | 116 | 117 | def squeezenet1_1(pretrained=False, **kwargs): 118 | r"""SqueezeNet 1.1 model from the `official SqueezeNet repo 119 | `_. 120 | SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters 121 | than SqueezeNet 1.0, without sacrificing accuracy. 122 | 123 | Args: 124 | pretrained (bool): If True, returns a model pre-trained on ImageNet 125 | """ 126 | model = SqueezeNet(version=1.1, **kwargs) 127 | if pretrained: 128 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1'])) 129 | return model 130 | -------------------------------------------------------------------------------- /sample_models/VGG.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = [ 6 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 7 | 'vgg19_bn', 'vgg19', 8 | ] 9 | 10 | 11 | model_urls = { 12 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 13 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 14 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 15 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 16 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 17 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 18 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 19 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 20 | } 21 | 22 | 23 | class VGG(nn.Module): 24 | 25 | def __init__(self, features, num_classes=1000, init_weights=True): 26 | super(VGG, self).__init__() 27 | self.features = features 28 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 29 | self.classifier = nn.Sequential( 30 | nn.Linear(512 * 7 * 7, 4096), 31 | nn.ReLU(True), 32 | nn.Dropout(), 33 | nn.Linear(4096, 4096), 34 | nn.ReLU(True), 35 | nn.Dropout(), 36 | nn.Linear(4096, num_classes), 37 | ) 38 | if init_weights: 39 | self._initialize_weights() 40 | 41 | def forward(self, x): 42 | x = self.features(x) 43 | x = self.avgpool(x) 44 | x = x.view(x.size(0), -1) 45 | x = self.classifier(x) 46 | return x 47 | 48 | def _initialize_weights(self): 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.kaiming_normal_( 52 | m.weight, mode='fan_out', nonlinearity='relu') 53 | if m.bias is not None: 54 | nn.init.constant_(m.bias, 0) 55 | elif isinstance(m, nn.BatchNorm2d): 56 | nn.init.constant_(m.weight, 1) 57 | nn.init.constant_(m.bias, 0) 58 | elif isinstance(m, nn.Linear): 59 | nn.init.normal_(m.weight, 0, 0.01) 60 | nn.init.constant_(m.bias, 0) 61 | 62 | 63 | def make_layers(cfg, batch_norm=False): 64 | layers = [] 65 | in_channels = 3 66 | for v in cfg: 67 | if v == 'M': 68 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 69 | else: 70 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 71 | if batch_norm: 72 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 73 | else: 74 | layers += [conv2d, nn.ReLU(inplace=True)] 75 | in_channels = v 76 | return nn.Sequential(*layers) 77 | 78 | 79 | cfg = { 80 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 81 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 82 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 83 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 84 | } 85 | 86 | 87 | def vgg11(pretrained=False, **kwargs): 88 | """VGG 11-layer model (configuration "A") 89 | 90 | Args: 91 | pretrained (bool): If True, returns a model pre-trained on ImageNet 92 | """ 93 | if pretrained: 94 | kwargs['init_weights'] = False 95 | model = VGG(make_layers(cfg['A']), **kwargs) 96 | if pretrained: 97 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 98 | return model 99 | 100 | 101 | def vgg11_bn(pretrained=False, **kwargs): 102 | """VGG 11-layer model (configuration "A") with batch normalization 103 | 104 | Args: 105 | pretrained (bool): If True, returns a model pre-trained on ImageNet 106 | """ 107 | if pretrained: 108 | kwargs['init_weights'] = False 109 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 110 | if pretrained: 111 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 112 | return model 113 | 114 | 115 | def vgg13(pretrained=False, **kwargs): 116 | """VGG 13-layer model (configuration "B") 117 | 118 | Args: 119 | pretrained (bool): If True, returns a model pre-trained on ImageNet 120 | """ 121 | if pretrained: 122 | kwargs['init_weights'] = False 123 | model = VGG(make_layers(cfg['B']), **kwargs) 124 | if pretrained: 125 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 126 | return model 127 | 128 | 129 | def vgg13_bn(pretrained=False, **kwargs): 130 | """VGG 13-layer model (configuration "B") with batch normalization 131 | 132 | Args: 133 | pretrained (bool): If True, returns a model pre-trained on ImageNet 134 | """ 135 | if pretrained: 136 | kwargs['init_weights'] = False 137 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 138 | if pretrained: 139 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 140 | return model 141 | 142 | 143 | def vgg16(pretrained=False, **kwargs): 144 | """VGG 16-layer model (configuration "D") 145 | 146 | Args: 147 | pretrained (bool): If True, returns a model pre-trained on ImageNet 148 | """ 149 | if pretrained: 150 | kwargs['init_weights'] = False 151 | model = VGG(make_layers(cfg['D']), **kwargs) 152 | if pretrained: 153 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 154 | return model 155 | 156 | 157 | def vgg16_bn(pretrained=False, **kwargs): 158 | """VGG 16-layer model (configuration "D") with batch normalization 159 | 160 | Args: 161 | pretrained (bool): If True, returns a model pre-trained on ImageNet 162 | """ 163 | if pretrained: 164 | kwargs['init_weights'] = False 165 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 166 | if pretrained: 167 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 168 | return model 169 | 170 | 171 | def vgg19(pretrained=False, **kwargs): 172 | """VGG 19-layer model (configuration "E") 173 | 174 | Args: 175 | pretrained (bool): If True, returns a model pre-trained on ImageNet 176 | """ 177 | if pretrained: 178 | kwargs['init_weights'] = False 179 | model = VGG(make_layers(cfg['E']), **kwargs) 180 | if pretrained: 181 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 182 | return model 183 | 184 | 185 | def vgg19_bn(pretrained=False, **kwargs): 186 | """VGG 19-layer model (configuration 'E') with batch normalization 187 | 188 | Args: 189 | pretrained (bool): If True, returns a model pre-trained on ImageNet 190 | """ 191 | if pretrained: 192 | kwargs['init_weights'] = False 193 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 194 | if pretrained: 195 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 196 | return model 197 | -------------------------------------------------------------------------------- /sample_models/WideResNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | __all__ = ["WideResNet"] 6 | 7 | 8 | class BasicBlock(nn.Module): 9 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0): 10 | super(BasicBlock, self).__init__() 11 | self.bn1 = nn.BatchNorm2d(in_planes) 12 | self.relu1 = nn.ReLU(inplace=True) 13 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(out_planes) 16 | self.relu2 = nn.ReLU(inplace=True) 17 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 18 | padding=1, bias=False) 19 | self.droprate = drop_rate 20 | self.equalInOut = (in_planes == out_planes) 21 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 22 | padding=0, bias=False) or None 23 | if self.droprate > 0: 24 | self.dropout = nn.Dropout(p=self.droprate) 25 | 26 | def forward(self, x): 27 | if not self.equalInOut: 28 | x = self.relu1(self.bn1(x)) 29 | else: 30 | out = self.relu1(self.bn1(x)) 31 | 32 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 33 | if self.droprate > 0: 34 | out = self.dropout(out) 35 | out = self.conv2(out) 36 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 37 | 38 | 39 | class NetworkBlock(nn.Module): 40 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 41 | super(NetworkBlock, self).__init__() 42 | self.layer = self._make_layer( 43 | block, in_planes, out_planes, nb_layers, stride, dropRate) 44 | 45 | @staticmethod 46 | def _make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate): 47 | layers = [] 48 | for i in range(nb_layers): 49 | layers.append(block(i == 0 and in_planes or out_planes, 50 | out_planes, i == 0 and stride or 1, dropRate)) 51 | return nn.Sequential(*layers) 52 | 53 | def forward(self, x): 54 | return self.layer(x) 55 | 56 | 57 | class WideResNet(nn.Module): 58 | def __init__(self, depth=10, num_classes=1000, widen_factor=1, drop_rate=0.0): 59 | super(WideResNet, self).__init__() 60 | n_channels = [16, 16 * widen_factor, 61 | 32 * widen_factor, 64 * widen_factor] 62 | assert ((depth - 4) % 6 == 0) 63 | n = int((depth - 4) / 6) 64 | block = BasicBlock 65 | # 1st conv before any network block 66 | self.conv1 = nn.Conv2d(3, n_channels[0], kernel_size=3, stride=1, 67 | padding=1, bias=False) 68 | # 1st block 69 | self.block1 = NetworkBlock( 70 | n, n_channels[0], n_channels[1], block, 1, drop_rate) 71 | # 2nd block 72 | self.block2 = NetworkBlock( 73 | n, n_channels[1], n_channels[2], block, 2, drop_rate) 74 | # 3rd block 75 | self.block3 = NetworkBlock( 76 | n, n_channels[2], n_channels[3], block, 2, drop_rate) 77 | # global average pooling and classifier 78 | self.bn1 = nn.BatchNorm2d(n_channels[3]) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.avg_pool = nn.AvgPool2d(8) 81 | self.fc = nn.Linear(n_channels[3], num_classes) 82 | 83 | self.nChannels = n_channels[3] 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 88 | m.weight.data.normal_(0, math.sqrt(2. / n)) 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | elif isinstance(m, nn.Linear): 93 | m.bias.data.zero_() 94 | 95 | def forward(self, x): 96 | out = self.conv1(x) 97 | out = self.block1(out) 98 | out = self.block2(out) 99 | out = self.block3(out) 100 | out = self.relu(self.bn1(out)) 101 | out = self.avg_pool(out) 102 | out = out.view(-1, self.nChannels) 103 | return self.fc(out) 104 | -------------------------------------------------------------------------------- /sample_models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .AlexNet import * 3 | from .VGG import * 4 | from .ResNet import * 5 | from .LeNet import * 6 | from .DenseNet import * 7 | from .GoogleNet import * 8 | from .WideResNet import * 9 | from .Inception import * 10 | from .SqueezeNet import * 11 | from .amoebaNet import AmoebaNet_D 12 | from .torchGpipe_ref_models import amoebanetd, resnet101 as torchgpipe_resnet101 13 | -------------------------------------------------------------------------------- /sample_models/amoebaNet/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Optional, List 3 | 4 | import torch 5 | from torch import nn 6 | from torch import Tensor 7 | from .genotype import Genotype, amoebanetd_genotype 8 | from .utils import Conv_3x3, Conv_7x1_1x7, Conv_Cell, FactorizedReduce, Pool_Operation 9 | 10 | __all__ = ["AmoebaNet_D"] 11 | 12 | # based on torchGpipe implementation https://github.com/kakaobrain/torchgpipe/blob/master/examples/amoebanet/__init__.py 13 | 14 | 15 | def conv_1x1(channel: int, stride: int, affine: bool) -> Conv_Cell: 16 | return Conv_Cell(channel, channel, 1, stride, 0, affine, use_relu=True) 17 | 18 | 19 | def avg_pool_3x3(channel: int, stride: int, affine: bool) -> Pool_Operation: 20 | return Pool_Operation('avg', 3, channel, stride, affine) 21 | 22 | 23 | def max_pool_2x2(channel: int, stride: int, affine: bool) -> Pool_Operation: 24 | return Pool_Operation('max', 2, channel, stride, affine) 25 | 26 | 27 | def skip_connect(channel: int, stride: int, affine: bool) -> nn.Module: 28 | if stride == 1: 29 | return nn.Identity() 30 | return FactorizedReduce(channel, channel, affine) 31 | 32 | 33 | class Classifier(nn.Module): 34 | def __init__(self, channel_prev: int, num_classes: int): 35 | super().__init__() 36 | 37 | self.global_pooling = nn.AvgPool2d(7) 38 | self.classifier = nn.Linear(channel_prev, num_classes) 39 | 40 | def forward(self, x: Tensor) -> Tensor: 41 | s1 = self.global_pooling(x) 42 | y = self.classifier(s1.view(s1.size(0), -1)) 43 | return y 44 | 45 | 46 | class Stem(nn.Module): 47 | def __init__(self, channel: int): 48 | super(Stem, self).__init__() 49 | self.conv_cell = Conv_Cell(3, channel, 3, 2, 1, False) 50 | 51 | def forward(self, x: Tensor) -> Tuple[Tensor]: 52 | out = self.conv_cell(x) 53 | return out 54 | 55 | 56 | OPS = { 57 | 'skip_connect': skip_connect, 58 | 'avg_pool_3x3': avg_pool_3x3, 59 | 'max_pool_2x2': max_pool_2x2, 60 | 'conv_7x1_1x7': Conv_7x1_1x7, 61 | 'conv_1x1____': conv_1x1, 62 | 'conv_3x3____': Conv_3x3, 63 | } 64 | 65 | 66 | class AmoebaNet_D(nn.Module): 67 | """an AmoebaNet-D model for ImageNet.""" 68 | 69 | def __init__(self, num_classes: int = 10, num_layers: int = 4, 70 | num_filters: int = 512, 71 | genotype: Optional[Genotype] = None): 72 | super(AmoebaNet_D, self).__init__() 73 | 74 | genotype = amoebanetd_genotype if genotype is None else genotype 75 | assert isinstance(genotype, Genotype) 76 | channel = num_filters // 4 77 | 78 | channel_prev_prev, channel_prev, channel_curr = channel, channel, channel 79 | cells = [] 80 | 81 | # reduction 82 | channel_curr *= 2 83 | reduction_prev = False 84 | reduction = True 85 | cell = Amoeba_Cell(genotype, channel_prev_prev, 86 | channel_prev, channel_curr, reduction, reduction_prev) 87 | multiplier = len(cell.concat_indices) 88 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 89 | cells.append(cell) 90 | 91 | # reduction 92 | channel_curr *= 2 93 | reduction_prev = True 94 | reduction = True 95 | cell = Amoeba_Cell(genotype, channel_prev_prev, 96 | channel_prev, channel_curr, reduction, reduction_prev) 97 | multiplier = len(cell.concat_indices) 98 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 99 | cells.append(cell) 100 | 101 | # not reduction 102 | reduction_prev = True 103 | reduction = False 104 | for _ in range(num_layers): 105 | cell = Amoeba_Cell(genotype, channel_prev_prev, 106 | channel_prev, channel_curr, reduction, reduction_prev) 107 | multiplier = len(cell.concat_indices) 108 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 109 | cells.append(cell) 110 | reduction_prev = False 111 | 112 | # reduction 113 | channel_curr *= 2 114 | reduction_prev = False 115 | reduction = True 116 | cell = Amoeba_Cell(genotype, channel_prev_prev, 117 | channel_prev, channel_curr, reduction, reduction_prev) 118 | multiplier = len(cell.concat_indices) 119 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 120 | cells.append(cell) 121 | 122 | # not reduction 123 | reduction_prev = True 124 | reduction = False 125 | for _ in range(num_layers): 126 | cell = Amoeba_Cell(genotype, channel_prev_prev, 127 | channel_prev, channel_curr, reduction, reduction_prev) 128 | multiplier = len(cell.concat_indices) 129 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 130 | cells.append(cell) 131 | reduction_prev = False 132 | 133 | # reduction 134 | channel_curr *= 2 135 | reduction_prev = False 136 | reduction = True 137 | cell = Amoeba_Cell(genotype, channel_prev_prev, 138 | channel_prev, channel_curr, reduction, reduction_prev) 139 | multiplier = len(cell.concat_indices) 140 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 141 | cells.append(cell) 142 | 143 | # not reduction 144 | reduction_prev = True 145 | reduction = False 146 | for _ in range(num_layers): 147 | cell = Amoeba_Cell(genotype, channel_prev_prev, 148 | channel_prev, channel_curr, reduction, reduction_prev) 149 | multiplier = len(cell.concat_indices) 150 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 151 | cells.append(cell) 152 | reduction_prev = False 153 | 154 | self.stem = Stem(channel) 155 | self.cells = nn.Sequential(*cells) 156 | self.classifier = Classifier(channel_prev, num_classes) 157 | 158 | def forward(self, x: Tensor) -> Tensor: 159 | out = self.stem(x) 160 | out = self.cells((out, out)) 161 | return self.classifier(out[1]) 162 | 163 | 164 | class Amoeba_Cell(nn.Module): 165 | def __init__(self, genotype: Genotype, 166 | channel_prev_prev: int, channel_prev: int, channel: int, 167 | reduction: bool, reduction_prev: bool): 168 | super(Amoeba_Cell, self).__init__() 169 | 170 | preprocess0 = nn.Sequential() 171 | if reduction_prev: 172 | preprocess0 = FactorizedReduce(channel_prev_prev, channel) 173 | elif channel_prev_prev != channel: 174 | preprocess0 = Conv_Cell(channel_prev_prev, channel, 1, 1, 0, True) 175 | 176 | preprocess0: nn.Module = preprocess0 177 | preprocess1 = Conv_Cell(channel_prev, channel, 1, 1, 0, True) 178 | 179 | if reduction: 180 | op_names, indices = zip(*genotype.reduce) 181 | concat = genotype.reduce_concat 182 | else: 183 | op_names, indices = zip(*genotype.normal) 184 | concat = genotype.normal_concat 185 | 186 | ops = [] 187 | for name, index in zip(op_names, indices): 188 | if reduction and index < 2: 189 | stride = 2 190 | else: 191 | stride = 1 192 | op = OPS[name](channel, stride, True) 193 | ops.append((op, index)) 194 | 195 | self.preprocess0 = preprocess0 196 | self.preprocess1 = preprocess1 197 | 198 | layers = [] 199 | assert (len(ops) % 2) == 0 200 | for i in range(len(ops) // 2): 201 | op0, i0 = ops[i * 2] 202 | op1, i1 = ops[i * 2 + 1] 203 | layers.extend([ 204 | InputOne(op0, i=i0, insert=2 + i), 205 | # Output: x..., op0(x[i0]), skip] 206 | 207 | InputOne(op1, i=i1, insert=2 + i + 1), 208 | # Output: x..., op0(x[i0]), op1(x[i1]), skip 209 | 210 | MergeTwo(2 + i, 2 + i + 1), 211 | # Output: x..., op0(x[i0]) + op1(x[i1]), skip 212 | ]) 213 | self.layers = nn.Sequential(*layers) 214 | 215 | self.concat_indices = concat 216 | 217 | assert len(concat) > 0 and all(i < (3 + (len(ops) // 2) - 1) 218 | for i in concat) 219 | 220 | def forward(self, xs): 221 | preprocessed = self.preprocess(xs) 222 | # preprocess(x0),preprocess(x1),x1 223 | out = preprocessed 224 | out = self.layers(out) 225 | # x,........,skip 226 | reduced = self.reduce_channels(self.concat_indices, out) 227 | # skip,concat 228 | return reduced 229 | 230 | def preprocess(self, xs): 231 | x0, x1 = xs 232 | return self.preprocess0(x0), self.preprocess1(x1), x1 233 | 234 | def reduce_channels(self, indices, xs): 235 | # indices = 4,5,6 236 | # x0,x1,x2,x3,x4,x5,x6,x7 237 | # x7,x0,x1,x2,x3,x4,x5,x6 238 | # x7,concat(x4,x5,x6) 239 | return xs[-1], torch.cat([xs[i] for i in indices], dim=1) 240 | 241 | 242 | Tensors = Tuple[Tensor, ...] 243 | 244 | 245 | class Hack(nn.Module): 246 | 247 | def forward(self, *args, **kwargs): 248 | raise NotImplementedError() 249 | 250 | 251 | class InputOne(Hack): 252 | """Picks one tensor for the underlying module input:: 253 | a -----> a 254 | b --f--> f(b) 255 | c -----> c 256 | """ 257 | 258 | def __init__(self, module: nn.Module, i: int, insert: Optional[int] = None): 259 | super().__init__() 260 | self.module = module 261 | self.i = i 262 | self.insert = insert 263 | 264 | def forward(self, tensors: Tensors) -> Tensors: # type: ignore 265 | i = self.i 266 | 267 | # for t in tensors: 268 | # print(t.shape[1:]) 269 | # print("\n") 270 | input = tensors[i] 271 | output = self.module(input) 272 | 273 | if not isinstance(output, tuple): 274 | output = (output,) 275 | 276 | if self.insert is None: 277 | # Replace with the input. 278 | return tensors[:i] + output + tensors[i + 1:] 279 | 280 | return tensors[:self.insert] + output + tensors[self.insert:] 281 | 282 | 283 | class MergeTwo(Hack): 284 | """Merges the last two tensors and replace them with the result:: 285 | a -----> a 286 | b --+--> b+c 287 | c --+ 288 | """ 289 | 290 | def __init__(self, i: int, j: int): 291 | super().__init__() 292 | self.i = i 293 | self.j = j 294 | 295 | def forward(self, *tensors: Tensors) -> Tensors: # type: ignore 296 | if len(tensors) > 1: 297 | return sum(tensors) 298 | 299 | tensors = tensors[0] 300 | i = self.i 301 | j = self.j 302 | # Set the initial value as the first tensor 303 | # to type as 'Tensor' instead of 'Union[Tensor, int]'. 304 | merged = sum(tensors[i + 1:j + 1], tensors[i]) 305 | 306 | return tensors[:i] + (merged,) + tensors[j + 1:] 307 | -------------------------------------------------------------------------------- /sample_models/amoebaNet/genotype.py: -------------------------------------------------------------------------------- 1 | from typing import List, NamedTuple, Tuple 2 | 3 | __all__ = ['amoebanetd_genotype'] 4 | 5 | 6 | class Genotype(NamedTuple): 7 | normal: List[Tuple[str, int]] 8 | normal_concat: List[int] 9 | reduce: List[Tuple[str, int]] 10 | reduce_concat: List[int] 11 | 12 | 13 | # The AmoebaNet-D genotype is based on the 'Regularized Evolution for Image Classifier 14 | # Architecture Search' paper (https://arxiv.org/pdf/1802.01548.pdf). 15 | amoebanetd_genotype = Genotype( 16 | normal=[ 17 | ('avg_pool_3x3', 0), 18 | ('conv_1x1____', 0), 19 | ('skip_connect', 2), 20 | ('avg_pool_3x3', 2), 21 | ('skip_connect', 0), 22 | ('conv_7x1_1x7', 1), 23 | ('conv_1x1____', 1), 24 | ('conv_7x1_1x7', 1), 25 | ('avg_pool_3x3', 0), 26 | ('conv_1x1____', 3), 27 | ], 28 | normal_concat=[4, 5, 6], 29 | reduce=[ 30 | ('max_pool_2x2', 1), 31 | ('avg_pool_3x3', 1), 32 | ('conv_3x3____', 0), 33 | ('skip_connect', 2), 34 | ('conv_7x1_1x7', 2), 35 | ('avg_pool_3x3', 2), 36 | ('avg_pool_3x3', 2), 37 | ('conv_1x1____', 3), 38 | ('skip_connect', 3), 39 | ('max_pool_2x2', 0), 40 | ], 41 | reduce_concat=[4, 5, 6] 42 | ) 43 | -------------------------------------------------------------------------------- /sample_models/amoebaNet/utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | from torch import nn 5 | from torch import Tensor 6 | 7 | 8 | class Pool_Operation(nn.Module): 9 | def __init__(self, pool_type: str, kernel_size: int, channel: int, stride: int, affine: bool): 10 | super(Pool_Operation, self).__init__() 11 | assert pool_type in['avg', 'max'] 12 | 13 | if pool_type == 'avg': 14 | self.pool = nn.AvgPool2d(kernel_size, stride=stride, 15 | padding=1, count_include_pad=False) 16 | else: 17 | self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=0) 18 | 19 | self.conv_cell = Conv_Cell(channel, channel, 1, 20 | 1, 0, affine, use_relu=False) 21 | 22 | def forward(self, x: Tensor) -> Tensor: 23 | out = self.pool(x) 24 | return self.conv_cell(out) 25 | 26 | 27 | class Conv_Cell(nn.Module): 28 | def __init__(self, in_channels: int, out_channels: int, kernel, stride, padding, affine: bool, use_relu: bool = True): 29 | super(Conv_Cell, self).__init__() 30 | self.relu = nn.ReLU(inplace=False) if use_relu else None 31 | self.conv = nn.Conv2d(in_channels, out_channels, 32 | kernel, stride=stride, padding=padding, bias=False) 33 | self.norm = nn.BatchNorm2d(out_channels, affine=affine) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | out = x if self.relu is None else self.relu(x) 37 | out = self.conv(out) 38 | return self.norm(out) 39 | 40 | 41 | class Conv_3x3(nn.Module): 42 | def __init__(self, channel: int, stride: int, affine: bool): 43 | super(Conv_3x3, self).__init__() 44 | 45 | self.conv1_1x1 = Conv_Cell(channel, channel//4, 1, 46 | 1, 0, affine) 47 | self.conv2_3x3 = Conv_Cell(channel//4, channel//4, 3, 48 | (stride, stride), 1, affine) 49 | self.conv3_1x1 = Conv_Cell(channel//4, channel, 1, 50 | 1, 0, affine) 51 | 52 | def forward(self, x: Tensor) -> Tensor: 53 | out = self.conv1_1x1(x) 54 | out = self.conv2_3x3(out) 55 | return self.conv3_1x1(out) 56 | 57 | 58 | class Conv_7x1_1x7(nn.Module): 59 | def __init__(self, channel: int, stride: int, affine: bool): 60 | super(Conv_7x1_1x7, self).__init__() 61 | 62 | self.conv1_1x1 = Conv_Cell(channel, channel//4, 1, 63 | 1, 0, affine) 64 | 65 | self.conv2_1x7 = Conv_Cell(channel//4, channel//4, (1, 7), 66 | (1, stride), (0, 3), affine) 67 | 68 | self.conv3_7x1 = Conv_Cell(channel//4, channel//4, (7, 1), 69 | (stride, 1), (3, 0), affine) 70 | 71 | self.conv4_1x1 = Conv_Cell(channel//4, channel, 1, 72 | 1, 0, affine) 73 | 74 | def forward(self, x: Tensor) -> Tensor: 75 | out = self.conv1_1x1(x) 76 | out = self.conv2_1x7(out) 77 | out = self.conv3_7x1(out) 78 | return self.conv4_1x1(out) 79 | 80 | 81 | class FactorizedReduce(nn.Module): 82 | """Operation Factorized reduce""" 83 | 84 | def __init__(self, in_planes: int, out_planes: int, affine: bool = True): 85 | super().__init__() 86 | self.relu = nn.ReLU(inplace=False) 87 | self.conv_1 = nn.Conv2d(in_planes, out_planes // 88 | 2, 1, stride=2, padding=0, bias=False) 89 | self.conv_2 = nn.Conv2d(in_planes, out_planes // 90 | 2, 1, stride=2, padding=0, bias=False) 91 | self.bn = nn.BatchNorm2d(out_planes, affine=affine) 92 | 93 | def forward(self, x: Tensor) -> Tensor: 94 | x = self.relu(x) 95 | branch1 = self.conv_1(x) 96 | branch2 = self.conv_2(x[:, :, 1:, 1:]) 97 | out = torch.cat([branch1, branch2], dim=1) 98 | out = self.bn(out) 99 | return out 100 | -------------------------------------------------------------------------------- /sample_models/torchGpipe_ref_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet101 2 | from .amoebanet_d import amoebanetd 3 | -------------------------------------------------------------------------------- /sample_models/torchGpipe_ref_models/amoebanet_d/__init__.py: -------------------------------------------------------------------------------- 1 | """An AmoebaNet-D implementation but using :class:`nn.Sequential`. :func:`amoebanetd` 2 | returns a :class:`nn.Sequential`. 3 | """ 4 | from collections import OrderedDict 5 | from typing import Tuple 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from ..flatten_sequential import flatten_sequential 11 | from .genotype import Genotype, amoebanetd_genotype 12 | from .surgery import Concat, FirstAnd, InputOne, MergeTwo, Shift, Twin, TwinLast 13 | 14 | __all__ = ['amoebanetd'] 15 | 16 | 17 | class Identity(nn.Module): 18 | """Through the tensor:: 19 | x ---> x 20 | """ 21 | 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore 23 | return x 24 | 25 | 26 | class ReLUConvBN(nn.Module): 27 | """Operation ReLU + Conv2d + BatchNorm2d""" 28 | 29 | def __init__(self, in_planes: int, out_planes: int, kernel_size: int, stride: int, 30 | padding: int, affine: bool = True): 31 | super().__init__() 32 | self.op = nn.Sequential( 33 | nn.ReLU(inplace=False), 34 | nn.Conv2d(in_planes, out_planes, kernel_size, 35 | stride=stride, padding=padding, bias=False), 36 | nn.BatchNorm2d(out_planes, affine=affine) 37 | ) 38 | 39 | def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore 40 | return self.op(x) 41 | 42 | 43 | class FactorizedReduce(nn.Module): 44 | """Operation Factorized reduce""" 45 | 46 | def __init__(self, in_planes: int, out_planes: int, affine: bool = True): 47 | super().__init__() 48 | self.relu = nn.ReLU(inplace=False) 49 | self.conv_1 = nn.Conv2d(in_planes, out_planes // 50 | 2, 1, stride=2, padding=0, bias=False) 51 | self.conv_2 = nn.Conv2d(in_planes, out_planes // 52 | 2, 1, stride=2, padding=0, bias=False) 53 | self.bn = nn.BatchNorm2d(out_planes, affine=affine) 54 | 55 | def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore 56 | x = self.relu(x) 57 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1) 58 | out = self.bn(out) 59 | return out 60 | 61 | 62 | def skip_connect(channel: int, stride: int, affine: bool) -> nn.Module: 63 | if stride == 1: 64 | return Identity() 65 | return FactorizedReduce(channel, channel, affine=affine) 66 | 67 | 68 | def avg_pool_3x3(channel: int, stride: int, affine: bool) -> nn.Sequential: 69 | return nn.Sequential( 70 | nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 71 | nn.Conv2d(channel, channel, (1, 1), stride=( 72 | 1, 1), padding=(0, 0), bias=False), 73 | nn.BatchNorm2d(channel, affine=affine)) 74 | 75 | 76 | def max_pool_3x3(channel: int, stride: int, affine: bool) -> nn.Sequential: 77 | return nn.Sequential( 78 | nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 79 | nn.Conv2d(channel, channel, (1, 1), stride=( 80 | 1, 1), padding=(0, 0), bias=False), 81 | nn.BatchNorm2d(channel, affine=affine)) 82 | 83 | 84 | def max_pool_2x2(channel: int, stride: int, affine: bool) -> nn.Sequential: 85 | return nn.Sequential( 86 | nn.MaxPool2d(2, stride=stride, padding=0), 87 | nn.Conv2d(channel, channel, (1, 1), stride=( 88 | 1, 1), padding=(0, 0), bias=False), 89 | nn.BatchNorm2d(channel, affine=affine)) 90 | 91 | 92 | def conv_7x1_1x7(channel: int, stride: int, affine: bool) -> nn.Sequential: 93 | return nn.Sequential( 94 | nn.ReLU(inplace=False), 95 | nn.Conv2d(channel, channel // 4, (1, 1), 96 | stride=(1, 1), padding=(0, 0), bias=False), 97 | nn.BatchNorm2d(channel // 4, affine=affine), 98 | nn.ReLU(inplace=False), 99 | nn.Conv2d(channel // 4, channel // 4, (1, 7), 100 | stride=(1, stride), padding=(0, 3), bias=False), 101 | nn.BatchNorm2d(channel // 4, affine=affine), 102 | nn.ReLU(inplace=False), 103 | nn.Conv2d(channel // 4, channel // 4, (7, 1), 104 | stride=(stride, 1), padding=(3, 0), bias=False), 105 | nn.BatchNorm2d(channel // 4, affine=affine), 106 | nn.ReLU(inplace=False), 107 | nn.Conv2d(channel // 4, channel, (1, 1), 108 | stride=(1, 1), padding=(0, 0), bias=False), 109 | nn.BatchNorm2d(channel, affine=affine)) 110 | 111 | 112 | def conv_1x1(channel: int, stride: int, affine: bool) -> nn.Sequential: 113 | return nn.Sequential( 114 | nn.ReLU(inplace=False), 115 | nn.Conv2d(channel, channel, (1, 1), stride=( 116 | stride, stride), padding=(0, 0), bias=False), 117 | nn.BatchNorm2d(channel, affine=affine)) 118 | 119 | 120 | def conv_3x3(channel: int, stride: int, affine: bool) -> nn.Sequential: 121 | return nn.Sequential( 122 | nn.ReLU(inplace=False), 123 | nn.Conv2d(channel, channel // 4, (1, 1), 124 | stride=(1, 1), padding=(0, 0), bias=False), 125 | nn.BatchNorm2d(channel // 4, affine=affine), 126 | nn.ReLU(inplace=False), 127 | nn.Conv2d(channel // 4, channel // 4, (3, 3), 128 | stride=(stride, stride), padding=(1, 1), bias=False), 129 | nn.BatchNorm2d(channel // 4, affine=affine), 130 | nn.ReLU(inplace=False), 131 | nn.Conv2d(channel // 4, channel, (1, 1), 132 | stride=(1, 1), padding=(0, 0), bias=False), 133 | nn.BatchNorm2d(channel, affine=affine)) 134 | 135 | 136 | OPS = { 137 | 'skip_connect': skip_connect, 138 | 'avg_pool_3x3': avg_pool_3x3, 139 | 'max_pool_3x3': max_pool_3x3, 140 | 'max_pool_2x2': max_pool_2x2, 141 | 'conv_7x1_1x7': conv_7x1_1x7, 142 | 'conv_1x1____': conv_1x1, 143 | 'conv_3x3____': conv_3x3, 144 | } 145 | 146 | 147 | class Classifier(nn.Module): 148 | 149 | def __init__(self, channel_prev: int, num_classes: int): 150 | super().__init__() 151 | 152 | self.global_pooling = nn.AvgPool2d(7) 153 | self.classifier = nn.Linear(channel_prev, num_classes) 154 | 155 | def forward(self, x: torch.Tensor) -> nn.Linear: # type: ignore 156 | s1 = self.global_pooling(x[1]) 157 | y = self.classifier(s1.view(s1.size(0), -1)) 158 | 159 | return y 160 | 161 | 162 | def make_stem(channel: int) -> nn.Sequential: 163 | return nn.Sequential( 164 | nn.ReLU(inplace=False), 165 | nn.Conv2d(3, channel, 3, stride=2, padding=1, bias=False), 166 | nn.BatchNorm2d(channel), 167 | Twin(), 168 | ) 169 | 170 | 171 | def make_cell(genotype: Genotype, 172 | channel_prev_prev: int, channel_prev: int, channel: int, 173 | reduction: bool, reduction_prev: bool 174 | ) -> Tuple[nn.Sequential, int]: 175 | 176 | preprocess0: nn.Module = nn.Sequential() 177 | 178 | if reduction_prev: 179 | preprocess0 = FactorizedReduce(channel_prev_prev, channel) 180 | elif channel_prev_prev != channel: 181 | preprocess0 = ReLUConvBN(channel_prev_prev, channel, 1, 1, 0) 182 | preprocess1 = ReLUConvBN(channel_prev, channel, 1, 1, 0) 183 | 184 | if reduction: 185 | op_names, indices = zip(*genotype.reduce) 186 | concat = genotype.reduce_concat 187 | else: 188 | op_names, indices = zip(*genotype.normal) 189 | concat = genotype.normal_concat 190 | 191 | ops = [] 192 | for name, index in zip(op_names, indices): 193 | if reduction and index < 2: 194 | stride = 2 195 | else: 196 | stride = 1 197 | 198 | op = OPS[name](channel, stride, True) 199 | ops.append((op, index)) 200 | 201 | layers = [ 202 | InputOne(preprocess0, i=0), 203 | TwinLast(), 204 | InputOne(preprocess1, i=1), 205 | # Output: (preprocess0(x[0]), preprocess1(x[1]), x[1]) 206 | # The last tensor x[1] is passed until the cell output. 207 | # The comments below call x[1] "skip". 208 | ] 209 | 210 | for i in range(len(ops) // 2): 211 | op0, i0 = ops[i*2] 212 | op1, i1 = ops[i*2 + 1] 213 | 214 | layers.extend([ 215 | InputOne(op0, i=i0, insert=2+i), 216 | # Output: x..., op0(x[i0]), skip] 217 | 218 | InputOne(op1, i=i1, insert=2 + i + 1), 219 | # Output: x..., op0(x[i0]), op1(x[i1]), skip 220 | 221 | MergeTwo(2 + i, 2 + i + 1), 222 | # Output: x..., op0(x[i0]) + op1(x[i1]), skip 223 | ]) 224 | 225 | layers.extend([ 226 | # Move skip to the head. 227 | Shift(), 228 | # Output: skip, x... 229 | 230 | FirstAnd(Concat(concat)), 231 | # Output: skip, concat(x...) 232 | ]) 233 | 234 | multiplier = len(concat) 235 | 236 | return nn.Sequential(*layers), multiplier 237 | 238 | 239 | def amoebanetd(num_classes: int = 10, 240 | num_layers: int = 4, 241 | num_filters: int = 512, 242 | ) -> nn.Sequential: 243 | """Builds an AmoebaNet-D model for ImageNet.""" 244 | channel = num_filters // 4 245 | 246 | def make_layer(channel: int, 247 | num_layers: int, 248 | genotype: Genotype, 249 | ) -> Tuple[nn.Sequential, int]: 250 | n = num_layers 251 | channel_prev_prev, channel_prev, channel_curr = channel, channel, channel 252 | cells = [] 253 | 254 | reduction_prev = False 255 | reduction = True 256 | channel_curr *= 2 257 | cell, multiplier = make_cell(genotype, channel_prev_prev, 258 | channel_prev, channel_curr, reduction, reduction_prev) 259 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 260 | cells.append(cell) 261 | 262 | reduction_prev = True 263 | reduction = True 264 | channel_curr *= 2 265 | cell, multiplier = make_cell(genotype, channel_prev_prev, 266 | channel_prev, channel_curr, reduction, reduction_prev) 267 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 268 | cells.append(cell) 269 | 270 | reduction = False 271 | reduction_prev = True 272 | for _ in range(n): 273 | cell, multiplier = make_cell(genotype, channel_prev_prev, 274 | channel_prev, channel_curr, reduction, reduction_prev) 275 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 276 | cells.append(cell) 277 | reduction_prev = False 278 | 279 | reduction_prev = False 280 | reduction = True 281 | channel_curr *= 2 282 | cell, multiplier = make_cell(genotype, channel_prev_prev, 283 | channel_prev, channel_curr, reduction, reduction_prev) 284 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 285 | cells.append(cell) 286 | 287 | reduction = False 288 | reduction_prev = True 289 | for _ in range(n): 290 | cell, multiplier = make_cell(genotype, channel_prev_prev, 291 | channel_prev, channel_curr, reduction, reduction_prev) 292 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 293 | cells.append(cell) 294 | reduction_prev = False 295 | 296 | reduction_prev = False 297 | reduction = True 298 | channel_curr *= 2 299 | cell, multiplier = make_cell(genotype, channel_prev_prev, 300 | channel_prev, channel_curr, reduction, reduction_prev) 301 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 302 | cells.append(cell) 303 | 304 | reduction = False 305 | reduction_prev = True 306 | for _ in range(n): 307 | cell, multiplier = make_cell(genotype, channel_prev_prev, 308 | channel_prev, channel_curr, reduction, reduction_prev) 309 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr 310 | cells.append(cell) 311 | reduction_prev = False 312 | 313 | return nn.Sequential(*cells), channel_prev 314 | 315 | cells, channel_prev = make_layer(channel, num_layers, amoebanetd_genotype) 316 | 317 | model = nn.Sequential(OrderedDict([ 318 | ('stem', make_stem(channel)), 319 | ('cells', cells), 320 | ('fin', Classifier(channel_prev, num_classes)) 321 | ])) 322 | 323 | return flatten_sequential(model) 324 | -------------------------------------------------------------------------------- /sample_models/torchGpipe_ref_models/amoebanet_d/genotype.py: -------------------------------------------------------------------------------- 1 | from typing import List, NamedTuple, Tuple 2 | 3 | __all__ = ['amoebanetd_genotype'] 4 | 5 | 6 | class Genotype(NamedTuple): 7 | normal: List[Tuple[str, int]] 8 | normal_concat: List[int] 9 | reduce: List[Tuple[str, int]] 10 | reduce_concat: List[int] 11 | 12 | 13 | # The AmoebaNet-D genotype is based on the 'Regularized Evolution for Image Classifier 14 | # Architecture Search' paper (https://arxiv.org/pdf/1802.01548.pdf). 15 | amoebanetd_genotype = Genotype( 16 | normal=[ 17 | ('max_pool_3x3', 0), 18 | ('conv_1x1____', 0), 19 | ('skip_connect', 2), 20 | ('max_pool_3x3', 2), 21 | ('skip_connect', 0), 22 | ('conv_7x1_1x7', 1), 23 | ('conv_1x1____', 1), 24 | ('conv_7x1_1x7', 1), 25 | ('avg_pool_3x3', 0), 26 | ('conv_1x1____', 3), 27 | ], 28 | normal_concat=[4, 5, 6], 29 | reduce=[ 30 | ('max_pool_2x2', 1), 31 | ('max_pool_3x3', 1), 32 | ('conv_3x3____', 0), 33 | ('skip_connect', 2), 34 | ('conv_7x1_1x7', 2), 35 | ('max_pool_3x3', 2), 36 | ('avg_pool_3x3', 2), 37 | ('conv_1x1____', 3), 38 | ('skip_connect', 3), 39 | ('max_pool_2x2', 0), 40 | ], 41 | reduce_concat=[4, 5, 6] 42 | ) 43 | -------------------------------------------------------------------------------- /sample_models/torchGpipe_ref_models/amoebanet_d/surgery.py: -------------------------------------------------------------------------------- 1 | """Utility modules for breaking a complex module into sequential. 2 | """ 3 | from typing import List, Optional, Tuple 4 | 5 | import torch 6 | from torch import nn 7 | 8 | __all__ = ['Concat', 'FirstAnd', 'InputOne', 9 | 'MergeTwo', 'Shift', 'Twin', 'TwinLast'] 10 | 11 | 12 | Tensors = Tuple[torch.Tensor, ...] 13 | 14 | 15 | class Hack(nn.Module): 16 | 17 | def forward(self, *args, **kwargs): 18 | raise NotImplementedError() 19 | 20 | 21 | class Twin(Hack): 22 | """Duplicates the tensor:: 23 | ┌──────┐ 24 | a --│ Twin │--> a 25 | │ '--│--> a 26 | └──────┘ 27 | """ 28 | 29 | def forward(self, tensor: torch.Tensor) -> Tensors: # type: ignore 30 | return tensor, tensor 31 | 32 | 33 | class TwinLast(Hack): 34 | """Duplicates the last tensor:: 35 | a -----> a 36 | b -----> b 37 | c --+--> c 38 | +--> c' 39 | """ 40 | 41 | def forward(self, tensors: Tensors) -> Tensors: # type: ignore 42 | return tensors + (tensors[-1],) 43 | 44 | 45 | class InputOne(Hack): 46 | """Picks one tensor for the underlying module input:: 47 | a -----> a 48 | b --f--> f(b) 49 | c -----> c 50 | """ 51 | 52 | def __init__(self, module: nn.Module, i: int, insert: Optional[int] = None): 53 | super().__init__() 54 | self.module = module 55 | self.i = i 56 | self.insert = insert 57 | 58 | def forward(self, tensors: Tensors) -> Tensors: # type: ignore 59 | i = self.i 60 | 61 | # for t in tensors: 62 | # print(t.shape[1:]) 63 | # print("\n") 64 | input = tensors[i] 65 | output = self.module(input) 66 | 67 | if not isinstance(output, tuple): 68 | output = (output,) 69 | 70 | if self.insert is None: 71 | # Replace with the input. 72 | return tensors[:i] + output + tensors[i+1:] 73 | 74 | return tensors[:self.insert] + output + tensors[self.insert:] 75 | 76 | 77 | class Shift(Hack): 78 | """Moves the last tensor ahead of the tensors:: 79 | +--> c 80 | a --|--> a 81 | b --|--> b 82 | c --+ 83 | """ 84 | 85 | def forward(self, tensors: Tensors) -> Tensors: # type: ignore 86 | return (tensors[-1],) + tensors[:-1] 87 | 88 | 89 | class MergeTwo(Hack): 90 | """Merges the last two tensors and replace them with the result:: 91 | a -----> a 92 | b --+--> b+c 93 | c --+ 94 | """ 95 | 96 | def __init__(self, i: int, j: int): 97 | super().__init__() 98 | self.i = i 99 | self.j = j 100 | 101 | def forward(self, tensors: Tensors) -> Tensors: # type: ignore 102 | i = self.i 103 | j = self.j 104 | 105 | # Set the initial value as the first tensor 106 | # to type as 'Tensor' instead of 'Union[Tensor, int]'. 107 | merged = sum(tensors[i+1:j+1], tensors[i]) 108 | 109 | return tensors[:i] + (merged,) + tensors[j+1:] 110 | 111 | 112 | class FirstAnd(Hack): 113 | """Skips the first tensor, executes the underlying module by the remaining 114 | tensors:: 115 | a -----> a 116 | b --+--> f(b, c) 117 | c --+ 118 | """ 119 | 120 | def __init__(self, module: nn.Module): 121 | super().__init__() 122 | self.module = module 123 | 124 | def forward(self, tensors: Tensors) -> Tensors: # type: ignore 125 | output = self.module(tensors[1:]) 126 | if not isinstance(output, tuple): 127 | output = (output,) 128 | return (tensors[0],) + output 129 | 130 | 131 | class Concat(Hack): 132 | """Concat all tensors:: 133 | a --+ 134 | b --+--> concat(a, b, c) 135 | c --+ 136 | """ 137 | 138 | def __init__(self, indices: List): 139 | super().__init__() 140 | self.indices = indices 141 | 142 | def forward(self, tensors: Tensors) -> torch.Tensor: # type: ignore 143 | return torch.cat([tensors[i] for i in self.indices], dim=1) 144 | -------------------------------------------------------------------------------- /sample_models/torchGpipe_ref_models/flatten_sequential.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Iterator, Tuple 3 | 4 | from torch import nn 5 | 6 | __all__ = ['flatten_sequential'] 7 | 8 | 9 | def flatten_sequential(module: nn.Sequential) -> nn.Sequential: 10 | """Flattens a nested sequential module.""" 11 | if not isinstance(module, nn.Sequential): 12 | raise TypeError('not sequential') 13 | 14 | return nn.Sequential(OrderedDict(_flatten_sequential(module))) 15 | 16 | 17 | def _flatten_sequential(module: nn.Sequential) -> Iterator[Tuple[str, nn.Module]]: 18 | for name, child in module.named_children(): 19 | # flatten_sequential child sequential layers only. 20 | if isinstance(child, nn.Sequential): 21 | for sub_name, sub_child in _flatten_sequential(child): 22 | yield ('%s_%s' % (name, sub_name), sub_child) 23 | else: 24 | yield (name, child) 25 | -------------------------------------------------------------------------------- /sample_models/torchGpipe_ref_models/resnet/__init__.py: -------------------------------------------------------------------------------- 1 | """A ResNet implementation but using :class:`nn.Sequential`. :func:`resnet101` 2 | returns a :class:`nn.Sequential` instead of ``ResNet``. 3 | This code is transformed :mod:`torchvision.models.resnet`. 4 | """ 5 | from collections import OrderedDict 6 | from typing import Any, List 7 | 8 | from torch import Tensor 9 | import torch.nn as nn 10 | 11 | from .bottleneck import bottleneck 12 | from ..flatten_sequential import flatten_sequential 13 | 14 | __all__ = ['resnet101'] 15 | 16 | 17 | class Flat(nn.Module): 18 | """Flattens any input tensor into an 1-d tensor.""" 19 | 20 | def forward(self, x: Tensor): # type: ignore 21 | return x.view(x.size(0), -1) 22 | 23 | 24 | def build_resnet(layers: List[int], 25 | num_classes: int = 1000, 26 | ) -> nn.Sequential: 27 | """Builds a ResNet as a simple sequential model. 28 | Note: 29 | The implementation is copied from :mod:`torchvision.models.resnet`. 30 | """ 31 | inplanes = 64 32 | 33 | def make_layer(planes: int, blocks: int, stride: int = 1) -> nn.Sequential: 34 | nonlocal inplanes 35 | 36 | downsample = None 37 | if stride != 1 or inplanes != planes * 4: 38 | downsample = nn.Sequential( 39 | nn.Conv2d(inplanes, planes * 4, 40 | kernel_size=1, stride=stride, bias=False), 41 | nn.BatchNorm2d(planes * 4), 42 | ) 43 | 44 | layers = [] 45 | layers.append(bottleneck(inplanes, planes, stride, downsample)) 46 | inplanes = planes * 4 47 | for _ in range(1, blocks): 48 | layers.append(bottleneck(inplanes, planes)) 49 | 50 | return nn.Sequential(*layers) 51 | 52 | # Build ResNet as a sequential model. 53 | model = nn.Sequential(OrderedDict([ 54 | ('conv1', nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)), 55 | ('bn1', nn.BatchNorm2d(64)), 56 | ('relu', nn.ReLU()), 57 | ('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 58 | 59 | ('layer1', make_layer(64, layers[0])), 60 | ('layer2', make_layer(128, layers[1], stride=2)), 61 | ('layer3', make_layer(256, layers[2], stride=2)), 62 | ('layer4', make_layer(512, layers[3], stride=2)), 63 | 64 | ('avgpool', nn.AdaptiveAvgPool2d((1, 1))), 65 | ('flat', Flat()), 66 | ('fc', nn.Linear(512 * 4, num_classes)), 67 | ])) 68 | 69 | # Flatten nested sequentials. 70 | model = flatten_sequential(model) 71 | 72 | # Initialize weights for Conv2d and BatchNorm2d layers. 73 | def init_weight(m: nn.Module) -> None: 74 | if isinstance(m, nn.Conv2d): 75 | assert isinstance(m.kernel_size, tuple) 76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 77 | 78 | m.weight.requires_grad = False 79 | m.weight.normal_(0, 2. / n**0.5) 80 | m.weight.requires_grad = True 81 | 82 | elif isinstance(m, nn.BatchNorm2d): 83 | m.weight.requires_grad = False 84 | m.weight.fill_(1) 85 | m.weight.requires_grad = True 86 | 87 | m.bias.requires_grad = False 88 | m.bias.zero_() 89 | m.bias.requires_grad = True 90 | 91 | model.apply(init_weight) 92 | 93 | return model 94 | 95 | 96 | def resnet101(**kwargs: Any) -> nn.Sequential: 97 | """Constructs a ResNet-101 model.""" 98 | return build_resnet([3, 4, 23, 3], **kwargs) 99 | -------------------------------------------------------------------------------- /sample_models/torchGpipe_ref_models/resnet/bottleneck.py: -------------------------------------------------------------------------------- 1 | """A ResNet bottleneck implementation but using :class:`nn.Sequential`.""" 2 | from collections import OrderedDict 3 | from typing import Dict, Optional, Tuple, Union 4 | 5 | from torch import Tensor 6 | import torch.nn as nn 7 | 8 | __all__ = ['bottleneck'] 9 | 10 | Tensors = Tuple[Tensor, ...] 11 | TensorOrTensors = Union[Tensor, Tensors] 12 | 13 | 14 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 21 | """1x1 convolution""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 23 | 24 | 25 | class Twin(nn.Module): 26 | # ┌──────┐ 27 | # a --│ Twin │--> a 28 | # │ '--│--> a 29 | # └──────┘ 30 | def forward(self, # type: ignore 31 | tensor: Tensor, 32 | ) -> Tuple[Tensor, Tensor]: 33 | return tensor, tensor 34 | 35 | 36 | class Gutter(nn.Module): 37 | # ┌───────────┐ 38 | # a --│ Gutter[M] │--> M(a) 39 | # b --│-----------│--> b 40 | # └───────────┘ 41 | def __init__(self, module: nn.Module): 42 | super().__init__() 43 | self.module = module 44 | 45 | def forward(self, # type: ignore 46 | input_and_skip: Tuple[Tensor, Tensor], 47 | ) -> Tuple[Tensor, Tensor]: 48 | input, skip = input_and_skip 49 | output = self.module(input) 50 | return output, skip 51 | 52 | 53 | class Residual(nn.Module): 54 | """A residual block for ResNet.""" 55 | 56 | def __init__(self, downsample: Optional[nn.Module] = None): 57 | super().__init__() 58 | self.downsample = downsample 59 | 60 | def forward(self, # type: ignore 61 | input_and_identity: Tuple[Tensor, Tensor], 62 | ) -> Tensor: 63 | input, identity = input_and_identity 64 | if self.downsample is not None: 65 | identity = self.downsample(identity) 66 | return input + identity 67 | 68 | 69 | def bottleneck(inplanes: int, 70 | planes: int, 71 | stride: int = 1, 72 | downsample: Optional[nn.Module] = None, 73 | ) -> nn.Sequential: 74 | """Creates a bottlenect block in ResNet as a :class:`nn.Sequential`.""" 75 | layers: Dict[str, nn.Module] = OrderedDict() 76 | layers['twin'] = Twin() 77 | 78 | layers['conv1'] = Gutter(conv1x1(inplanes, planes)) 79 | layers['bn1'] = Gutter(nn.BatchNorm2d(planes)) 80 | layers['conv2'] = Gutter(conv3x3(planes, planes, stride)) 81 | layers['bn2'] = Gutter(nn.BatchNorm2d(planes)) 82 | layers['conv3'] = Gutter(conv1x1(planes, planes * 4)) 83 | layers['bn3'] = Gutter(nn.BatchNorm2d(planes * 4)) 84 | layers['residual'] = Residual(downsample) 85 | layers['relu'] = nn.ReLU() 86 | 87 | return nn.Sequential(layers) 88 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alondj/Pytorch-Gpipe/7c9bf892fe65ffe3efdb8cb00a7052b665d41f75/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_delayed_norm.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from itertools import chain 3 | 4 | import pytest 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | from pytorch_Gpipe.delayedNorm import DelayedBatchNorm 10 | 11 | NUM_MICRO_BATCHES = 4 12 | 13 | # taken from torchGpipe repo 14 | 15 | 16 | def tilt_dist(x): 17 | # Tilt variance by channel. 18 | rgb = x.transpose(0, 1) 19 | rgb[0] *= 1 20 | rgb[1] *= 10 21 | rgb[2] *= 100 22 | 23 | # Tilt mean by single batch. 24 | for i, single in enumerate(x): 25 | single += 2**i 26 | 27 | return x 28 | 29 | 30 | def chunked_forward(model, x, micro_batches=NUM_MICRO_BATCHES): 31 | output_micro_batches = [] 32 | 33 | for chunk in x.chunk(micro_batches): 34 | output_micro_batches.append(model(chunk)) 35 | 36 | return torch.cat(output_micro_batches) 37 | 38 | 39 | @pytest.mark.parametrize('micro_batches', [1, 4]) 40 | @pytest.mark.parametrize('x_requires_grad', [True, False]) 41 | def test_transparency(micro_batches, x_requires_grad): 42 | bn = nn.BatchNorm2d(3) 43 | dbn = DelayedBatchNorm.convertBatchNorm( 44 | deepcopy(bn), num_micro_batches=micro_batches) 45 | 46 | x1 = torch.rand(16, 3, 224, 224) 47 | x1 = tilt_dist(x1) 48 | x2 = x1.clone() 49 | x1.requires_grad = x_requires_grad 50 | x2.requires_grad = x_requires_grad 51 | 52 | output1 = chunked_forward(bn, x1, micro_batches=micro_batches) 53 | output2 = chunked_forward(dbn, x2, micro_batches=micro_batches) 54 | 55 | assert torch.allclose(output1, output2, atol=1e-4) 56 | 57 | output1.mean().backward() 58 | output2.mean().backward() 59 | 60 | assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4) 61 | 62 | if x_requires_grad: 63 | assert x1.grad is not None 64 | assert x2.grad is not None 65 | assert torch.allclose(x1.grad, x2.grad, atol=1e-4) 66 | 67 | 68 | @pytest.mark.parametrize('momentum', [0.1, None]) 69 | def test_running_stats(momentum): 70 | bn = nn.BatchNorm2d(3, momentum=momentum) 71 | dbn = DelayedBatchNorm.convertBatchNorm( 72 | deepcopy(bn), num_micro_batches=NUM_MICRO_BATCHES) 73 | 74 | x = torch.rand(16, 3, 224, 224) 75 | x = tilt_dist(x) 76 | 77 | bn(x) 78 | chunked_forward(dbn, x) 79 | 80 | assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4) 81 | assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4) 82 | 83 | 84 | def test_convert(): 85 | bn = nn.BatchNorm2d(3, track_running_stats=False) 86 | bn = DelayedBatchNorm.convertBatchNorm( 87 | bn, num_micro_batches=NUM_MICRO_BATCHES) 88 | assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False 89 | 90 | dbn = DelayedBatchNorm(3, num_micro_batches=NUM_MICRO_BATCHES) 91 | dbn_again = DelayedBatchNorm.convertBatchNorm( 92 | dbn, num_micro_batches=NUM_MICRO_BATCHES) 93 | assert dbn.weight is dbn_again.weight 94 | assert dbn.bias is dbn_again.bias 95 | assert dbn.running_mean is dbn_again.running_mean 96 | assert dbn.running_var is dbn_again.running_var 97 | 98 | 99 | def test_eval(): 100 | bn = nn.BatchNorm2d(3) 101 | dbn = DelayedBatchNorm.convertBatchNorm( 102 | deepcopy(bn), num_micro_batches=NUM_MICRO_BATCHES) 103 | 104 | x = torch.rand(16, 3, 224, 224) 105 | x = tilt_dist(x) 106 | 107 | bn(x) 108 | chunked_forward(dbn, x) 109 | 110 | bn.eval() 111 | dbn.eval() 112 | 113 | assert torch.allclose(bn(x), dbn(x), atol=1e-4) 114 | 115 | 116 | def test_optimize(): 117 | bn = nn.BatchNorm2d(3) 118 | dbn = DelayedBatchNorm.convertBatchNorm( 119 | deepcopy(bn), num_micro_batches=NUM_MICRO_BATCHES) 120 | 121 | opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0) 122 | 123 | for i in range(5): 124 | x = torch.rand(16, 3, 224, 224) 125 | x = tilt_dist(x) 126 | 127 | # train 128 | y = bn(x) 129 | a = y.sum() 130 | a.backward() 131 | 132 | y = chunked_forward(dbn, x) 133 | b = y.sum() 134 | b.backward() 135 | 136 | opt.step() 137 | 138 | # eval 139 | bn.eval() 140 | dbn.eval() 141 | 142 | with torch.no_grad(): 143 | assert torch.allclose(bn(x), dbn(x), atol=1e-1 * (10**i)) 144 | 145 | 146 | def test_conv_bn(): 147 | bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) 148 | dbn = DelayedBatchNorm.convertBatchNorm( 149 | deepcopy(bn), num_micro_batches=NUM_MICRO_BATCHES) 150 | 151 | x = torch.rand(16, 3, 224, 224) 152 | x = tilt_dist(x) 153 | 154 | opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1) 155 | 156 | # 1st step 157 | a = bn(x) 158 | b = chunked_forward(dbn, x) 159 | 160 | # Outputs are different. (per-mini-batch vs. per-micro-batch) 161 | assert not torch.allclose(a, b) 162 | 163 | a.sum().backward() 164 | b.sum().backward() 165 | opt.step() 166 | opt.zero_grad() 167 | 168 | # Conv layers are also trained differently because of their different outputs. 169 | assert not torch.allclose(bn[0].weight, dbn[0].weight) 170 | 171 | # But BNs track identical running stats. 172 | assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) 173 | assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e+3) 174 | 175 | # 2nd step 176 | a = bn(x) 177 | b = chunked_forward(dbn, x) 178 | a.sum().backward() 179 | b.sum().backward() 180 | 181 | # BNs can't track identical running stats due to the different conv layers. 182 | assert not torch.allclose( 183 | bn[1].running_mean, dbn[1].running_mean, atol=1e-4) 184 | assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e+3) 185 | 186 | 187 | def test_x_requiring_grad(): 188 | dbn = DelayedBatchNorm(3, num_micro_batches=NUM_MICRO_BATCHES) 189 | 190 | x = torch.rand(16, 3, 224, 224, requires_grad=True) 191 | x = tilt_dist(x) 192 | 193 | chunked_forward(dbn, x) 194 | 195 | assert not dbn.sum.requires_grad 196 | assert dbn.sum.grad_fn is None 197 | -------------------------------------------------------------------------------- /tests/test_metis_lib_bindings.py: -------------------------------------------------------------------------------- 1 | from pytorch_Gpipe.METIS import METIS_partition, mdbglvl_et 2 | 3 | 4 | # ------------------------------------------------------------------------- 5 | # basic tests 6 | # ------------------------------------------------------------------------- 7 | 8 | 9 | def example_adjlist(): 10 | return [[1, 2, 3, 4], [0], [0], [0], [0, 5], [4, 6], [13, 5, 7], 11 | [8, 6], [9, 10, 11, 12, 7], [8], [8], [8], [8], [14, 6], [13, 15], 12 | [16, 17, 18, 14], [15], [15], [15]] 13 | 14 | 15 | def test_1(): 16 | adjlist = example_adjlist() 17 | 18 | print("Testing k-way cut") 19 | parts, cuts = METIS_partition(adjlist, 3, algorithm="metis", 20 | dbglvl=mdbglvl_et.METIS_DBG_ALL) 21 | assert cuts == 2 22 | assert set(parts) == set([0, 1, 2]) 23 | 24 | print("Testing recursive cut") 25 | parts, cuts = METIS_partition(adjlist, 3, algorithm="metis_recursive", 26 | dbglvl=mdbglvl_et.METIS_DBG_ALL) 27 | assert cuts == 2 28 | assert set(parts) == set([0, 1, 2]) 29 | 30 | # print("METIS appears to be working.") 31 | 32 | 33 | def test_2(): 34 | nVertices = 6 35 | nParts = 2 36 | 37 | # Indexes of starting points in adjacent array 38 | adj_idx = [0, 2, 5, 7, 9, 12, 14] 39 | 40 | # Adjacent vertices in consecutive index order 41 | adjv = [1, 3, 0, 4, 2, 1, 5, 0, 4, 3, 1, 5, 4, 2] 42 | 43 | adjlist = [adjv[adj_idx[i]:adj_idx[i+1]] for i in range(nVertices)] 44 | 45 | # Weights of vertices 46 | # if all weights are equal then can be set to NULL 47 | nodew = [i*nVertices for i in range(nVertices)] 48 | 49 | # int ret = METIS_PartGraphRecursive( & nVertices, & nWeights, xadj, adjncy, 50 | # NULL, NULL, NULL, & nParts, NULL, 51 | # NULL, NULL, & objval, part) 52 | 53 | parts, _ = METIS_partition( 54 | adjlist, nParts, algorithm="metis", nodew=nodew, contig=1) 55 | 56 | assert len(set(parts)) == nParts 57 | -------------------------------------------------------------------------------- /tests/test_networkx.py: -------------------------------------------------------------------------------- 1 | 2 | import itertools 3 | 4 | 5 | import networkx as nx 6 | 7 | import nxmetis 8 | from nxmetis import exceptions 9 | from nxmetis import metis 10 | from nxmetis import types 11 | import pytest 12 | 13 | 14 | def make_cycle(n): 15 | xadj = list(range(0, 2 * n + 1, 2)) 16 | adjncy = list( 17 | itertools.chain.from_iterable( 18 | zip(itertools.chain([n - 1], range(n - 1)), 19 | itertools.chain(range(1, n), [0])))) 20 | return xadj, adjncy 21 | 22 | 23 | def assert_equal(a, b): 24 | assert a == b 25 | 26 | 27 | def assert_not_equal(a, b): 28 | assert not (a == b) 29 | 30 | 31 | class TestMetis(object): 32 | 33 | def setup_method(self, test_method): 34 | self.node_list = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 35 | 1, 2, 3, 4, 5, 6] 36 | self.G = nx.Graph() 37 | nx.add_path(self.G, self.node_list) 38 | self.G.add_edge(self.node_list[-1], self.node_list[0]) 39 | 40 | def test_node_nested_dissection_unweighted(self): 41 | 42 | node_ordering = nxmetis.node_nested_dissection(self.G) 43 | assert_equal(len(self.G), len(node_ordering)) 44 | assert_equal(set(self.G), set(node_ordering)) 45 | 46 | # Tests for exercising package's ability to handle self-loops 47 | # METIS crashes on self loops. networkx-metis should not 48 | self.G.add_edge(1, 1) 49 | self.G.add_edge('a', 'a') 50 | node_ordering = nxmetis.node_nested_dissection(self.G) 51 | assert_equal(len(self.G), len(node_ordering)) 52 | assert_equal(set(self.G), set(node_ordering)) 53 | 54 | def test_partition(self): 55 | partition = nxmetis.partition(self.G, 4) 56 | # When we choose one node from one part of the partitioned Graph, 57 | # It must be adjacent to one or more of the nodes in the same part. 58 | # This is to verify the continuity of the chain of nodes. 59 | parts = partition[1] # List containing partitioned node lists 60 | 61 | assert_equal(partition[0], 4) 62 | assert_equal(len(partition[1]), 4) 63 | 64 | for part in parts: 65 | assert_not_equal(0, len(part)) # Non-empty set 66 | assert_equal( 67 | len(part), len(set(part))) # Duplicate-free 68 | assert (nx.is_connected(self.G.subgraph(part))) # Connected 69 | 70 | # Disjoint sets 71 | for part1, part2 in itertools.combinations(parts, 2): 72 | assert_equal(set(), set(part1) & set(part2)) 73 | 74 | # These parts must be exhaustive with the node list of the Graph 75 | parts_combined = parts[0] + parts[1] + parts[2] + parts[3] 76 | assert_equal(set(parts_combined), set(self.G)) 77 | 78 | def test_vertex_separator(self): 79 | sep, part1, part2 = nxmetis.vertex_separator(self.G) 80 | 81 | # The two separator nodes must not be present in the 82 | # two bisected chains 83 | assert (sep[0] not in part1) 84 | assert (sep[0] not in part2) 85 | assert (sep[1] not in part1) 86 | assert (sep[1] not in part2) 87 | 88 | # There should be two different separator nodes 89 | assert_equal(len(sep), 2) 90 | assert_not_equal(sep[0], sep[1]) 91 | 92 | # The lists should be exhaustive with the node list of the Graph 93 | assert_equal(set(sep) | set(part1) | set(part2), 94 | set(self.G)) 95 | 96 | # The parts must be disjoint sets 97 | assert_equal(set(), set(part1) & set(part2)) 98 | 99 | # Non-empty set 100 | assert_not_equal(len(part1), 0) 101 | assert_not_equal(len(part2), 0) 102 | 103 | # Duplicate-free 104 | assert_equal(len(part1), len(set(part1))) 105 | assert_equal(len(part2), len(set(part2))) 106 | 107 | # Connected 108 | assert (nx.is_connected(self.G.subgraph(part1))) 109 | assert (nx.is_connected(self.G.subgraph(part2))) 110 | 111 | # def test_MetisOptions(self): 112 | # n = 16 113 | # xadj, adjncy = make_cycle(n) 114 | # options = types.MetisOptions(niter=-2) 115 | # nose.tools.assert_raises_regexp(exceptions.MetisError, 116 | # 'Input Error: Incorrect niter.', 117 | # metis.part_graph, xadj, adjncy, 2, 118 | # options=options) 119 | --------------------------------------------------------------------------------