├── TorchTensorOpsBench ├── README.md ├── run.sh └── torch_tensor_ops_bench.py ├── fp16util.py ├── README.md ├── shufflenet.py ├── shufflenet_v2.py ├── xception.py └── micro_benchmarking_pytorch.py /TorchTensorOpsBench/README.md: -------------------------------------------------------------------------------- 1 | To run the microbenchmark for an op: 2 | ``` 3 | python torch_tensor_ops_bench.py --op 4 | ``` 5 | 6 | The script also takes optional arguments: 7 | ``` 8 | --dtype [=fp32 | fp16 | bf16] 9 | --device [=cuda | cpu] 10 | --input-dim dims separated by '-', default "64-1024-1024" 11 | --op-type [=None | binary(for a binary op)] 12 | -------------------------------------------------------------------------------- /TorchTensorOpsBench/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # run model ops in fp32, fp16 and bf16 4 | printf "########## Running model ops with fp32 type ###########\n" 5 | python3 torch_tensor_ops_bench.py --run-model-ops --dtype fp32 |& tee model_ops_fp32.log 6 | printf "\n########## Running model ops with fp16 type ###########\n" 7 | python3 torch_tensor_ops_bench.py --run-model-ops --dtype fp16 |& tee model_ops_fp16.log 8 | printf "\n########## Running model ops with bf16 type ###########\n" 9 | python3 torch_tensor_ops_bench.py --run-model-ops --dtype bf16 |& tee model_ops_bf16.log 10 | 11 | # run predefined ops with generic tensor size of 64-1024-1024 12 | printf "\n########## Running pre-defined ops with fp32 type ###########\n" 13 | python3 torch_tensor_ops_bench.py --run-predefined --dtype fp32 |& tee predefined_ops_fp32.log 14 | printf "\n########## Running pre-defined ops with fp16 type ###########\n" 15 | python3 torch_tensor_ops_bench.py --run-predefined --dtype fp16 |& tee predefined_ops_fp16.log 16 | 17 | printf "Done\n" 18 | 19 | -------------------------------------------------------------------------------- /fp16util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | enable_miopen = (os.getenv("DISABLE_MIOPEN") == None) 6 | 7 | class tofp16(nn.Module): 8 | def __init__(self): 9 | super(tofp16, self).__init__() 10 | 11 | def forward(self, input): 12 | return input.half() 13 | 14 | def copy_in_params(net, params): 15 | net_params = list(net.parameters()) 16 | for i in range(len(params)): 17 | net_params[i].data.copy_(params[i].data) 18 | 19 | def set_grad(params, params_with_grad): 20 | 21 | for param, param_w_grad in zip(params, params_with_grad): 22 | if param.grad is None: 23 | param.grad = torch.nn.Parameter(param.data.new().resize_(*param.data.size())) 24 | param.grad.data.copy_(param_w_grad.grad.data) 25 | 26 | def get_param_copy(net): 27 | param_copy = [param.clone().type(torch.cuda.FloatTensor).detach() for param in net.parameters()] 28 | for param in param_copy: 29 | param.requires_grad=True 30 | return param_copy 31 | 32 | def BN_convert_float(module): 33 | if (enable_miopen): 34 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 35 | module.float() 36 | for child in module.children(): 37 | BN_convert_float(child) 38 | return module 39 | 40 | def network_to_half(network): 41 | return nn.Sequential(tofp16(), BN_convert_float(network.half())) 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-micro-benchmarking 2 | We supply a small microbenchmarking script for PyTorch training on ROCm. 3 | 4 | To execute: 5 | `python micro_benchmarking_pytorch.py --network [--batch-size ] [--iterations ] [--fp16 <0 or 1> ] [--distributed_dataparallel] [--device_ids ] ` 6 | 7 | Possible network names are: `alexnet`, `densenet121`, `inception_v3`, `resnet50`, `resnet101`, `SqueezeNet`, `vgg16` etc. 8 | 9 | Default are 10 training iterations, `fp16` off (i.e., 0), and a batch size of 64. 10 | 11 | For mGPU runs, use one of the following methods. 12 | - `torchrun`: It will spawn multiple sub-processes for each of the GPUs and adjust `world_size` and `rank` accordingly. `torchrun` also defaults to using distributed dataparallel. 13 | - `--distributed_dataparallel`: Uses torch.nn.parallel.DistributedDataParallel to run multiple processes/node. However, the script only launches one process per GPU, multiple processes need to be launched manually. See example below. 14 | 15 | _NOTE_: `--distributed_dataparallel` option will be deprecated in the future as this path can be exercised now with `torchrun`. 16 | _NOTE_: If comparing `--distributed_dataprallel` performance with `torchrun` one, you need to multiply the `--batch-size` with number of nodes in the `torchrun` command. `torchrun` will split the batch size into mini batches that run on each of the nodes. `--distributed_dataparallel` doesn't do that automatically, it run with whatever the user provides. 17 | 18 | Examples: 19 | - for a 1-GPU resnet50 run: 20 | ``` 21 | python3 micro_benchmarking_pytorch.py --network resnet50 22 | ``` 23 | 24 | - for a 2-GPU run on a single node using `torchrun`: 25 | ``` 26 | torchrun --nproc-per-node 2 micro_benchmarking_pytorch.py --network resnet50 --batch-size 128 27 | 28 | ``` 29 | 30 | - for a 2-GPU run on a single node using `--distributed_dataparallel`: 31 | ``` 32 | python3 micro_benchmarking_pytorch.py --device_ids=0 --network resnet50 --distributed_dataparallel --rank 0 --world-size 2 --dist-backend nccl --dist-url tcp://127.0.0.1:4332 --batch-size 64 & 33 | python3 micro_benchmarking_pytorch.py --device_ids=1 --network resnet50 --distributed_dataparallel --rank 1 --world-size 2 --dist-backend nccl --dist-url tcp://127.0.0.1:4332 --batch-size 64 & 34 | ``` 35 | 36 | 37 | To run FlopsProfiler (with deepspeed.profiling.flops_profiler imported): 38 | `python micro_benchmarking_pytorch.py --network resnet50 --amp-opt-level=2 --batch-size=256 --iterations=20 --flops-prof-step 10` 39 | 40 | ## Performance tuning 41 | If performance on a specific card and/or model is found to be lacking, typically some gains can be made by tuning MIOpen. For this, `export MIOPEN_FIND_ENFORCE=3` prior to running the model. This will take some time if untuned configurations are encountered and write to a local performance database. More information on this can be found in the [MIOpen documentation](https://rocm.github.io/MIOpen/doc/html/perfdatabase.html). 42 | 43 | ## PyTorch 2.0 44 | Added the `--compile` option opens up PyTorch 2.0 capabilities, which comes with several options. Here are some notes from upstream: 45 | ``` 46 | Optimizes given model/function using TorchDynamo and specified backend. 47 | 48 | Args: 49 | model (Callable): Module/function to optimize 50 | fullgraph (bool): Whether it is ok to break model into several subgraphs 51 | dynamic (bool): Use dynamic shape tracing 52 | backend (str or Callable): backend to be used 53 | mode (str): Can be either "default", "reduce-overhead" or "max-autotune" 54 | options (dict): A dictionary of options to pass to the backend. 55 | disable (bool): Turn torch.compile() into a no-op for testing 56 | 57 | Example:: 58 | 59 | @torch.compile(options={"matmul-padding": True}, fullgraph=True) 60 | def foo(x): 61 | return torch.sin(x) + torch.cos(x) 62 | ``` 63 | 64 | With the required `--compile` option, these additional options are now available from the command line with the `--compileContext` flag. Here are a few examples: 65 | 66 | ```bash 67 | python micro_benchmarking_pytorch.py --network resnet50 --compile # default run 68 | ``` 69 | 70 | ```bash 71 | python micro_benchmarking_pytorch.py --network resnet50 --compile --compileContext "{'mode': 'max-autotune', 'fullgraph': 'True'}" 72 | ``` 73 | 74 | ```bash 75 | python micro_benchmarking_pytorch.py --network resnet50 --compile --compileContext "{'options': {'static-memory': 'True', 'matmul-padding': 'True'}}" 76 | ``` 77 | Note: you cannot pass the `mode` and `options` options together. 78 | -------------------------------------------------------------------------------- /shufflenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | import numpy as np 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | class ShufflenetUnit(nn.Module): 15 | expansion = 4 16 | def __init__(self, inplanes, planes, stride=1, downsample=None, flag=False): 17 | super(ShufflenetUnit, self).__init__() 18 | self.downsample = downsample 19 | group_num = 3 20 | self.flag = flag 21 | if self.flag: 22 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, groups=1, bias=False) 23 | else: 24 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, groups=group_num, bias=False) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | 27 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 28 | padding=1, bias=False) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | 31 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, groups=group_num, bias=False) 32 | self.bn3 = nn.BatchNorm2d(planes * 4) 33 | self.relu = nn.ReLU(inplace=True) 34 | 35 | def _shuffle(self, features, g): 36 | channels = features.size()[1] 37 | index = torch.from_numpy(np.asarray([i for i in range(channels)])) 38 | index = index.view(-1, g).t().contiguous() 39 | index = index.view(-1).cuda() 40 | features = features[:, index] 41 | return features 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | if not self.flag: 51 | out = self._shuffle(out, 3) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | out = self.conv3(out) 57 | out = self.bn3(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | out = torch.cat((out, residual), 1) 62 | else: 63 | out += residual 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | class ShuffleNet(nn.Module): 69 | inplanes = 24 70 | def __init__(self, block, layers, num_classes=1000): 71 | super(ShuffleNet, self).__init__() 72 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=24, kernel_size=3, 73 | padding=1, stride=2, bias=False) 74 | self.bn1 = nn.BatchNorm2d(24) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) 77 | 78 | self.stage2 = self._make_layer(block, 240, layers[0], True) 79 | self.stage3 = self._make_layer(block, 480, layers[1], False) 80 | self.stage4 = self._make_layer(block, 960, layers[2], False) 81 | 82 | self.globalpool = nn.AvgPool2d(kernel_size=7, stride=1) 83 | self.fc = nn.Linear(960, num_classes) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 88 | m.weight.data.normal_(0, math.sqrt(2. / n)) 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | def _make_layer(self, block, planes, blocks, flag): 94 | downsample = nn.Sequential( 95 | nn.AvgPool2d(kernel_size=3, stride=2,padding=1) 96 | ) 97 | 98 | inner_plane = (planes - self.inplanes) / 4 99 | layers = [] 100 | layers.append(block(self.inplanes, inner_plane, 2, downsample, flag=flag)) 101 | self.inplanes = planes 102 | for i in range(blocks): 103 | layers.append(block(planes, planes/4)) 104 | 105 | return nn.Sequential(*layers) 106 | 107 | def forward(self,x): 108 | x = self.conv1(x) 109 | x = self.bn1(x) 110 | x = self.relu(x) 111 | x = self.maxpool(x) 112 | 113 | x = self.stage2(x) 114 | x = self.stage3(x) 115 | x = self.stage4(x) 116 | 117 | x = self.globalpool(x) 118 | x = x.view(x.size(0), -1) 119 | x = self.fc(x) 120 | 121 | return x 122 | 123 | 124 | def shufflenet(): 125 | model = ShuffleNet(ShufflenetUnit, [3, 7, 3]) 126 | return model 127 | 128 | if __name__=="__main__": 129 | model = shufflenet() 130 | print(model) 131 | -------------------------------------------------------------------------------- /shufflenet_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import math 7 | import numpy as np 8 | 9 | def conv3x3(in_channels, out_channels, stride, padding=1, groups=1): 10 | """3x3 convolution""" 11 | return nn.Conv2d(in_channels, out_channels, 12 | kernel_size=3, stride=stride, padding=padding, 13 | groups=groups, 14 | bias=False) 15 | 16 | def conv1x1(in_channels, out_channels, stride=1): 17 | """1x1 convolution""" 18 | return nn.Conv2d(in_channels, out_channels, 19 | kernel_size=1, stride=stride,padding=0, 20 | bias=False) 21 | 22 | class ShufflenetUnit(nn.Module): 23 | def __init__(self, inplanes, planes, stride=1, downsample=None): 24 | super(ShufflenetUnit, self).__init__() 25 | self.downsample = downsample 26 | 27 | if not self.downsample: #---if not downsample, then channel split, so the channel become half 28 | inplanes = inplanes // 2 29 | planes = planes // 2 30 | 31 | self.conv1x1_1 = conv1x1(in_channels=inplanes, out_channels=planes) 32 | self.conv1x1_1_bn = nn.BatchNorm2d(planes) 33 | 34 | self.dwconv3x3 = conv3x3(in_channels=planes, out_channels=planes, stride=stride, groups=planes) 35 | self.dwconv3x3_bn= nn.BatchNorm2d(planes) 36 | 37 | self.conv1x1_2 = conv1x1(in_channels=planes, out_channels=planes) 38 | self.conv1x1_2_bn = nn.BatchNorm2d(planes) 39 | 40 | self.relu = nn.ReLU(inplace=True) 41 | 42 | def _channel_split(self, features, ratio=0.5): 43 | """ 44 | ratio: c'/c, default value is 0.5 45 | """ 46 | size = features.size()[1] 47 | split_idx = int(size * ratio) 48 | return features[:,:split_idx], features[:,split_idx:] 49 | 50 | def _channel_shuffle(self, features, g=2): 51 | channels = features.size()[1] 52 | index = torch.from_numpy(np.asarray([i for i in range(channels)])) 53 | index = index.view(-1, g).t().contiguous() 54 | index = index.view(-1).cuda() 55 | features = features[:, index] 56 | return features 57 | 58 | def forward(self, x): 59 | if self.downsample: 60 | #x1 = x.clone() #----deep copy x, so where x2 is modified, x1 not be affected 61 | x1 = x 62 | x2 = x 63 | else: 64 | x1, x2 = self._channel_split(x) 65 | 66 | #----right branch----- 67 | x2 = self.conv1x1_1(x2) 68 | x2 = self.conv1x1_1_bn(x2) 69 | x2 = self.relu(x2) 70 | 71 | x2 = self.dwconv3x3(x2) 72 | x2 = self.dwconv3x3_bn(x2) 73 | 74 | x2 = self.conv1x1_2(x2) 75 | x2 = self.conv1x1_2_bn(x2) 76 | x2 = self.relu(x2) 77 | 78 | #---left branch------- 79 | if self.downsample: 80 | x1 = self.downsample(x1) 81 | 82 | x = torch.cat([x1, x2], 1) 83 | x = self._channel_shuffle(x) 84 | return x 85 | 86 | class ShuffleNet(nn.Module): 87 | def __init__(self, feature_dim, layers_num, num_classes=1000): 88 | super(ShuffleNet, self).__init__() 89 | dim1, dim2, dim3, dim4, dim5 = feature_dim 90 | self.conv1 = conv3x3(in_channels=3, out_channels=dim1, 91 | stride=2, padding=1) 92 | self.bn1 = nn.BatchNorm2d(dim1) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 95 | 96 | self.stage2 = self._make_layer(dim1, dim2, layers_num[0]) 97 | self.stage3 = self._make_layer(dim2, dim3, layers_num[1]) 98 | self.stage4 = self._make_layer(dim3, dim4, layers_num[2]) 99 | 100 | self.conv5 = conv1x1(in_channels=dim4, out_channels=dim5) 101 | self.globalpool = nn.AvgPool2d(kernel_size=7, stride=1) 102 | self.fc = nn.Linear(dim5, num_classes) 103 | 104 | """ 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 108 | m.weight.data.normal_(0, math.sqrt(2. / n)) 109 | elif isinstance(m, nn.BatchNorm2d): 110 | m.weight.data.fill_(1) 111 | m.bias.data.zero_() 112 | """ 113 | 114 | def _make_layer(self, dim1, dim2, blocks_num): 115 | half_channel = dim2 // 2 116 | downsample = nn.Sequential( 117 | conv3x3(in_channels=dim1, out_channels=dim1, stride=2, padding=1, groups=dim1), 118 | nn.BatchNorm2d(dim1), 119 | conv1x1(in_channels=dim1, out_channels=half_channel), 120 | nn.BatchNorm2d(half_channel), 121 | nn.ReLU(inplace=True) 122 | ) 123 | 124 | layers = [] 125 | layers.append(ShufflenetUnit(dim1, half_channel, stride=2, downsample=downsample)) 126 | for i in range(blocks_num): 127 | layers.append(ShufflenetUnit(dim2, dim2, stride=1)) 128 | 129 | return nn.Sequential(*layers) 130 | 131 | def forward(self,x): 132 | x = self.conv1(x) 133 | x = self.bn1(x) 134 | x = self.relu(x) 135 | #print("x0.size:\t", x.size()) 136 | x = self.maxpool(x) 137 | #print("x1.size:\t", x.size()) 138 | x = self.stage2(x) 139 | #print("x2.size:\t", x.size()) 140 | x = self.stage3(x) 141 | #print("x3.size:\t", x.size()) 142 | x = self.stage4(x) 143 | #print("x4.size:\t", x.size()) 144 | 145 | x = self.conv5(x) 146 | #print("x5.size:\t", x.size()) 147 | x = self.globalpool(x) 148 | #print("x6.size:\t", x.size()) 149 | 150 | x = x.view(-1, 1024) 151 | x = self.fc(x) 152 | 153 | return x 154 | 155 | features = { 156 | "0.5x":[24, 48, 96, 192, 1024], 157 | "1x":[24, 116, 232, 464, 1024], 158 | "1.5x":[24, 176, 352, 704, 1024], 159 | "2x":[24, 244, 488, 976, 2048] 160 | } 161 | 162 | def shufflenet(): 163 | model = ShuffleNet(features["1x"], [3, 7, 3]) 164 | return model 165 | 166 | if __name__=="__main__": 167 | model = shufflenet().cuda() 168 | print(model) 169 | x = torch.rand((1,3,224,224)) 170 | x = Variable(x).cuda() 171 | x = model(x) 172 | -------------------------------------------------------------------------------- /xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | from torch.nn import init 6 | import torch 7 | 8 | __all__ = ['xception'] 9 | 10 | model_urls = { 11 | 'xception':'https://www.dropbox.com/s/1hplpzet9d7dv29/xception-c0a72b38.pth.tar?dl=1' 12 | } 13 | 14 | 15 | class SeparableConv2d(nn.Module): 16 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False): 17 | super(SeparableConv2d,self).__init__() 18 | 19 | self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias) 20 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) 21 | 22 | def forward(self,x): 23 | x = self.conv1(x) 24 | x = self.pointwise(x) 25 | return x 26 | 27 | 28 | class Block(nn.Module): 29 | def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True): 30 | super(Block, self).__init__() 31 | 32 | if out_filters != in_filters or strides!=1: 33 | self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False) 34 | self.skipbn = nn.BatchNorm2d(out_filters) 35 | else: 36 | self.skip=None 37 | 38 | self.relu = nn.ReLU(inplace=True) 39 | rep=[] 40 | 41 | filters=in_filters 42 | if grow_first: 43 | rep.append(self.relu) 44 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) 45 | rep.append(nn.BatchNorm2d(out_filters)) 46 | filters = out_filters 47 | 48 | for i in range(reps-1): 49 | rep.append(self.relu) 50 | rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False)) 51 | rep.append(nn.BatchNorm2d(filters)) 52 | 53 | if not grow_first: 54 | rep.append(self.relu) 55 | rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) 56 | rep.append(nn.BatchNorm2d(out_filters)) 57 | 58 | if not start_with_relu: 59 | rep = rep[1:] 60 | else: 61 | rep[0] = nn.ReLU(inplace=False) 62 | 63 | if strides != 1: 64 | rep.append(nn.MaxPool2d(3,strides,1)) 65 | self.rep = nn.Sequential(*rep) 66 | 67 | def forward(self,inp): 68 | x = self.rep(inp) 69 | 70 | if self.skip is not None: 71 | skip = self.skip(inp) 72 | skip = self.skipbn(skip) 73 | else: 74 | skip = inp 75 | 76 | x+=skip 77 | return x 78 | 79 | 80 | 81 | class Xception(nn.Module): 82 | """ 83 | Xception optimized for the ImageNet dataset, as specified in 84 | https://arxiv.org/pdf/1610.02357.pdf 85 | """ 86 | def __init__(self, num_classes=1000): 87 | """ Constructor 88 | Args: 89 | num_classes: number of classes 90 | """ 91 | super(Xception, self).__init__() 92 | 93 | 94 | self.num_classes = num_classes 95 | 96 | self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False) 97 | self.bn1 = nn.BatchNorm2d(32) 98 | self.relu = nn.ReLU(inplace=True) 99 | 100 | self.conv2 = nn.Conv2d(32,64,3,bias=False) 101 | self.bn2 = nn.BatchNorm2d(64) 102 | #do relu here 103 | 104 | self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True) 105 | self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True) 106 | self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True) 107 | 108 | self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True) 109 | self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True) 110 | self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True) 111 | self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True) 112 | 113 | self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True) 114 | self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True) 115 | self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True) 116 | self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True) 117 | 118 | self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) 119 | 120 | self.conv3 = SeparableConv2d(1024,1536,3,1,1) 121 | self.bn3 = nn.BatchNorm2d(1536) 122 | 123 | #do relu here 124 | self.conv4 = SeparableConv2d(1536,2048,3,1,1) 125 | self.bn4 = nn.BatchNorm2d(2048) 126 | 127 | self.fc = nn.Linear(2048, num_classes) 128 | 129 | 130 | 131 | #------- init weights -------- 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 135 | m.weight.data.normal_(0, math.sqrt(2. / n)) 136 | elif isinstance(m, nn.BatchNorm2d): 137 | m.weight.data.fill_(1) 138 | m.bias.data.zero_() 139 | #----------------------------- 140 | 141 | 142 | 143 | 144 | 145 | def forward(self, x): 146 | x = self.conv1(x) 147 | x = self.bn1(x) 148 | x = self.relu(x) 149 | 150 | x = self.conv2(x) 151 | x = self.bn2(x) 152 | x = self.relu(x) 153 | 154 | x = self.block1(x) 155 | x = self.block2(x) 156 | x = self.block3(x) 157 | x = self.block4(x) 158 | x = self.block5(x) 159 | x = self.block6(x) 160 | x = self.block7(x) 161 | x = self.block8(x) 162 | x = self.block9(x) 163 | x = self.block10(x) 164 | x = self.block11(x) 165 | x = self.block12(x) 166 | 167 | x = self.conv3(x) 168 | x = self.bn3(x) 169 | x = self.relu(x) 170 | 171 | x = self.conv4(x) 172 | x = self.bn4(x) 173 | x = self.relu(x) 174 | 175 | x = F.adaptive_avg_pool2d(x, (1, 1)) 176 | x = x.view(x.size(0), -1) 177 | x = self.fc(x) 178 | 179 | return x 180 | 181 | 182 | 183 | def xception(pretrained=False,**kwargs): 184 | model = Xception(**kwargs) 185 | return model 186 | 187 | 188 | if __name__=="__main__": 189 | model = xception() 190 | print(model) 191 | 192 | -------------------------------------------------------------------------------- /TorchTensorOpsBench/torch_tensor_ops_bench.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | import time 6 | import argparse 7 | import sys 8 | 9 | dtype_map = {"fp32" : torch.float32, 10 | "fp16" : torch.float16, 11 | "bf16" : torch.bfloat16} 12 | 13 | # important sizes 14 | # BERT, DLRM 15 | # list of ops 16 | # [(op_name, op_type, inp1_dim, inp2_dim)] 17 | model_ops = [ 18 | ('add', 'binary', '32-128-1024', '32-128-1024'), 19 | ('add', 'binary', '32-16-128-128','32-1-1-128'), 20 | ('add', 'binary', '64-128-1024', '64-128-1024'), 21 | ('add', 'binary', '64-16-128-128', '64-1-1-128'), 22 | ('add', 'binary', '4-512-1024', '4-512-1024'), 23 | ('add', 'binary', '4-16-512-512', '4-1-1-512'), 24 | ('add', 'binary', '8-512-1024', '8-512-1024'), 25 | ('add', 'binary', '8-16-512-512', '8-1-1-512'), 26 | ('add_', 'binary', '32-128-1024', '1024'), 27 | ('add_', 'binary', '64-128-1024', '1024'), 28 | ('add_', 'binary', '4-512-1024', '1024'), 29 | ('add_', 'binary', '8-512-1024', '1024'), 30 | ('add_', 'binary', '512-13', '512-13'), 31 | ('add_', 'binary', '512-512', '512-512'), 32 | ('add_', 'binary', '256-512', '256-512'), 33 | ('add_', 'binary', '256-256', '256-256'), 34 | ('add_', 'binary', '128-256', '128-256'), 35 | ('add_', 'binary', '128-128', '128-128'), 36 | ('add_', 'binary', '1024-480', '1024-480'), 37 | ('add_', 'binary', '1024-1024', '1024-1024'), 38 | ('add_', 'binary', '512-1024', '512-1024'), 39 | ('add_', 'binary', '1-256', '1-256'), 40 | ('add_', 'binary', '1', '1'), 41 | ('div', 'binary', '32-16-128-128', '1'), 42 | ('div', 'binary', '64-16-128-128', '1'), 43 | ('div', 'binary', '4-16-512-512', '1'), 44 | ('div', 'binary', '8-16-512-512', '1'), 45 | ('sum', 'reduction', '32-128-1024'), 46 | ('sum', 'reduction', '64-128-1024'), 47 | ('sum', 'reduction', '4-512-1024'), 48 | ('sum', 'reduction', '8-512-1024'), 49 | ('add_', 'sparse', '32709138-128', '851968'), 50 | ('relu_', 'unary', '32768-512'), 51 | ('relu_', 'unary', '32768-256'), 52 | ('relu_', 'unary', '32768-128'), 53 | ('relu_', 'unary', '32768-1024'), 54 | ] 55 | 56 | # initial set of ops. 57 | # TODO: add more ops to this list 58 | binary_ops = ['add', 'mul', 'div', 'sub', 'eq'] 59 | unary_ops = ['exp', 'relu', 'tanh', 'sqrt'] 60 | reduction_ops = ['sum', 'prod', 'norm', 'max', 'mean', 'std', 'var', 'argmax', 'argmin'] 61 | predefined_ops = binary_ops + unary_ops + reduction_ops 62 | 63 | def time_wrap(use_gpu): 64 | if use_gpu: 65 | torch.cuda.synchronize() 66 | return time.time() 67 | 68 | def benchmark(op_str, args): 69 | device = torch.device(args.device) 70 | dtype = dtype_map[args.dtype] 71 | input1_dim = [int(dim) for dim in args.input1_dim.split("-")] 72 | input1 = torch.randn(input1_dim, device=device, dtype=dtype) 73 | sparse_str = "(sparse)" if args.op_type == "sparse" else "" 74 | dtype_str = "(" + args.dtype + ")" if args.append_dtype else "" 75 | op_meta = op_str + sparse_str + dtype_str + "(" + args.input1_dim 76 | op_args = [] 77 | 78 | if args.op_type == 'sparse': 79 | last_dim = input1_dim[-1] 80 | num_indices = int(args.input2_dim) 81 | indices = torch.cuda.LongTensor(np.random.uniform(0, input1_dim[0], [1, num_indices])) 82 | values = torch.cuda.FloatTensor(np.random.uniform(-1, 1, [num_indices, last_dim])) 83 | values = values.to(dtype=dtype) 84 | input2 = torch.cuda.sparse.FloatTensor(indices, values, size=input1_dim) 85 | op_args.append(input2) 86 | op_meta += "," + str(num_indices) + "-" + str(last_dim) 87 | 88 | if op_str in binary_ops or args.op_type == 'binary': 89 | assert args.input2_dim, "input2_dim should be set for binary op - {}".format(op_str) 90 | input2_dim = [int(dim) for dim in args.input2_dim.split("-")] 91 | input2 = torch.randn(input2_dim, device=device, dtype=dtype) 92 | op_meta += "," + args.input2_dim 93 | op_args.append(input2) 94 | 95 | op_meta += ")" 96 | args.op_meta = op_meta 97 | if args.inplace and ((op_str in binary_ops) or (op_str in unary_ops)) and op_str[-1] != '_': 98 | op_str += '_' #inplace 99 | 100 | op = getattr(input1, op_str) 101 | 102 | try: 103 | # warmup iterations 104 | for _ in range(args.num_warmup_iters): 105 | op(*op_args) 106 | 107 | # main iterations 108 | with torch.autograd.profiler.profile(enabled=args.enable_profiling, use_cuda=args.use_gpu) as prof: 109 | start_time = time_wrap(args.use_gpu) 110 | for _ in range(args.num_iters): 111 | op(*op_args) 112 | end_time = time_wrap(args.use_gpu) 113 | time_per_iter = 1000.0*(end_time - start_time)/args.num_iters 114 | print("{:45} : {:.2f} ms/it".format(op_meta, time_per_iter)) 115 | 116 | if args.enable_profiling: 117 | if args.use_gpu: 118 | print(prof.key_averages().table(sort_by="cuda_time_total")) 119 | else: 120 | print(prof.key_averages().table(sort_by="cpu_time_total")) 121 | except RuntimeError as e: 122 | raise RuntimeError("{} operator failed with error: {}".format(op_meta, str(e))) 123 | 124 | if __name__ == "__main__": 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument("--device", default='cuda', required=False, type=str) 127 | parser.add_argument("--dtype", default='fp32', required=False, type=str, choices=['fp32', 'fp16', 'bf16']) 128 | parser.add_argument("--input1-dim", default="64-1024-1024", type=str, required=False) 129 | parser.add_argument("--input2-dim", default="64-1024-1024", type=str, required=False) 130 | parser.add_argument("--op", default='add', required=False, type=str) 131 | parser.add_argument("--op-type", default=None, required=False, type=str) 132 | parser.add_argument("--inplace", default=False, action="store_true") 133 | parser.add_argument("--run-predefined", default=False, action="store_true") 134 | parser.add_argument("--run-model-ops", default=False, action="store_true") 135 | parser.add_argument("--num-iters", default=20, type=int, required=False) 136 | parser.add_argument("--num-warmup-iters", default=5, type=int, required=False) 137 | parser.add_argument("--enable-profiling", action="store_true", default=False) 138 | parser.add_argument("--append-dtype", action="store_true", default=False) 139 | 140 | args = parser.parse_args() 141 | args.use_gpu = True if 'cuda' in args.device else False 142 | 143 | print("========= Milliseconds per iteration for PyTorch operators with {} dtype =========".format(args.dtype)) 144 | if args.run_predefined: 145 | for op_str in predefined_ops: 146 | benchmark(op_str, args) 147 | elif args.run_model_ops: 148 | assert (not args.inplace), "inplace should not be set when running model ops" 149 | for op_info in model_ops: 150 | if len(op_info) == 4: 151 | op_str, args.op_type, args.input1_dim, args.input2_dim = op_info 152 | else: 153 | op_str, args.op_type, args.input1_dim = op_info 154 | args.input2_dim = None 155 | benchmark(op_str, args) 156 | else: 157 | benchmark(args.op, args) 158 | -------------------------------------------------------------------------------- /micro_benchmarking_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import random 4 | import time 5 | import argparse 6 | import os 7 | import sys 8 | import ast 9 | import copy 10 | import math 11 | import torch.nn as nn 12 | import torch.multiprocessing as mp 13 | from fp16util import network_to_half, get_param_copy 14 | from shufflenet import shufflenet 15 | from shufflenet_v2 import shufflenet as shufflenet_v2 16 | from xception import xception 17 | import csv 18 | import json 19 | from torch.amp import autocast, GradScaler 20 | from torch.optim.lr_scheduler import LambdaLR 21 | 22 | try: 23 | import torch._dynamo 24 | torch._dynamo.config.verbose=True 25 | HAVE_DYNAMO = True 26 | except: 27 | HAVE_DYNAMO = False 28 | 29 | IS_PT2 = hasattr(torch, "compile") 30 | 31 | is_torchrun = False 32 | if "LOCAL_RANK" in os.environ: 33 | # this indicates we're using torchrun 34 | is_torchrun = True 35 | 36 | try: 37 | import apex 38 | HAVE_APEX = True 39 | except: 40 | HAVE_APEX = False 41 | 42 | def xform(m: nn.Module) -> nn.Module: 43 | m = m.cuda() 44 | m.to(memory_format=torch.channels_last) 45 | return m 46 | 47 | def weight_init(m): 48 | if isinstance(m, nn.Conv2d): 49 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 50 | m.weight.data.normal_(0, math.sqrt(2. / n)) 51 | if m.bias is not None: 52 | m.bias.data.zero_() 53 | elif isinstance(m, nn.BatchNorm2d): 54 | m.weight.data.fill_(1) 55 | m.bias.data.zero_() 56 | 57 | # num_classes=1000 58 | models = { 59 | "alexnet" : torchvision.models.alexnet, 60 | "densenet121" : torchvision.models.densenet121, 61 | "densenet161" : torchvision.models.densenet161, 62 | "densenet169" : torchvision.models.densenet169, 63 | "densenet201" : torchvision.models.densenet201, 64 | "googlenet" : torchvision.models.googlenet, 65 | "inception_v3" : torchvision.models.inception_v3, 66 | "mnasnet0_5" : torchvision.models.mnasnet0_5, 67 | "mnasnet0_75" : torchvision.models.mnasnet0_75, 68 | "mnasnet1_0" : torchvision.models.mnasnet1_0, 69 | "mnasnet1_3" : torchvision.models.mnasnet1_3, 70 | "mobilenet_v2" : torchvision.models.mobilenet_v2, 71 | "resnet18" : torchvision.models.resnet18, 72 | "resnet34" : torchvision.models.resnet34, 73 | "resnet50" : torchvision.models.resnet50, 74 | "resnet101" : torchvision.models.resnet101, 75 | "resnet152" : torchvision.models.resnet152, 76 | "resnext50" : torchvision.models.resnext50_32x4d, 77 | "resnext50_32x4d" : torchvision.models.resnext50_32x4d, 78 | "resnext101" : torchvision.models.resnext101_32x8d, 79 | "resnext101_32x8d" : torchvision.models.resnext101_32x8d, 80 | "shufflenet" : shufflenet, 81 | "shufflenet_v2" : shufflenet_v2, 82 | "shufflenet_v2_x05" : torchvision.models.shufflenet_v2_x0_5, 83 | "shufflenet_v2_x10" : torchvision.models.shufflenet_v2_x1_0, 84 | "shufflenet_v2_x15" : torchvision.models.shufflenet_v2_x1_5, 85 | "shufflenet_v2_x20" : torchvision.models.shufflenet_v2_x2_0, 86 | "shufflenet_v2_x0_5" : torchvision.models.shufflenet_v2_x0_5, 87 | "shufflenet_v2_x1_0" : torchvision.models.shufflenet_v2_x1_0, 88 | "shufflenet_v2_x1_5" : torchvision.models.shufflenet_v2_x1_5, 89 | "shufflenet_v2_x2_0" : torchvision.models.shufflenet_v2_x2_0, 90 | "SqueezeNet" : torchvision.models.squeezenet1_0, 91 | "squeezenet1_0" : torchvision.models.squeezenet1_0, 92 | "SqueezeNet1.1" : torchvision.models.squeezenet1_1, 93 | "squeezenet1_1" : torchvision.models.squeezenet1_1, 94 | "vgg11" : torchvision.models.vgg11, 95 | "vgg13" : torchvision.models.vgg13, 96 | "vgg16" : torchvision.models.vgg16, 97 | "vgg19" : torchvision.models.vgg19, 98 | "vgg11_bn" : torchvision.models.vgg11_bn, 99 | "vgg13_bn" : torchvision.models.vgg13_bn, 100 | "vgg16_bn" : torchvision.models.vgg16_bn, 101 | "vgg19_bn" : torchvision.models.vgg19_bn, 102 | "wide_resnet50_2" : torchvision.models.wide_resnet50_2, 103 | "wide_resnet101_2" : torchvision.models.wide_resnet101_2, 104 | "xception" : xception, 105 | } 106 | 107 | # newer torchvision models, for backwards compat 108 | try: 109 | models["swin_t"] = torchvision.models.swin_t 110 | models["swin_s"] = torchvision.models.swin_s 111 | models["swin_b"] = torchvision.models.swin_b 112 | models["swin_v2_t"] = torchvision.models.swin_v2_t 113 | models["swin_v2_s"] = torchvision.models.swin_v2_s 114 | models["swin_v2_b"] = torchvision.models.swin_v2_b 115 | models["vit_b_16"] = torchvision.models.vit_b_16 116 | models["vit_b_32"] = torchvision.models.vit_b_32 117 | models["vit_l_16"] = torchvision.models.vit_l_16 118 | models["vit_l_32"] = torchvision.models.vit_l_32 119 | models["vit_h_14"] = torchvision.models.vit_h_14 120 | models["efficientnet_b0"] = torchvision.models.efficientnet_b0 121 | models["efficientnet_b1"] = torchvision.models.efficientnet_b1 122 | models["efficientnet_b2"] = torchvision.models.efficientnet_b2 123 | models["efficientnet_b3"] = torchvision.models.efficientnet_b3 124 | models["efficientnet_b4"] = torchvision.models.efficientnet_b4 125 | models["efficientnet_b5"] = torchvision.models.efficientnet_b5 126 | models["efficientnet_b6"] = torchvision.models.efficientnet_b6 127 | models["efficientnet_b7"] = torchvision.models.efficientnet_b7 128 | models["maxvit_t"] = torchvision.models.maxvit_t 129 | except AttributeError: 130 | pass 131 | 132 | try: 133 | models["mobilenet_v3_large"] = torchvision.models.mobilenet_v3_large 134 | models["mobilenet_v3_small"] = torchvision.models.mobilenet_v3_small 135 | except AttributeError: 136 | pass 137 | # segmentation models, num_classes=21 138 | segmentation_models = { 139 | "fcn_resnet50" : torchvision.models.segmentation.fcn_resnet50, 140 | "fcn_resnet101" : torchvision.models.segmentation.fcn_resnet101, 141 | "deeplabv3_resnet50" : torchvision.models.segmentation.deeplabv3_resnet50, 142 | "deeplabv3_resnet101" : torchvision.models.segmentation.deeplabv3_resnet101, 143 | } 144 | 145 | # newer torchvision segmentation models, for backwards compat 146 | try: 147 | segmentation_models["deeplabv3_mobilenet_v3_large"] = torchvision.models.segmentation.deeplabv3_mobilenet_v3_large 148 | segmentation_models["lraspp_mobilenet_v3_large"] = torchvision.models.segmentation.lraspp_mobilenet_v3_large, 149 | except AttributeError: 150 | pass 151 | 152 | def get_network_names(): 153 | return sorted(list(models.keys()) + list(segmentation_models.keys())) 154 | 155 | def get_network(net, params): 156 | # aux_logits=False only used by inception_v3 157 | if "inception_v3" == net: 158 | if params.nhwc: 159 | return xform(models[net](aux_logits=False)) 160 | return models[net](aux_logits=False).to(device="cuda") 161 | elif net in models: 162 | if params.nhwc: 163 | return xform(models[net]()) 164 | return models[net]().to(device="cuda") 165 | elif net in segmentation_models: 166 | if params.nhwc: 167 | return xform(segmentation_models[net]()) 168 | return segmentation_models[net]().to(device="cuda") 169 | else: 170 | print ("ERROR: not a supported model '%s'" % net) 171 | sys.exit(1) 172 | 173 | def forwardbackward(inp, optimizer, network, params, target, scaler, step=0, opt_step=1, flops_prof_step=0): 174 | if step % opt_step == 0: 175 | optimizer.zero_grad() 176 | if flops_prof_step: 177 | prof = FlopsProfiler(network) 178 | prof.start_profile() 179 | 180 | # AMP 181 | if params.amp: 182 | with autocast('cuda'): 183 | out = network(inp) 184 | # If using HuggingFace model outputs logits, we need to extract them 185 | if hasattr(out, 'logits'): 186 | logits = out.logits 187 | else: 188 | logits = out 189 | loss_fn = torch.nn.CrossEntropyLoss().to(device="cuda") 190 | if params.nhwc: 191 | loss_fn = loss_fn.to(memory_format=torch.channels_last) 192 | loss = loss_fn(logits, target) 193 | 194 | scaler.scale(loss).backward() 195 | if (step + 1) % opt_step == 0: 196 | scaler.step(optimizer) 197 | scaler.update() 198 | optimizer.zero_grad() 199 | # Not use amp (autocast and scaler) 200 | else: 201 | out = network(inp) 202 | # If using HuggingFace model outputs logits, we need to extract them 203 | if hasattr(out, 'logits'): 204 | logits = out.logits 205 | else: 206 | logits = out 207 | loss_fn = torch.nn.CrossEntropyLoss().to(device="cuda") 208 | if params.nhwc: 209 | loss_fn = loss_fn.to(memory_format=torch.channels_last) 210 | loss = loss_fn(logits, target) 211 | 212 | loss.backward() 213 | if (step + 1) % opt_step == 0: 214 | optimizer.step() 215 | optimizer.zero_grad() 216 | 217 | if flops_prof_step: 218 | # End profiler here to profile both fwd and bwd passes 219 | # flops = prof.get_total_flops(as_string=True) 220 | # params = prof.get_total_params(as_string=True) 221 | prof.print_model_profile(profile_step=flops_prof_step) 222 | prof.end_profile() 223 | 224 | 225 | def forward(inp, optimizer, network, params, target=None, scaler=None, step=0, opt_step=1, flops_prof_step=0): 226 | 227 | if flops_prof_step: 228 | prof = FlopsProfiler(network) 229 | prof.start_profile() 230 | 231 | # Run the forward pass 232 | with torch.no_grad(): # Disable gradient calculation 233 | if params.amp: 234 | with autocast('cuda'): 235 | out = network(inp) 236 | else: 237 | out = network(inp) 238 | 239 | if hasattr(out, 'logits'): 240 | return out.logits 241 | return out 242 | 243 | if flops_prof_step: 244 | # End profiler here to profile the forward pass 245 | prof.print_model_profile(profile_step=flops_prof_step) 246 | prof.end_profile() 247 | 248 | return out 249 | 250 | def rendezvous(distributed_parameters): 251 | print("Initializing process group...") 252 | torch.distributed.init_process_group(backend=distributed_parameters['dist_backend'], init_method=distributed_parameters['dist_url'], rank=distributed_parameters['rank'], world_size=distributed_parameters['world_size']) 253 | print("Rendezvous complete. Created process group...") 254 | 255 | def run_benchmarking_wrapper(params): 256 | params.flops_prof_step = max(0, min(params.flops_prof_step, params.iterations - 1)) 257 | if (params.device_ids): 258 | params.device_ids = [int(x) for x in params.device_ids.split(",")] 259 | else: 260 | params.device_ids = None 261 | params.distributed_parameters = {} 262 | if is_torchrun: 263 | params.distributed_parameters['rank'] = int(os.environ["LOCAL_RANK"]) 264 | params.distributed_parameters['world_size'] = int(os.environ["WORLD_SIZE"]) 265 | params.distributed_parameters['dist_backend'] = "nccl" 266 | params.distributed_parameters['dist_url'] = 'tcp://' + os.environ["MASTER_ADDR"] + ":" + os.environ["MASTER_PORT"] 267 | else: 268 | params.distributed_parameters['rank'] = params.rank 269 | params.distributed_parameters['world_size'] = params.world_size 270 | params.distributed_parameters['dist_backend'] = params.dist_backend 271 | params.distributed_parameters['dist_url'] = params.dist_url 272 | 273 | # Some arguments are required for distributed_dataparallel 274 | if params.distributed_dataparallel: 275 | assert params.distributed_parameters['rank'] is not None and \ 276 | params.distributed_parameters['world_size'] is not None and \ 277 | params.distributed_parameters['dist_backend'] is not None and \ 278 | params.distributed_parameters['dist_url'] is not None, "rank, world-size, dist-backend and dist-url are required arguments for distributed_dataparallel" 279 | 280 | if is_torchrun: 281 | params.ngpus = params.distributed_parameters['world_size'] 282 | elif params.distributed_dataparallel: 283 | params.ngpus = len(params.device_ids) if params.device_ids else torch.cuda.device_count() 284 | else: 285 | params.ngpus = 1 286 | 287 | if is_torchrun: 288 | run_benchmarking(params.distributed_parameters['rank'], params) 289 | elif params.distributed_dataparallel: 290 | # Assumption below that each process launched with --distributed_dataparallel has the same number of devices visible/specified 291 | params.distributed_parameters['world_size'] = params.ngpus * params.distributed_parameters['world_size'] 292 | params.distributed_parameters['rank'] = params.ngpus * params.distributed_parameters['rank'] 293 | mp.spawn(run_benchmarking, nprocs=params.ngpus, args=(params,)) 294 | else: 295 | run_benchmarking(0, params) 296 | 297 | def run_benchmarking(local_rank, params): 298 | device_ids = params.device_ids 299 | ngpus = params.ngpus 300 | net = params.network 301 | run_fp16 = params.fp16 302 | run_amp = params.amp 303 | distributed_dataparallel = params.distributed_dataparallel 304 | distributed_parameters = params.distributed_parameters 305 | batch_size = params.batch_size 306 | kineto = params.kineto 307 | iterations = params.iterations 308 | autograd_profiler = params.autograd_profiler 309 | flops_prof_step = params.flops_prof_step 310 | 311 | if is_torchrun: 312 | torch.cuda.set_device("cuda:%d" % local_rank) 313 | elif device_ids: 314 | assert ngpus == len(device_ids) 315 | torch.cuda.set_device("cuda:%d" % device_ids[local_rank]) 316 | else: 317 | torch.cuda.set_device("cuda:0") 318 | 319 | network = get_network(net, params) 320 | if "shufflenet" == net: 321 | network.apply(weight_init) 322 | 323 | if (run_fp16): 324 | network = network_to_half(network) 325 | 326 | if params.compile: 327 | compile_ctx = {"mode": None, 328 | "dynamic": False, 329 | "fullgraph": False, 330 | "backend": "inductor", 331 | "options": None, 332 | "disable": False} 333 | options = None # needed for internal pytorch checks 334 | if params.compileContext: 335 | compile_ctx.update(ast.literal_eval(params.compileContext)) 336 | if compile_ctx["mode"] is not None and compile_ctx["options"] is not None: 337 | raise RuntimeError("Cannot specify mode and options simultaneously") 338 | if compile_ctx["options"] is not None: 339 | options = {} # needed to save multiple options 340 | for compiler_pass in compile_ctx["options"].keys(): 341 | options.update({compiler_pass: bool(compile_ctx["options"][compiler_pass])}) 342 | if IS_PT2: 343 | network = torch.compile(network, 344 | mode=compile_ctx["mode"], 345 | dynamic=bool(compile_ctx["dynamic"]), 346 | fullgraph=bool(compile_ctx["fullgraph"]), 347 | backend=compile_ctx["backend"], 348 | options=options, 349 | disable=compile_ctx["disable"]) 350 | else: 351 | print ("ERROR: requested torch.compile but this isn't pytorch 2.x") 352 | sys.exit(1) 353 | 354 | param_copy = network.parameters() 355 | if (run_fp16): 356 | param_copy = get_param_copy(network) 357 | ## MLPerf Setting 358 | sgd_opt_base_learning_rate = 0.01 359 | sgd_opt_end_learning_rate = 1e-4 360 | sgd_opt_learning_rate_decay_poly_power = 2 361 | sgd_opt_weight_decay = 0.0001 362 | sgd_opt_momentum = 0.9 363 | opt_learning_rate_warmup_epochs = 5 364 | 365 | total_epochs = params.iterations 366 | optimizer = torch.optim.SGD(param_copy, lr = sgd_opt_base_learning_rate, momentum = sgd_opt_momentum, weight_decay=sgd_opt_weight_decay) 367 | 368 | def poly_decay(epoch): 369 | if epoch < opt_learning_rate_warmup_epochs: 370 | return float(epoch + 1) / opt_learning_rate_warmup_epochs 371 | else: 372 | poly = ((1 - (epoch - opt_learning_rate_warmup_epochs) / (total_epochs - opt_learning_rate_warmup_epochs)) ** sgd_opt_learning_rate_decay_poly_power) 373 | return (sgd_opt_end_learning_rate + (sgd_opt_base_learning_rate - sgd_opt_end_learning_rate) * poly) / sgd_opt_base_learning_rate 374 | 375 | scheduler = LambdaLR(optimizer, lr_lambda=poly_decay) 376 | 377 | if is_torchrun: 378 | rendezvous(distributed_parameters) 379 | devices_to_run_on = [local_rank] 380 | print ("INFO: Rank {} running distributed_dataparallel on devices: {}".format(distributed_parameters['rank'], str(devices_to_run_on))) 381 | network = torch.nn.parallel.DistributedDataParallel(network, device_ids=devices_to_run_on) 382 | batch_size = int(batch_size / ngpus) 383 | elif (distributed_dataparallel): 384 | distributed_parameters['rank'] += local_rank 385 | rendezvous(distributed_parameters) 386 | devices_to_run_on = [(device_ids[local_rank] if device_ids else local_rank)] 387 | print ("INFO: Rank {} running distributed_dataparallel on devices: {}".format(distributed_parameters['rank'], str(devices_to_run_on))) 388 | network = torch.nn.parallel.DistributedDataParallel(network, device_ids=devices_to_run_on) 389 | batch_size = int(batch_size / ngpus) 390 | 391 | if (net == "inception_v3"): 392 | inp = torch.randn(batch_size, 3, 299, 299, device="cuda") 393 | if params.nhwc: 394 | inp = inp.to(memory_format=torch.channels_last) 395 | else: 396 | inp = torch.randn(batch_size, 3, 224, 224, device="cuda") 397 | if params.nhwc: 398 | inp = inp.to(memory_format=torch.channels_last) 399 | if (run_fp16): 400 | inp = inp.half() 401 | if params.nhwc: 402 | inp = inp.to(memory_format=torch.channels_last) 403 | if net in models: 404 | # number of classes is 1000 for imagenet 405 | target = torch.randint(0, 1000, (batch_size,), device="cuda") 406 | elif net in segmentation_models: 407 | # number of classes is 21 for segmentation 408 | target = torch.randint(0, 21, (batch_size,), device="cuda") 409 | 410 | if params.mode == "training": 411 | forward_fn = forwardbackward 412 | network.train() 413 | else: 414 | forward_fn = forward 415 | network.eval() 416 | 417 | scaler = GradScaler('cuda') 418 | ## warmup. 419 | print ("INFO: running forward and backward for warmup.") 420 | for i in range(2): 421 | forward_fn(inp, optimizer, network, params, target, scaler=scaler, step=0, opt_step=params.opt_step) 422 | 423 | time.sleep(1) 424 | torch.cuda.synchronize() 425 | 426 | ## benchmark. 427 | print ("INFO: running the benchmark..") 428 | if kineto: 429 | from torch.profiler import schedule, profile, ProfilerActivity, record_function 430 | profiler_schedule = schedule( 431 | skip_first = 0, 432 | wait = 1, 433 | warmup = 2, 434 | active = 5, 435 | repeat = 1, 436 | ) 437 | 438 | def trace_ready_callback(prof): 439 | rank = 0 440 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 441 | rank = torch.distributed.get_rank() 442 | if rank == 0: 443 | print("----------- Trace Ready -----------") 444 | prof.export_chrome_trace(f"{params.profiler_output}.json") 445 | # print(f"----------- Rank {rank} Trace Ready -----------") 446 | # prof.export_chrome_trace(f"{params.profiler_output}_rank{rank}.json") 447 | 448 | tm = time.time() 449 | with profile( 450 | activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 451 | schedule=profiler_schedule, 452 | on_trace_ready=trace_ready_callback) as prof: 453 | for i in range(iterations): 454 | with record_function(f"iteration {i}"): 455 | forward_fn(inp, optimizer, network, params, target, scaler=scaler, step=i, opt_step=params.opt_step) 456 | prof.step() 457 | torch.cuda.synchronize() 458 | print(prof.key_averages().table(sort_by="cuda_time_total")) 459 | else: 460 | tm = time.time() 461 | with torch.autograd.profiler.emit_nvtx(enabled=autograd_profiler): 462 | for i in range(iterations): 463 | if i == flops_prof_step: 464 | forward_fn(inp, optimizer, network, params, target, scaler=scaler, step=i, opt_step=params.opt_step, flops_prof_step=i) 465 | else: 466 | forward_fn(inp, optimizer, network, params, target, scaler=scaler, step=i, opt_step=params.opt_step) 467 | torch.cuda.synchronize() 468 | 469 | tm2 = time.time() 470 | time_per_batch = (tm2 - tm) / iterations 471 | 472 | if run_fp16: 473 | dtype = 'FP16' 474 | elif run_amp: 475 | dtype = 'AMP: PyTorch Native Automatic Mixed Precision' 476 | else: 477 | dtype = 'FP32' 478 | 479 | result = None 480 | if not params.output_dir: 481 | params.output_dir = "." 482 | 483 | print ("OK: finished running benchmark..") 484 | print ("--------------------SUMMARY--------------------------") 485 | print ("Microbenchmark for network : {}".format(net)) 486 | if distributed_dataparallel or is_torchrun: 487 | print ("--------This process: rank " + str(distributed_parameters['rank']) + "--------"); 488 | print ("Num devices: 1") 489 | else: 490 | print ("Num devices: {}".format(ngpus)) 491 | result = { 492 | "Name": params.output_file, 493 | "GPUs": 1, 494 | "Mini batch size [img]": batch_size, 495 | "Mini batch size [img/gpu]": batch_size, 496 | "Throughput [img/sec]": batch_size / time_per_batch, 497 | "Time per mini-batch": time_per_batch 498 | } 499 | with open(f"{params.output_dir}/{params.output_file}.json", "w") as f: 500 | json.dump(result, f, indent=2) 501 | print ("Dtype: {}".format(dtype)) 502 | print ("Mini batch size [img] : {}".format(batch_size)) 503 | print ("Throughput [img/sec] : {}".format(batch_size/time_per_batch)) 504 | print ("Time per mini-batch : {}".format(time_per_batch)) 505 | 506 | if (distributed_dataparallel or is_torchrun) and distributed_parameters['rank'] == 0: 507 | print ("") 508 | print ("--------Overall (all ranks) (assuming same num/type devices for each rank)--------") 509 | world_size = distributed_parameters['world_size'] 510 | print ("Num devices: {}".format(world_size)) 511 | print ("Dtype: {}".format(dtype)) 512 | print ("Mini batch size [img] : {}".format(batch_size*world_size)) 513 | print ("Throughput [img/sec] : {}".format(batch_size*world_size/time_per_batch)) 514 | print ("Time per mini-batch : {}".format(time_per_batch)) 515 | result = { 516 | "Name": params.output_file, 517 | "GPUs": distributed_parameters['world_size'], 518 | "Mini batch size [img]": batch_size * distributed_parameters['world_size'], 519 | "Mini batch size [img/gpu]": batch_size, 520 | "Throughput [img/sec]": batch_size * distributed_parameters['world_size'] / time_per_batch, 521 | "Time per mini-batch": time_per_batch 522 | } 523 | with open(f"{params.output_dir}/{params.output_file}.json", "w") as f: 524 | json.dump(result, f, indent=2) 525 | 526 | csv_filename = f"{params.output_dir}/benchmark_summary.csv" 527 | if params.csv_file: 528 | csv_filename = params.csv_file 529 | file_exists = os.path.isfile(csv_filename) 530 | if result: 531 | with open(csv_filename, "a", newline='') as csvfile: 532 | writer = csv.writer(csvfile) 533 | if not file_exists: 534 | writer.writerow(result.keys()) 535 | writer.writerow(result.values()) 536 | print(f"Benchmark result saved to {csv_filename}") 537 | 538 | def main(): 539 | run_benchmarking_wrapper(copy.deepcopy(args)) 540 | 541 | if __name__ == '__main__': 542 | parser = argparse.ArgumentParser() 543 | parser.add_argument("--network", type=str, choices=get_network_names(), required=True, help="Network to run.") 544 | parser.add_argument("--batch-size" , type=int, required=False, default=64, help="Batch size (will be split among devices used by this invocation)") 545 | parser.add_argument("--iterations", type=int, required=False, default=20, help="Iterations") 546 | parser.add_argument("--flops-prof-step", type=int, required=False, default=0, help="The flops profiling step") 547 | parser.add_argument("--kineto", action='store_true', required=False, help="Turn kineto profiling on") 548 | parser.add_argument("--autograd_profiler", action='store_true', required=False, help="Use PyTorch autograd (old) profiler") 549 | parser.add_argument("--fp16", type=int, required=False, default=0,help="FP16 mixed precision benchmarking") 550 | parser.add_argument("--distributed_dataparallel", action='store_true', required=False, help="Use torch.nn.parallel.DistributedDataParallel api to run on multiple processes/nodes. The multiple processes need to be launched manually, this script will only launch ONE process per invocation. Either use --distributed_dataparallel and manually launch multiple processes or launch this script with `torchrun`") 551 | parser.add_argument("--device_ids", type=str, required=False, default=None, help="Comma-separated list (no spaces) to specify which HIP devices (0-indexed) to run distributedDataParallel api on. Might need to use HIP_VISIBLE_DEVICES to limit visiblity of devices to different processes.") 552 | parser.add_argument("--rank", type=int, required=False, default=None, help="Rank of this process. Required for --distributed_dataparallel") 553 | parser.add_argument("--world-size", type=int, required=False, default=None, help="Total number of ranks/processes. Required for --distributed_dataparallel") 554 | parser.add_argument("--dist-backend", type=str, required=False, default=None, help="Backend used for distributed training. Can be one of 'nccl' or 'gloo'. Required for --distributed_dataparallel") 555 | parser.add_argument("--dist-url", type=str, required=False, default=None, help="url used for rendezvous of processes in distributed training. Needs to contain IP and open port of master rank0 eg. 'tcp://172.23.2.1:54321'. Required for --distributed_dataparallel") 556 | parser.add_argument("--compile", action='store_true', required=False, help="use pytorch 2.0") 557 | parser.add_argument("--compileContext", default={}, required=False, help="additional compile options") 558 | parser.add_argument("--amp", action='store_true', default=False, required=False, help="Automatic mixed precision benchmarking") 559 | parser.add_argument("--csv-file", type=str, default=None, required=False, help="assign output csv file name.") 560 | parser.add_argument("--mode", type=str, choices=['training', 'inference'], default="training", help="Select mode: training or inference") 561 | parser.add_argument("--nhwc", action='store_true', default=False, help="Use nhwc format") 562 | parser.add_argument("--opt-step", type=int, required=False, default=1, help="Optimizer update step") 563 | parser.add_argument("--output-dir", type=str, default="", help="assign output directory name.") 564 | parser.add_argument("--output-file", type=str, default="", help="assign output file name.") 565 | parser.add_argument("--profiler-output", type=str, default="", help="assign profiler output name.") 566 | 567 | args = parser.parse_args() 568 | 569 | if args.flops_prof_step: 570 | try: 571 | from deepspeed.profiling.flops_profiler import FlopsProfiler 572 | except: 573 | print("ERROR: You must install (or copy) deepspeed.profiling to use --flops-prof-step") 574 | sys.exit(1) 575 | 576 | main() --------------------------------------------------------------------------------