├── .gitignore
├── README.md
├── build_notes.txt
├── environment.yml
├── experiments
├── __init__.py
├── experiments.txt
├── loss_exp.py
├── my_loss_exp.py
├── speedup_exp.py
├── split_size_exp.py
├── throughput_exp.py
└── utils.py
├── misc
└── networkx_metis__init__.py
├── parse_declarations.py
├── partition_torchvision_networks.py
├── pytorch_Gpipe
├── METIS
│ ├── METIS_graph_partition.py
│ ├── METIS_manual.pdf
│ ├── __init__.py
│ ├── libmetis.dll
│ └── libmetis.so
├── __init__.py
├── delayedNorm.py
├── model_partitioning
│ ├── __init__.py
│ ├── module_generation
│ │ ├── __init__.py
│ │ ├── constructor.py
│ │ ├── declarations.pyi
│ │ ├── forward.py
│ │ ├── generate.py
│ │ └── misc.py
│ ├── partition_graph.py
│ └── process_partition.py
├── model_profiling
│ ├── __init__.py
│ ├── control_flow_graph.py
│ └── network_profiler.py
├── pipeline.py
└── utils.py
├── sample_models
├── AlexNet.py
├── DenseNet.py
├── GoogleNet.py
├── Inception.py
├── LeNet.py
├── ResNet.py
├── SqueezeNet.py
├── VGG.py
├── WideResNet.py
├── __init__.py
├── amoebaNet
│ ├── __init__.py
│ ├── genotype.py
│ └── utils.py
└── torchGpipe_ref_models
│ ├── __init__.py
│ ├── amoebanet_d
│ ├── __init__.py
│ ├── genotype.py
│ └── surgery.py
│ ├── flatten_sequential.py
│ └── resnet
│ ├── __init__.py
│ └── bottleneck.py
└── tests
├── __init__.py
├── test_delayed_norm.py
├── test_metis_lib_bindings.py
└── test_networkx.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/*
2 | .ipynb_checkpoints/*
3 | modelParallel/__pycache__/*
4 | data/*
5 | models/__pycache__/*
6 | */__pycache__/*
7 | __pycache__/*
8 | ideas/*
9 | .idea/
10 | .pytest_cache/**
11 | playground/**
12 | model_partition/graph/__pycache__/
13 | *.pyc
14 | main.py
15 | experiment/stanford_car_dataset_images_in_224x224/*
16 | experiments/stanford-car-dataset-by-classes-folder-224/*
17 | playground.py
18 | trace.txt
19 | inc.pdf
20 | *.pdf
21 | generated_modules.py
22 | *.log
23 | transformer_example/**
24 | transformer_trace.txt
25 | generated*
26 | *Bug*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Pytorch implementation of Pipeline Model Parallelism as described in Google AI article https://arxiv.org/pdf/1811.06965.pdf
2 |
3 | this is still a work in progress and not yet functional
4 |
--------------------------------------------------------------------------------
/build_notes.txt:
--------------------------------------------------------------------------------
1 | prerequisites:
2 | 1. make sure you have the packages specified in environment.yml (graphviz and python-graphviz are optional for graph visualization)
3 | 2. clone the repository
4 | 3. if you wish to use METIS library(recommendend) in order to decide how to divide the model follow the next section
5 |
6 | METIS build instructions:
7 | requires CMake to build !!!!!!!!!!!
8 | source code can be downloaded from http://glaros.dtc.umn.edu/gkhome/fetch/sw/metis/metis-5.1.0.tar.gz
9 |
10 | for linux(tested):
11 | in the file include/metis.h
12 | 1.in line 33 set IDXTYPEWIDTH to 64
13 | 2.in line 42 set REALTYPEWIDTH to 64
14 |
15 | 1. make config shared=1
16 | 2. make install
17 | 3. copy build/libmetis/libmetis.so to pytorch_Gpipe/METIS
18 | 4. run tests/test_metis.py if passes then you have successfully integrated metis into this project
19 |
20 | for windows(tested):
21 | 1.in the file include/metis.h
22 | a. in line 33 set IDXTYPEWIDTH to 64
23 | b. in line 42 set REALTYPEWIDTH to 64
24 |
25 | 1. in the top level CMakeLists.txt:
26 | a. change line 4 set(GKLIB_PATH "GKlib" CACHE PATH "path to GKlib") to set(GKLIB_PATH "${CMAKE_SOURCE_DIR}/GKlib" CACHE PATH "path to GKlib")
27 | b. change line 6 set(SHARED FALSE CACHE BOOL "build a shared library") to set(SHARED TRUE CACHE BOOL "build a shared library")
28 |
29 | 2. in the file GKlib/gk_arch.h
30 | a. comment line 42 #include sys/resource.h
31 |
32 | 3. in the file GKlib/getopt.h
33 | a. comment all extern declarations lines 53-57:
34 | extern int gk_getopt(int __argc, char **__argv, char *__shortopts);
35 | extern int gk_getopt_long(int __argc, char **__argv, char *__shortopts,struct gk_option *__longopts, int *__longind);
36 | extern int gk_getopt_long_only (int __argc, char **__argv,char *__shortopts, struct gk_option *__longopts, int *__longind);
37 |
38 | 4. build the libmetis.dll target
39 | 5. copy build/libmetis/libmetis.dll to pytorch_Gpipe/METIS
40 | 6. run tests/test_metis.py if passes then you have successfully integrated metis into this project
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: Gpipe
2 | channels:
3 | - pytorch
4 | - defaults
5 | - conda-forge
6 | dependencies:
7 | - python=3.7
8 | - pytorch=1.2.0
9 | - torchvision==0.4.0
10 | - graphviz
11 | - python-graphviz
12 | - networkx
13 | - networkx-metis
14 |
--------------------------------------------------------------------------------
/experiments/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/experiments/experiments.txt:
--------------------------------------------------------------------------------
1 | #TODO experiments
2 |
3 |
4 | performence:
5 | 1. compare memory consumption between the pipeline and dataPatallel/singleGpu/naiveModelParallel
6 | 2. for a given network partition and a given batch size find optimal microbatch size
7 | 3. for a given network partition find optimal batch size (microbatch size is batch/nparts)
8 |
9 | effects of partition on performence:
10 | 1. for a given network find effects of depth and basic blocks on performence
11 | 2. for a given network find effects of number of partition
12 |
13 | reproduce paper results:
14 | 1. resnet 101
15 | 2. AmoebaNet-D https://github.com/tensorflow/tpu/tree/e5c126d66aa3d25e0cb066bdf7fc46f98fe59901/models/experimental/amoeba_net
16 |
17 |
--------------------------------------------------------------------------------
/experiments/loss_exp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils.data import DataLoader
4 | import torch.nn as nn
5 | import torchvision.transforms as transforms
6 | from sample_models import *
7 | import torchvision
8 | import matplotlib.pyplot as plt
9 | import torch.optim as optim
10 | from .utils import ExpParser
11 |
12 | from sample_models.AlexNet import alexnet
13 | from pytorch_Gpipe import pipe_model
14 |
15 |
16 | def loss_exp(model_class, num_devices: int, batch_size: int, model_params: dict, pipeline_params: dict,
17 | dataset: torch.utils.data.Dataset, train_ratio: float = 0.8, num_epochs=500):
18 | device_single = 'cuda:0' if torch.cuda.is_available() else 'cpu'
19 |
20 | pipeline_params['devices'] = list(range(num_devices))
21 | pipeline_params['sample_batch'] = torch.randn(batch_size, *dataset[0].shape)
22 |
23 | train_amount = int(train_ratio * len(dataset))
24 | test_amount = len(dataset) - train_amount
25 | train_set, test_set = torch.utils.data.random_split(dataset, (train_amount, test_amount))
26 |
27 | # the models to compare
28 | model_single = model_class(**model_params).to(device_single)
29 |
30 | model_pipe = pipe_model(model_class(**model_params), **pipeline_params)
31 | model_pipe.zero_grad()
32 |
33 | model_dp = nn.DataParallel(model_class(**model_params), device_ids=pipeline_params['devices'])
34 |
35 | # train the models
36 | print(f"Training model on {device_single}")
37 | stats_single_train, stats_single_test = train_with_stats_saved(model_single, train_set, test_set, num_epochs,
38 | batch_size, device_single)
39 |
40 | print("Training model using G-pipe")
41 | stats_pipe_train, stats_pipe_test = train_with_stats_saved(model_pipe, train_set, test_set, num_epochs, batch_size)
42 |
43 | print("Training model using data parallel")
44 | stats_dp_train, stats_dp_test = train_with_stats_saved(model_dp, train_set, test_set, num_epochs, batch_size)
45 |
46 | # plot results
47 | plot_loss(stats_single_train, stats_pipe_train, stats_dp_train)
48 | plot_loss(stats_single_test, stats_pipe_test, stats_dp_test, False)
49 |
50 | # print(stats_single)
51 | # print(stats_pipe)
52 | # print(stats_dp)
53 |
54 | accuracy_single = stats_single_test[1][-1]
55 | accuracy_pipe = stats_pipe_test[1][-1]
56 | accuracy_dp = stats_dp_test[1][-1]
57 |
58 | print(f"Test accuracy on single device: {accuracy_single}")
59 | print(f"Test accuracy on model parallel: {accuracy_pipe}")
60 | print(f"Test accuracy on data parallel: {accuracy_dp}")
61 |
62 | # asserts that pipeline's accuracy is within 5% difference from regular model
63 | assert accuracy_single * 0.95 < accuracy_pipe < accuracy_single * 1.05
64 |
65 |
66 | def train_with_stats_saved(model, train_set, test_set, num_epochs, batch_size, input_device='cpu'):
67 | # using fixed seed to ensure same train-set loader
68 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
69 |
70 | criterion = nn.CrossEntropyLoss()
71 | optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
72 |
73 | train_losses, train_accuracies = [], []
74 | test_stats = []
75 |
76 | for epoch in range(num_epochs):
77 | epoch_losses = []
78 | epoch_accuracies = []
79 |
80 | if epoch % 5 == 0:
81 | print(f"epoch number {epoch}:")
82 |
83 | for i, (inputs, labels) in enumerate(train_loader):
84 | optimizer.zero_grad()
85 |
86 | outputs = model(inputs.to(input_device))
87 | labels = labels.to(outputs.device)
88 |
89 | loss = criterion(outputs, labels)
90 | loss.backward()
91 | optimizer.step()
92 |
93 | # saving statistics every 25 batches
94 | epoch_losses.append(loss.item())
95 | epoch_accuracies.append(get_accuracy(outputs, labels))
96 |
97 | # average loss so far
98 | train_losses.append(np.mean(epoch_losses))
99 |
100 | # average accuracy so far
101 | train_accuracies.append(np.mean(epoch_accuracies))
102 |
103 | test_stats.append(test(model, test_set, batch_size, input_device))
104 |
105 | return (train_losses, train_accuracies), tuple(zip(*test_stats))
106 |
107 |
108 | def test(model, test_set, batch_size, input_device='cpu'):
109 | # always using same test-set loader
110 | test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
111 |
112 | criterion = nn.CrossEntropyLoss()
113 |
114 | losses, accuracies = [], []
115 |
116 | with torch.no_grad():
117 | for batch in test_loader:
118 | inputs, labels = batch
119 | outputs = model(inputs.to(input_device))
120 | labels = labels.to(outputs.device)
121 |
122 | loss = criterion(outputs, labels)
123 |
124 | losses.append(loss.item())
125 | accuracies.append(get_accuracy(outputs, labels))
126 |
127 | return np.mean(losses), np.mean(accuracies)
128 |
129 |
130 | def plot_loss(stats_single, stats_pipe, stats_dp, training=True):
131 | loss_s, acc_s = stats_single
132 | loss_p, acc_p = stats_pipe
133 | loss_dp, acc_dp = stats_dp
134 |
135 | x = range(0, 25 * len(loss_s), 25)
136 |
137 | plt.subplot(1, 2, 1)
138 |
139 | plt.plot(x, loss_s, label='single device')
140 | plt.plot(x, loss_p, label='model parallel')
141 | plt.plot(x, loss_dp, label='data parallel')
142 |
143 | plt.xlabel('epoch')
144 | plt.ylabel('loss')
145 | plt.title(f'{"train set" if training else "test set"} Cross-Entropy Loss Comparison')
146 | plt.legend()
147 |
148 | plt.subplot(1, 2, 2)
149 |
150 | plt.plot(x, acc_s, label='single device')
151 | plt.plot(x, acc_p, label='model parallel')
152 | plt.plot(x, acc_dp, label='data parallel')
153 |
154 | plt.xlabel('epoch')
155 | plt.ylabel('accuracy')
156 | plt.title(f'{"train set" if training else "test set"} Accuracy')
157 | plt.legend()
158 |
159 | plt.show()
160 |
161 |
162 | def get_accuracy(outputs, labels):
163 | _, predictions = torch.max(outputs.data, 1)
164 |
165 | return (predictions == labels).sum().item() / labels.size(0)
166 |
167 |
168 | if __name__ == '__main__':
169 | parser = ExpParser(description='Run the speedup experiment.')
170 | args = parser.parse_args()
171 | loss_exp(**args)
--------------------------------------------------------------------------------
/experiments/my_loss_exp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import time
4 | import numpy as np
5 | # add our code to the path so we could import it
6 | import sys
7 | import os
8 | sys.path.append(os.path.join(os.getcwd(), '..'))
9 |
10 | from pytorch_Gpipe import pipe_model
11 | import torch.nn.functional as F
12 | from torch.optim import SGD
13 | from torch.utils.data import DataLoader
14 | import torchvision
15 | from torchvision import transforms as transforms
16 | import argparse
17 | from sample_models import alexnet, resnet101, vgg19_bn, squeezenet1_1, densenet201, \
18 | WideResNet, AmoebaNet_D as amoebanet, amoebanetd as torchgpipe_amoebanet, torchgpipe_resnet101
19 | import platform
20 |
21 | MODELS = {
22 | "alexnet": alexnet,
23 | "resnet101": resnet101,
24 | "vgg19_bn": vgg19_bn,
25 | "squeezenet1_1": squeezenet1_1,
26 | "densenet201": densenet201,
27 | "WideResNet": WideResNet,
28 | "amoebanet": amoebanet,
29 | "torchgpipe_amoebanet": torchgpipe_amoebanet,
30 | "torchgpipe_resnet101": torchgpipe_resnet101
31 | }
32 |
33 | # setups
34 |
35 |
36 | def single_gpu(model_class: nn.Module, devices, *model_args, **model_kwargs):
37 | model = model_class(*model_args, **model_kwargs).to(devices[0])
38 | used_devices = [devices[0]]
39 |
40 | return model, used_devices
41 |
42 |
43 | def pipeLine(model_class: nn.Module, devices, microbatch_size, pipe_sample, model_args, model_kwargs, pipeline_args,
44 | pipeline_kwargs):
45 | net = model_class(*model_args, **model_kwargs).to(devices[0])
46 | net(pipe_sample)
47 | torch.cuda.synchronize()
48 |
49 | piped = pipe_model(net, microbatch_size, pipe_sample, *pipeline_args,
50 | devices=devices, **pipeline_kwargs, return_graph=True)
51 | return piped, piped.module_devices
52 |
53 |
54 | def dataParallel(model_class: nn.Module, devices, *model_args, **model_kwargs):
55 | return nn.DataParallel(model_class(*model_args, **model_kwargs), devices).to(devices[0]), devices
56 |
57 |
58 | SETUPS = {
59 | "single_gpu": single_gpu,
60 | "pipeLine": pipeLine,
61 | "dataParallel": dataParallel
62 | }
63 |
64 | # the experiment itself
65 | img_size = (224, 224)
66 |
67 |
68 | def create_dataloaders(batch_train, batch_test):
69 | """
70 | Assumes the following folder structures:
71 | cats_dogs/cat/
72 | cats_dogs/cat/
73 |
74 | ...
75 |
76 | cats_dogs/dog/
77 | cats_dogs/dog/
78 | """
79 | dataset_dir = "cats_dogs"
80 | tfms = transforms.Compose([
81 | transforms.Resize(img_size),
82 | transforms.RandomHorizontalFlip(),
83 | transforms.RandomRotation(15),
84 | transforms.ToTensor(),
85 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
86 | ])
87 |
88 | dataset = torchvision.datasets.ImageFolder(root=dataset_dir, transform=tfms)
89 | train_set, test_set = torch.utils.data.random_split(dataset, [20000, 5000])
90 |
91 | train_loader = DataLoader(train_set, batch_size=batch_train, shuffle=True, num_workers=2)
92 | test_loader = DataLoader(test_set, batch_size=batch_test, shuffle=False, num_workers=2)
93 |
94 | return train_loader, test_loader
95 |
96 |
97 | def loss_exp(config):
98 | assert torch.cuda.is_available(), "gpus are required"
99 | devices = [torch.device('cuda', i)
100 | for i in range(torch.cuda.device_count())]
101 |
102 | setup = config['setup']
103 | model_class = config['model']
104 | model_args = config['model_args']
105 | model_kwargs = config['model_kwargs']
106 | model_kwargs['num_classes'] = 2
107 | batch_size = config['batch_size']
108 | pipeLine_kwargs = config['pipeLine_kwargs']
109 | pipeLine_args = config['pipeLine_args']
110 | epochs = config['epochs']
111 | batch_shape = config['batch_shape']
112 | microbatch_size = config['microbatch_size']
113 | profile_sample_size = config['profile_sample_size']
114 | if microbatch_size is None:
115 | microbatch_size = batch_size // len(devices)
116 |
117 | if setup is pipeLine:
118 | assert len(devices) > 1, "automatic partitioning does not work for 1 gpu"
119 | pipe_sample = torch.randn(
120 | (profile_sample_size,) + batch_shape[1:]).to(devices[0])
121 | model, used_devices = pipeLine(model_class, devices, microbatch_size, pipe_sample, model_args,
122 | model_kwargs, pipeLine_args, pipeLine_kwargs)
123 | else:
124 | model, used_devices = setup(
125 | model_class, devices, *model_args, **model_kwargs)
126 |
127 | used_devices = list(used_devices)
128 | in_device = used_devices[0]
129 | out_device = used_devices[-1]
130 | batch_size = batch_shape[0]
131 | optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
132 |
133 | train_loader, test_loader = create_dataloaders(batch_size, batch_size)
134 |
135 | throughputs = []
136 | elapsed_times = []
137 | memory_consumptions = {device: [] for device in used_devices}
138 | losses = []
139 | accuracies = []
140 |
141 | # run one epoch and gather statistics
142 | def run_epoch(epoch):
143 | torch.cuda.synchronize(in_device)
144 | tick = time.time()
145 | model.train()
146 |
147 | data_trained = 0
148 | steps = len(train_loader)
149 | for i, data in enumerate(train_loader):
150 | batch, target = data
151 | data_trained += batch.shape[0]
152 |
153 | output = model(batch.to(in_device))
154 | gt = target.to(output.device)
155 | loss = F.cross_entropy(output, gt)
156 | loss.backward()
157 |
158 | optimizer.step()
159 | optimizer.zero_grad()
160 |
161 | # estimate statistics after each batch
162 | percent = i / steps * 100
163 | throughput = data_trained / (time.time() - tick)
164 |
165 | # print every 100 steps
166 | if i % 100 == 0 or i == steps - 1:
167 | print(
168 | f"{epoch+1}/{epochs} epochs ({percent:.2f}%%) | {throughput:.2f} samples/sec (estimated)")
169 |
170 | torch.cuda.synchronize(in_device)
171 | tock = time.time()
172 |
173 | # calculate exact statistics after epoch is finished
174 | test_loss, test_accuracy = test(model, test_loader, used_devices[0])
175 | losses.append(test_loss)
176 | accuracies.append(test_accuracy)
177 |
178 | elapsed_time = tock - tick
179 | throughput = batch_size * steps / elapsed_time
180 | print(
181 | f"{epoch+1}/{epochs} epochs | {throughput:.2f} samples/sec, {elapsed_time:.2f} sec/epoch, test_accuracy {test_accuracy:.2f}, test loss {test_loss:.2f}")
182 |
183 | return throughput, elapsed_time
184 |
185 | exp = setup.__name__
186 | title = f'loss experiment\n config: {exp}\n used_gpus: {len(used_devices)}\n epochs: {epochs}\n'
187 | print(title)
188 |
189 | gpus = [torch.cuda.get_device_name(device) for device in used_devices]
190 | print('system information\n python: %s, torch: %s, cudnn: %s, cuda: %s, \ngpus: %s' % (
191 | platform.python_version(),
192 | torch.__version__,
193 | torch.backends.cudnn.version(),
194 | torch.version.cuda,
195 | gpus))
196 | print("\n")
197 | for epoch in range(epochs):
198 | for d in used_devices:
199 | torch.cuda.reset_max_memory_allocated(device=d)
200 |
201 | throughput, elapsed_time = run_epoch(epoch)
202 | # first epoch is used as a warmup
203 | if epoch < 1:
204 | continue
205 |
206 | for d in used_devices:
207 | memory_consumptions[d].append(
208 | torch.cuda.max_memory_allocated(device=d))
209 |
210 | throughputs.append(throughput)
211 | elapsed_times.append(elapsed_time)
212 |
213 | print("\n")
214 | n = len(throughputs)
215 | throughput = sum(throughputs) / n
216 | elapsed_time = sum(elapsed_times) / n
217 |
218 | avg_mem_per_epoch = {device: sum(
219 | memory_consumptions[device]) / n for device in used_devices}
220 | # Just use 'min' instead of 'max' for minimum.
221 | maximum = max(avg_mem_per_epoch, key=avg_mem_per_epoch.get)
222 |
223 | print(
224 | f'{title} {throughput:.2f} samples/sec\n{elapsed_time:.2f} sec/epoch (average)\n max memory consumption on device: {maximum} {(avg_mem_per_epoch[maximum]/1e6):.2f} MB')
225 | print(f"final loss {losses[-1]:.2f} final accuracy {accuracies[-1]:.2f}")
226 |
227 |
228 | def test(model, test_loader, input_device):
229 | criterion = F.cross_entropy
230 |
231 | model.train(mode=False)
232 | losses, accuracies = [], []
233 |
234 | with torch.no_grad():
235 | for batch in test_loader:
236 | inputs, labels = batch
237 | outputs = model(inputs.to(input_device))
238 | labels = labels.to(outputs.device)
239 |
240 | loss = criterion(outputs, labels)
241 |
242 | losses.append(loss.item())
243 | accuracies.append(get_accuracy(outputs, labels))
244 |
245 | return np.mean(losses), np.mean(accuracies)
246 |
247 |
248 | def get_accuracy(outputs, labels):
249 | _, predictions = torch.max(outputs.data, 1)
250 |
251 | return (predictions == labels).sum().item() / labels.size(0)
252 |
253 |
254 | class StoreDict(argparse.Action):
255 | def __call__(self, parser, namespace, values, option_string=None):
256 | kv = {}
257 | if not isinstance(values, (list,)):
258 | values = (values,)
259 | for value in values:
260 | n, v = value.split('=')
261 | kv[n] = v
262 | setattr(namespace, self.dest, kv)
263 |
264 |
265 | class ExpParser(argparse.ArgumentParser):
266 | def __init__(self, *args, **kwargs):
267 | super().__init__(*args, **kwargs)
268 |
269 | self.add_argument('--setup', '-s', help='The way to run the model.',
270 | choices=SETUPS.keys(),
271 | required=True, dest='setup')
272 | self.add_argument('--model', '-m', help='The model we want to run the experiment on.', choices=MODELS.keys(),
273 | required=True, dest='model')
274 | self.add_argument('--model_args', '-ma', help='additional args passed to the pipeline', nargs='*',
275 | default=[])
276 | self.add_argument('--model_kwargs', '-mkw', help='additional kwargs passed to the model', nargs='*', action=StoreDict,
277 | default={})
278 | self.add_argument('--batch_size', '-b', help='batch size used in the experiment defaults to 64',
279 | type=int, dest='batch_size', default=64)
280 | self.add_argument('--pipeLine_args', '-pa', help='additional args passed to the pipeline', nargs='*',
281 | default=[])
282 | self.add_argument('--pipeLine_kwargs', '-pkw', help='additional kwargs passed to the pipeline', nargs='*',
283 | action=StoreDict, default={})
284 | self.add_argument('--epochs', '-e', help="the number of training epochs used,\nthe first epoch is a warmup and will not effect results",
285 | type=int, default=10)
286 | self.add_argument('--microbatch_size', '-mb', help="micro batch size of the pipeline\nif not given defaults to batch_size/num_devices",
287 | type=int, default=None)
288 | self.add_argument('--profile_sample_size', '-ps', help='size of batch used to partition the model if testing the pipeline',
289 | type=int, default=16)
290 |
291 | def parse_args(self, *args, **kwargs):
292 | res = vars(super().parse_args(*args, **kwargs))
293 | res['model'] = MODELS[res['model']]
294 | res['setup'] = SETUPS[res['setup']]
295 | net = res['model']
296 |
297 | if net.__name__.find("inception") != -1:
298 | sample_shape = (3, 299, 299)
299 | elif net.__name__.find("GoogLeNet") != -1:
300 | sample_shape = (3, 32, 32)
301 | elif net.__name__.find("LeNet") != -1:
302 | sample_shape = (3, 32, 32)
303 | else:
304 | sample_shape = (3, 224, 224)
305 |
306 | res['batch_shape'] = (res['batch_size'],) + sample_shape
307 |
308 | return res
309 |
310 |
311 | if __name__ == "__main__":
312 | # TODO all var/kw args are treated as strings
313 | # throuput_exp.py - s single_gpu - m amoebanet - -model_args 10 - -model_kwargs kw0 = 1 kw1 = hello
314 |
315 | parser = ExpParser()
316 | args = parser.parse_args()
317 |
318 | loss_exp(args)
319 |
--------------------------------------------------------------------------------
/experiments/speedup_exp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from typing import Tuple
5 | import timeit
6 | from sample_models import *
7 | from .utils import *
8 |
9 | plt.switch_backend('Agg')
10 |
11 |
12 | def exp_model_time(run_type, model_class, num_classes, batch_shape: Tuple[int, ...], num_repeats, num_warmups,
13 | model_params: dict, tests_config: dict, pipeline_params: dict = None, num_devices=None):
14 |
15 | tests_config['num_classes'] = num_classes
16 | tests_config['batch_shape'] = batch_shape
17 |
18 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
19 |
20 | if run_type in ['S', 'Single']:
21 | tests_config['model'] = model_class(**model_params).to(device)
22 |
23 | print('Single GPU:')
24 |
25 | elif run_type in ['P', 'Pipeline-Parallel']:
26 | pipeline_params['devices'] = list(range(num_devices))
27 | model = model_class(**model_params).to(device)
28 | tests_config['model'] = create_pipeline(model, batch_shape, **pipeline_params)
29 |
30 | print('Pipeline-Parallel:')
31 |
32 | elif run_type in ['D', 'Data-Parallel']:
33 | devices_ids = list(range(num_devices))
34 | model = model_class(**model_params).to(device)
35 | tests_config['model'] = nn.DataParallel(model, device_ids=devices_ids).to(device)
36 |
37 | print('Data-Parallel:')
38 |
39 | else:
40 | raise ValueError('Not a valid run type')
41 |
42 | run_times, mem_uses = track_train(num_repeats + num_warmups, **tests_config)
43 | run_times, mem_uses = run_times[num_warmups:], mem_uses[num_warmups:]
44 | rt_mean, rt_std = np.mean(run_times), np.std(run_times)
45 | max_mem = np.mean(mem_uses)
46 |
47 | print(f'\trun time mean - {rt_mean}')
48 | print(f'\trun time std - {rt_std}')
49 | print(f'\tmax memory usage - {max_mem}')
50 |
51 |
52 | if __name__ == '__main__':
53 | parser = ExpParser(uses_dataset=False, description='Run the speedup experiment.')
54 | args = parser.parse_args()
55 | exp_model_time(**args)
56 |
--------------------------------------------------------------------------------
/experiments/split_size_exp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | from sample_models import *
5 | from typing import Tuple
6 | import timeit
7 | from .utils import *
8 |
9 |
10 | def exp_split_size(model_class, num_devices: int, num_classes: int, batch_shape: Tuple[int, ...], model_params: dict,
11 | tests_config: dict, pipeline_params: dict):
12 | num_repeat = 10
13 |
14 | tests_config['num_classes'] = num_classes
15 | tests_config['batch_shape'] = batch_shape
16 |
17 | pipeline_params['devices'] = list(range(num_devices))
18 |
19 | stmt = call_func_stmt(train, 'model', **tests_config)
20 |
21 | model_init_stmt = call_func_stmt(model_class, **model_params)
22 |
23 | means = []
24 | stds = []
25 | split_sizes = [1, 3, 5, 8, 10, 12, 20, 40, 60]
26 |
27 | for split_size in split_sizes:
28 | pipeline_params['microbatch_size'] = split_size
29 | setup = f"model = {call_func_stmt(create_pipeline, model_init_stmt, batch_shape, **pipeline_params)}"
30 | pp_run_times = timeit.repeat(
31 | stmt, setup, number=1, repeat=num_repeat, globals=globals())
32 | means.append(np.mean(pp_run_times))
33 | stds.append(np.std(pp_run_times))
34 | print(
35 | f'Split size {split_size} has a mean execution time of {means[-1]} with standard deviation of {stds[-1]}')
36 |
37 | fig, ax = plt.subplots()
38 | ax.plot(split_sizes, means)
39 | ax.errorbar(split_sizes, means, yerr=stds, ecolor='red', fmt='ro')
40 | ax.set_ylabel('ResNet50 Execution Time (Second)')
41 | ax.set_xlabel('Pipeline Split Size')
42 | ax.set_xticks(split_sizes)
43 | ax.yaxis.grid(True)
44 | plt.tight_layout()
45 | plt.savefig("split_size_tradeoff.png")
46 | plt.close(fig)
47 |
48 |
49 | if __name__ == '__main__':
50 | parser = ExpParser(uses_dataset=False, description='Run the speedup experiment.')
51 | args = parser.parse_args()
52 | exp_split_size(**args)
53 |
--------------------------------------------------------------------------------
/experiments/throughput_exp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import time
4 |
5 | # add our code to the path so we could import it
6 | import sys
7 | import os
8 | sys.path.append(os.path.join(os.getcwd(), '..'))
9 |
10 | from pytorch_Gpipe import pipe_model
11 | import torch.nn.functional as F
12 | from torch.optim import SGD
13 | import argparse
14 | from sample_models import alexnet, resnet101, vgg19_bn, squeezenet1_1, inception_v3, densenet201, GoogLeNet, LeNet, \
15 | WideResNet, AmoebaNet_D as amoebanet, amoebanetd as torchgpipe_amoebanet, torchgpipe_resnet101
16 | import platform
17 |
18 | MODELS = {
19 | "alexnet": alexnet,
20 | "resnet101": resnet101,
21 | "vgg19_bn": vgg19_bn,
22 | "squeezenet1_1": squeezenet1_1,
23 | "inception_v3": inception_v3,
24 | "densenet201": densenet201,
25 | "GoogLeNet": GoogLeNet,
26 | "LeNet": LeNet,
27 | "WideResNet": WideResNet,
28 | "amoebanet": amoebanet,
29 | "torchgpipe_amoebanet": torchgpipe_amoebanet,
30 | "torchgpipe_resnet101": torchgpipe_resnet101
31 | }
32 |
33 | # setups
34 |
35 |
36 | def single_gpu(model_class: nn.Module, devices, *model_args, **model_kwargs):
37 | return model_class(*model_args, **model_kwargs).to(devices[0]), [devices[0]]
38 |
39 |
40 | def pipeLine(model_class: nn.Module, devices, pipe_sample, model_args, model_kwargs, pipeline_args, pipeline_kwargs):
41 | net = model_class(*model_args, **model_kwargs).to(devices[0])
42 | net(pipe_sample)
43 | torch.cuda.synchronize()
44 |
45 | piped = pipe_model(net, pipe_sample.shape[0], pipe_sample, *pipeline_args,
46 | devices=devices, **pipeline_kwargs)
47 | # TODO assumes first is input_device and last is output device
48 | return piped, list(piped.module_devices)
49 |
50 |
51 | def dataParallel(model_class: nn.Module, devices, *model_args, **model_kwargs):
52 | return nn.DataParallel(model_class(*model_args, **model_kwargs), devices).to(devices[0]), devices
53 |
54 |
55 | SETUPS = {
56 | "single_gpu": single_gpu,
57 | "pipeLine": pipeLine,
58 | "dataParallel": dataParallel
59 | }
60 |
61 | # the experiment itself
62 |
63 |
64 | def throughput_exp(config):
65 | assert torch.cuda.is_available(), "gpus are required"
66 | devices = [torch.device('cuda', i)
67 | for i in range(torch.cuda.device_count())]
68 |
69 | setup = config['setup']
70 | model_class = config['model']
71 | model_args = config['model_args']
72 | model_kwargs = config['model_kwargs']
73 | batch_size = config['batch_size']
74 | pipeLine_kwargs = config['pipeLine_kwargs']
75 | pipeLine_args = config['pipeLine_args']
76 | epochs = config['epochs']
77 | batch_shape = config['batch_shape']
78 | microbatch_size = config['microbatch_size']
79 | profile_sample_size = config['profile_sample_size']
80 | if microbatch_size is None:
81 | microbatch_size = batch_size // len(devices)
82 |
83 | if setup is pipeLine:
84 | assert len(devices) > 1, "automatic partitioning does not work for 1 gpu"
85 | pipe_sample = torch.randn(
86 | (profile_sample_size,) + batch_shape[1:]).to(devices[0])
87 | model, used_devices = pipeLine(model_class, devices, pipe_sample, model_args,
88 | model_kwargs, pipeLine_args, pipeLine_kwargs)
89 | else:
90 | model, used_devices = setup(
91 | model_class, devices, *model_args, **model_kwargs)
92 |
93 | in_device = used_devices[0]
94 | out_device = used_devices[-1]
95 | batch_size = batch_shape[0]
96 | optimizer = SGD(model.parameters(), lr=0.1)
97 | # This experiment cares about only training performance, rather than
98 | # accuracy. To eliminate any overhead due to data loading, we use fake data
99 | batch = torch.randn(batch_shape, device=in_device)
100 | target = torch.randint(10, (batch_size,))
101 |
102 | throughputs = []
103 | elapsed_times = []
104 | memory_consumptions = {device: [] for device in used_devices}
105 |
106 | # run one epoch and gather statistics
107 | def run_epoch(epoch):
108 | torch.cuda.synchronize(in_device)
109 | tick = time.time()
110 |
111 | data_trained = 0
112 | steps = 50000 // batch_size
113 | for i in range(steps):
114 | data_trained += batch_size
115 |
116 | output = model(batch)
117 | gt = target.to(output.device)
118 | loss = F.cross_entropy(output, gt)
119 | loss.backward()
120 |
121 | optimizer.step()
122 | optimizer.zero_grad()
123 |
124 | # estimate statistics after each batch
125 | percent = i / steps * 100
126 | throughput = data_trained / (time.time() - tick)
127 |
128 | # print every 100 steps
129 | if i % 100 == 0 or i == steps - 1:
130 | print(
131 | f"{epoch+1}/{epochs} epochs ({percent:.2f}%%) | {throughput:.2f} samples/sec (estimated)")
132 |
133 | torch.cuda.synchronize(in_device)
134 | tock = time.time()
135 |
136 | # calculate exact statistics after epoch is finished
137 |
138 | elapsed_time = tock - tick
139 | throughput = batch_size * steps / elapsed_time
140 | print(
141 | f"{epoch+1}/{epochs} epochs | {throughput:.2f} samples/sec, {elapsed_time:.2f} sec/epoch")
142 |
143 | return throughput, elapsed_time
144 |
145 | exp = setup.__name__
146 | title = f'throughput experiment\n config: {exp}\n used_gpus: {len(used_devices)}\n epochs: {epochs}\n'
147 | print(title)
148 |
149 | gpus = [torch.cuda.get_device_name(device) for device in used_devices]
150 | print('system information\n python: %s, torch: %s, cudnn: %s, cuda: %s, \ngpus: %s' % (
151 | platform.python_version(),
152 | torch.__version__,
153 | torch.backends.cudnn.version(),
154 | torch.version.cuda,
155 | gpus))
156 | print("\n")
157 | for epoch in range(epochs):
158 | for d in used_devices:
159 | torch.cuda.reset_max_memory_allocated(device=d)
160 |
161 | throughput, elapsed_time = run_epoch(epoch)
162 | # first epoch is used as a warmup
163 | if epoch < 1:
164 | continue
165 |
166 | for d in used_devices:
167 | memory_consumptions[d].append(
168 | torch.cuda.max_memory_allocated(device=d))
169 |
170 | throughputs.append(throughput)
171 | elapsed_times.append(elapsed_time)
172 |
173 | print("\n")
174 | n = len(throughputs)
175 | throughput = sum(throughputs) / n
176 | elapsed_time = sum(elapsed_times) / n
177 |
178 | avg_mem_per_epoch = {device: sum(
179 | memory_consumptions[device]) / n for device in used_devices}
180 | # Just use 'min' instead of 'max' for minimum.
181 | maximum = max(avg_mem_per_epoch, key=avg_mem_per_epoch.get)
182 |
183 | print(
184 | f'{title} {throughput:.2f} samples/sec\n{elapsed_time:.2f} sec/epoch (average)\n max memory consumption on device: {maximum} {(avg_mem_per_epoch[maximum]/1e6):.2f} MB')
185 |
186 |
187 | class StoreDict(argparse.Action):
188 | def __call__(self, parser, namespace, values, option_string=None):
189 | kv = {}
190 | if not isinstance(values, (list,)):
191 | values = (values,)
192 | for value in values:
193 | n, v = value.split('=')
194 | kv[n] = v
195 | setattr(namespace, self.dest, kv)
196 |
197 |
198 | class ExpParser(argparse.ArgumentParser):
199 | def __init__(self, *args, **kwargs):
200 | super().__init__(*args, **kwargs)
201 |
202 | self.add_argument('--setup', '-s', help='The way to run the model.',
203 | choices=SETUPS.keys(),
204 | required=True, dest='setup')
205 | self.add_argument('--model', '-m', help='The model we want to run the experiment on.', choices=MODELS.keys(),
206 | required=True, dest='model')
207 | self.add_argument('--model_args', '-ma', help='additional args passed to the pipeline', nargs='*',
208 | default=[])
209 | self.add_argument('--model_kwargs', '-mkw', help='additional kwargs passed to the model', nargs='*', action=StoreDict,
210 | default={})
211 | self.add_argument('--batch_size', '-b', help='batch size used in the experiment defaults to 64',
212 | type=int, dest='batch_size', default=64)
213 | self.add_argument('--pipeLine_args', '-pa', help='additional args passed to the pipeline', nargs='*',
214 | default=[])
215 | self.add_argument('--pipeLine_kwargs', '-pkw', help='additional kwargs passed to the pipeline', nargs='*',
216 | action=StoreDict, default={})
217 | self.add_argument('--epochs', '-e', help="the number of training epochs used,\nthe first epoch is a warmup and will not effect results",
218 | type=int, default=10)
219 | self.add_argument('--microbatch_size', '-mb', help="micro batch size of the pipeline\nif not given defaults to batch_size/num_devices",
220 | type=int, default=None)
221 | self.add_argument('--profile_sample_size', '-ps', help='size of batch used to partition the model if testing the pipeline',
222 | type=int, default=16)
223 |
224 | def parse_args(self, *args, **kwargs):
225 | res = vars(super().parse_args(*args, **kwargs))
226 | res['model'] = MODELS[res['model']]
227 | res['setup'] = SETUPS[res['setup']]
228 | net = res['model']
229 |
230 | if net.__name__.find("inception") != -1:
231 | sample_shape = (3, 299, 299)
232 | elif net.__name__.find("GoogLeNet") != -1:
233 | sample_shape = (3, 32, 32)
234 | elif net.__name__.find("LeNet") != -1:
235 | sample_shape = (3, 32, 32)
236 | else:
237 | sample_shape = (3, 224, 224)
238 |
239 | res['batch_shape'] = (res['batch_size'],) + sample_shape
240 |
241 | return res
242 |
243 |
244 | if __name__ == "__main__":
245 | # TODO all var/kw args are treated as strings the following are all strings
246 | # throuput_exp.py - s single_gpu - m amoebanet - -model_args 10 - -model_kwargs kw0 = 1 kw1 = hello
247 |
248 | # run as follows from the experiments folder for eg.
249 | # will check throuput, execution time and memory consumption of amoebanet
250 | # accross 10 training epochs each with 50000//batch_size batches
251 | # the first epoch is a warmpu and will not affect results
252 | # with batch size 64
253 | # the partition will be done using a batch size of 32
254 | # python throuput_exp.py --epochs 10 --setup pipeLine - m amoebanet --microbatch_size 64 --profile_sample_size 32
255 |
256 | parser = ExpParser()
257 | args = parser.parse_args()
258 |
259 | throughput_exp(args)
260 |
--------------------------------------------------------------------------------
/experiments/utils.py:
--------------------------------------------------------------------------------
1 | from torch import optim
2 | import torch
3 | import torch.nn as nn
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | from pytorch_Gpipe import pipe_model
7 | import argparse
8 | import sample_models
9 | import torchvision
10 | import sys
11 |
12 |
13 | def kwargs_string(*pos_strings, **kwargs):
14 | return ', '.join(list(pos_strings) + [f'{key}={val}' for key, val in kwargs.items()])
15 |
16 |
17 | def reset_mex_memory_allocated():
18 | for i in range(torch.cuda.device_count()):
19 | torch.cuda.reset_max_memory_allocated(i)
20 |
21 |
22 | def get_max_memory_allocated():
23 | max_mem = -1
24 |
25 | for i in range(torch.cuda.device_count()):
26 | mem_alloc = torch.cuda.max_memory_allocated(i)
27 | max_mem = max(max_mem, mem_alloc)
28 |
29 | return max_mem
30 |
31 |
32 | def call_func_stmt(func, *params, **kwargs):
33 | if isinstance(func, str):
34 | func_name = func
35 | else:
36 | func_name = func.__name__
37 |
38 | params_str = [str(param) for param in params]
39 |
40 | return f'{func_name}({kwargs_string(*params_str, **kwargs)})'
41 |
42 |
43 | def track_train(num_repeats, model, num_classes, num_batches, batch_shape):
44 | num_batches = int(num_batches)
45 |
46 | run_times = []
47 | mem_uses = []
48 | for _ in range(num_repeats):
49 | reset_mex_memory_allocated()
50 | start = torch.cuda.Event(enable_timing=True)
51 | end = torch.cuda.Event(enable_timing=True)
52 |
53 | start.record()
54 | train(model, num_classes, num_batches, batch_shape)
55 | end.record()
56 |
57 | torch.cuda.synchronize()
58 | run_times.append(start.elapsed_time(end))
59 | mem_uses.append(get_max_memory_allocated())
60 | print('.', end='')
61 |
62 | return run_times, mem_uses
63 |
64 |
65 | def train(model, num_classes, num_batches, batch_shape):
66 | model.train(True)
67 | loss_fn = nn.MSELoss()
68 | optimizer = optim.SGD(model.parameters(), lr=0.001)
69 |
70 | batch_size = batch_shape[0]
71 |
72 | one_hot_indices = torch.LongTensor(batch_size).random_(
73 | 0, num_classes).view(batch_size, 1)
74 |
75 | dev = 'cuda:0' if torch.cuda.is_available() else 'cpu'
76 |
77 | for b in range(num_batches):
78 | # generate random inputs and labels
79 | inputs = torch.randn(*batch_shape)
80 | labels = torch.zeros(batch_size, num_classes).scatter_(
81 | 1, one_hot_indices, 1)
82 |
83 | # run forward pass
84 | optimizer.zero_grad()
85 | outputs = model(inputs.to(dev))
86 |
87 | # run backward pass
88 | labels = labels.to(outputs.device)
89 | loss_fn(outputs, labels).backward()
90 | optimizer.step()
91 |
92 |
93 | def plot(means, stds, labels, fig_name, fig_label):
94 | fig, ax = plt.subplots()
95 | ax.bar(np.arange(len(means)), means, yerr=stds,
96 | align='center', alpha=0.5, ecolor='red', capsize=10, width=0.6)
97 | ax.set_ylabel(fig_label)
98 | ax.set_xticks(np.arange(len(means)))
99 | ax.set_xticklabels(labels)
100 | ax.yaxis.grid(True)
101 | plt.tight_layout()
102 | plt.savefig(fig_name)
103 | plt.close(fig)
104 |
105 |
106 | def create_pipeline(model, batch_shape, microbatch_size, **kwargs):
107 | device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
108 | microbatch_size = int(microbatch_size)
109 | return pipe_model(model.to(device), microbatch_size, torch.randn(*batch_shape, device=device), **kwargs)
110 |
111 |
112 | class StoreDict(argparse.Action):
113 | def __call__(self, parser, namespace, values, option_string=None):
114 | kv = {}
115 | if not isinstance(values, (list,)):
116 | values = (values,)
117 | for value in values:
118 | n, v = value.split('=')
119 | kv[n] = v
120 | setattr(namespace, self.dest, kv)
121 |
122 |
123 | class ExpParser(argparse.ArgumentParser):
124 | def __init__(self, *args, uses_dataset=True, **kwargs):
125 | super().__init__(*args, **kwargs)
126 | self.uses_dataset = uses_dataset
127 |
128 | models = [
129 | 'AlexNet', 'alexnet', 'DenseNet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'GoogLeNet',
130 | 'Inception3', 'inception_v3', 'LeNet', 'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
131 | 'resnet152', 'SqueezeNet', 'squeezenet1_0', 'squeezenet1_1', 'VGG', 'vgg11', 'vgg11_bn', 'vgg13',
132 | 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'WideResNet', 'AmoebaNet_D', 'amoebanetd',
133 | 'resnet101', 'torchgpipe_resnet101'
134 | ]
135 |
136 | self.add_argument('--run_type', '-r', help='The way to run the model.',
137 | choices=['Single', 'Data-Parallel', 'Pipeline-Parallel', 'S', 'D', 'P'],
138 | required=True, dest='run_type')
139 | self.add_argument('--model', '-m', help='The model we want to run the experiment on.', choices=models,
140 | required=True, dest='model_class')
141 | self.add_argument('--classes', '-c', help='The number of classes in the prediction problem.', type=int,
142 | required=True, dest='num_classes')
143 | self.add_argument('--repeats', '-n', help='amount of times to repeat the experiments.', type=int,
144 | default=10, dest='num_repeats')
145 | self.add_argument('--warmups', '-w', help='amount of times to run the experiments before tracking results.',
146 | type=int, default=1, dest='num_warmups')
147 | self.add_argument('--model_params', help='The parameters for the model', nargs='*', action=StoreDict,
148 | default={})
149 | self.add_argument('--devices', '-d', help='The number of devices to use in the experiment.', type=int,
150 | dest='num_devices')
151 | self.add_argument('--pipeline_params', help='Parameters for the pipeline itself other then devices', nargs='*',
152 | action=StoreDict, default={})
153 |
154 | if uses_dataset:
155 | self.add_argument('--dataset', '-s', choices=list(torchvision.datasets.__all__), required=True)
156 | self.add_argument('--ds_root', '-r', type=str, required=True)
157 | else:
158 | self.add_argument('--batch_shape', '-s', help='The shape of one batch.', nargs='*', type=int, required=True)
159 | self.add_argument('--tests_config', help='Any other config kwargs for the test', nargs='*',
160 | action=StoreDict, default={})
161 |
162 | def parse_args(self, *args, **kwargs):
163 | res = vars(super().parse_args(*args, **kwargs))
164 |
165 | res['model_params']['num_classes'] = res['num_classes']
166 |
167 | res['model_class'] = getattr(sys.modules['sample_models'], res['model_class'])
168 |
169 | if self.uses_dataset:
170 | ds_class = getattr(sys.modules['torchvision.datasets'], res['dataset'])
171 | res['dataset'] = ds_class(res['ds_root'])
172 |
173 | return res
174 |
--------------------------------------------------------------------------------
/misc/networkx_metis__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (C) 2015 ysitu
4 | # All rights reserved
5 |
6 | """
7 | Wrappers of METIS graph partitioning functions.
8 | """
9 |
10 | import contextlib
11 | import decorator
12 | import itertools
13 | import sys
14 |
15 | import networkx as nx
16 | import six
17 |
18 | from nxmetis import enums
19 | from nxmetis import exceptions
20 | from nxmetis import metis
21 | from nxmetis import types
22 |
23 | __all__ = ['node_nested_dissection', 'partition', 'vertex_separator',
24 | 'MetisOptions']
25 |
26 | MetisOptions = types.MetisOptions
27 |
28 |
29 | @contextlib.contextmanager
30 | def _zero_numbering(options):
31 | """Temporarily force zero-based numbering."""
32 | if options:
33 | numbering = options.numbering
34 | options.numbering = enums.MetisNumbering.zero
35 | try:
36 | yield
37 | finally:
38 | if options:
39 | options.numbering = numbering
40 |
41 |
42 | def _convert_graph(G):
43 | """Convert a graph to the numbered adjacency list structure expected by
44 | METIS.
45 | """
46 | index = dict(zip(G, list(range(len(G)))))
47 | xadj = [0]
48 | adjncy = []
49 | for u in G:
50 | adjncy.extend(index[v] for v in G[u])
51 | xadj.append(len(adjncy))
52 | return xadj, adjncy
53 |
54 |
55 | def _convert_exceptions(convert_type, catch_types=None):
56 | """Decorator to convert types of exceptions
57 |
58 | Parameters
59 | ----------
60 | convert_type : subclass of Exception
61 | Target type to convert to.
62 |
63 | catch_types : tuple of subclasses of Exception, optional
64 | Source types whose instances are to be caught and converted. If None,
65 | all instances of Exception will be caught and converted.
66 |
67 | Returns
68 | -------
69 | _convert_exceptions : function
70 | Function that performs exception type conversion.
71 |
72 | Example
73 | -------
74 | Decorate functions like this::
75 |
76 | @_convert_exceptions(nx.NetworkXError, (ValueError,))
77 | def function():
78 | pass
79 | """
80 | @decorator.decorator
81 | def _convert_exceptions(func, *args, **kwargs):
82 | try:
83 | return func(*args, **kwargs)
84 | except catch_types as e:
85 | exc = e
86 | except Exception as e:
87 | if catch_types is not None:
88 | raise
89 | exc = sys.exc_info()
90 | six.reraise(convert_type, convert_type(exc[1]), exc[2])
91 | return _convert_exceptions
92 |
93 |
94 | @nx.utils.not_implemented_for('directed')
95 | @nx.utils.not_implemented_for('multigraph')
96 | @_convert_exceptions(
97 | nx.NetworkXError, (ValueError, TypeError, exceptions.MetisError))
98 | def node_nested_dissection(G, weight='weight', options=None):
99 | """Compute a node ordering of a graph that reduces fill when the Laplacian
100 | matrix of the graph is LU factorized. The algorithm aims to minimize the
101 | sum of weights of vertices in separators computed in the process.
102 |
103 | Parameters
104 | ----------
105 | G : NetworkX graph
106 | A graph.
107 |
108 | weight : object, optional
109 | The data key used to determine the weight of each node. If None, each
110 | node has unit weight. Default value: 'weight'.
111 |
112 | options : MetisOptions, optional
113 | METIS options. If None, the default options are used. Default value:
114 | None.
115 |
116 | Returns
117 | -------
118 | perm : list of nodes
119 | The node ordering.
120 |
121 | Raises
122 | ------
123 | NetworkXError
124 | If the parameters cannot be converted to valid METIS input format, or
125 | METIS returns an error status.
126 | """
127 | if len(G) == 0:
128 | return []
129 |
130 | vwgt = [G.nodes[u].get(weight, 1) for u in G]
131 | if all(w == 1 for w in vwgt):
132 | vwgt = None
133 |
134 | xadj, adjncy = _convert_graph(G)
135 |
136 | with _zero_numbering(options):
137 | perm = metis.node_nd(xadj, adjncy, vwgt, options)[0]
138 |
139 | nodes = list(G)
140 | perm = [nodes[i] for i in perm]
141 |
142 | return perm
143 |
144 |
145 | @nx.utils.not_implemented_for('directed')
146 | @nx.utils.not_implemented_for('multigraph')
147 | @_convert_exceptions(
148 | nx.NetworkXError, (ValueError, TypeError, exceptions.MetisError))
149 | def partition(G, nparts, node_weight='weight', node_size='size',
150 | edge_weight='weight', tpwgts=None, ubvec=None, options=None,
151 | recursive=False):
152 | """Partition a graph using multilevel recursive bisection or multilevel
153 | multiway partitioning.
154 |
155 | Parameters
156 | ----------
157 | G : NetworkX graph
158 | An undirected graph.
159 |
160 | nparts : int
161 | Number of parts to partition the graph. It should be at least 2.
162 |
163 | node_weight : object, optional
164 | The data key used to determine the weight of each node. If None, each
165 | node has unit weight. Default value: 'weight'.
166 |
167 | node_size : object, optional
168 | The data key used to determine the size of each node when computing the
169 | total communication volumne. If None, each node has unit size. Default
170 | value: 'size'
171 |
172 | edge_weight : object, optional
173 | The data key used to determine the weight of each edge. If None, each
174 | edge has unit weight. Default value: 'weight'.
175 |
176 | tpwgts : list of lists of floats, optional
177 | The target weights of the partitions and the constraints. The target
178 | weight of the `i`-th partition and the `j`-th constraint is given by
179 | ``tpwgts[i][j]`` (the numbering for both partitions and constraints
180 | starts from zero). For each constraint the sum of the ``tpwgts[][]``
181 | entries must be 1.0 (i.e., `\sum_i \\text{tpwgts}[i][j] = 1.0`). If
182 | None, the graph is equally divided among the partitions. Default value:
183 | None.
184 |
185 | ubvec : list of floats, optional
186 | The allowed load imbalance tolerance for each constraint. For the
187 | `i`-th and the `j`-th constraint, the allowed weight is the
188 | ``ubvec[j] * tpwgts[i][j]`` fraction of the `j`-th constraint's total
189 | weight. The load imbalances must be greater than 1.0. If None, the load
190 | imbalance tolerance is 1.001 if there is exactly one constraint or 1.01
191 | if there are more. Default value: None.
192 |
193 | options : MetisOptions, optional.
194 | METIS options. If None, the default options are used. Default value:
195 | None.
196 |
197 | recursive : bool, optional
198 | If True, multilevel recursive bisection is used. Otherwise, multileve
199 | multilevel multiway partitioning is used. Default value: False.
200 |
201 | Returns
202 | -------
203 | objval : int
204 | The edge-cut or the total communication volume of the partitioning
205 | solution. The value returned depends on the partitioining's objective
206 | function.
207 |
208 | parts : lists of nodes
209 | The partitioning.
210 |
211 | Raises
212 | ------
213 | NetworkXNotImplemented
214 | If the graph is directed or is a multigraph.
215 |
216 | NetworkXError
217 | If the parameters cannot be converted to valid METIS input format, or
218 | METIS returns an error status.
219 | """
220 | if nparts < 1:
221 | raise nx.NetworkXError('nparts is less than one.')
222 | if nparts == 1:
223 | return 0, [list(G)]
224 |
225 | if len(G) == 0:
226 | return 0, [[] for i in range(nparts)]
227 |
228 | xadj, adjncy = _convert_graph(G)
229 |
230 | vwgt = [G.nodes[u].get(node_weight, 1) for u in G]
231 | if all(w == 1 for w in vwgt):
232 | vwgt = None
233 |
234 | vsize = [G.nodes[u].get(node_size, 1) for u in G]
235 | if all(w == 1 for w in vsize):
236 | vsize = None
237 |
238 | adjwgt = [G[u][v].get(edge_weight, 1) for u in G for v in G[u]]
239 | if all(w == 1 for w in adjwgt):
240 | adjwgt = None
241 |
242 | if tpwgts is not None:
243 | if len(tpwgts) != nparts:
244 | raise nx.NetworkXError('length of tpwgts is not equal to nparts.')
245 | ncon = len(tpwgts[0])
246 | if any(len(tpwgts[j]) != ncon for j in range(1, nparts)):
247 | raise nx.NetworkXError(
248 | 'lists in tpwgts are not of the same length.')
249 | if ubvec is not None and len(ubvec) != ncon:
250 | raise nx.NetworkXError(
251 | 'ubvec is not of the same length as tpwgts.')
252 | tpwgts = list(itertools.chain.from_iterable(tpwgts))
253 |
254 | with _zero_numbering(options):
255 | objval, part = metis.part_graph(xadj, adjncy, nparts, vwgt, vsize,
256 | adjwgt, tpwgts, ubvec, options,
257 | recursive)
258 |
259 | parts = [[] for i in range(nparts)]
260 | for u, i in zip(G, part):
261 | parts[i].append(u)
262 |
263 | return objval, parts
264 |
265 |
266 | @nx.utils.not_implemented_for('directed')
267 | @nx.utils.not_implemented_for('multigraph')
268 | @_convert_exceptions(
269 | nx.NetworkXError, (ValueError, TypeError, exceptions.MetisError))
270 | def vertex_separator(G, weight='weight', options=None):
271 | """Compute a vertex separator that bisects a graph. The algorithm aims to
272 | minimize the sum of weights of vertices in the separator.
273 |
274 | Parameters
275 | ----------
276 | G : NetworkX graph
277 | A graph.
278 |
279 | weight : object, optional
280 | The data key used to determine the weight of each node. If None, each
281 | node has unit weight. Default value: 'weight'.
282 |
283 | options : MetisOptions, optional
284 | METIS options. If None, the default options are used. Default value:
285 | None.
286 |
287 | Returns
288 | -------
289 | sep, part1, part2 : lists of nodes
290 | The separator and the two parts of the bisection represented as lists.
291 |
292 | Raises
293 | ------
294 | NetworkXError
295 | If the parameters cannot be converted to valid METIS input format, or
296 | METIS returns an error status.
297 | """
298 | if len(G) == 0:
299 | return [], [], []
300 |
301 | vwgt = [G.nodes[u].get(weight, 1) for u in G]
302 | if all(w == 1 for w in vwgt):
303 | vwgt = None
304 |
305 | xadj, adjncy = _convert_graph(G)
306 |
307 | with _zero_numbering(options):
308 | part = metis.compute_vertex_separator(xadj, adjncy, vwgt, options)[1]
309 |
310 | groups = [[], [], []]
311 | for u, i in zip(G, part):
312 | groups[i].append(u)
313 |
314 | return groups[2], groups[0], groups[1]
315 |
--------------------------------------------------------------------------------
/parse_declarations.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import torch
3 | from torch import Tensor
4 | import numpy as np
5 | import builtins
6 |
7 | ts = set()
8 |
9 |
10 | class Type():
11 | def __call__(self, other: type):
12 | raise NotImplementedError()
13 |
14 | def __str__(self):
15 | raise NotImplementedError()
16 |
17 |
18 | class UnionType(Type):
19 | def __init__(self, *types: Type):
20 | self.types = types
21 |
22 | def __call__(self, other):
23 | return any(t(other) for t in self.types)
24 |
25 | def __str__(self):
26 | return f"Union[{','.join([str(t) for t in self.types])}]"
27 |
28 |
29 | class BasicType(Type):
30 | def __init__(self, t: type):
31 | self.type = t
32 |
33 | def __call__(self, other):
34 | return isinstance(other, self.type)
35 |
36 | def __str__(self):
37 | return self.type.__name__
38 |
39 |
40 | class OptionalType(Type):
41 | def __init__(self, _type: Type):
42 | self.type = _type
43 |
44 | def __call__(self, other):
45 | return (other is None) or self.type(other)
46 |
47 | def __str__(self):
48 | return f"Optional[{self.type}]"
49 |
50 |
51 | class ListType(Type):
52 | def __init__(self, t):
53 | self.type = t
54 |
55 | def __call__(self, other):
56 | return(isinstance(other, list) and all(self.type(e) for e in other))
57 |
58 | def __str__(self):
59 | return f"List[{self.type}]"
60 |
61 |
62 | class TupleType(Type):
63 | def __init__(self, types, homogeneous=True):
64 | if homogeneous:
65 | assert isinstance(types, Type)
66 | self.types = types
67 | self.homogeneous = homogeneous
68 |
69 | def __call__(self, other):
70 | if not isinstance(other, tuple):
71 | return False
72 |
73 | if self.homogeneous:
74 | return all(self.types(t) for t in other)
75 |
76 | if len(self.types) != len(other):
77 | return False
78 |
79 | return all(e(a) for e, a in zip(self.types, other))
80 |
81 | def __str__(self):
82 | if self.homogeneous:
83 | return f"Tuple[{self.types},...]"
84 | else:
85 | return f"Tuple[{', '.join([str(t) for t in self.types])}]"
86 |
87 |
88 | class AnyType(Type):
89 | def __call__(self, other):
90 | return True
91 |
92 | def __str__(self):
93 | return "Any"
94 |
95 |
96 | def getLines(lines):
97 | for line in lines:
98 | line = line.strip().rstrip(" .=").replace(" ", "")
99 | if line in ['', '\n', '\r\n']:
100 | # skip empty lines
101 | continue
102 | elif line.startswith("import") or line.startswith("from"):
103 | # skip imports
104 | continue
105 | elif line.startswith("#"):
106 | # skip comments
107 | continue
108 | else:
109 | yield line
110 |
111 |
112 | def parse_function(line, types):
113 | function_decl = line[3:line.rindex(")") + 1]
114 | func_name = line[3:line.index("(")]
115 |
116 | args = function_decl[function_decl.index("(") + 1:-1]
117 | args = args.strip().split(",")
118 |
119 | i = 0
120 | keyword = False
121 | while i < len(args):
122 | arg = args[i]
123 |
124 | if ('[' in arg)and arg.count('[') > arg.count(']'):
125 | # this is a compound type, merge tokens untill brackets are balanced
126 | to_merge = [arg]
127 | cnt = arg.count('[') - arg.count(']')
128 | i += 1
129 | while i < len(args):
130 | to_merge.append(args[i])
131 | cnt += (args[i].count('[') - args[i].count(']'))
132 | i += 1
133 | if cnt == 0:
134 | break
135 | i -= 1
136 | arg = ",".join(to_merge)
137 |
138 | if ('(' in arg)and arg.count('(') > arg.count(')'):
139 | # this is a unbalanced tuple, merge tokens untill brackets are balanced
140 | to_merge = [arg]
141 | cnt = arg.count('(') - arg.count(')')
142 | i += 1
143 | while i < len(args):
144 | to_merge.append(args[i])
145 | cnt += (args[i].count('(') - args[i].count(')'))
146 | i += 1
147 | if cnt == 0:
148 | break
149 | i -= 1
150 | arg = ",".join(to_merge)
151 |
152 | if '*' == arg:
153 | # end of positionals
154 | keyword = True
155 | i += 1
156 | continue
157 | elif keyword or '=' in arg:
158 | # keyword
159 | keyword = True
160 | arg_name = arg.split(":")[0]
161 | if '=' in arg:
162 | arg_type, default_value = arg.split(":")[1].split("=")
163 | else:
164 | arg_type = arg.split(":")[1]
165 | elif '**' in arg:
166 | # **kwargs
167 | keyword = True
168 | arg_name = arg.split("*")[-1]
169 | arg_type = "dict"
170 | elif '*' in arg:
171 | # *args
172 | arg_name = arg.split(":")[0][1:]
173 | arg_type = f"Tuple[{arg.split(':')[1]}, ...]"
174 | keyword = True
175 | else:
176 | # a:cls
177 | arg_name, arg_type = arg.split(":")
178 |
179 | i += 1
180 | arg_type = arg_type.replace(" ", "")
181 | ts.add(arg_type)
182 | return func_name, args
183 |
184 |
185 | def generateTypes():
186 | types = dict()
187 | # builtins
188 | for s in ['int', 'bool', 'float', 'str', 'slice']:
189 | types[f"_{s}"] = BasicType(getattr(builtins, s))
190 | types[f"List[_{s}]"] = ListType(types[f"_{s}"])
191 | types[f"Tuple[_{s},...]"] = TupleType(types[f"_{s}"])
192 | types[f"Optional[_{s}]"] = OptionalType(types[f"_{s}"])
193 |
194 | types["List"] = BasicType(list)
195 | types["Tuple"] = BasicType(tuple)
196 | types["Number"] = UnionType(types['_int'], types['_bool'], types['_float'])
197 | types["Any"] = AnyType()
198 | types["_ndarray"] = BasicType(np.ndarray)
199 | types["Callable"] = inspect.isfunction
200 |
201 | # torch types
202 | for s in ['Tensor', 'layout', 'qscheme', 'Generator', "Storage", 'memory_format', "device", "dtype", "Size"]:
203 | types[f'_{s}'] = BasicType(getattr(torch, s))
204 | types[f"List[_{s}]"] = ListType(types[f"_{s}"])
205 | types[f"Tuple[_{s},...]"] = TupleType(types[f"_{s}"])
206 | types[f"Optional[_{s}]"] = OptionalType(types[f"_{s}"])
207 |
208 | # special cases
209 | types['_size'] = UnionType([types[s]
210 | for s in ['_Size', 'List[_int]', 'Tuple[_int,...]']])
211 | types['Union[_Tensor,List]'] = UnionType(types['_Tensor'], types['List'])
212 | types['Union[_int,List[_int]]'] = UnionType(types['_int'],
213 | types['List[_int]'])
214 | types['Union[_Tensor,Number]'] = UnionType(types['_Tensor'],
215 | types['Number'])
216 | types['Optional[_size]'] = OptionalType(types['_size'])
217 | types['Union[_int,_size]'] = UnionType(types['_int'], types['_size'])
218 | types['Optional[Union[_device,_str]]'] = OptionalType(UnionType(types['_device'],
219 | types['_str']))
220 | types['Optional[Union[_str,_dtype]]'] = OptionalType(UnionType(types['_str'],
221 | types['_dtype']))
222 | types['Union[Tuple[_Tensor,...],List[_Tensor]]'] = UnionType(
223 | types['Tuple[_Tensor,...]'], types['List[_Tensor]'])
224 | types['Optional[Union[Tuple[_Tensor,...],List[_Tensor]]]'] = OptionalType(
225 | types['Union[Tuple[_Tensor,...],List[_Tensor]]'])
226 |
227 | types['Optional[Union[_int,_slice,_Tensor,List,Tuple]]'] =\
228 | OptionalType(
229 | UnionType([types[s] for s in ['_int', '_slice', '_Tensor', 'List', 'Tuple']]))
230 |
231 | return types
232 |
233 |
234 | def parse():
235 | decl_file = "pytorch_Gpipe/model_partitioning/module_generation/declarations.pyi"
236 | types = generateTypes()
237 | functions = dict()
238 | is_torch = False
239 | with open(decl_file, "r") as f:
240 | for line in getLines(f.readlines()):
241 |
242 | if line.startswith('class'):
243 | # class declaration
244 | current_class = line[line.index(
245 | "class") + 5:].split(":")[0]
246 | # print(current_class)
247 | elif line.startswith('def'):
248 | # function decl
249 | func, args = parse_function(line, types)
250 | assert hasattr(Tensor, func) or hasattr(torch, func)
251 | elif line.startswith('@overload'):
252 | # function overload
253 | pass
254 |
255 |
256 | # function + args in order => positional/keyword
257 | if __name__ == "__main__":
258 | parse()
259 |
--------------------------------------------------------------------------------
/partition_torchvision_networks.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pytorch_Gpipe import partition_with_profiler, profileNetwork, distribute_by_memory, distribute_by_time, \
3 | distribute_using_profiler, pipe_model
4 | import torch
5 | from sample_models import alexnet, resnet101, vgg19_bn, squeezenet1_1, inception_v3, densenet201, GoogLeNet, LeNet, \
6 | WideResNet
7 | from sample_models import AmoebaNet_D as my_amoeaba, amoebanetd as ref_amoeba, torchgpipe_resnet101
8 |
9 | import torch.nn as nn
10 | from pytorch_Gpipe.utils import model_scopes
11 | import datetime
12 |
13 |
14 | def partition_torchvision(networks=None, nparts=4, depth=100, nruns=4, blocks=None,
15 | save_graph=False, show_scope_diff=False,
16 | dump_graph=False, **model_kwargs):
17 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
18 | if networks != None and not isinstance(networks, (list, tuple)):
19 | networks = [networks]
20 |
21 | if networks is None:
22 | networks = [my_amoeaba, ref_amoeba, alexnet, resnet101, torchgpipe_resnet101, vgg19_bn, squeezenet1_1,
23 | inception_v3, densenet201, GoogLeNet, LeNet, WideResNet]
24 |
25 | if not isinstance(nparts, (list, tuple)):
26 | nparts = [nparts]
27 |
28 | if not isinstance(depth, (list, tuple)):
29 | depth = [depth]
30 |
31 | for i in range(nruns):
32 | for net in networks:
33 | model = net(**model_kwargs).to(device)
34 | print("model built")
35 | basic_blocks = blocks
36 | for p in nparts:
37 | for d in depth:
38 | print(f"current net is {net.__name__}")
39 | if net.__name__.find("inception") != -1:
40 | graph = partition_with_profiler(
41 | model, torch.zeros(16, 3, 299, 299, device=device), nparts=p, max_depth=d,
42 | basic_blocks=basic_blocks)
43 | elif net.__name__.find("GoogLeNet") != -1:
44 | graph = partition_with_profiler(
45 | model, torch.zeros(16, 3, 32, 32, device=device), nparts=p, max_depth=d,
46 | basic_blocks=basic_blocks)
47 | elif net.__name__.find("LeNet") != -1:
48 | graph = partition_with_profiler(
49 | model, torch.zeros(16, 3, 32, 32, device=device), nparts=p, max_depth=d,
50 | basic_blocks=basic_blocks)
51 | elif net.__name__.find("moeba") != -1:
52 | graph = partition_with_profiler(
53 | model, torch.zeros(4, 3, 224, 224, device=device), nparts=p, max_depth=d, basic_blocks=basic_blocks)
54 | else:
55 | graph = partition_with_profiler(
56 | model, torch.zeros(16, 3, 224, 224, device=device), nparts=p, max_depth=d, basic_blocks=basic_blocks)
57 |
58 | time_stemp = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
59 | filename = f"{net.__name__}_run{i}_attempted_{p}_partitions_at_depth_{d}_{time_stemp}"
60 |
61 | curr_dir = os.path.dirname(os.path.realpath(__file__))
62 | out_dir = f"{curr_dir}\\partition_visualization"
63 |
64 | if dump_graph:
65 | graph.serialize(f"{curr_dir}\\graph_dump\\{filename}")
66 | if save_graph:
67 | graph.save(directory=out_dir, file_name=filename,
68 | show_buffs_params=False, show_weights=False)
69 | print(filename)
70 |
71 | if show_scope_diff:
72 | scopes = set(model_scopes(model, depth=d,
73 | basic_blocks=basic_blocks))
74 | graph_scopes = graph.scopes()
75 | diff = scopes.difference(graph_scopes)
76 | print(f"scope diff {len(diff)}")
77 | for s in diff:
78 | print(s)
79 | print("\n")
80 |
81 |
82 | def distribute_torchvision(networks=None, nparts=4, depth=100, nruns=4,
83 | fake_gpus=False, save_graph=False, show_scope_diff=False,
84 | optimize_pipeline_wrappers=True, dump_graph=False, **model_kwargs):
85 | if not torch.cuda.is_available():
86 | raise ValueError("CUDA is required")
87 |
88 | device = 'cuda:0'
89 | if networks != None and not isinstance(networks, (list, tuple)):
90 | networks = [networks]
91 |
92 | if networks is None:
93 | networks = [my_amoeaba, ref_amoeba, alexnet, resnet101, torchgpipe_resnet101, vgg19_bn, squeezenet1_1,
94 | inception_v3, densenet201, GoogLeNet, LeNet, WideResNet]
95 |
96 | if not isinstance(nparts, (list, tuple)):
97 | nparts = [nparts]
98 |
99 | if not isinstance(depth, (list, tuple)):
100 | depth = [depth]
101 |
102 | for i in range(nruns):
103 | for net in networks:
104 | model = net(**model_kwargs).to(device)
105 | print("model built")
106 | basic_blocks = None
107 | for p in nparts:
108 | if fake_gpus:
109 | devices = [f'cuda:0' for _ in range(p)]
110 | else:
111 | assert torch.cuda.device_count() == p
112 | devices = [f'cuda:{i}' for i in range(p)]
113 | for d in depth:
114 | print(f"current net is {net.__name__}")
115 | if net.__name__.find("inception") != -1:
116 | model, _, _, graph = distribute_using_profiler(
117 | model, torch.zeros(16, 3, 299, 299, device=device),
118 | optimize_pipeline_wrappers=optimize_pipeline_wrappers, devices=devices, max_depth=d,
119 | basic_blocks=basic_blocks)
120 | elif net.__name__.find("GoogLeNet") != -1:
121 | model, _, _, graph = distribute_using_profiler(
122 | model, torch.zeros(16, 3, 32, 32, device=device),
123 | optimize_pipeline_wrappers=optimize_pipeline_wrappers, devices=devices, max_depth=d,
124 | basic_blocks=basic_blocks)
125 | elif net.__name__.find("LeNet") != -1:
126 | model, _, _, graph = distribute_using_profiler(
127 | model, torch.zeros(16, 3, 32, 32, device=device),
128 | optimize_pipeline_wrappers=optimize_pipeline_wrappers, devices=devices, max_depth=d,
129 | basic_blocks=basic_blocks)
130 | elif net.__name__.find("moeba") != -1:
131 | model, _, _, graph = distribute_using_profiler(
132 | model, torch.zeros(8, 3, 224, 224, device=device),
133 | optimize_pipeline_wrappers=optimize_pipeline_wrappers, devices=devices, max_depth=d,
134 | basic_blocks=basic_blocks)
135 | else:
136 | model, _, _, graph = distribute_using_profiler(
137 | model, torch.zeros(16, 3, 224, 224, device=device),
138 | optimize_pipeline_wrappers=optimize_pipeline_wrappers, devices=devices, max_depth=d,
139 | basic_blocks=basic_blocks)
140 |
141 | time_stemp = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
142 | filename = f"{net.__name__}_run{i}_attempted_{p}_partitions_at_depth_{d}_{time_stemp}"
143 |
144 | curr_dir = os.path.dirname(os.path.realpath(__file__))
145 | out_dir = f"{curr_dir}\\distributed_models"
146 | if dump_graph:
147 | graph.serialize(f"{curr_dir}\\graph_dump\\{filename}")
148 | if save_graph:
149 | graph.save(directory=out_dir, file_name=filename,
150 | show_buffs_params=False, show_weights=False)
151 |
152 | if show_scope_diff:
153 | scopes = set(model_scopes(model, depth=d,
154 | basic_blocks=basic_blocks))
155 | graph_scopes = graph.scopes()
156 | diff = scopes.difference(graph_scopes)
157 | print(f"scope diff {len(diff)}")
158 | for s in diff:
159 | print(s)
160 | print("\n")
161 |
162 |
163 | if __name__ == "__main__":
164 | # distribute_torchvision(networks=[my_amoeaba, ref_amoeba], nparts=8,
165 | # fake_gpus=True, save_graph=True, nruns=2, num_layers=5)
166 |
167 | from sample_models.Inception import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE
168 | partition_torchvision(networks=inception_v3, save_graph=True, depth=[0, 1, 100], blocks=[
169 | BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE], nruns=2)
170 |
--------------------------------------------------------------------------------
/pytorch_Gpipe/METIS/METIS_manual.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alondj/Pytorch-Gpipe/7c9bf892fe65ffe3efdb8cb00a7052b665d41f75/pytorch_Gpipe/METIS/METIS_manual.pdf
--------------------------------------------------------------------------------
/pytorch_Gpipe/METIS/__init__.py:
--------------------------------------------------------------------------------
1 | from .METIS_graph_partition import *
2 |
--------------------------------------------------------------------------------
/pytorch_Gpipe/METIS/libmetis.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alondj/Pytorch-Gpipe/7c9bf892fe65ffe3efdb8cb00a7052b665d41f75/pytorch_Gpipe/METIS/libmetis.dll
--------------------------------------------------------------------------------
/pytorch_Gpipe/METIS/libmetis.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alondj/Pytorch-Gpipe/7c9bf892fe65ffe3efdb8cb00a7052b665d41f75/pytorch_Gpipe/METIS/libmetis.so
--------------------------------------------------------------------------------
/pytorch_Gpipe/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, List, Dict, Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .model_partitioning import generatePartitionModules, partition
7 | from .model_profiling import Graph, graph_builder, profileNetwork
8 | from .pipeline import Pipeline
9 | from .utils import Devices, Tensors
10 |
11 | __all__ = ['pipe_model', 'partition_with_profiler',
12 | 'partition', 'Pipeline']
13 |
14 |
15 | def pipe_model(model: nn.Module, sample_batch: Tensors, kwargs: Optional[Dict] = None, nparts: int = 4, partition_by_memory: bool = False, output_file: str = None, DEBUG=False):
16 | '''attemps to partition a model to given number of parts using our profiler
17 | this will produce a python file with the partition config
18 |
19 | the generated python file exposes a method named {modelClass}Pipeline that creates the pipeline
20 | for this specific model config
21 |
22 | Parameters:
23 | model:
24 | the network we wish to model
25 | sample_batch:
26 | a sample input to use for tracing
27 | nparts:
28 | the number of partitions
29 | partition_by_memory:
30 | whether to partition by memory consumption if False partitions by time defaults to False
31 | output_file:
32 | the file name in which to save the partition config
33 | if not given defualts to generated_{modelClass}{actualNumberOfPartitions}
34 | DEBUG:
35 | whether to generate the debug version of the partition more comments and assertions in the generated file
36 | '''
37 | def by_time(w):
38 | if hasattr(w, 'forward_time') and hasattr(w, 'backward_time'):
39 | return max(int(100 * (w.forward_time + w.backward_time) / 2), 1)
40 | return 1
41 |
42 | def by_memory(w):
43 | if hasattr(w, 'cuda_memory_forward') and hasattr(w, 'cuda_memory_backward'):
44 | return max(int(100 * (w.cuda_memory_forward + w.cuda_memory_backward) / 2), 1)
45 | return 1
46 |
47 | if partition_by_memory:
48 | w_func = by_memory
49 | else:
50 | w_func = by_time
51 |
52 | graph = partition_with_profiler(model, sample_batch, kwargs=kwargs, nparts=nparts,
53 | weighting_function=w_func)
54 |
55 | generatePartitionModules(graph, model,
56 | output_file=output_file, verbose=DEBUG)
57 |
58 | return graph
59 |
60 |
61 | def partition_with_profiler(model: nn.Module, sample_batch: Tensors, kwargs: Optional[Dict] = None, nparts=4, max_depth=100, basic_blocks: Optional[List[nn.Module]] = None, weighting_function: Optional[Callable[[Any], int]] = None) -> Graph:
62 | '''
63 | return a graph representing the partitioned model with the weights given by the profiler
64 | this method does not distribute the model accross devices
65 |
66 | Parameters:
67 | model:
68 | the network we wish to model
69 | sample_batch:
70 | a sample input to use for tracing
71 | nparts:
72 | the number of partitions
73 | max_depth:
74 | how far down we go in the model tree determines the detail level of the graph
75 | basic_blocks:
76 | an optional list of modules that if encountered will not be broken down
77 | weighting_function:
78 | an optional function from node weights to non negative integers if not provided a default function will be used
79 | '''
80 | graph = graph_builder(model, sample_batch, kwargs=kwargs, max_depth=max_depth,
81 | basic_blocks=basic_blocks, use_profiler=True)
82 |
83 | graph = partition(graph, nparts, weighting_function=weighting_function)
84 |
85 | return graph
86 |
--------------------------------------------------------------------------------
/pytorch_Gpipe/delayedNorm.py:
--------------------------------------------------------------------------------
1 | """Tracks the running statistics per mini-batch instead of micro-batch."""
2 | from typing import Optional
3 |
4 | import torch
5 | from torch import Tensor, nn
6 | import torch.nn.functional as F
7 | from torch.nn.modules.batchnorm import _BatchNorm
8 |
9 | # taken form torchGpipe repo
10 |
11 |
12 | class DelayedBatchNorm(_BatchNorm):
13 | """A BatchNorm layer tracks multiple micro-batches to update running
14 | statistics per mini-batch.
15 | """
16 |
17 | def __init__(self,
18 | num_features: int,
19 | eps: float = 1e-5,
20 | momentum: Optional[float] = 0.1,
21 | affine: bool = True,
22 | num_micro_batches: int = 1,
23 | is_recomputing: bool = False
24 | ):
25 | super().__init__(num_features, eps, momentum, affine, track_running_stats=True)
26 |
27 | self.register_buffer('sum', torch.zeros_like(self.running_mean))
28 | self.register_buffer('sum_squares', torch.zeros_like(self.running_var))
29 |
30 | self.counter = 0
31 | self.tracked = 0
32 | self.num_micro_batches = num_micro_batches
33 | self.is_recomputing = is_recomputing
34 |
35 | def _check_input_dim(self, x: Tensor):
36 | if x.dim() <= 2:
37 | raise ValueError(
38 | 'expected at least 3D input (got %dD input)' % x.dim())
39 |
40 | def _track(self, x: Tensor) -> bool:
41 | """Tracks statistics of a micro-batch."""
42 | # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d.
43 | dim = [0]
44 | dim.extend(range(2, x.dim()))
45 |
46 | with torch.no_grad():
47 | self.sum += x.sum(dim)
48 | self.sum_squares += (x**2).sum(dim)
49 |
50 | size = x.size().numel() // x.size(1)
51 | self.counter += size
52 | self.tracked += 1
53 |
54 | return (self.tracked == self.num_micro_batches)
55 |
56 | def _commit(self):
57 | """Updates the running statistics of a mini-batch."""
58 | exponential_average_factor = 0.0
59 | self.num_batches_tracked += 1
60 | if self.momentum is None: # use cumulative moving average
61 | exponential_average_factor = 1.0 / float(self.num_batches_tracked)
62 | else: # use exponential moving average
63 | exponential_average_factor = self.momentum
64 |
65 | mean = self.sum / self.counter
66 | var = self.sum_squares / self.counter - mean**2
67 |
68 | # Calculate the exponential moving average here.
69 | m = exponential_average_factor
70 |
71 | self.running_mean *= 1 - m
72 | self.running_mean += mean * m
73 |
74 | self.running_var *= 1 - m
75 | self.running_var += var * m
76 |
77 | self.sum.zero_()
78 | self.sum_squares.zero_()
79 | self.counter = 0
80 | self.tracked = 0
81 |
82 | def forward(self, x: Tensor) -> Tensor: # type: ignore
83 | if not self.training:
84 | # Don't train parameters on the evaluation mode.
85 | return F.batch_norm(
86 | x,
87 | running_mean=self.running_mean,
88 | running_var=self.running_var,
89 | weight=self.weight,
90 | bias=self.bias,
91 | training=False,
92 | momentum=0.0,
93 | eps=self.eps,
94 | )
95 |
96 | if not self.is_recomputing:
97 | # Track a micro-batch on the training mode
98 | # but not under a recomputation.
99 | tracked_enough = self._track(x)
100 |
101 | # Update the running statistics for a mini-batch
102 | # if it has tracked enough micro-batches.
103 | if tracked_enough:
104 | self._commit()
105 |
106 | # Normalize a micro-batch and train the parameters.
107 | return F.batch_norm(
108 | x,
109 | running_mean=None,
110 | running_var=None,
111 | weight=self.weight,
112 | bias=self.bias,
113 | training=True,
114 | momentum=0.0,
115 | eps=self.eps,
116 | )
117 |
118 | @classmethod
119 | def convertBatchNorm(cls, module: nn.Module, num_micro_batches: int = 1) -> nn.Module:
120 | """Converts a :class:`nn.BatchNorm` or underlying
121 | :class:`nn.BatchNorm`s into :class:`DelayedBatchNorm`::
122 | from torchvision.models.resnet import resnet101
123 | from pytorch_Gpipe.delayedNorm import DelayedBatchNorm
124 | model = resnet101()
125 | model = DelayedBatchNorm.convertBatchNorm(model)
126 | """
127 | if isinstance(module, DelayedBatchNorm) and module.num_micro_batches is num_micro_batches:
128 | return module
129 |
130 | if isinstance(module, _BatchNorm) and module.track_running_stats:
131 | module_output = DelayedBatchNorm(module.num_features,
132 | module.eps,
133 | module.momentum,
134 | module.affine,
135 | num_micro_batches)
136 | if module.affine:
137 | module_output.register_parameter('weight', module.weight)
138 | module_output.register_parameter('bias', module.bias)
139 | module_output.register_buffer('running_mean', module.running_mean)
140 | module_output.register_buffer('running_var', module.running_var)
141 | module_output.register_buffer(
142 | 'num_batches_tracked', module.num_batches_tracked)
143 |
144 | return module_output
145 |
146 | for name, child in module.named_children():
147 | module.add_module(
148 | name, cls.convertBatchNorm(child, num_micro_batches))
149 |
150 | return module
151 |
--------------------------------------------------------------------------------
/pytorch_Gpipe/model_partitioning/__init__.py:
--------------------------------------------------------------------------------
1 | from .module_generation import generatePartitionModules
2 | from .partition_graph import partiton_graph as partition
3 | from .process_partition import post_process_partition
4 |
5 | __all__ = ["post_process_partition",
6 | "partition", "generatePartitionModules"]
7 |
--------------------------------------------------------------------------------
/pytorch_Gpipe/model_partitioning/module_generation/__init__.py:
--------------------------------------------------------------------------------
1 | from .generate import generatePartitionModules
2 |
--------------------------------------------------------------------------------
/pytorch_Gpipe/model_partitioning/module_generation/constructor.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Dict, Set
2 | import re
3 | from itertools import chain
4 | from torch.nn import Module
5 | tab = ' '
6 | dtab = tab + tab
7 |
8 |
9 | def generateConstructor(class_name: str, full_names: List[str], layer_classes: Dict[str, Module],
10 | is_param_dict: Dict[str, bool], buff_param_names: Set[str]) -> Tuple[str, Dict[str, str]]:
11 | '''creates the partition constructor and the mapping between layers and field ids
12 | '''
13 | class_decl = f"class {class_name}(nn.Module):"
14 | init_dec = f"{tab}def __init__(self, layers, buffers, parameters):"
15 | super_init = f'{dtab}super({class_name}, self).__init__()'
16 | layer_names = [f'self.l_{idx}' for idx, _ in enumerate(full_names)]
17 | layers_init = generate__init__layersStatements(layer_names, full_names,
18 | layer_classes)
19 | scope = dict(zip(full_names, layer_names))
20 |
21 | params, buffs = [], []
22 | for k, v in is_param_dict.items():
23 | if k not in buff_param_names:
24 | continue
25 | elif v:
26 | params.append(k)
27 | else:
28 | buffs.append(k)
29 |
30 | tensor_init, tensor_ids = generate__init__BuffParamStatements(buffs,
31 | params)
32 | lookup = generateLookup(scope, tensor_ids)
33 | scope.update(tensor_ids)
34 |
35 | device_id = re.search(r'\d+$', class_name).group()
36 |
37 | # we initialize it to expected device if DEBUG then the pipeline will set it to cpu device
38 | device = f"{dtab}self.device = torch.device('cuda:{device_id}')"
39 |
40 | return '\n'.join([class_decl, init_dec, super_init, layers_init, tensor_init, device, lookup]) + '\n', scope
41 |
42 |
43 | def generate__init__layersStatements(layer_names: List[str], full_names: List[str], layer_classes: Dict[str, Module]) -> str:
44 | ''' generates partition field initialization statements\n
45 | and comments to describe which scope is allocated to which field
46 | '''
47 | statements = [f'{dtab}# initializing partition layers',
48 | generate__init__assertGuards(len(layer_names))]
49 |
50 | for field, full_name in zip(layer_names, full_names):
51 | statements.extend([f"# {full_name}",
52 | f"assert '{full_name}' in layers, 'layer {full_name} was expected but not given'",
53 | f"{field} = layers['{full_name}']"])
54 | class_name = layer_classes[full_name].__name__
55 | error_msg = f"f'layers[{full_name}] is expected to be of type {class_name} but was of type {{type({field})}}'"
56 | statements.append(
57 | f"assert isinstance({field},{class_name}) ,{error_msg}")
58 | return f'\n{dtab}'.join(statements)
59 |
60 |
61 | def generate__init__assertGuards(nlayers: int) -> str:
62 | ''' generate assert guards ensuring we recieve the necessary amount of layers\n
63 | in the *layers vararg argument of the constructor
64 | '''
65 | assert_statements = f"assert isinstance(layers,dict), f'expected layers to be of type dict but got type{{type(layers)}}'\n"
66 | assert_statements += f"{dtab}assert(len(layers) == {nlayers})\n"
67 | assert_statements += f"{dtab}assert(all(isinstance(k, str) for k in layers.keys())), 'string keys are expected'\n"
68 | assert_statements += f"{dtab}assert(all(isinstance(v, nn.Module) for v in layers.values())), 'Module values are expected'"
69 | return assert_statements
70 |
71 |
72 | def generate__init__BuffParamStatements(buffers: List[str], parameters: List[str]) -> str:
73 | tensor_ids = {}
74 | lines = [f"\n{dtab}# initializing partition buffers",
75 | f"assert isinstance(buffers,dict), f'expected buffers to be of type dict got {{type(buffers)}}'",
76 | f"assert len(buffers) == {len(buffers)}, f'expected buffers to have {len(buffers)} elements but has {{len(buffers)}} elements'",
77 | f"assert all(isinstance(k,str) for k in buffers.keys()), 'string keys are expected'",
78 | f"assert all(isinstance(v,Tensor) for v in buffers.values()), 'Tensor values are expected'"]
79 | for idx, b_name in enumerate(buffers):
80 | lines.extend([f"# {b_name}",
81 | f"assert '{b_name}' in buffers, '{b_name} buffer was expected but not given'",
82 | f"self.register_buffer('b_{idx}',buffers['{b_name}'])"])
83 | tensor_ids[b_name] = f'self.b_{idx}'
84 |
85 | lines.extend([f"\n{dtab}# initializing partition parameters",
86 | f"assert isinstance(parameters,dict), f'expected parameters to be of type dict got {{type(parameters)}}'",
87 | f"assert len(parameters) == {len(parameters)}, f'expected parameters to have {len(parameters)} elements but has {{len(parameters)}} elements'",
88 | f"assert all(isinstance(k,str) for k in parameters.keys()), 'string keys are expected'",
89 | f"assert all(isinstance(v,Tensor) for v in parameters.values()), 'Tensor values are expected'"])
90 | for idx, p_name in enumerate(parameters):
91 | lines.extend([f"# {p_name}",
92 | f"assert '{p_name}' in parameters, '{p_name} parameter was expected but not given'",
93 | f"self.p_{idx} = parameters['{p_name}']"])
94 | tensor_ids[p_name] = f'self.p_{idx}'
95 |
96 | return f'\n{dtab}'.join(lines), tensor_ids
97 |
98 |
99 | def generateLookup(layers_to_id, tensors_to_id):
100 | # first generate lookup table
101 | {'p_0': 'w',
102 | 'l_1': 'module0.sub1.linear'}
103 | lookup = []
104 | for scope, id in chain(layers_to_id.items(), tensors_to_id.items()):
105 | # scope: testMod/Linear[linear0] id: l_0
106 | # we will have 2 keys: l_0.weight l_0.bias
107 | # we wish to replace l_0 with linear0
108 | # resulting in keys: linear0.weight linear0.bias
109 | # for eg scope testMod/Mod0[a]/Sub[b] => a.b
110 | fields = re.findall("\[[a-zA-Z0-9_]*\]", scope)
111 | fields = map(lambda s: s[1:-1:], fields)
112 | prefix = '.'.join(fields)
113 | # remove the self. part of the id
114 | lookup.append(f"'{id[5:]}': '{prefix}'")
115 | lookup = f",\n{dtab}{dtab}{dtab}".join(lookup)
116 | return f"{dtab}self.lookup = {{ {lookup}}}"
117 |
--------------------------------------------------------------------------------
/pytorch_Gpipe/model_partitioning/module_generation/generate.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch import Tensor
4 | from torch.nn import Module
5 | import torch.nn.functional as F
6 | from pytorch_Gpipe.model_profiling.control_flow_graph import Node, NodeTypes, Graph
7 | from pytorch_Gpipe.utils import traverse_model, traverse_params_buffs
8 | import string
9 | from .forward import generateForwardFunction
10 | from .constructor import generateConstructor
11 | from .misc import generateMiscMethods
12 | from typing import List, Tuple, Dict
13 | from pytorch_Gpipe.utils import OrderedSet
14 | from collections import OrderedDict
15 | import inspect
16 |
17 | tab = ' '
18 | dtab = tab + tab
19 |
20 |
21 | def generatePartitionModules(graph: Graph, model: Module, verbose=False, output_file=None):
22 | layer_classes = {scope: type(layer) for layer, scope, _
23 | in traverse_model(model, depth=graph.depth)}
24 | is_param_dict = {scope: t.requires_grad for t,
25 | scope in traverse_params_buffs(model)}
26 |
27 | parts = groupByPartition(graph.nodes)
28 |
29 | lines = generateImports(layer_classes)
30 | lines.append(connections(graph))
31 | ios = []
32 | # the main code generation loop generating a class decl
33 | # and forward function
34 | partitions_code = []
35 | ios = dict()
36 | for idx, part in parts:
37 | class_name = f'{graph.model_name}Partition{idx}'
38 | layer_names = [n.scope for n in part if n.type == NodeTypes.LAYER]
39 | buff_param_names = {n.scope for n in part
40 | if n.type == NodeTypes.BUFF_PARAM}
41 | class_decl, scope_to_class_field = generateConstructor(class_name, layer_names,
42 | layer_classes, is_param_dict,
43 | buff_param_names)
44 | misc_functions = generateMiscMethods()
45 | forward_function, io = generateForwardFunction(part, graph.output_scopes, scope_to_class_field,
46 | verbose=verbose)
47 | partitions_code.append(class_decl)
48 | partitions_code.extend(forward_function)
49 | partitions_code.append(misc_functions)
50 | ios[idx] = io
51 |
52 | lines.append(generatePiplineAndGetConfig(graph, parts, model, ios))
53 | lines += partitions_code
54 |
55 | if output_file is None:
56 | output_file = f'generated_{graph.model_name}{len(parts)}'
57 |
58 | output_file = output_file + '.py'
59 |
60 | with open(output_file, 'w') as f:
61 | f.write('\n'.join(lines))
62 |
63 |
64 | def groupByPartition(nodes: List[Node]) -> List[Tuple[int, List[Node]]]:
65 | # groups layers and their respective nodes according to their partition
66 | idxs = {n.part for n in nodes}
67 | parts = OrderedDict()
68 | for i in sorted(idxs):
69 | parts[i] = []
70 |
71 | for n in nodes:
72 | if n.type == NodeTypes.IN:
73 | continue
74 | elif n.type == NodeTypes.BUFF_PARAM:
75 | parts[n.part].append(n)
76 | elif n.type == NodeTypes.LAYER:
77 | parts[n.part].append(n)
78 | elif n.type == NodeTypes.OP:
79 | scope = n.scope
80 | # we handle torch,Tensor and torch.nn.functional nameSpaces
81 | func_name = getFunctionName(scope)
82 | if hasattr(torch, func_name) or hasattr(F, func_name) or hasattr(Tensor, func_name):
83 | parts[n.part].append(n)
84 | elif 'aten::slice' in scope:
85 | parts[n.part].append(n)
86 | else:
87 | assert False, f'could not find nameSpace for {scope}'
88 | elif n.type == NodeTypes.PYTHON_PRIMITIVE:
89 | scope = n.scope
90 | assert 'prim::' in scope, f'primitive does not have prim:: prefix {scope}'
91 | func_name = scope.split('prim::')[1].rstrip(string.digits)
92 | parts[n.part].append(n)
93 | else:
94 | assert n.type == NodeTypes.CONSTANT, f'got type {n.type}'
95 | parts[n.part].append(n)
96 |
97 | return parts.items()
98 |
99 |
100 | def generateImports(layer_classes: Dict[str, Module]) -> List[str]:
101 | '''generates imports to torch torch.nn, torch.nn.functionl as F and torch.Tensor,
102 | and to every layer used and various other small things
103 | '''
104 | imports = 'import torch\nfrom torch import Tensor\nimport torch.nn as nn\nimport torch.nn.functional as F\n'
105 | imports += 'from itertools import chain\n'
106 | imports += 'from pytorch_Gpipe.utils import layerDict, tensorDict, OrderedSet\n'
107 | imports += 'from pytorch_Gpipe import Pipeline\n'
108 | unique_classes = set(layer_classes.values())
109 |
110 | for cls in unique_classes:
111 | imports += f'from {inspect.getmodule(cls).__name__} import {cls.__name__}\n'
112 |
113 | disclaimer = '# this is an auto generated file do not edit unless you know what you are doing\n\n'
114 |
115 | return imports.splitlines() + [disclaimer]
116 |
117 |
118 | def getFunctionName(scope: str) -> str:
119 | if 'aten::' in scope:
120 | sep = 'aten::'
121 | else:
122 | assert 'prim::' in scope, f"attempting to find function name but got {scope}"
123 | sep = 'prim::'
124 |
125 | return scope.split(sep)[1].rstrip(string.digits)
126 |
127 |
128 | def createConfig(graph: Graph, partitions: List[List[Node]], model: Module, ios: Dict[int, OrderedSet]):
129 | model_buffers = {scope: t for t, scope in traverse_params_buffs(model)
130 | if not t.requires_grad}
131 | model_parameteres = {scope: t for t, scope in traverse_params_buffs(model)
132 | if t.requires_grad}
133 | model_class = model.__class__.__name__
134 | # function header
135 | lines = [
136 | f"def createConfig(model,DEBUG=False,partitions_only=False):",
137 | "layer_dict = layerDict(model)",
138 | "tensor_dict = tensorDict(model)",
139 | f"\n{tab}# now constructing the partitions in order"
140 | ]
141 |
142 | # hard code which layers buffers and parameters belong to each partition
143 | construction_args = []
144 | for idx, part in partitions:
145 | layer_scopes = [f"'{n.scope}'"
146 | for n in part if n.type == NodeTypes.LAYER]
147 | buffer_scopes = [
148 | f"'{n.scope}'" for n in part if n.scope in model_buffers]
149 | parameter_scopes = [f"'{n.scope}'" for n in part
150 | if n.scope in model_parameteres]
151 | construction_args.append(
152 | (layer_scopes, buffer_scopes, parameter_scopes))
153 |
154 | # create partition generation statements
155 | for idx, (layer_scopes, buffer_scopes, parameter_scopes) in zip(sorted(list(ios.keys())), construction_args):
156 | l_scopes = 'layer_scopes = [' + f",\n{dtab}".join(layer_scopes) + ']'
157 | b_scopes = 'buffer_scopes = [' + f",\n{dtab}".join(buffer_scopes) + ']'
158 | p_scopes = 'parameter_scopes = [' + \
159 | f",\n{dtab}".join(parameter_scopes) + ']'
160 | lines.extend([l_scopes, b_scopes, p_scopes,
161 | f"layers = {{l: layer_dict[l] for l in layer_scopes}}",
162 | f"buffers = {{b: tensor_dict[b] for b in buffer_scopes}}",
163 | f"parameters = {{p: tensor_dict[p] for p in parameter_scopes}}",
164 | f"partition{idx} = {model_class}Partition{idx}(layers,buffers,parameters)\n"])
165 |
166 | # create and return the partition config
167 | exp = f',\n{dtab}{tab}'.join([f"{k}: {v}" for k, v in ios.items()])
168 | lines.append(
169 | f"# creating configuration\n{tab}config = {{{exp}\n{dtab}{tab}}}")
170 |
171 | for idx in sorted(list(ios.keys())):
172 | lines.extend([f"device = 'cpu' if DEBUG else torch.device('cuda:{idx}')",
173 | f"partition{idx}.device=device",
174 | f"config[{idx}]['model'] = partition{idx}.to(device)"])
175 |
176 | input_ids = [f"'input{idx}'" for idx in range(graph.num_inputs)]
177 | lines.extend([f"config['model inputs'] = [{', '.join(input_ids)}]",
178 | f"config['model outputs'] = {list(graph.output_scopes)}",
179 | f"\n{tab}return [config[i]['model'] for i in range({len(ios)})] if partitions_only else config"])
180 |
181 | return f"\n{tab}".join(lines) + "\n"
182 |
183 |
184 | def generatePiplineAndGetConfig(graph: Graph, partitions: List[List[Node]], model: Module, ios: Dict[int, OrderedSet]):
185 | '''generates function that will perform the actual partition returning a Pipeline object\n
186 | the function will have the partition config hardcoded into it,\n
187 | enabling us to perform the partition process once and use the config multiple times
188 | '''
189 | model_class = model.__class__.__name__
190 | config = createConfig(graph, partitions, model, ios)
191 |
192 | lines = [
193 | f'\ndef {model_class}Pipeline(model:nn.Module,output_device=None,split_dim=0,use_delayedNorm=False,DEBUG=False):']
194 |
195 | lines.append(
196 | f"return Pipeline(createConfig(model,DEBUG=DEBUG,partitions_only=False),output_device=output_device,split_dim=split_dim,use_delayedNorm=use_delayedNorm)\n\n",
197 | )
198 |
199 | return f'\n{tab}'.join(lines) + f"\n{config}\n"
200 |
201 |
202 | def connections(graph: Graph):
203 | adj_matrix = [{"inputs": set(), "outputs": set()}
204 | for i in range(graph.num_parts + 2)]
205 |
206 | for node in graph.nodes:
207 | if node.idx < graph.num_inputs:
208 | for n in node.out_nodes:
209 | adj_matrix[n.part + 1]["inputs"].add(node.scope)
210 | adj_matrix[0]["outputs"].add(n.part)
211 |
212 | idx = graph.output_scopes.indexOf(node.scope)
213 |
214 | if idx >= 0:
215 | adj_matrix[graph.num_parts + 1]["inputs"].add(node.part)
216 | adj_matrix[node.part + 1]["outputs"].add(f"output{idx}")
217 |
218 | for n in node.out_nodes:
219 | if n.part != node.part:
220 | adj_matrix[node.part + 1]["outputs"].add(n.part)
221 | adj_matrix[n.part + 1]["inputs"].add(node.part)
222 |
223 | lines = ["# partition adjacency"]
224 | lines.append(f"# model inputs {adj_matrix[0]['outputs']}")
225 | for i, line in enumerate(adj_matrix[1:-1:]):
226 | lines.append(f"# partition {i} {line}")
227 | lines.append(
228 | f"# model outputs {adj_matrix[graph.num_parts + 1]['inputs']}")
229 | return '\n'.join(lines) + '\n'
230 |
--------------------------------------------------------------------------------
/pytorch_Gpipe/model_partitioning/module_generation/misc.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Dict, Set
2 | from torch.nn import Module
3 | tab = ' '
4 | dtab = tab + tab
5 |
6 |
7 | def generateMiscMethods():
8 | state_dict = generateStateDictFunction()
9 | load_state_dict = generateLoadStateDict()
10 | named_parameters = generateNamedParametersFunction()
11 | named_buffers = generateNamedBuffersFunction()
12 |
13 | return "\n".join([state_dict, load_state_dict, named_parameters, named_buffers]) + "\n\n"
14 |
15 |
16 | def generateStateDictFunction():
17 | state_dict_function = ["def state_dict(self,device):",
18 | f"# we return the state dict of this part as it should be in the original model",
19 | "state = super().state_dict()",
20 | f"lookup = self.lookup",
21 | "result = dict()",
22 | "for k, v in state.items():",
23 | f"{tab}if k in lookup:",
24 | f"{dtab}result[lookup[k]] = v if device is None else v.to(device)",
25 | f"{tab}else:",
26 | f"{dtab}assert '.' in k",
27 | f"{dtab}split_idx = k.find('.')",
28 | f"{dtab}new_k = lookup[k[:split_idx]] + k[split_idx:]",
29 | f"{dtab}result[new_k] = v if device is None else v.to(device)",
30 | f"return result"]
31 |
32 | return f"{tab}" + f"\n{dtab}".join(state_dict_function)
33 |
34 |
35 | def generateNamedParametersFunction():
36 | named_parameters_function = ["def named_parameters(self):",
37 | f"# we return the named parameters of this part as it should be in the original model",
38 | "params = super().named_parameters()",
39 | f"lookup = self.lookup",
40 | "for k, v in params:",
41 | f"{tab}if k in lookup:",
42 | f"{dtab}yield (lookup[k],v)",
43 | f"{tab}else:",
44 | f"{dtab}assert '.' in k",
45 | f"{dtab}split_idx = k.find('.')",
46 | f"{dtab}new_k = lookup[k[:split_idx]] + k[split_idx:]",
47 | f"{dtab}yield (new_k, v)"]
48 | return f"\n{tab}" + f"\n{dtab}".join(named_parameters_function)
49 |
50 |
51 | def generateNamedBuffersFunction():
52 | named_buffers_function = ["def named_buffers(self):",
53 | f"# we return the named buffers of this part as it should be in the original model",
54 | "params = super().named_buffers()",
55 | f"lookup = self.lookup",
56 | "for k, v in params:",
57 | f"{tab}if k in lookup:",
58 | f"{dtab}yield (lookup[k],v)",
59 | f"{tab}else:",
60 | f"{dtab}assert '.' in k",
61 | f"{dtab}split_idx = k.find('.')",
62 | f"{dtab}new_k = lookup[k[:split_idx]] + k[split_idx:]",
63 | f"{dtab}yield (new_k, v)"]
64 | return f"\n{tab}" + f"\n{dtab}".join(named_buffers_function)
65 |
66 |
67 | def generateLoadStateDict():
68 | func = ['def load_state_dict(self, state):',
69 | 'reverse_lookup = {v: k for k, v in self.lookup.items()}',
70 | 'ts = chain(self.named_parameters(), self.named_buffers())',
71 | 'device = list(ts)[0][1].device',
72 | 'keys = list(self.state_dict(None).keys())',
73 | 'new_state = dict()',
74 | 'for k in keys:',
75 | tab + 'if k in reverse_lookup:',
76 | dtab + 'new_state[reverse_lookup[k]] = state[k].to(device)',
77 | dtab + 'continue',
78 | tab + 'idx = k.rfind(".")',
79 | tab + 'to_replace = k[:idx]',
80 | tab + 'if to_replace in reverse_lookup:',
81 | dtab + 'key = reverse_lookup[to_replace] + k[idx:]',
82 | dtab + 'new_state[key] = state[k].to(device)',
83 | 'super().load_state_dict(new_state, strict=True)']
84 |
85 | return f"\n{tab}" + f"\n{dtab}".join(func)
86 |
--------------------------------------------------------------------------------
/pytorch_Gpipe/model_partitioning/partition_graph.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Any, Callable, Optional
3 |
4 | from ..model_profiling import Graph
5 | from .process_partition import post_process_partition
6 | import networkx as nx
7 | import nxmetis
8 |
9 | __all__ = ["partiton_graph"]
10 |
11 |
12 | def partition_METIS(graph: Graph, num_partitions: int, weighting_function: Optional[Callable[[Any], int]] = None, **METIS_opts) -> Graph:
13 | '''
14 | partition the graph using METIS's PartGraphKway and then optimizes it to our needs
15 |
16 | Parameters
17 | ----------
18 | graph:
19 | the Graph object to partition
20 | num_partitions:
21 | the requested number of partitions
22 | weighting_function:
23 | a weighting function that transforms the graph weights to non negative integers
24 | if not specified a default function will be used
25 | METIS_opts:
26 | additional options to pass to METIS
27 | for eg. for the option METIS_OPTION_SEED pass seed=value
28 | '''
29 | from ..METIS import METIS_partition
30 | wfunc = weighting_function if weighting_function != None else default_weight_func
31 |
32 | adjlist = graph.adjacency_list()
33 | nodew = graph.get_weights().values()
34 |
35 | assert(len(adjlist) == len(nodew))
36 |
37 | weights = [wfunc(w) for w in nodew]
38 |
39 | if 'seed' not in METIS_opts:
40 | METIS_opts['seed'] = 0
41 |
42 | if 'contig' not in METIS_opts:
43 | METIS_opts['contig'] = 1
44 |
45 | partition, _ = METIS_partition(adjlist, nparts=num_partitions, algorithm="metis",
46 | nodew=weights, **METIS_opts)
47 |
48 | post_process_partition(graph, partition)
49 |
50 | actual_nparts = len({n.part for n in graph.nodes})
51 |
52 | if(actual_nparts < num_partitions):
53 | print(
54 | f"expected {num_partitions} partitions but only {actual_nparts} found implicating that the model to partition is too small")
55 | print("consider increasing the depth of graph or disabling the basic blocks option")
56 | return graph
57 |
58 |
59 | def default_weight_func(w):
60 | if hasattr(w, 'forward_time') and hasattr(w, 'backward_time'):
61 | return max(int(100 * (w.forward_time + w.backward_time) / 2), 1)
62 | return 1
63 |
64 |
65 | def partiton_graph(graph: Graph, num_partitions: int, weighting_function: Optional[Callable[[Any], int]] = None, **METIS_opts):
66 | wfunc = weighting_function if weighting_function != None else default_weight_func
67 |
68 | weights = {node.idx: wfunc(node.weight) for node in graph.nodes}
69 |
70 | G = graph.asNetworkx()
71 | nx.set_node_attributes(G, weights, 'weight')
72 |
73 | _, parts = nxmetis.partition(G, num_partitions)
74 |
75 | parts = sorted((idx, n) for n, p in enumerate(parts)for idx in p)
76 | parts = [n for _, n in parts]
77 |
78 | post_process_partition(graph, parts)
79 |
80 | actual_nparts = len({n.part for n in graph.nodes})
81 |
82 | if(actual_nparts < num_partitions):
83 | print(
84 | f"expected {num_partitions} partitions but only {actual_nparts} found implicating that the model to partition is too small")
85 | print("consider increasing the depth of graph or disabling the basic blocks option")
86 | return graph
87 |
--------------------------------------------------------------------------------
/pytorch_Gpipe/model_partitioning/process_partition.py:
--------------------------------------------------------------------------------
1 | from collections import Counter, deque
2 | from typing import Dict, List
3 |
4 | from ..model_profiling import Graph, NodeTypes
5 |
6 | __all__ = ["post_process_partition"]
7 |
8 |
9 | def post_process_partition(graph: Graph, part: List[int]):
10 | '''
11 | process the partition and optimize it
12 | called as part of partition_graph method
13 |
14 | Parameters:
15 | ----------
16 | graph:
17 | the Graph object that was partitioned
18 | part:
19 | a list of the nodes partition indices
20 | '''
21 |
22 | for node, idx in zip(graph.nodes, part):
23 | node.part = idx
24 |
25 | cannonize_partition_indices(graph)
26 | make_partitions_change_only_at_end_of_scope(graph)
27 | # make sure every scc in the graph is not splitted between different parts
28 | scc_partition_correction(graph)
29 | ensure_dag(graph, part)
30 |
31 | cannonize_partition_indices(graph)
32 | # TODO we disabled this optimization
33 | # fix_arithmetic_inputs(graph)
34 | return
35 |
36 |
37 | def fix_arithmetic_inputs(graph: Graph):
38 | while True:
39 | changed = False
40 | for node in graph.nodes:
41 | if node.type is NodeTypes.OP:
42 | for n in node.in_nodes:
43 | if n.part != node.part:
44 | n.part = node.part
45 | changed = True
46 | if not changed:
47 | break
48 |
49 |
50 | def ensure_dag(graph: Graph, node_parts: List[int]):
51 | flag = True
52 | while flag:
53 | flag, prob_edge = not_dag(graph, node_parts)
54 |
55 | if flag:
56 | fix_problem_node(graph, prob_edge)
57 |
58 |
59 | def not_dag(graph: Graph, node_parts):
60 | part_edges = []
61 | num_parts = len(set(node_parts))
62 | for node in graph.nodes:
63 | for out_node in node.out_nodes:
64 | if node.part != out_node.part:
65 | part_edge = (node.part, out_node.part)
66 | if part_edge not in part_edges:
67 | part_edges.append(part_edge)
68 |
69 | for num_part1 in range(num_parts):
70 | for num_part2 in range(num_parts):
71 | if (num_part1, num_part2) in part_edges and (num_part2, num_part1) in part_edges and num_part1 < num_part2:
72 | return True, (num_part1, num_part2)
73 |
74 | return False, (-1, -1)
75 |
76 |
77 | def fix_problem_node(graph: Graph, prob_edge: tuple):
78 | first_part, second_part = prob_edge
79 | for node in graph.nodes:
80 | if node.part == second_part:
81 | for o_node in node.out_nodes:
82 | if o_node.part == first_part:
83 | node.part = first_part
84 |
85 |
86 | def scc_partition_correction(graph: Graph):
87 | # create the scc graph
88 | vertices = [v.idx for v in graph.nodes]
89 | edges = {}
90 | for v in graph.nodes:
91 | idx_out_nodes = [h.idx for h in v.out_nodes]
92 | edges.update({v.idx: idx_out_nodes})
93 |
94 | for scc in strongly_connected_components_iterative(vertices, edges):
95 | # check if the scc is splitted between 2 parts or more
96 | scc_parts = []
97 | for v in scc:
98 | if graph.nodes[v].part not in scc_parts:
99 | scc_parts.append(graph.nodes[v].part)
100 | if len(scc_parts) >= 2:
101 | break
102 | # if he is splitted:
103 | if len(scc_parts) >= 2:
104 | output_part = -1
105 | # find out what part edges go to from this scc
106 | for v in scc:
107 | for out in graph.nodes[v].out_nodes:
108 | if out.idx not in scc:
109 | output_part = graph.nodes[out.idx].part
110 | break
111 | if output_part != -1:
112 | break
113 | # update the scc part to the part we found
114 | for v in scc:
115 | graph.nodes[v].part = output_part
116 |
117 |
118 | def strongly_connected_components_iterative(vertices: List[int], edges: Dict[int, List[int]]):
119 | identified = set()
120 | stack = []
121 | index = {}
122 | boundaries = []
123 |
124 | for v in vertices:
125 | if v not in index:
126 | to_do = [('VISIT', v)]
127 | while to_do:
128 | operation_type, v = to_do.pop()
129 | if operation_type == 'VISIT':
130 | index[v] = len(stack)
131 | stack.append(v)
132 | boundaries.append(index[v])
133 | to_do.append(('POSTVISIT', v))
134 | # We reverse to keep the search order identical to that of
135 | # the recursive code; the reversal is not necessary for
136 | # correctness, and can be omitted.
137 | to_do.extend(
138 | reversed([('VISITEDGE', w) for w in edges[v]]))
139 | elif operation_type == 'VISITEDGE':
140 | if v not in index:
141 | to_do.append(('VISIT', v))
142 | elif v not in identified:
143 | while index[v] < boundaries[-1]:
144 | boundaries.pop()
145 | else:
146 | # operation_type == 'POSTVISIT'
147 | if boundaries[-1] == index[v]:
148 | boundaries.pop()
149 | scc = set(stack[index[v]:])
150 | del stack[index[v]:]
151 | identified.update(scc)
152 | yield scc
153 |
154 |
155 | def cannonize_partition_indices(graph: Graph):
156 | num_parts = len({n.part for n in graph.nodes})
157 | num_taken = 0
158 | model_inputs = [node for node in graph.nodes if node.type == NodeTypes.IN]
159 | open_nodes = deque(model_inputs)
160 | closed = set()
161 | cannonical_parts = dict()
162 |
163 | while num_taken < num_parts:
164 | node = open_nodes.popleft()
165 | if node in closed or node in open_nodes:
166 | continue
167 | if node.part not in cannonical_parts:
168 | cannonical_parts[node.part] = num_taken
169 | num_taken += 1
170 |
171 | closed.add(node)
172 | edges = node.out_nodes.union(node.in_nodes)
173 | nodes = edges.difference(closed, set(open_nodes))
174 | open_nodes.extend(nodes)
175 |
176 | for node in graph.nodes:
177 | node.part = cannonical_parts[node.part]
178 |
179 | graph.num_parts = len(cannonical_parts)
180 |
181 |
182 | def make_partitions_change_only_at_end_of_scope(graph: Graph):
183 | def is_first_in_partition(node):
184 | return any(other.part != node.part for other in node.in_nodes)
185 |
186 | first_nodes_of_partition = filter(is_first_in_partition, graph.nodes)
187 |
188 | for node in first_nodes_of_partition:
189 | scope_depth = node.scope.count('/') - 1
190 | # dont do it too shallow
191 | if scope_depth >= 2: # TODO think about threshold
192 | parent_scope = node.scope.rsplit('/', 1)[0]
193 |
194 | def in_scope(n):
195 | return parent_scope == n.scope.rsplit('/', 1)[0]
196 |
197 | scope_nodes = list(filter(in_scope, graph.nodes))
198 | parts = [n.part for n in scope_nodes]
199 | part_histogram = Counter(parts)
200 | most_common, num_layers = part_histogram.most_common(1)[0]
201 | if num_layers >= len(parts) // 2:
202 | for other in scope_nodes:
203 | other.part = most_common
204 |
--------------------------------------------------------------------------------
/pytorch_Gpipe/model_profiling/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from ..utils import Tensors, traverse_model, traverse_params_buffs, model_scopes, _count_elements
7 | from .control_flow_graph import Graph, NodeTypes
8 | from .network_profiler import profileNetwork
9 |
10 | __all__ = ['graph_builder', 'profileNetwork']
11 |
12 |
13 | def graph_builder(model: nn.Module, sample_batch: Tensors, kwargs: Optional[Dict] = None, max_depth: int = 1000, weights: Optional[Dict[str, Any]] = None, basic_blocks: Optional[List[nn.Module]] = None, use_profiler=False) -> Graph:
14 | '''
15 | returns a graph that models the control flow of the given network by tracing it's forward pass
16 |
17 | Parameters:
18 | model:
19 | the network we wish to model
20 | sample_batch:
21 | a sample input to use for tracing
22 | kwargs:
23 | keyword args to pass to the model
24 | max_depth:
25 | how far down we go in the model tree determines the detail level of the graph
26 | basic_blocks:
27 | an optional list of modules that if encountered will not be broken down
28 | weights:
29 | an optional dictionary from scopes to Node weights
30 | use_profiler:
31 | wether to use weights given by our profiler
32 | this option supersedes the wieghts option defaults to False
33 | '''
34 | weights = weights if weights != None else {}
35 | if kwargs is None:
36 | kwargs = {}
37 | if not isinstance(sample_batch, tuple):
38 | sample_batch = (sample_batch,)
39 |
40 | if use_profiler:
41 | weights = profileNetwork(model, sample_batch, kwargs=kwargs, max_depth=max_depth,
42 | basic_blocks=basic_blocks)
43 |
44 | buffer_param_names = map(lambda t: t[1], traverse_params_buffs(model))
45 | buffer_param_names = list(buffer_param_names)
46 |
47 | layerNames = model_scopes(model, depth=max_depth,
48 | basic_blocks=basic_blocks)
49 | layerNames = list(layerNames)
50 |
51 | # trace the model and build a graph
52 | with torch.no_grad():
53 | trace_graph, _ = torch.jit.get_trace_graph(
54 | model, sample_batch, kwargs)
55 | trace_graph = trace_graph.graph()
56 |
57 | num_inputs = _count_elements(*sample_batch) + len(kwargs)
58 |
59 | graph = Graph(layerNames, num_inputs, buffer_param_names,
60 | trace_graph, weights, basic_blocks, max_depth)
61 |
62 | return graph
63 |
--------------------------------------------------------------------------------
/pytorch_Gpipe/model_profiling/network_profiler.py:
--------------------------------------------------------------------------------
1 | import time
2 | from collections import namedtuple
3 | from typing import Dict, List, Optional
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from ..utils import Tensors, _detach_inputs, _get_size, get_device, traverse_model
9 |
10 | __all__ = ['profileNetwork', 'Profile']
11 |
12 | Profile = namedtuple('Profile',
13 | 'forward_time backward_time cuda_memory_forward cuda_memory_backward layer_size')
14 |
15 |
16 | def profileNetwork(net: nn.Module, sample_batch: Tensors, kwargs: Optional[Dict] = None, basic_blocks: Optional[List[nn.Module]] = None, max_depth=100) -> Dict[str, Profile]:
17 | '''
18 | profiles a network's computation time(forward/backward) and memory consumption
19 | returns a dictionary from layer_scope to Profile
20 |
21 | Parameters
22 | ----------
23 | net:
24 | the network we wish to profile a nn.Module
25 |
26 | sample_batch:
27 | a sample batch that will be used to measure executation time of network
28 | can be single/multiple inputs
29 |
30 | kwargs:
31 | keyword args to pass to the profiled model
32 |
33 | basic_blocks:
34 | a tuple of nn.Module classes that the profiler will regard as a cohesive unit
35 | for eg. if basic_blocks = nn.Sequential then the profiler will break it down to its components
36 |
37 | max_depth:
38 | determins how far the profiler will go in the model tree
39 |
40 |
41 |
42 | '''
43 | if kwargs is None:
44 | kwargs = {}
45 | if not isinstance(sample_batch, tuple):
46 | sample_batch = (sample_batch,)
47 |
48 | # wrap all individula layers for profiling
49 | layers_dict = _wrap_profiled_layers(net, max_depth, basic_blocks)
50 |
51 | # perform 2 symbolic forward backward run first one is warmup as we have seen the first time measurements are higher
52 | _perform_forward_backward_pass(net, *sample_batch, **kwargs)
53 | _perform_forward_backward_pass(net, *sample_batch, **kwargs)
54 |
55 | # gather forward and backward execution times
56 | backward_times = [layer.backward_time
57 | for layer in layers_dict.values()]
58 | forward_times = [layer.forward_time
59 | for layer in layers_dict.values()]
60 |
61 | # gather input and output sizes
62 | layer_input_sizes = [layer.input_size for layer in layers_dict.values()]
63 | layer_output_sizes = [layer.output_size for layer in layers_dict.values()]
64 |
65 | # gather all individual layer sizes
66 | param_sizes = [layer.parameters_size for layer in layers_dict.values()]
67 | buffer_sizes = [layer.buffers_size for layer in layers_dict.values()]
68 |
69 | # gather cuda memory consumption
70 | cuda_memory = [(layer.forward_cuda_mem, layer.backward_cuda_mem)
71 | for layer in layers_dict.values()]
72 |
73 | # prepare profiling results
74 | layers_profile = {name: Profile(forward, backward, *cuda_mem, param_size + buffer_size + in_size + out_size) for name, forward, backward, param_size, buffer_size, in_size, out_size, cuda_mem in zip(
75 | layers_dict.keys(), forward_times, backward_times, param_sizes, buffer_sizes, layer_input_sizes, layer_output_sizes, cuda_memory)}
76 |
77 | _unwrap_layers(net)
78 |
79 | return layers_profile
80 |
81 |
82 | def _perform_forward_backward_pass(net, *sample_batch: Tensors, **kwargs: Dict):
83 | device = get_device(sample_batch)
84 | if device.type == "cuda":
85 | torch.cuda.synchronize(device=device)
86 | out = net(*sample_batch, **kwargs)
87 | torch.cuda.synchronize(device=device)
88 | else:
89 | out = net(*sample_batch, **kwargs)
90 | net.zero_grad()
91 | return out
92 |
93 |
94 | def _wrap_profiled_layers(module: nn.Module, depth, basic_blocks: List[nn.Module]):
95 | layers_dict = {}
96 |
97 | for sub_layer, scope, parent in traverse_model(module, depth, basic_blocks):
98 | name = scope[scope.rfind('[') + 1:-1]
99 | wrapper = Wrapper(sub_layer)
100 | parent.add_module(name, wrapper)
101 | layers_dict[scope] = wrapper
102 |
103 | return layers_dict
104 |
105 |
106 | def _unwrap_layers(module: nn.Module):
107 | for name, sub_module in module.named_children():
108 | if isinstance(sub_module, Wrapper):
109 | module.add_module(name, sub_module.layer)
110 | else:
111 | _unwrap_layers(sub_module)
112 |
113 |
114 | class Wrapper(nn.Module):
115 | '''
116 | a module whose purpose is to profile a given layer\n
117 | when the wrapper performs forward propagation it records the following metrics:\n
118 | forward_time: the execution time of a forward pass of the underlying layer in milliseconds\n
119 | backward_time: the execution time of a backward pass of the underlying layer in milliseconds\n
120 | input_size: the input size in MB
121 | output_size: the layer output size in MB
122 | parameters_size: the size of parameters of the layer in MB
123 | buffers_size: the size of buffers of the layer in MB
124 | forward_cuda_mem: the peak CUDA memory usage during the forward pass in MB
125 | backward_cuda_mem: the peak CUDA memory usage during the backward pass in MB
126 |
127 | Parameters
128 | ----------
129 | sub_module:
130 | a nn.module to be profiled
131 |
132 | '''
133 |
134 | def __init__(self, sub_module: nn.Module):
135 | super(Wrapper, self).__init__()
136 | self.layer = sub_module
137 | self.forward_time = 0
138 | self.backward_time = 0
139 | self.input_size = 0
140 | self.output_size = 0
141 | self.parameters_size, self.buffers_size = self._layer_size()
142 | self.forward_cuda_mem = 0
143 | self.backward_cuda_mem = 0
144 |
145 | def _layer_size(self):
146 | '''
147 | return the size of the layer considering parameters and buffers
148 | '''
149 | parameters_size = buffers_size = 0
150 | for param in self.layer.parameters():
151 | parameters_size += param.nelement() * param.element_size()
152 | for buffer in self.layer.buffers():
153 | buffers_size += buffer.nelement() * buffer.element_size()
154 |
155 | return parameters_size, buffers_size
156 |
157 | def forward(self, *inputs: Tensors, **kwargs: Dict):
158 | '''
159 | perform forward and backward pass of the underlying layer and measure metrics
160 | '''
161 | # detach inputs from previous history enabling us to measure execution time
162 | # only for this layer
163 | device = get_device(inputs)
164 | detached_inputs = _detach_inputs(inputs)
165 |
166 | self.forward_time, outputs, self.forward_cuda_mem = self._time_op(
167 | self.layer, *detached_inputs, **kwargs)
168 |
169 | # reduce outputs to calculate dummy loss
170 | loss = torch.zeros(1, requires_grad=True, device=device)
171 | for out in outputs:
172 | loss = loss + out.norm()
173 |
174 | # measure backward execution time
175 | self.backward_time, _, self.backward_cuda_mem = self._time_op(
176 | torch.autograd.backward, loss)
177 |
178 | # input and output size
179 | self.input_size = _get_size(inputs)
180 | self.output_size = _get_size(outputs)
181 |
182 | #size in MegaBytes
183 | self.backward_cuda_mem /= 1e6
184 | self.forward_cuda_mem /= 1e6
185 | self.input_size /= 1e6
186 | self.output_size /= 1e6
187 | self.parameters_size /= 1e6
188 | self.buffers_size /= 1e6
189 |
190 | return outputs
191 |
192 | def _time_op(self, func, *inputs: Tensors, **kwargs: Dict):
193 | exec_time = 0
194 | cuda_mem = 0
195 | device = get_device(inputs)
196 | if(device.type == 'cuda'):
197 | torch.cuda.reset_max_memory_allocated(device=device)
198 | base_mem = torch.cuda.max_memory_allocated(device=device)
199 |
200 | # measure execution time
201 | torch.cuda.synchronize(device=device)
202 | start = torch.cuda.Event(enable_timing=True)
203 | end = torch.cuda.Event(enable_timing=True)
204 | start.record()
205 | out = func(*inputs, **kwargs)
206 | end.record()
207 | torch.cuda.synchronize(device=device)
208 | exec_time = (start.elapsed_time(end))
209 |
210 | # record memory usage
211 | peak_usage = torch.cuda.max_memory_allocated(device=device)
212 | cuda_mem = peak_usage - base_mem
213 | else:
214 | # convert seconds to milliseconds
215 | start = time.time()
216 | out = func(*inputs, **kwargs)
217 | end = time.time()
218 | exec_time = 1000 * (end - start)
219 |
220 | return exec_time, out, cuda_mem
221 |
222 | # just in case those operations are required we pass them to the profiled layer
223 |
224 | def __iter__(self):
225 | return iter(self.layer)
226 |
227 | def __getitem__(self, key):
228 | return self.layer[key]
229 |
230 | def __setitem__(self, key, value):
231 | self.layer[key] = value
232 |
233 | def __delitem__(self, idx):
234 | delattr(self.layer, idx)
235 |
236 | def __len__(self):
237 | return len(self.layer)
238 |
239 | def __contains__(self, key):
240 | return key in self.layer
241 |
242 | def __getattr__(self, name):
243 | try:
244 | return super().__getattr__(name)
245 | except Exception:
246 | return getattr(self.layer, name)
247 |
--------------------------------------------------------------------------------
/sample_models/AlexNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 |
5 | __all__ = ['AlexNet', 'alexnet']
6 |
7 |
8 | model_urls = {
9 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
10 | }
11 |
12 |
13 | class AlexNet(nn.Module):
14 |
15 | def __init__(self, num_classes=1000):
16 | super(AlexNet, self).__init__()
17 | self.features = nn.Sequential(
18 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
19 | nn.ReLU(inplace=True),
20 | nn.MaxPool2d(kernel_size=3, stride=2),
21 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
22 | nn.ReLU(inplace=True),
23 | nn.MaxPool2d(kernel_size=3, stride=2),
24 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
25 | nn.ReLU(inplace=True),
26 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
27 | nn.ReLU(inplace=True),
28 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
29 | nn.ReLU(inplace=True),
30 | nn.MaxPool2d(kernel_size=3, stride=2),
31 | )
32 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
33 | self.classifier = nn.Sequential(
34 | nn.Dropout(),
35 | nn.Linear(256 * 6 * 6, 4096),
36 | nn.ReLU(inplace=True),
37 | nn.Dropout(),
38 | nn.Linear(4096, 4096),
39 | nn.ReLU(inplace=True),
40 | nn.Linear(4096, num_classes),
41 | )
42 |
43 | def forward(self, x):
44 | x = self.features(x)
45 | x = self.avgpool(x)
46 | x = x.view(x.size(0), 256 * 6 * 6)
47 | x = self.classifier(x)
48 | return x
49 |
50 |
51 | def alexnet(pretrained=False, **kwargs):
52 | r"""AlexNet model architecture from the
53 | `"One weird trick..." `_ paper.
54 |
55 | Args:
56 | pretrained (bool): If True, returns a model pre-trained on ImageNet
57 | """
58 | model = AlexNet(**kwargs)
59 | if pretrained:
60 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
61 | return model
62 |
--------------------------------------------------------------------------------
/sample_models/DenseNet.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import torch.nn as nn
4 | import torch.utils.model_zoo as model_zoo
5 | from collections import OrderedDict
6 |
7 | __all__ = ['DenseNet', 'densenet121',
8 | 'densenet169', 'densenet201', 'densenet161']
9 |
10 |
11 | model_urls = {
12 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
13 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
14 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
15 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
16 | }
17 |
18 |
19 | class _DenseLayer(nn.Sequential):
20 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
21 | super(_DenseLayer, self).__init__()
22 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
23 | self.add_module('relu1', nn.ReLU(inplace=True)),
24 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
25 | growth_rate, kernel_size=1, stride=1, bias=False)),
26 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
27 | self.add_module('relu2', nn.ReLU(inplace=True)),
28 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
29 | kernel_size=3, stride=1, padding=1, bias=False)),
30 | self.drop_rate = drop_rate
31 | if self.drop_rate > 0:
32 | self.add_module("dropout", nn.Dropout(p=self.drop_rate))
33 |
34 | def forward(self, x):
35 | new_features = super(_DenseLayer, self).forward(x)
36 | return torch.cat([x, new_features], 1)
37 |
38 |
39 | class _DenseBlock(nn.Sequential):
40 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
41 | super(_DenseBlock, self).__init__()
42 | for i in range(num_layers):
43 | layer = _DenseLayer(num_input_features + i *
44 | growth_rate, growth_rate, bn_size, drop_rate)
45 | self.add_module('denselayer%d' % (i + 1), layer)
46 |
47 |
48 | class _Transition(nn.Sequential):
49 | def __init__(self, num_input_features, num_output_features):
50 | super(_Transition, self).__init__()
51 | self.add_module('norm', nn.BatchNorm2d(num_input_features))
52 | self.add_module('relu', nn.ReLU(inplace=True))
53 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
54 | kernel_size=1, stride=1, bias=False))
55 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
56 |
57 |
58 | class DenseNet(nn.Module):
59 | r"""Densenet-BC model class, based on
60 | `"Densely Connected Convolutional Networks" `_
61 |
62 | Args:
63 | growth_rate (int) - how many filters to add each layer (`k` in paper)
64 | block_config (list of 4 ints) - how many layers in each pooling block
65 | num_init_features (int) - the number of filters to learn in the first convolution layer
66 | bn_size (int) - multiplicative factor for number of bottle neck layers
67 | (i.e. bn_size * k features in the bottleneck layer)
68 | drop_rate (float) - dropout rate after each dense layer
69 | num_classes (int) - number of classification classes
70 | """
71 |
72 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
73 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):
74 |
75 | super(DenseNet, self).__init__()
76 |
77 | # First convolution
78 | self.features = nn.Sequential(OrderedDict([
79 | ('conv0', nn.Conv2d(3, num_init_features,
80 | kernel_size=7, stride=2, padding=3, bias=False)),
81 | ('norm0', nn.BatchNorm2d(num_init_features)),
82 | ('relu0', nn.ReLU(inplace=True)),
83 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
84 | ]))
85 |
86 | # Each denseblock
87 | num_features = num_init_features
88 | for i, num_layers in enumerate(block_config):
89 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
90 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
91 | self.features.add_module('denseblock%d' % (i + 1), block)
92 | num_features = num_features + num_layers * growth_rate
93 | if i != len(block_config) - 1:
94 | trans = _Transition(
95 | num_input_features=num_features, num_output_features=num_features // 2)
96 | self.features.add_module('transition%d' % (i + 1), trans)
97 | num_features = num_features // 2
98 |
99 | # Final batch norm
100 | self.features.add_module('norm5', nn.BatchNorm2d(num_features))
101 |
102 | self.relu = nn.ReLU()
103 | self.adaptive_avg_pool2d = nn.AdaptiveAvgPool2d((1, 1))
104 |
105 | # Linear layer
106 | self.classifier = nn.Linear(num_features, num_classes)
107 |
108 | # Official init from torch repo.
109 | for m in self.modules():
110 | if isinstance(m, nn.Conv2d):
111 | nn.init.kaiming_normal_(m.weight)
112 | elif isinstance(m, nn.BatchNorm2d):
113 | nn.init.constant_(m.weight, 1)
114 | nn.init.constant_(m.bias, 0)
115 | elif isinstance(m, nn.Linear):
116 | nn.init.constant_(m.bias, 0)
117 |
118 | def forward(self, x):
119 | features = self.features(x)
120 | out = self.relu(features)
121 | out = self.adaptive_avg_pool2d(out)
122 | out = out.view(out.size(0), -1)
123 | out = self.classifier(out)
124 | return out
125 |
126 |
127 | def densenet121(pretrained=False, **kwargs):
128 | r"""Densenet-121 model from
129 | `"Densely Connected Convolutional Networks" `_
130 |
131 | Args:
132 | pretrained (bool): If True, returns a model pre-trained on ImageNet
133 | """
134 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
135 | **kwargs)
136 | if pretrained:
137 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
138 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
139 | # They are also in the checkpoints in model_urls. This pattern is used
140 | # to find such keys.
141 | pattern = re.compile(
142 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
143 | state_dict = model_zoo.load_url(model_urls['densenet121'])
144 | for key in list(state_dict.keys()):
145 | res = pattern.match(key)
146 | if res:
147 | new_key = res.group(1) + res.group(2)
148 | state_dict[new_key] = state_dict[key]
149 | del state_dict[key]
150 | model.load_state_dict(state_dict)
151 | return model
152 |
153 |
154 | def densenet169(pretrained=False, **kwargs):
155 | r"""Densenet-169 model from
156 | `"Densely Connected Convolutional Networks" `_
157 |
158 | Args:
159 | pretrained (bool): If True, returns a model pre-trained on ImageNet
160 | """
161 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
162 | **kwargs)
163 | if pretrained:
164 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
165 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
166 | # They are also in the checkpoints in model_urls. This pattern is used
167 | # to find such keys.
168 | pattern = re.compile(
169 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
170 | state_dict = model_zoo.load_url(model_urls['densenet169'])
171 | for key in list(state_dict.keys()):
172 | res = pattern.match(key)
173 | if res:
174 | new_key = res.group(1) + res.group(2)
175 | state_dict[new_key] = state_dict[key]
176 | del state_dict[key]
177 | model.load_state_dict(state_dict)
178 | return model
179 |
180 |
181 | def densenet201(pretrained=False, **kwargs):
182 | r"""Densenet-201 model from
183 | `"Densely Connected Convolutional Networks" `_
184 |
185 | Args:
186 | pretrained (bool): If True, returns a model pre-trained on ImageNet
187 | """
188 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
189 | **kwargs)
190 | if pretrained:
191 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
192 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
193 | # They are also in the checkpoints in model_urls. This pattern is used
194 | # to find such keys.
195 | pattern = re.compile(
196 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
197 | state_dict = model_zoo.load_url(model_urls['densenet201'])
198 | for key in list(state_dict.keys()):
199 | res = pattern.match(key)
200 | if res:
201 | new_key = res.group(1) + res.group(2)
202 | state_dict[new_key] = state_dict[key]
203 | del state_dict[key]
204 | model.load_state_dict(state_dict)
205 | return model
206 |
207 |
208 | def densenet161(pretrained=False, **kwargs):
209 | r"""Densenet-161 model from
210 | `"Densely Connected Convolutional Networks" `_
211 |
212 | Args:
213 | pretrained (bool): If True, returns a model pre-trained on ImageNet
214 | """
215 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
216 | **kwargs)
217 | if pretrained:
218 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
219 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
220 | # They are also in the checkpoints in model_urls. This pattern is used
221 | # to find such keys.
222 | pattern = re.compile(
223 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
224 | state_dict = model_zoo.load_url(model_urls['densenet161'])
225 | for key in list(state_dict.keys()):
226 | res = pattern.match(key)
227 | if res:
228 | new_key = res.group(1) + res.group(2)
229 | state_dict[new_key] = state_dict[key]
230 | del state_dict[key]
231 | model.load_state_dict(state_dict)
232 | return model
233 |
--------------------------------------------------------------------------------
/sample_models/GoogleNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | __all__ = ["GoogLeNet"]
5 |
6 |
7 | class Inception(nn.Module):
8 | def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
9 | super(Inception, self).__init__()
10 | # 1x1 conv branch
11 | self.b1 = nn.Sequential(
12 | nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
13 | nn.BatchNorm2d(kernel_1_x),
14 | nn.ReLU(True),
15 | )
16 |
17 | # 1x1 conv -> 3x3 conv branch
18 | self.b2 = nn.Sequential(
19 | nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
20 | nn.BatchNorm2d(kernel_3_in),
21 | nn.ReLU(True),
22 | nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
23 | nn.BatchNorm2d(kernel_3_x),
24 | nn.ReLU(True),
25 | )
26 |
27 | # 1x1 conv -> 5x5 conv branch
28 | self.b3 = nn.Sequential(
29 | nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
30 | nn.BatchNorm2d(kernel_5_in),
31 | nn.ReLU(True),
32 | nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
33 | nn.BatchNorm2d(kernel_5_x),
34 | nn.ReLU(True),
35 | nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
36 | nn.BatchNorm2d(kernel_5_x),
37 | nn.ReLU(True),
38 | )
39 |
40 | # 3x3 pool -> 1x1 conv branch
41 | self.b4 = nn.Sequential(
42 | nn.MaxPool2d(3, stride=1, padding=1),
43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1),
44 | nn.BatchNorm2d(pool_planes),
45 | nn.ReLU(True),
46 | )
47 |
48 | def forward(self, x):
49 | y1 = self.b1(x)
50 | y2 = self.b2(x)
51 | y3 = self.b3(x)
52 | y4 = self.b4(x)
53 | return torch.cat([y1, y2, y3, y4], 1)
54 |
55 |
56 | class GoogLeNet(nn.Module):
57 | def __init__(self, num_classes=1000):
58 | super(GoogLeNet, self).__init__()
59 | self.pre_layers = nn.Sequential(
60 | nn.Conv2d(3, 192, kernel_size=3, padding=1),
61 | nn.BatchNorm2d(192),
62 | nn.ReLU(True),
63 | )
64 |
65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
67 |
68 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
69 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
70 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
71 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
72 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
73 |
74 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
75 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
76 |
77 | self.linear = nn.Linear(1024, num_classes)
78 |
79 | def forward(self, x):
80 | x = self.pre_layers(x)
81 | x = self.a3(x)
82 | x = self.b3(x)
83 | x = F.max_pool2d(x, 3, stride=2, padding=1)
84 | x = self.a4(x)
85 | x = self.b4(x)
86 | x = self.c4(x)
87 | x = self.d4(x)
88 | x = self.e4(x)
89 | x = F.max_pool2d(x, 3, stride=2, padding=1)
90 | x = self.a5(x)
91 | x = self.b5(x)
92 | x = F.avg_pool2d(x, 8, stride=1)
93 | x = x.view(x.size(0), -1)
94 | x = self.linear(x)
95 | return x
96 |
--------------------------------------------------------------------------------
/sample_models/LeNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | __all__ = ["LeNet"]
4 |
5 |
6 | class LeNet(nn.Module):
7 | def __init__(self, num_classes=1000):
8 | super(LeNet, self).__init__()
9 | self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
10 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
11 | self.fc1 = nn.Linear(16*5*5, 120)
12 | self.fc2 = nn.Linear(120, 84)
13 | self.fc3 = nn.Linear(84, num_classes)
14 | self.relu1 = nn.ReLU()
15 | self.relu2 = nn.ReLU()
16 | self.relu3 = nn.ReLU()
17 | self.relu4 = nn.ReLU()
18 | self.max_pool2d1 = nn.MaxPool2d(2)
19 | self.max_pool2d2 = nn.MaxPool2d(2)
20 |
21 | def forward(self, x):
22 | x = self.relu1(self.conv1(x))
23 | x = self.max_pool2d1(x)
24 | x = self.relu2(self.conv2(x))
25 | x = self.max_pool2d2(x)
26 | x = x.view(x.size(0), -1)
27 | x = self.relu3(self.fc1(x))
28 | x = self.relu4(self.fc2(x))
29 | x = self.fc3(x)
30 | return x
31 |
--------------------------------------------------------------------------------
/sample_models/SqueezeNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.utils.model_zoo as model_zoo
5 |
6 |
7 | __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
8 |
9 |
10 | model_urls = {
11 | 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
12 | 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
13 | }
14 |
15 |
16 | class Fire(nn.Module):
17 |
18 | def __init__(self, inplanes, squeeze_planes,
19 | expand1x1_planes, expand3x3_planes):
20 | super(Fire, self).__init__()
21 | self.inplanes = inplanes
22 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
23 | self.squeeze_activation = nn.ReLU(inplace=True)
24 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
25 | kernel_size=1)
26 | self.expand1x1_activation = nn.ReLU(inplace=True)
27 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
28 | kernel_size=3, padding=1)
29 | self.expand3x3_activation = nn.ReLU(inplace=True)
30 |
31 | def forward(self, x):
32 | x = self.squeeze_activation(self.squeeze(x))
33 | return torch.cat([
34 | self.expand1x1_activation(self.expand1x1(x)),
35 | self.expand3x3_activation(self.expand3x3(x))
36 | ], 1)
37 |
38 |
39 | class SqueezeNet(nn.Module):
40 |
41 | def __init__(self, version=1.0, num_classes=1000):
42 | super(SqueezeNet, self).__init__()
43 | if version not in [1.0, 1.1]:
44 | raise ValueError("Unsupported SqueezeNet version {version}:"
45 | "1.0 or 1.1 expected".format(version=version))
46 | self.num_classes = num_classes
47 | if version == 1.0:
48 | self.features = nn.Sequential(
49 | nn.Conv2d(3, 96, kernel_size=7, stride=2),
50 | nn.ReLU(inplace=True),
51 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
52 | Fire(96, 16, 64, 64),
53 | Fire(128, 16, 64, 64),
54 | Fire(128, 32, 128, 128),
55 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
56 | Fire(256, 32, 128, 128),
57 | Fire(256, 48, 192, 192),
58 | Fire(384, 48, 192, 192),
59 | Fire(384, 64, 256, 256),
60 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
61 | Fire(512, 64, 256, 256),
62 | )
63 | else:
64 | self.features = nn.Sequential(
65 | nn.Conv2d(3, 64, kernel_size=3, stride=2),
66 | nn.ReLU(inplace=True),
67 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
68 | Fire(64, 16, 64, 64),
69 | Fire(128, 16, 64, 64),
70 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
71 | Fire(128, 32, 128, 128),
72 | Fire(256, 32, 128, 128),
73 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
74 | Fire(256, 48, 192, 192),
75 | Fire(384, 48, 192, 192),
76 | Fire(384, 64, 256, 256),
77 | Fire(512, 64, 256, 256),
78 | )
79 | # Final convolution is initialized differently form the rest
80 | final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
81 | self.classifier = nn.Sequential(
82 | nn.Dropout(p=0.5),
83 | final_conv,
84 | nn.ReLU(inplace=True),
85 | nn.AdaptiveAvgPool2d((1, 1))
86 | )
87 |
88 | for m in self.modules():
89 | if isinstance(m, nn.Conv2d):
90 | if m is final_conv:
91 | init.normal_(m.weight, mean=0.0, std=0.01)
92 | else:
93 | init.kaiming_uniform_(m.weight)
94 | if m.bias is not None:
95 | init.constant_(m.bias, 0)
96 |
97 | def forward(self, x):
98 | x = self.features(x)
99 | x = self.classifier(x)
100 | return x.view(x.size(0), self.num_classes)
101 |
102 |
103 | def squeezenet1_0(pretrained=False, **kwargs):
104 | r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
105 | accuracy with 50x fewer parameters and <0.5MB model size"
106 | `_ paper.
107 |
108 | Args:
109 | pretrained (bool): If True, returns a model pre-trained on ImageNet
110 | """
111 | model = SqueezeNet(version=1.0, **kwargs)
112 | if pretrained:
113 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0']))
114 | return model
115 |
116 |
117 | def squeezenet1_1(pretrained=False, **kwargs):
118 | r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
119 | `_.
120 | SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
121 | than SqueezeNet 1.0, without sacrificing accuracy.
122 |
123 | Args:
124 | pretrained (bool): If True, returns a model pre-trained on ImageNet
125 | """
126 | model = SqueezeNet(version=1.1, **kwargs)
127 | if pretrained:
128 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1']))
129 | return model
130 |
--------------------------------------------------------------------------------
/sample_models/VGG.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 |
5 | __all__ = [
6 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
7 | 'vgg19_bn', 'vgg19',
8 | ]
9 |
10 |
11 | model_urls = {
12 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
13 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
14 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
15 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
16 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
17 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
18 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
19 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
20 | }
21 |
22 |
23 | class VGG(nn.Module):
24 |
25 | def __init__(self, features, num_classes=1000, init_weights=True):
26 | super(VGG, self).__init__()
27 | self.features = features
28 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
29 | self.classifier = nn.Sequential(
30 | nn.Linear(512 * 7 * 7, 4096),
31 | nn.ReLU(True),
32 | nn.Dropout(),
33 | nn.Linear(4096, 4096),
34 | nn.ReLU(True),
35 | nn.Dropout(),
36 | nn.Linear(4096, num_classes),
37 | )
38 | if init_weights:
39 | self._initialize_weights()
40 |
41 | def forward(self, x):
42 | x = self.features(x)
43 | x = self.avgpool(x)
44 | x = x.view(x.size(0), -1)
45 | x = self.classifier(x)
46 | return x
47 |
48 | def _initialize_weights(self):
49 | for m in self.modules():
50 | if isinstance(m, nn.Conv2d):
51 | nn.init.kaiming_normal_(
52 | m.weight, mode='fan_out', nonlinearity='relu')
53 | if m.bias is not None:
54 | nn.init.constant_(m.bias, 0)
55 | elif isinstance(m, nn.BatchNorm2d):
56 | nn.init.constant_(m.weight, 1)
57 | nn.init.constant_(m.bias, 0)
58 | elif isinstance(m, nn.Linear):
59 | nn.init.normal_(m.weight, 0, 0.01)
60 | nn.init.constant_(m.bias, 0)
61 |
62 |
63 | def make_layers(cfg, batch_norm=False):
64 | layers = []
65 | in_channels = 3
66 | for v in cfg:
67 | if v == 'M':
68 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
69 | else:
70 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
71 | if batch_norm:
72 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
73 | else:
74 | layers += [conv2d, nn.ReLU(inplace=True)]
75 | in_channels = v
76 | return nn.Sequential(*layers)
77 |
78 |
79 | cfg = {
80 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
81 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
82 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
83 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
84 | }
85 |
86 |
87 | def vgg11(pretrained=False, **kwargs):
88 | """VGG 11-layer model (configuration "A")
89 |
90 | Args:
91 | pretrained (bool): If True, returns a model pre-trained on ImageNet
92 | """
93 | if pretrained:
94 | kwargs['init_weights'] = False
95 | model = VGG(make_layers(cfg['A']), **kwargs)
96 | if pretrained:
97 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
98 | return model
99 |
100 |
101 | def vgg11_bn(pretrained=False, **kwargs):
102 | """VGG 11-layer model (configuration "A") with batch normalization
103 |
104 | Args:
105 | pretrained (bool): If True, returns a model pre-trained on ImageNet
106 | """
107 | if pretrained:
108 | kwargs['init_weights'] = False
109 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
110 | if pretrained:
111 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
112 | return model
113 |
114 |
115 | def vgg13(pretrained=False, **kwargs):
116 | """VGG 13-layer model (configuration "B")
117 |
118 | Args:
119 | pretrained (bool): If True, returns a model pre-trained on ImageNet
120 | """
121 | if pretrained:
122 | kwargs['init_weights'] = False
123 | model = VGG(make_layers(cfg['B']), **kwargs)
124 | if pretrained:
125 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
126 | return model
127 |
128 |
129 | def vgg13_bn(pretrained=False, **kwargs):
130 | """VGG 13-layer model (configuration "B") with batch normalization
131 |
132 | Args:
133 | pretrained (bool): If True, returns a model pre-trained on ImageNet
134 | """
135 | if pretrained:
136 | kwargs['init_weights'] = False
137 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
138 | if pretrained:
139 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
140 | return model
141 |
142 |
143 | def vgg16(pretrained=False, **kwargs):
144 | """VGG 16-layer model (configuration "D")
145 |
146 | Args:
147 | pretrained (bool): If True, returns a model pre-trained on ImageNet
148 | """
149 | if pretrained:
150 | kwargs['init_weights'] = False
151 | model = VGG(make_layers(cfg['D']), **kwargs)
152 | if pretrained:
153 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
154 | return model
155 |
156 |
157 | def vgg16_bn(pretrained=False, **kwargs):
158 | """VGG 16-layer model (configuration "D") with batch normalization
159 |
160 | Args:
161 | pretrained (bool): If True, returns a model pre-trained on ImageNet
162 | """
163 | if pretrained:
164 | kwargs['init_weights'] = False
165 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
166 | if pretrained:
167 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
168 | return model
169 |
170 |
171 | def vgg19(pretrained=False, **kwargs):
172 | """VGG 19-layer model (configuration "E")
173 |
174 | Args:
175 | pretrained (bool): If True, returns a model pre-trained on ImageNet
176 | """
177 | if pretrained:
178 | kwargs['init_weights'] = False
179 | model = VGG(make_layers(cfg['E']), **kwargs)
180 | if pretrained:
181 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
182 | return model
183 |
184 |
185 | def vgg19_bn(pretrained=False, **kwargs):
186 | """VGG 19-layer model (configuration 'E') with batch normalization
187 |
188 | Args:
189 | pretrained (bool): If True, returns a model pre-trained on ImageNet
190 | """
191 | if pretrained:
192 | kwargs['init_weights'] = False
193 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
194 | if pretrained:
195 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
196 | return model
197 |
--------------------------------------------------------------------------------
/sample_models/WideResNet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | __all__ = ["WideResNet"]
6 |
7 |
8 | class BasicBlock(nn.Module):
9 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0):
10 | super(BasicBlock, self).__init__()
11 | self.bn1 = nn.BatchNorm2d(in_planes)
12 | self.relu1 = nn.ReLU(inplace=True)
13 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
14 | padding=1, bias=False)
15 | self.bn2 = nn.BatchNorm2d(out_planes)
16 | self.relu2 = nn.ReLU(inplace=True)
17 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
18 | padding=1, bias=False)
19 | self.droprate = drop_rate
20 | self.equalInOut = (in_planes == out_planes)
21 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
22 | padding=0, bias=False) or None
23 | if self.droprate > 0:
24 | self.dropout = nn.Dropout(p=self.droprate)
25 |
26 | def forward(self, x):
27 | if not self.equalInOut:
28 | x = self.relu1(self.bn1(x))
29 | else:
30 | out = self.relu1(self.bn1(x))
31 |
32 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
33 | if self.droprate > 0:
34 | out = self.dropout(out)
35 | out = self.conv2(out)
36 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
37 |
38 |
39 | class NetworkBlock(nn.Module):
40 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
41 | super(NetworkBlock, self).__init__()
42 | self.layer = self._make_layer(
43 | block, in_planes, out_planes, nb_layers, stride, dropRate)
44 |
45 | @staticmethod
46 | def _make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate):
47 | layers = []
48 | for i in range(nb_layers):
49 | layers.append(block(i == 0 and in_planes or out_planes,
50 | out_planes, i == 0 and stride or 1, dropRate))
51 | return nn.Sequential(*layers)
52 |
53 | def forward(self, x):
54 | return self.layer(x)
55 |
56 |
57 | class WideResNet(nn.Module):
58 | def __init__(self, depth=10, num_classes=1000, widen_factor=1, drop_rate=0.0):
59 | super(WideResNet, self).__init__()
60 | n_channels = [16, 16 * widen_factor,
61 | 32 * widen_factor, 64 * widen_factor]
62 | assert ((depth - 4) % 6 == 0)
63 | n = int((depth - 4) / 6)
64 | block = BasicBlock
65 | # 1st conv before any network block
66 | self.conv1 = nn.Conv2d(3, n_channels[0], kernel_size=3, stride=1,
67 | padding=1, bias=False)
68 | # 1st block
69 | self.block1 = NetworkBlock(
70 | n, n_channels[0], n_channels[1], block, 1, drop_rate)
71 | # 2nd block
72 | self.block2 = NetworkBlock(
73 | n, n_channels[1], n_channels[2], block, 2, drop_rate)
74 | # 3rd block
75 | self.block3 = NetworkBlock(
76 | n, n_channels[2], n_channels[3], block, 2, drop_rate)
77 | # global average pooling and classifier
78 | self.bn1 = nn.BatchNorm2d(n_channels[3])
79 | self.relu = nn.ReLU(inplace=True)
80 | self.avg_pool = nn.AvgPool2d(8)
81 | self.fc = nn.Linear(n_channels[3], num_classes)
82 |
83 | self.nChannels = n_channels[3]
84 |
85 | for m in self.modules():
86 | if isinstance(m, nn.Conv2d):
87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
88 | m.weight.data.normal_(0, math.sqrt(2. / n))
89 | elif isinstance(m, nn.BatchNorm2d):
90 | m.weight.data.fill_(1)
91 | m.bias.data.zero_()
92 | elif isinstance(m, nn.Linear):
93 | m.bias.data.zero_()
94 |
95 | def forward(self, x):
96 | out = self.conv1(x)
97 | out = self.block1(out)
98 | out = self.block2(out)
99 | out = self.block3(out)
100 | out = self.relu(self.bn1(out))
101 | out = self.avg_pool(out)
102 | out = out.view(-1, self.nChannels)
103 | return self.fc(out)
104 |
--------------------------------------------------------------------------------
/sample_models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .AlexNet import *
3 | from .VGG import *
4 | from .ResNet import *
5 | from .LeNet import *
6 | from .DenseNet import *
7 | from .GoogleNet import *
8 | from .WideResNet import *
9 | from .Inception import *
10 | from .SqueezeNet import *
11 | from .amoebaNet import AmoebaNet_D
12 | from .torchGpipe_ref_models import amoebanetd, resnet101 as torchgpipe_resnet101
13 |
--------------------------------------------------------------------------------
/sample_models/amoebaNet/__init__.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Optional, List
3 |
4 | import torch
5 | from torch import nn
6 | from torch import Tensor
7 | from .genotype import Genotype, amoebanetd_genotype
8 | from .utils import Conv_3x3, Conv_7x1_1x7, Conv_Cell, FactorizedReduce, Pool_Operation
9 |
10 | __all__ = ["AmoebaNet_D"]
11 |
12 | # based on torchGpipe implementation https://github.com/kakaobrain/torchgpipe/blob/master/examples/amoebanet/__init__.py
13 |
14 |
15 | def conv_1x1(channel: int, stride: int, affine: bool) -> Conv_Cell:
16 | return Conv_Cell(channel, channel, 1, stride, 0, affine, use_relu=True)
17 |
18 |
19 | def avg_pool_3x3(channel: int, stride: int, affine: bool) -> Pool_Operation:
20 | return Pool_Operation('avg', 3, channel, stride, affine)
21 |
22 |
23 | def max_pool_2x2(channel: int, stride: int, affine: bool) -> Pool_Operation:
24 | return Pool_Operation('max', 2, channel, stride, affine)
25 |
26 |
27 | def skip_connect(channel: int, stride: int, affine: bool) -> nn.Module:
28 | if stride == 1:
29 | return nn.Identity()
30 | return FactorizedReduce(channel, channel, affine)
31 |
32 |
33 | class Classifier(nn.Module):
34 | def __init__(self, channel_prev: int, num_classes: int):
35 | super().__init__()
36 |
37 | self.global_pooling = nn.AvgPool2d(7)
38 | self.classifier = nn.Linear(channel_prev, num_classes)
39 |
40 | def forward(self, x: Tensor) -> Tensor:
41 | s1 = self.global_pooling(x)
42 | y = self.classifier(s1.view(s1.size(0), -1))
43 | return y
44 |
45 |
46 | class Stem(nn.Module):
47 | def __init__(self, channel: int):
48 | super(Stem, self).__init__()
49 | self.conv_cell = Conv_Cell(3, channel, 3, 2, 1, False)
50 |
51 | def forward(self, x: Tensor) -> Tuple[Tensor]:
52 | out = self.conv_cell(x)
53 | return out
54 |
55 |
56 | OPS = {
57 | 'skip_connect': skip_connect,
58 | 'avg_pool_3x3': avg_pool_3x3,
59 | 'max_pool_2x2': max_pool_2x2,
60 | 'conv_7x1_1x7': Conv_7x1_1x7,
61 | 'conv_1x1____': conv_1x1,
62 | 'conv_3x3____': Conv_3x3,
63 | }
64 |
65 |
66 | class AmoebaNet_D(nn.Module):
67 | """an AmoebaNet-D model for ImageNet."""
68 |
69 | def __init__(self, num_classes: int = 10, num_layers: int = 4,
70 | num_filters: int = 512,
71 | genotype: Optional[Genotype] = None):
72 | super(AmoebaNet_D, self).__init__()
73 |
74 | genotype = amoebanetd_genotype if genotype is None else genotype
75 | assert isinstance(genotype, Genotype)
76 | channel = num_filters // 4
77 |
78 | channel_prev_prev, channel_prev, channel_curr = channel, channel, channel
79 | cells = []
80 |
81 | # reduction
82 | channel_curr *= 2
83 | reduction_prev = False
84 | reduction = True
85 | cell = Amoeba_Cell(genotype, channel_prev_prev,
86 | channel_prev, channel_curr, reduction, reduction_prev)
87 | multiplier = len(cell.concat_indices)
88 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
89 | cells.append(cell)
90 |
91 | # reduction
92 | channel_curr *= 2
93 | reduction_prev = True
94 | reduction = True
95 | cell = Amoeba_Cell(genotype, channel_prev_prev,
96 | channel_prev, channel_curr, reduction, reduction_prev)
97 | multiplier = len(cell.concat_indices)
98 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
99 | cells.append(cell)
100 |
101 | # not reduction
102 | reduction_prev = True
103 | reduction = False
104 | for _ in range(num_layers):
105 | cell = Amoeba_Cell(genotype, channel_prev_prev,
106 | channel_prev, channel_curr, reduction, reduction_prev)
107 | multiplier = len(cell.concat_indices)
108 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
109 | cells.append(cell)
110 | reduction_prev = False
111 |
112 | # reduction
113 | channel_curr *= 2
114 | reduction_prev = False
115 | reduction = True
116 | cell = Amoeba_Cell(genotype, channel_prev_prev,
117 | channel_prev, channel_curr, reduction, reduction_prev)
118 | multiplier = len(cell.concat_indices)
119 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
120 | cells.append(cell)
121 |
122 | # not reduction
123 | reduction_prev = True
124 | reduction = False
125 | for _ in range(num_layers):
126 | cell = Amoeba_Cell(genotype, channel_prev_prev,
127 | channel_prev, channel_curr, reduction, reduction_prev)
128 | multiplier = len(cell.concat_indices)
129 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
130 | cells.append(cell)
131 | reduction_prev = False
132 |
133 | # reduction
134 | channel_curr *= 2
135 | reduction_prev = False
136 | reduction = True
137 | cell = Amoeba_Cell(genotype, channel_prev_prev,
138 | channel_prev, channel_curr, reduction, reduction_prev)
139 | multiplier = len(cell.concat_indices)
140 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
141 | cells.append(cell)
142 |
143 | # not reduction
144 | reduction_prev = True
145 | reduction = False
146 | for _ in range(num_layers):
147 | cell = Amoeba_Cell(genotype, channel_prev_prev,
148 | channel_prev, channel_curr, reduction, reduction_prev)
149 | multiplier = len(cell.concat_indices)
150 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
151 | cells.append(cell)
152 | reduction_prev = False
153 |
154 | self.stem = Stem(channel)
155 | self.cells = nn.Sequential(*cells)
156 | self.classifier = Classifier(channel_prev, num_classes)
157 |
158 | def forward(self, x: Tensor) -> Tensor:
159 | out = self.stem(x)
160 | out = self.cells((out, out))
161 | return self.classifier(out[1])
162 |
163 |
164 | class Amoeba_Cell(nn.Module):
165 | def __init__(self, genotype: Genotype,
166 | channel_prev_prev: int, channel_prev: int, channel: int,
167 | reduction: bool, reduction_prev: bool):
168 | super(Amoeba_Cell, self).__init__()
169 |
170 | preprocess0 = nn.Sequential()
171 | if reduction_prev:
172 | preprocess0 = FactorizedReduce(channel_prev_prev, channel)
173 | elif channel_prev_prev != channel:
174 | preprocess0 = Conv_Cell(channel_prev_prev, channel, 1, 1, 0, True)
175 |
176 | preprocess0: nn.Module = preprocess0
177 | preprocess1 = Conv_Cell(channel_prev, channel, 1, 1, 0, True)
178 |
179 | if reduction:
180 | op_names, indices = zip(*genotype.reduce)
181 | concat = genotype.reduce_concat
182 | else:
183 | op_names, indices = zip(*genotype.normal)
184 | concat = genotype.normal_concat
185 |
186 | ops = []
187 | for name, index in zip(op_names, indices):
188 | if reduction and index < 2:
189 | stride = 2
190 | else:
191 | stride = 1
192 | op = OPS[name](channel, stride, True)
193 | ops.append((op, index))
194 |
195 | self.preprocess0 = preprocess0
196 | self.preprocess1 = preprocess1
197 |
198 | layers = []
199 | assert (len(ops) % 2) == 0
200 | for i in range(len(ops) // 2):
201 | op0, i0 = ops[i * 2]
202 | op1, i1 = ops[i * 2 + 1]
203 | layers.extend([
204 | InputOne(op0, i=i0, insert=2 + i),
205 | # Output: x..., op0(x[i0]), skip]
206 |
207 | InputOne(op1, i=i1, insert=2 + i + 1),
208 | # Output: x..., op0(x[i0]), op1(x[i1]), skip
209 |
210 | MergeTwo(2 + i, 2 + i + 1),
211 | # Output: x..., op0(x[i0]) + op1(x[i1]), skip
212 | ])
213 | self.layers = nn.Sequential(*layers)
214 |
215 | self.concat_indices = concat
216 |
217 | assert len(concat) > 0 and all(i < (3 + (len(ops) // 2) - 1)
218 | for i in concat)
219 |
220 | def forward(self, xs):
221 | preprocessed = self.preprocess(xs)
222 | # preprocess(x0),preprocess(x1),x1
223 | out = preprocessed
224 | out = self.layers(out)
225 | # x,........,skip
226 | reduced = self.reduce_channels(self.concat_indices, out)
227 | # skip,concat
228 | return reduced
229 |
230 | def preprocess(self, xs):
231 | x0, x1 = xs
232 | return self.preprocess0(x0), self.preprocess1(x1), x1
233 |
234 | def reduce_channels(self, indices, xs):
235 | # indices = 4,5,6
236 | # x0,x1,x2,x3,x4,x5,x6,x7
237 | # x7,x0,x1,x2,x3,x4,x5,x6
238 | # x7,concat(x4,x5,x6)
239 | return xs[-1], torch.cat([xs[i] for i in indices], dim=1)
240 |
241 |
242 | Tensors = Tuple[Tensor, ...]
243 |
244 |
245 | class Hack(nn.Module):
246 |
247 | def forward(self, *args, **kwargs):
248 | raise NotImplementedError()
249 |
250 |
251 | class InputOne(Hack):
252 | """Picks one tensor for the underlying module input::
253 | a -----> a
254 | b --f--> f(b)
255 | c -----> c
256 | """
257 |
258 | def __init__(self, module: nn.Module, i: int, insert: Optional[int] = None):
259 | super().__init__()
260 | self.module = module
261 | self.i = i
262 | self.insert = insert
263 |
264 | def forward(self, tensors: Tensors) -> Tensors: # type: ignore
265 | i = self.i
266 |
267 | # for t in tensors:
268 | # print(t.shape[1:])
269 | # print("\n")
270 | input = tensors[i]
271 | output = self.module(input)
272 |
273 | if not isinstance(output, tuple):
274 | output = (output,)
275 |
276 | if self.insert is None:
277 | # Replace with the input.
278 | return tensors[:i] + output + tensors[i + 1:]
279 |
280 | return tensors[:self.insert] + output + tensors[self.insert:]
281 |
282 |
283 | class MergeTwo(Hack):
284 | """Merges the last two tensors and replace them with the result::
285 | a -----> a
286 | b --+--> b+c
287 | c --+
288 | """
289 |
290 | def __init__(self, i: int, j: int):
291 | super().__init__()
292 | self.i = i
293 | self.j = j
294 |
295 | def forward(self, *tensors: Tensors) -> Tensors: # type: ignore
296 | if len(tensors) > 1:
297 | return sum(tensors)
298 |
299 | tensors = tensors[0]
300 | i = self.i
301 | j = self.j
302 | # Set the initial value as the first tensor
303 | # to type as 'Tensor' instead of 'Union[Tensor, int]'.
304 | merged = sum(tensors[i + 1:j + 1], tensors[i])
305 |
306 | return tensors[:i] + (merged,) + tensors[j + 1:]
307 |
--------------------------------------------------------------------------------
/sample_models/amoebaNet/genotype.py:
--------------------------------------------------------------------------------
1 | from typing import List, NamedTuple, Tuple
2 |
3 | __all__ = ['amoebanetd_genotype']
4 |
5 |
6 | class Genotype(NamedTuple):
7 | normal: List[Tuple[str, int]]
8 | normal_concat: List[int]
9 | reduce: List[Tuple[str, int]]
10 | reduce_concat: List[int]
11 |
12 |
13 | # The AmoebaNet-D genotype is based on the 'Regularized Evolution for Image Classifier
14 | # Architecture Search' paper (https://arxiv.org/pdf/1802.01548.pdf).
15 | amoebanetd_genotype = Genotype(
16 | normal=[
17 | ('avg_pool_3x3', 0),
18 | ('conv_1x1____', 0),
19 | ('skip_connect', 2),
20 | ('avg_pool_3x3', 2),
21 | ('skip_connect', 0),
22 | ('conv_7x1_1x7', 1),
23 | ('conv_1x1____', 1),
24 | ('conv_7x1_1x7', 1),
25 | ('avg_pool_3x3', 0),
26 | ('conv_1x1____', 3),
27 | ],
28 | normal_concat=[4, 5, 6],
29 | reduce=[
30 | ('max_pool_2x2', 1),
31 | ('avg_pool_3x3', 1),
32 | ('conv_3x3____', 0),
33 | ('skip_connect', 2),
34 | ('conv_7x1_1x7', 2),
35 | ('avg_pool_3x3', 2),
36 | ('avg_pool_3x3', 2),
37 | ('conv_1x1____', 3),
38 | ('skip_connect', 3),
39 | ('max_pool_2x2', 0),
40 | ],
41 | reduce_concat=[4, 5, 6]
42 | )
43 |
--------------------------------------------------------------------------------
/sample_models/amoebaNet/utils.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 | from torch import nn
5 | from torch import Tensor
6 |
7 |
8 | class Pool_Operation(nn.Module):
9 | def __init__(self, pool_type: str, kernel_size: int, channel: int, stride: int, affine: bool):
10 | super(Pool_Operation, self).__init__()
11 | assert pool_type in['avg', 'max']
12 |
13 | if pool_type == 'avg':
14 | self.pool = nn.AvgPool2d(kernel_size, stride=stride,
15 | padding=1, count_include_pad=False)
16 | else:
17 | self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=0)
18 |
19 | self.conv_cell = Conv_Cell(channel, channel, 1,
20 | 1, 0, affine, use_relu=False)
21 |
22 | def forward(self, x: Tensor) -> Tensor:
23 | out = self.pool(x)
24 | return self.conv_cell(out)
25 |
26 |
27 | class Conv_Cell(nn.Module):
28 | def __init__(self, in_channels: int, out_channels: int, kernel, stride, padding, affine: bool, use_relu: bool = True):
29 | super(Conv_Cell, self).__init__()
30 | self.relu = nn.ReLU(inplace=False) if use_relu else None
31 | self.conv = nn.Conv2d(in_channels, out_channels,
32 | kernel, stride=stride, padding=padding, bias=False)
33 | self.norm = nn.BatchNorm2d(out_channels, affine=affine)
34 |
35 | def forward(self, x: Tensor) -> Tensor:
36 | out = x if self.relu is None else self.relu(x)
37 | out = self.conv(out)
38 | return self.norm(out)
39 |
40 |
41 | class Conv_3x3(nn.Module):
42 | def __init__(self, channel: int, stride: int, affine: bool):
43 | super(Conv_3x3, self).__init__()
44 |
45 | self.conv1_1x1 = Conv_Cell(channel, channel//4, 1,
46 | 1, 0, affine)
47 | self.conv2_3x3 = Conv_Cell(channel//4, channel//4, 3,
48 | (stride, stride), 1, affine)
49 | self.conv3_1x1 = Conv_Cell(channel//4, channel, 1,
50 | 1, 0, affine)
51 |
52 | def forward(self, x: Tensor) -> Tensor:
53 | out = self.conv1_1x1(x)
54 | out = self.conv2_3x3(out)
55 | return self.conv3_1x1(out)
56 |
57 |
58 | class Conv_7x1_1x7(nn.Module):
59 | def __init__(self, channel: int, stride: int, affine: bool):
60 | super(Conv_7x1_1x7, self).__init__()
61 |
62 | self.conv1_1x1 = Conv_Cell(channel, channel//4, 1,
63 | 1, 0, affine)
64 |
65 | self.conv2_1x7 = Conv_Cell(channel//4, channel//4, (1, 7),
66 | (1, stride), (0, 3), affine)
67 |
68 | self.conv3_7x1 = Conv_Cell(channel//4, channel//4, (7, 1),
69 | (stride, 1), (3, 0), affine)
70 |
71 | self.conv4_1x1 = Conv_Cell(channel//4, channel, 1,
72 | 1, 0, affine)
73 |
74 | def forward(self, x: Tensor) -> Tensor:
75 | out = self.conv1_1x1(x)
76 | out = self.conv2_1x7(out)
77 | out = self.conv3_7x1(out)
78 | return self.conv4_1x1(out)
79 |
80 |
81 | class FactorizedReduce(nn.Module):
82 | """Operation Factorized reduce"""
83 |
84 | def __init__(self, in_planes: int, out_planes: int, affine: bool = True):
85 | super().__init__()
86 | self.relu = nn.ReLU(inplace=False)
87 | self.conv_1 = nn.Conv2d(in_planes, out_planes //
88 | 2, 1, stride=2, padding=0, bias=False)
89 | self.conv_2 = nn.Conv2d(in_planes, out_planes //
90 | 2, 1, stride=2, padding=0, bias=False)
91 | self.bn = nn.BatchNorm2d(out_planes, affine=affine)
92 |
93 | def forward(self, x: Tensor) -> Tensor:
94 | x = self.relu(x)
95 | branch1 = self.conv_1(x)
96 | branch2 = self.conv_2(x[:, :, 1:, 1:])
97 | out = torch.cat([branch1, branch2], dim=1)
98 | out = self.bn(out)
99 | return out
100 |
--------------------------------------------------------------------------------
/sample_models/torchGpipe_ref_models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import resnet101
2 | from .amoebanet_d import amoebanetd
3 |
--------------------------------------------------------------------------------
/sample_models/torchGpipe_ref_models/amoebanet_d/__init__.py:
--------------------------------------------------------------------------------
1 | """An AmoebaNet-D implementation but using :class:`nn.Sequential`. :func:`amoebanetd`
2 | returns a :class:`nn.Sequential`.
3 | """
4 | from collections import OrderedDict
5 | from typing import Tuple
6 |
7 | import torch
8 | from torch import nn
9 |
10 | from ..flatten_sequential import flatten_sequential
11 | from .genotype import Genotype, amoebanetd_genotype
12 | from .surgery import Concat, FirstAnd, InputOne, MergeTwo, Shift, Twin, TwinLast
13 |
14 | __all__ = ['amoebanetd']
15 |
16 |
17 | class Identity(nn.Module):
18 | """Through the tensor::
19 | x ---> x
20 | """
21 |
22 | def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
23 | return x
24 |
25 |
26 | class ReLUConvBN(nn.Module):
27 | """Operation ReLU + Conv2d + BatchNorm2d"""
28 |
29 | def __init__(self, in_planes: int, out_planes: int, kernel_size: int, stride: int,
30 | padding: int, affine: bool = True):
31 | super().__init__()
32 | self.op = nn.Sequential(
33 | nn.ReLU(inplace=False),
34 | nn.Conv2d(in_planes, out_planes, kernel_size,
35 | stride=stride, padding=padding, bias=False),
36 | nn.BatchNorm2d(out_planes, affine=affine)
37 | )
38 |
39 | def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
40 | return self.op(x)
41 |
42 |
43 | class FactorizedReduce(nn.Module):
44 | """Operation Factorized reduce"""
45 |
46 | def __init__(self, in_planes: int, out_planes: int, affine: bool = True):
47 | super().__init__()
48 | self.relu = nn.ReLU(inplace=False)
49 | self.conv_1 = nn.Conv2d(in_planes, out_planes //
50 | 2, 1, stride=2, padding=0, bias=False)
51 | self.conv_2 = nn.Conv2d(in_planes, out_planes //
52 | 2, 1, stride=2, padding=0, bias=False)
53 | self.bn = nn.BatchNorm2d(out_planes, affine=affine)
54 |
55 | def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
56 | x = self.relu(x)
57 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
58 | out = self.bn(out)
59 | return out
60 |
61 |
62 | def skip_connect(channel: int, stride: int, affine: bool) -> nn.Module:
63 | if stride == 1:
64 | return Identity()
65 | return FactorizedReduce(channel, channel, affine=affine)
66 |
67 |
68 | def avg_pool_3x3(channel: int, stride: int, affine: bool) -> nn.Sequential:
69 | return nn.Sequential(
70 | nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
71 | nn.Conv2d(channel, channel, (1, 1), stride=(
72 | 1, 1), padding=(0, 0), bias=False),
73 | nn.BatchNorm2d(channel, affine=affine))
74 |
75 |
76 | def max_pool_3x3(channel: int, stride: int, affine: bool) -> nn.Sequential:
77 | return nn.Sequential(
78 | nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
79 | nn.Conv2d(channel, channel, (1, 1), stride=(
80 | 1, 1), padding=(0, 0), bias=False),
81 | nn.BatchNorm2d(channel, affine=affine))
82 |
83 |
84 | def max_pool_2x2(channel: int, stride: int, affine: bool) -> nn.Sequential:
85 | return nn.Sequential(
86 | nn.MaxPool2d(2, stride=stride, padding=0),
87 | nn.Conv2d(channel, channel, (1, 1), stride=(
88 | 1, 1), padding=(0, 0), bias=False),
89 | nn.BatchNorm2d(channel, affine=affine))
90 |
91 |
92 | def conv_7x1_1x7(channel: int, stride: int, affine: bool) -> nn.Sequential:
93 | return nn.Sequential(
94 | nn.ReLU(inplace=False),
95 | nn.Conv2d(channel, channel // 4, (1, 1),
96 | stride=(1, 1), padding=(0, 0), bias=False),
97 | nn.BatchNorm2d(channel // 4, affine=affine),
98 | nn.ReLU(inplace=False),
99 | nn.Conv2d(channel // 4, channel // 4, (1, 7),
100 | stride=(1, stride), padding=(0, 3), bias=False),
101 | nn.BatchNorm2d(channel // 4, affine=affine),
102 | nn.ReLU(inplace=False),
103 | nn.Conv2d(channel // 4, channel // 4, (7, 1),
104 | stride=(stride, 1), padding=(3, 0), bias=False),
105 | nn.BatchNorm2d(channel // 4, affine=affine),
106 | nn.ReLU(inplace=False),
107 | nn.Conv2d(channel // 4, channel, (1, 1),
108 | stride=(1, 1), padding=(0, 0), bias=False),
109 | nn.BatchNorm2d(channel, affine=affine))
110 |
111 |
112 | def conv_1x1(channel: int, stride: int, affine: bool) -> nn.Sequential:
113 | return nn.Sequential(
114 | nn.ReLU(inplace=False),
115 | nn.Conv2d(channel, channel, (1, 1), stride=(
116 | stride, stride), padding=(0, 0), bias=False),
117 | nn.BatchNorm2d(channel, affine=affine))
118 |
119 |
120 | def conv_3x3(channel: int, stride: int, affine: bool) -> nn.Sequential:
121 | return nn.Sequential(
122 | nn.ReLU(inplace=False),
123 | nn.Conv2d(channel, channel // 4, (1, 1),
124 | stride=(1, 1), padding=(0, 0), bias=False),
125 | nn.BatchNorm2d(channel // 4, affine=affine),
126 | nn.ReLU(inplace=False),
127 | nn.Conv2d(channel // 4, channel // 4, (3, 3),
128 | stride=(stride, stride), padding=(1, 1), bias=False),
129 | nn.BatchNorm2d(channel // 4, affine=affine),
130 | nn.ReLU(inplace=False),
131 | nn.Conv2d(channel // 4, channel, (1, 1),
132 | stride=(1, 1), padding=(0, 0), bias=False),
133 | nn.BatchNorm2d(channel, affine=affine))
134 |
135 |
136 | OPS = {
137 | 'skip_connect': skip_connect,
138 | 'avg_pool_3x3': avg_pool_3x3,
139 | 'max_pool_3x3': max_pool_3x3,
140 | 'max_pool_2x2': max_pool_2x2,
141 | 'conv_7x1_1x7': conv_7x1_1x7,
142 | 'conv_1x1____': conv_1x1,
143 | 'conv_3x3____': conv_3x3,
144 | }
145 |
146 |
147 | class Classifier(nn.Module):
148 |
149 | def __init__(self, channel_prev: int, num_classes: int):
150 | super().__init__()
151 |
152 | self.global_pooling = nn.AvgPool2d(7)
153 | self.classifier = nn.Linear(channel_prev, num_classes)
154 |
155 | def forward(self, x: torch.Tensor) -> nn.Linear: # type: ignore
156 | s1 = self.global_pooling(x[1])
157 | y = self.classifier(s1.view(s1.size(0), -1))
158 |
159 | return y
160 |
161 |
162 | def make_stem(channel: int) -> nn.Sequential:
163 | return nn.Sequential(
164 | nn.ReLU(inplace=False),
165 | nn.Conv2d(3, channel, 3, stride=2, padding=1, bias=False),
166 | nn.BatchNorm2d(channel),
167 | Twin(),
168 | )
169 |
170 |
171 | def make_cell(genotype: Genotype,
172 | channel_prev_prev: int, channel_prev: int, channel: int,
173 | reduction: bool, reduction_prev: bool
174 | ) -> Tuple[nn.Sequential, int]:
175 |
176 | preprocess0: nn.Module = nn.Sequential()
177 |
178 | if reduction_prev:
179 | preprocess0 = FactorizedReduce(channel_prev_prev, channel)
180 | elif channel_prev_prev != channel:
181 | preprocess0 = ReLUConvBN(channel_prev_prev, channel, 1, 1, 0)
182 | preprocess1 = ReLUConvBN(channel_prev, channel, 1, 1, 0)
183 |
184 | if reduction:
185 | op_names, indices = zip(*genotype.reduce)
186 | concat = genotype.reduce_concat
187 | else:
188 | op_names, indices = zip(*genotype.normal)
189 | concat = genotype.normal_concat
190 |
191 | ops = []
192 | for name, index in zip(op_names, indices):
193 | if reduction and index < 2:
194 | stride = 2
195 | else:
196 | stride = 1
197 |
198 | op = OPS[name](channel, stride, True)
199 | ops.append((op, index))
200 |
201 | layers = [
202 | InputOne(preprocess0, i=0),
203 | TwinLast(),
204 | InputOne(preprocess1, i=1),
205 | # Output: (preprocess0(x[0]), preprocess1(x[1]), x[1])
206 | # The last tensor x[1] is passed until the cell output.
207 | # The comments below call x[1] "skip".
208 | ]
209 |
210 | for i in range(len(ops) // 2):
211 | op0, i0 = ops[i*2]
212 | op1, i1 = ops[i*2 + 1]
213 |
214 | layers.extend([
215 | InputOne(op0, i=i0, insert=2+i),
216 | # Output: x..., op0(x[i0]), skip]
217 |
218 | InputOne(op1, i=i1, insert=2 + i + 1),
219 | # Output: x..., op0(x[i0]), op1(x[i1]), skip
220 |
221 | MergeTwo(2 + i, 2 + i + 1),
222 | # Output: x..., op0(x[i0]) + op1(x[i1]), skip
223 | ])
224 |
225 | layers.extend([
226 | # Move skip to the head.
227 | Shift(),
228 | # Output: skip, x...
229 |
230 | FirstAnd(Concat(concat)),
231 | # Output: skip, concat(x...)
232 | ])
233 |
234 | multiplier = len(concat)
235 |
236 | return nn.Sequential(*layers), multiplier
237 |
238 |
239 | def amoebanetd(num_classes: int = 10,
240 | num_layers: int = 4,
241 | num_filters: int = 512,
242 | ) -> nn.Sequential:
243 | """Builds an AmoebaNet-D model for ImageNet."""
244 | channel = num_filters // 4
245 |
246 | def make_layer(channel: int,
247 | num_layers: int,
248 | genotype: Genotype,
249 | ) -> Tuple[nn.Sequential, int]:
250 | n = num_layers
251 | channel_prev_prev, channel_prev, channel_curr = channel, channel, channel
252 | cells = []
253 |
254 | reduction_prev = False
255 | reduction = True
256 | channel_curr *= 2
257 | cell, multiplier = make_cell(genotype, channel_prev_prev,
258 | channel_prev, channel_curr, reduction, reduction_prev)
259 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
260 | cells.append(cell)
261 |
262 | reduction_prev = True
263 | reduction = True
264 | channel_curr *= 2
265 | cell, multiplier = make_cell(genotype, channel_prev_prev,
266 | channel_prev, channel_curr, reduction, reduction_prev)
267 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
268 | cells.append(cell)
269 |
270 | reduction = False
271 | reduction_prev = True
272 | for _ in range(n):
273 | cell, multiplier = make_cell(genotype, channel_prev_prev,
274 | channel_prev, channel_curr, reduction, reduction_prev)
275 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
276 | cells.append(cell)
277 | reduction_prev = False
278 |
279 | reduction_prev = False
280 | reduction = True
281 | channel_curr *= 2
282 | cell, multiplier = make_cell(genotype, channel_prev_prev,
283 | channel_prev, channel_curr, reduction, reduction_prev)
284 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
285 | cells.append(cell)
286 |
287 | reduction = False
288 | reduction_prev = True
289 | for _ in range(n):
290 | cell, multiplier = make_cell(genotype, channel_prev_prev,
291 | channel_prev, channel_curr, reduction, reduction_prev)
292 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
293 | cells.append(cell)
294 | reduction_prev = False
295 |
296 | reduction_prev = False
297 | reduction = True
298 | channel_curr *= 2
299 | cell, multiplier = make_cell(genotype, channel_prev_prev,
300 | channel_prev, channel_curr, reduction, reduction_prev)
301 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
302 | cells.append(cell)
303 |
304 | reduction = False
305 | reduction_prev = True
306 | for _ in range(n):
307 | cell, multiplier = make_cell(genotype, channel_prev_prev,
308 | channel_prev, channel_curr, reduction, reduction_prev)
309 | channel_prev_prev, channel_prev = channel_prev, multiplier * channel_curr
310 | cells.append(cell)
311 | reduction_prev = False
312 |
313 | return nn.Sequential(*cells), channel_prev
314 |
315 | cells, channel_prev = make_layer(channel, num_layers, amoebanetd_genotype)
316 |
317 | model = nn.Sequential(OrderedDict([
318 | ('stem', make_stem(channel)),
319 | ('cells', cells),
320 | ('fin', Classifier(channel_prev, num_classes))
321 | ]))
322 |
323 | return flatten_sequential(model)
324 |
--------------------------------------------------------------------------------
/sample_models/torchGpipe_ref_models/amoebanet_d/genotype.py:
--------------------------------------------------------------------------------
1 | from typing import List, NamedTuple, Tuple
2 |
3 | __all__ = ['amoebanetd_genotype']
4 |
5 |
6 | class Genotype(NamedTuple):
7 | normal: List[Tuple[str, int]]
8 | normal_concat: List[int]
9 | reduce: List[Tuple[str, int]]
10 | reduce_concat: List[int]
11 |
12 |
13 | # The AmoebaNet-D genotype is based on the 'Regularized Evolution for Image Classifier
14 | # Architecture Search' paper (https://arxiv.org/pdf/1802.01548.pdf).
15 | amoebanetd_genotype = Genotype(
16 | normal=[
17 | ('max_pool_3x3', 0),
18 | ('conv_1x1____', 0),
19 | ('skip_connect', 2),
20 | ('max_pool_3x3', 2),
21 | ('skip_connect', 0),
22 | ('conv_7x1_1x7', 1),
23 | ('conv_1x1____', 1),
24 | ('conv_7x1_1x7', 1),
25 | ('avg_pool_3x3', 0),
26 | ('conv_1x1____', 3),
27 | ],
28 | normal_concat=[4, 5, 6],
29 | reduce=[
30 | ('max_pool_2x2', 1),
31 | ('max_pool_3x3', 1),
32 | ('conv_3x3____', 0),
33 | ('skip_connect', 2),
34 | ('conv_7x1_1x7', 2),
35 | ('max_pool_3x3', 2),
36 | ('avg_pool_3x3', 2),
37 | ('conv_1x1____', 3),
38 | ('skip_connect', 3),
39 | ('max_pool_2x2', 0),
40 | ],
41 | reduce_concat=[4, 5, 6]
42 | )
43 |
--------------------------------------------------------------------------------
/sample_models/torchGpipe_ref_models/amoebanet_d/surgery.py:
--------------------------------------------------------------------------------
1 | """Utility modules for breaking a complex module into sequential.
2 | """
3 | from typing import List, Optional, Tuple
4 |
5 | import torch
6 | from torch import nn
7 |
8 | __all__ = ['Concat', 'FirstAnd', 'InputOne',
9 | 'MergeTwo', 'Shift', 'Twin', 'TwinLast']
10 |
11 |
12 | Tensors = Tuple[torch.Tensor, ...]
13 |
14 |
15 | class Hack(nn.Module):
16 |
17 | def forward(self, *args, **kwargs):
18 | raise NotImplementedError()
19 |
20 |
21 | class Twin(Hack):
22 | """Duplicates the tensor::
23 | ┌──────┐
24 | a --│ Twin │--> a
25 | │ '--│--> a
26 | └──────┘
27 | """
28 |
29 | def forward(self, tensor: torch.Tensor) -> Tensors: # type: ignore
30 | return tensor, tensor
31 |
32 |
33 | class TwinLast(Hack):
34 | """Duplicates the last tensor::
35 | a -----> a
36 | b -----> b
37 | c --+--> c
38 | +--> c'
39 | """
40 |
41 | def forward(self, tensors: Tensors) -> Tensors: # type: ignore
42 | return tensors + (tensors[-1],)
43 |
44 |
45 | class InputOne(Hack):
46 | """Picks one tensor for the underlying module input::
47 | a -----> a
48 | b --f--> f(b)
49 | c -----> c
50 | """
51 |
52 | def __init__(self, module: nn.Module, i: int, insert: Optional[int] = None):
53 | super().__init__()
54 | self.module = module
55 | self.i = i
56 | self.insert = insert
57 |
58 | def forward(self, tensors: Tensors) -> Tensors: # type: ignore
59 | i = self.i
60 |
61 | # for t in tensors:
62 | # print(t.shape[1:])
63 | # print("\n")
64 | input = tensors[i]
65 | output = self.module(input)
66 |
67 | if not isinstance(output, tuple):
68 | output = (output,)
69 |
70 | if self.insert is None:
71 | # Replace with the input.
72 | return tensors[:i] + output + tensors[i+1:]
73 |
74 | return tensors[:self.insert] + output + tensors[self.insert:]
75 |
76 |
77 | class Shift(Hack):
78 | """Moves the last tensor ahead of the tensors::
79 | +--> c
80 | a --|--> a
81 | b --|--> b
82 | c --+
83 | """
84 |
85 | def forward(self, tensors: Tensors) -> Tensors: # type: ignore
86 | return (tensors[-1],) + tensors[:-1]
87 |
88 |
89 | class MergeTwo(Hack):
90 | """Merges the last two tensors and replace them with the result::
91 | a -----> a
92 | b --+--> b+c
93 | c --+
94 | """
95 |
96 | def __init__(self, i: int, j: int):
97 | super().__init__()
98 | self.i = i
99 | self.j = j
100 |
101 | def forward(self, tensors: Tensors) -> Tensors: # type: ignore
102 | i = self.i
103 | j = self.j
104 |
105 | # Set the initial value as the first tensor
106 | # to type as 'Tensor' instead of 'Union[Tensor, int]'.
107 | merged = sum(tensors[i+1:j+1], tensors[i])
108 |
109 | return tensors[:i] + (merged,) + tensors[j+1:]
110 |
111 |
112 | class FirstAnd(Hack):
113 | """Skips the first tensor, executes the underlying module by the remaining
114 | tensors::
115 | a -----> a
116 | b --+--> f(b, c)
117 | c --+
118 | """
119 |
120 | def __init__(self, module: nn.Module):
121 | super().__init__()
122 | self.module = module
123 |
124 | def forward(self, tensors: Tensors) -> Tensors: # type: ignore
125 | output = self.module(tensors[1:])
126 | if not isinstance(output, tuple):
127 | output = (output,)
128 | return (tensors[0],) + output
129 |
130 |
131 | class Concat(Hack):
132 | """Concat all tensors::
133 | a --+
134 | b --+--> concat(a, b, c)
135 | c --+
136 | """
137 |
138 | def __init__(self, indices: List):
139 | super().__init__()
140 | self.indices = indices
141 |
142 | def forward(self, tensors: Tensors) -> torch.Tensor: # type: ignore
143 | return torch.cat([tensors[i] for i in self.indices], dim=1)
144 |
--------------------------------------------------------------------------------
/sample_models/torchGpipe_ref_models/flatten_sequential.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Iterator, Tuple
3 |
4 | from torch import nn
5 |
6 | __all__ = ['flatten_sequential']
7 |
8 |
9 | def flatten_sequential(module: nn.Sequential) -> nn.Sequential:
10 | """Flattens a nested sequential module."""
11 | if not isinstance(module, nn.Sequential):
12 | raise TypeError('not sequential')
13 |
14 | return nn.Sequential(OrderedDict(_flatten_sequential(module)))
15 |
16 |
17 | def _flatten_sequential(module: nn.Sequential) -> Iterator[Tuple[str, nn.Module]]:
18 | for name, child in module.named_children():
19 | # flatten_sequential child sequential layers only.
20 | if isinstance(child, nn.Sequential):
21 | for sub_name, sub_child in _flatten_sequential(child):
22 | yield ('%s_%s' % (name, sub_name), sub_child)
23 | else:
24 | yield (name, child)
25 |
--------------------------------------------------------------------------------
/sample_models/torchGpipe_ref_models/resnet/__init__.py:
--------------------------------------------------------------------------------
1 | """A ResNet implementation but using :class:`nn.Sequential`. :func:`resnet101`
2 | returns a :class:`nn.Sequential` instead of ``ResNet``.
3 | This code is transformed :mod:`torchvision.models.resnet`.
4 | """
5 | from collections import OrderedDict
6 | from typing import Any, List
7 |
8 | from torch import Tensor
9 | import torch.nn as nn
10 |
11 | from .bottleneck import bottleneck
12 | from ..flatten_sequential import flatten_sequential
13 |
14 | __all__ = ['resnet101']
15 |
16 |
17 | class Flat(nn.Module):
18 | """Flattens any input tensor into an 1-d tensor."""
19 |
20 | def forward(self, x: Tensor): # type: ignore
21 | return x.view(x.size(0), -1)
22 |
23 |
24 | def build_resnet(layers: List[int],
25 | num_classes: int = 1000,
26 | ) -> nn.Sequential:
27 | """Builds a ResNet as a simple sequential model.
28 | Note:
29 | The implementation is copied from :mod:`torchvision.models.resnet`.
30 | """
31 | inplanes = 64
32 |
33 | def make_layer(planes: int, blocks: int, stride: int = 1) -> nn.Sequential:
34 | nonlocal inplanes
35 |
36 | downsample = None
37 | if stride != 1 or inplanes != planes * 4:
38 | downsample = nn.Sequential(
39 | nn.Conv2d(inplanes, planes * 4,
40 | kernel_size=1, stride=stride, bias=False),
41 | nn.BatchNorm2d(planes * 4),
42 | )
43 |
44 | layers = []
45 | layers.append(bottleneck(inplanes, planes, stride, downsample))
46 | inplanes = planes * 4
47 | for _ in range(1, blocks):
48 | layers.append(bottleneck(inplanes, planes))
49 |
50 | return nn.Sequential(*layers)
51 |
52 | # Build ResNet as a sequential model.
53 | model = nn.Sequential(OrderedDict([
54 | ('conv1', nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)),
55 | ('bn1', nn.BatchNorm2d(64)),
56 | ('relu', nn.ReLU()),
57 | ('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
58 |
59 | ('layer1', make_layer(64, layers[0])),
60 | ('layer2', make_layer(128, layers[1], stride=2)),
61 | ('layer3', make_layer(256, layers[2], stride=2)),
62 | ('layer4', make_layer(512, layers[3], stride=2)),
63 |
64 | ('avgpool', nn.AdaptiveAvgPool2d((1, 1))),
65 | ('flat', Flat()),
66 | ('fc', nn.Linear(512 * 4, num_classes)),
67 | ]))
68 |
69 | # Flatten nested sequentials.
70 | model = flatten_sequential(model)
71 |
72 | # Initialize weights for Conv2d and BatchNorm2d layers.
73 | def init_weight(m: nn.Module) -> None:
74 | if isinstance(m, nn.Conv2d):
75 | assert isinstance(m.kernel_size, tuple)
76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
77 |
78 | m.weight.requires_grad = False
79 | m.weight.normal_(0, 2. / n**0.5)
80 | m.weight.requires_grad = True
81 |
82 | elif isinstance(m, nn.BatchNorm2d):
83 | m.weight.requires_grad = False
84 | m.weight.fill_(1)
85 | m.weight.requires_grad = True
86 |
87 | m.bias.requires_grad = False
88 | m.bias.zero_()
89 | m.bias.requires_grad = True
90 |
91 | model.apply(init_weight)
92 |
93 | return model
94 |
95 |
96 | def resnet101(**kwargs: Any) -> nn.Sequential:
97 | """Constructs a ResNet-101 model."""
98 | return build_resnet([3, 4, 23, 3], **kwargs)
99 |
--------------------------------------------------------------------------------
/sample_models/torchGpipe_ref_models/resnet/bottleneck.py:
--------------------------------------------------------------------------------
1 | """A ResNet bottleneck implementation but using :class:`nn.Sequential`."""
2 | from collections import OrderedDict
3 | from typing import Dict, Optional, Tuple, Union
4 |
5 | from torch import Tensor
6 | import torch.nn as nn
7 |
8 | __all__ = ['bottleneck']
9 |
10 | Tensors = Tuple[Tensor, ...]
11 | TensorOrTensors = Union[Tensor, Tensors]
12 |
13 |
14 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 |
20 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
21 | """1x1 convolution"""
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
23 |
24 |
25 | class Twin(nn.Module):
26 | # ┌──────┐
27 | # a --│ Twin │--> a
28 | # │ '--│--> a
29 | # └──────┘
30 | def forward(self, # type: ignore
31 | tensor: Tensor,
32 | ) -> Tuple[Tensor, Tensor]:
33 | return tensor, tensor
34 |
35 |
36 | class Gutter(nn.Module):
37 | # ┌───────────┐
38 | # a --│ Gutter[M] │--> M(a)
39 | # b --│-----------│--> b
40 | # └───────────┘
41 | def __init__(self, module: nn.Module):
42 | super().__init__()
43 | self.module = module
44 |
45 | def forward(self, # type: ignore
46 | input_and_skip: Tuple[Tensor, Tensor],
47 | ) -> Tuple[Tensor, Tensor]:
48 | input, skip = input_and_skip
49 | output = self.module(input)
50 | return output, skip
51 |
52 |
53 | class Residual(nn.Module):
54 | """A residual block for ResNet."""
55 |
56 | def __init__(self, downsample: Optional[nn.Module] = None):
57 | super().__init__()
58 | self.downsample = downsample
59 |
60 | def forward(self, # type: ignore
61 | input_and_identity: Tuple[Tensor, Tensor],
62 | ) -> Tensor:
63 | input, identity = input_and_identity
64 | if self.downsample is not None:
65 | identity = self.downsample(identity)
66 | return input + identity
67 |
68 |
69 | def bottleneck(inplanes: int,
70 | planes: int,
71 | stride: int = 1,
72 | downsample: Optional[nn.Module] = None,
73 | ) -> nn.Sequential:
74 | """Creates a bottlenect block in ResNet as a :class:`nn.Sequential`."""
75 | layers: Dict[str, nn.Module] = OrderedDict()
76 | layers['twin'] = Twin()
77 |
78 | layers['conv1'] = Gutter(conv1x1(inplanes, planes))
79 | layers['bn1'] = Gutter(nn.BatchNorm2d(planes))
80 | layers['conv2'] = Gutter(conv3x3(planes, planes, stride))
81 | layers['bn2'] = Gutter(nn.BatchNorm2d(planes))
82 | layers['conv3'] = Gutter(conv1x1(planes, planes * 4))
83 | layers['bn3'] = Gutter(nn.BatchNorm2d(planes * 4))
84 | layers['residual'] = Residual(downsample)
85 | layers['relu'] = nn.ReLU()
86 |
87 | return nn.Sequential(layers)
88 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alondj/Pytorch-Gpipe/7c9bf892fe65ffe3efdb8cb00a7052b665d41f75/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_delayed_norm.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from itertools import chain
3 |
4 | import pytest
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 |
9 | from pytorch_Gpipe.delayedNorm import DelayedBatchNorm
10 |
11 | NUM_MICRO_BATCHES = 4
12 |
13 | # taken from torchGpipe repo
14 |
15 |
16 | def tilt_dist(x):
17 | # Tilt variance by channel.
18 | rgb = x.transpose(0, 1)
19 | rgb[0] *= 1
20 | rgb[1] *= 10
21 | rgb[2] *= 100
22 |
23 | # Tilt mean by single batch.
24 | for i, single in enumerate(x):
25 | single += 2**i
26 |
27 | return x
28 |
29 |
30 | def chunked_forward(model, x, micro_batches=NUM_MICRO_BATCHES):
31 | output_micro_batches = []
32 |
33 | for chunk in x.chunk(micro_batches):
34 | output_micro_batches.append(model(chunk))
35 |
36 | return torch.cat(output_micro_batches)
37 |
38 |
39 | @pytest.mark.parametrize('micro_batches', [1, 4])
40 | @pytest.mark.parametrize('x_requires_grad', [True, False])
41 | def test_transparency(micro_batches, x_requires_grad):
42 | bn = nn.BatchNorm2d(3)
43 | dbn = DelayedBatchNorm.convertBatchNorm(
44 | deepcopy(bn), num_micro_batches=micro_batches)
45 |
46 | x1 = torch.rand(16, 3, 224, 224)
47 | x1 = tilt_dist(x1)
48 | x2 = x1.clone()
49 | x1.requires_grad = x_requires_grad
50 | x2.requires_grad = x_requires_grad
51 |
52 | output1 = chunked_forward(bn, x1, micro_batches=micro_batches)
53 | output2 = chunked_forward(dbn, x2, micro_batches=micro_batches)
54 |
55 | assert torch.allclose(output1, output2, atol=1e-4)
56 |
57 | output1.mean().backward()
58 | output2.mean().backward()
59 |
60 | assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4)
61 |
62 | if x_requires_grad:
63 | assert x1.grad is not None
64 | assert x2.grad is not None
65 | assert torch.allclose(x1.grad, x2.grad, atol=1e-4)
66 |
67 |
68 | @pytest.mark.parametrize('momentum', [0.1, None])
69 | def test_running_stats(momentum):
70 | bn = nn.BatchNorm2d(3, momentum=momentum)
71 | dbn = DelayedBatchNorm.convertBatchNorm(
72 | deepcopy(bn), num_micro_batches=NUM_MICRO_BATCHES)
73 |
74 | x = torch.rand(16, 3, 224, 224)
75 | x = tilt_dist(x)
76 |
77 | bn(x)
78 | chunked_forward(dbn, x)
79 |
80 | assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4)
81 | assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4)
82 |
83 |
84 | def test_convert():
85 | bn = nn.BatchNorm2d(3, track_running_stats=False)
86 | bn = DelayedBatchNorm.convertBatchNorm(
87 | bn, num_micro_batches=NUM_MICRO_BATCHES)
88 | assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False
89 |
90 | dbn = DelayedBatchNorm(3, num_micro_batches=NUM_MICRO_BATCHES)
91 | dbn_again = DelayedBatchNorm.convertBatchNorm(
92 | dbn, num_micro_batches=NUM_MICRO_BATCHES)
93 | assert dbn.weight is dbn_again.weight
94 | assert dbn.bias is dbn_again.bias
95 | assert dbn.running_mean is dbn_again.running_mean
96 | assert dbn.running_var is dbn_again.running_var
97 |
98 |
99 | def test_eval():
100 | bn = nn.BatchNorm2d(3)
101 | dbn = DelayedBatchNorm.convertBatchNorm(
102 | deepcopy(bn), num_micro_batches=NUM_MICRO_BATCHES)
103 |
104 | x = torch.rand(16, 3, 224, 224)
105 | x = tilt_dist(x)
106 |
107 | bn(x)
108 | chunked_forward(dbn, x)
109 |
110 | bn.eval()
111 | dbn.eval()
112 |
113 | assert torch.allclose(bn(x), dbn(x), atol=1e-4)
114 |
115 |
116 | def test_optimize():
117 | bn = nn.BatchNorm2d(3)
118 | dbn = DelayedBatchNorm.convertBatchNorm(
119 | deepcopy(bn), num_micro_batches=NUM_MICRO_BATCHES)
120 |
121 | opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0)
122 |
123 | for i in range(5):
124 | x = torch.rand(16, 3, 224, 224)
125 | x = tilt_dist(x)
126 |
127 | # train
128 | y = bn(x)
129 | a = y.sum()
130 | a.backward()
131 |
132 | y = chunked_forward(dbn, x)
133 | b = y.sum()
134 | b.backward()
135 |
136 | opt.step()
137 |
138 | # eval
139 | bn.eval()
140 | dbn.eval()
141 |
142 | with torch.no_grad():
143 | assert torch.allclose(bn(x), dbn(x), atol=1e-1 * (10**i))
144 |
145 |
146 | def test_conv_bn():
147 | bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3))
148 | dbn = DelayedBatchNorm.convertBatchNorm(
149 | deepcopy(bn), num_micro_batches=NUM_MICRO_BATCHES)
150 |
151 | x = torch.rand(16, 3, 224, 224)
152 | x = tilt_dist(x)
153 |
154 | opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1)
155 |
156 | # 1st step
157 | a = bn(x)
158 | b = chunked_forward(dbn, x)
159 |
160 | # Outputs are different. (per-mini-batch vs. per-micro-batch)
161 | assert not torch.allclose(a, b)
162 |
163 | a.sum().backward()
164 | b.sum().backward()
165 | opt.step()
166 | opt.zero_grad()
167 |
168 | # Conv layers are also trained differently because of their different outputs.
169 | assert not torch.allclose(bn[0].weight, dbn[0].weight)
170 |
171 | # But BNs track identical running stats.
172 | assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
173 | assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e+3)
174 |
175 | # 2nd step
176 | a = bn(x)
177 | b = chunked_forward(dbn, x)
178 | a.sum().backward()
179 | b.sum().backward()
180 |
181 | # BNs can't track identical running stats due to the different conv layers.
182 | assert not torch.allclose(
183 | bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
184 | assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e+3)
185 |
186 |
187 | def test_x_requiring_grad():
188 | dbn = DelayedBatchNorm(3, num_micro_batches=NUM_MICRO_BATCHES)
189 |
190 | x = torch.rand(16, 3, 224, 224, requires_grad=True)
191 | x = tilt_dist(x)
192 |
193 | chunked_forward(dbn, x)
194 |
195 | assert not dbn.sum.requires_grad
196 | assert dbn.sum.grad_fn is None
197 |
--------------------------------------------------------------------------------
/tests/test_metis_lib_bindings.py:
--------------------------------------------------------------------------------
1 | from pytorch_Gpipe.METIS import METIS_partition, mdbglvl_et
2 |
3 |
4 | # -------------------------------------------------------------------------
5 | # basic tests
6 | # -------------------------------------------------------------------------
7 |
8 |
9 | def example_adjlist():
10 | return [[1, 2, 3, 4], [0], [0], [0], [0, 5], [4, 6], [13, 5, 7],
11 | [8, 6], [9, 10, 11, 12, 7], [8], [8], [8], [8], [14, 6], [13, 15],
12 | [16, 17, 18, 14], [15], [15], [15]]
13 |
14 |
15 | def test_1():
16 | adjlist = example_adjlist()
17 |
18 | print("Testing k-way cut")
19 | parts, cuts = METIS_partition(adjlist, 3, algorithm="metis",
20 | dbglvl=mdbglvl_et.METIS_DBG_ALL)
21 | assert cuts == 2
22 | assert set(parts) == set([0, 1, 2])
23 |
24 | print("Testing recursive cut")
25 | parts, cuts = METIS_partition(adjlist, 3, algorithm="metis_recursive",
26 | dbglvl=mdbglvl_et.METIS_DBG_ALL)
27 | assert cuts == 2
28 | assert set(parts) == set([0, 1, 2])
29 |
30 | # print("METIS appears to be working.")
31 |
32 |
33 | def test_2():
34 | nVertices = 6
35 | nParts = 2
36 |
37 | # Indexes of starting points in adjacent array
38 | adj_idx = [0, 2, 5, 7, 9, 12, 14]
39 |
40 | # Adjacent vertices in consecutive index order
41 | adjv = [1, 3, 0, 4, 2, 1, 5, 0, 4, 3, 1, 5, 4, 2]
42 |
43 | adjlist = [adjv[adj_idx[i]:adj_idx[i+1]] for i in range(nVertices)]
44 |
45 | # Weights of vertices
46 | # if all weights are equal then can be set to NULL
47 | nodew = [i*nVertices for i in range(nVertices)]
48 |
49 | # int ret = METIS_PartGraphRecursive( & nVertices, & nWeights, xadj, adjncy,
50 | # NULL, NULL, NULL, & nParts, NULL,
51 | # NULL, NULL, & objval, part)
52 |
53 | parts, _ = METIS_partition(
54 | adjlist, nParts, algorithm="metis", nodew=nodew, contig=1)
55 |
56 | assert len(set(parts)) == nParts
57 |
--------------------------------------------------------------------------------
/tests/test_networkx.py:
--------------------------------------------------------------------------------
1 |
2 | import itertools
3 |
4 |
5 | import networkx as nx
6 |
7 | import nxmetis
8 | from nxmetis import exceptions
9 | from nxmetis import metis
10 | from nxmetis import types
11 | import pytest
12 |
13 |
14 | def make_cycle(n):
15 | xadj = list(range(0, 2 * n + 1, 2))
16 | adjncy = list(
17 | itertools.chain.from_iterable(
18 | zip(itertools.chain([n - 1], range(n - 1)),
19 | itertools.chain(range(1, n), [0]))))
20 | return xadj, adjncy
21 |
22 |
23 | def assert_equal(a, b):
24 | assert a == b
25 |
26 |
27 | def assert_not_equal(a, b):
28 | assert not (a == b)
29 |
30 |
31 | class TestMetis(object):
32 |
33 | def setup_method(self, test_method):
34 | self.node_list = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j',
35 | 1, 2, 3, 4, 5, 6]
36 | self.G = nx.Graph()
37 | nx.add_path(self.G, self.node_list)
38 | self.G.add_edge(self.node_list[-1], self.node_list[0])
39 |
40 | def test_node_nested_dissection_unweighted(self):
41 |
42 | node_ordering = nxmetis.node_nested_dissection(self.G)
43 | assert_equal(len(self.G), len(node_ordering))
44 | assert_equal(set(self.G), set(node_ordering))
45 |
46 | # Tests for exercising package's ability to handle self-loops
47 | # METIS crashes on self loops. networkx-metis should not
48 | self.G.add_edge(1, 1)
49 | self.G.add_edge('a', 'a')
50 | node_ordering = nxmetis.node_nested_dissection(self.G)
51 | assert_equal(len(self.G), len(node_ordering))
52 | assert_equal(set(self.G), set(node_ordering))
53 |
54 | def test_partition(self):
55 | partition = nxmetis.partition(self.G, 4)
56 | # When we choose one node from one part of the partitioned Graph,
57 | # It must be adjacent to one or more of the nodes in the same part.
58 | # This is to verify the continuity of the chain of nodes.
59 | parts = partition[1] # List containing partitioned node lists
60 |
61 | assert_equal(partition[0], 4)
62 | assert_equal(len(partition[1]), 4)
63 |
64 | for part in parts:
65 | assert_not_equal(0, len(part)) # Non-empty set
66 | assert_equal(
67 | len(part), len(set(part))) # Duplicate-free
68 | assert (nx.is_connected(self.G.subgraph(part))) # Connected
69 |
70 | # Disjoint sets
71 | for part1, part2 in itertools.combinations(parts, 2):
72 | assert_equal(set(), set(part1) & set(part2))
73 |
74 | # These parts must be exhaustive with the node list of the Graph
75 | parts_combined = parts[0] + parts[1] + parts[2] + parts[3]
76 | assert_equal(set(parts_combined), set(self.G))
77 |
78 | def test_vertex_separator(self):
79 | sep, part1, part2 = nxmetis.vertex_separator(self.G)
80 |
81 | # The two separator nodes must not be present in the
82 | # two bisected chains
83 | assert (sep[0] not in part1)
84 | assert (sep[0] not in part2)
85 | assert (sep[1] not in part1)
86 | assert (sep[1] not in part2)
87 |
88 | # There should be two different separator nodes
89 | assert_equal(len(sep), 2)
90 | assert_not_equal(sep[0], sep[1])
91 |
92 | # The lists should be exhaustive with the node list of the Graph
93 | assert_equal(set(sep) | set(part1) | set(part2),
94 | set(self.G))
95 |
96 | # The parts must be disjoint sets
97 | assert_equal(set(), set(part1) & set(part2))
98 |
99 | # Non-empty set
100 | assert_not_equal(len(part1), 0)
101 | assert_not_equal(len(part2), 0)
102 |
103 | # Duplicate-free
104 | assert_equal(len(part1), len(set(part1)))
105 | assert_equal(len(part2), len(set(part2)))
106 |
107 | # Connected
108 | assert (nx.is_connected(self.G.subgraph(part1)))
109 | assert (nx.is_connected(self.G.subgraph(part2)))
110 |
111 | # def test_MetisOptions(self):
112 | # n = 16
113 | # xadj, adjncy = make_cycle(n)
114 | # options = types.MetisOptions(niter=-2)
115 | # nose.tools.assert_raises_regexp(exceptions.MetisError,
116 | # 'Input Error: Incorrect niter.',
117 | # metis.part_graph, xadj, adjncy, 2,
118 | # options=options)
119 |
--------------------------------------------------------------------------------