├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── pytorch2timeloop ├── __init__.py ├── converter_pytorch.py └── utils │ ├── __init__.py │ ├── binary_elementwise.yaml │ ├── converter.py │ ├── convolution.yaml │ ├── depth_wise_convolution.yaml │ ├── grouped_convolution.yaml │ ├── hooks.py │ ├── interpreter.py │ ├── layer_descriptions.py │ ├── matmul.yaml │ ├── pool.yaml │ └── softmax.yaml ├── requirements.txt ├── setup.py └── test ├── __init__.py ├── test_configs.py ├── test_mobilenet_v2.py ├── test_mobilenet_v3.py ├── test_resnet.py └── test_simple_cnn.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Debug files 3 | debug.py 4 | 5 | # directory 6 | **/.directory 7 | 8 | # Emacs files 9 | *~ 10 | #*# 11 | 12 | #pycharm 13 | **/.idea/ 14 | venv/ 15 | 16 | # jupyter 17 | **/.ipynb_checkpoints/ 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | bin/ 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | .tox/ 47 | .coverage 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | 52 | # Translations 53 | *.mo 54 | 55 | # Mr Developer 56 | .mr.developer.cfg 57 | .project 58 | .pydevproject 59 | 60 | # Rope 61 | .ropeproject 62 | 63 | # TensorCanvas tmp file 64 | tmp.mp4 65 | 66 | # Others 67 | test.py 68 | workloads/ 69 | 70 | # Test results 71 | .test.tmp 72 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Accelergy Project 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft pytorch2timeloop/ 2 | global-exclude *.pyc 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch2timeloop-converter 2 | 3 | Converting pytorch convolutional neural nets and certain transformers to Timeloop workload files. 4 | 5 | ### Installing the converter 6 | After cloning this repository, run `python setup.py install` to finish the installation. Note that this converter has been developed and tested with: 7 | 8 | - python 3.11 9 | - pytorch 2.4.0 10 | - torchvision 0.19.0 11 | - numpy 2.1.0 12 | - pyyaml 5.3.1 13 | 14 | ### Using the converter 15 | ```python 16 | import torchvision.models as models 17 | import pytorch2timeloop 18 | 19 | # Define a pytorch-based neural network model, for example, a pre-defined alexnet from torchvision. 20 | net = models.alexnet() 21 | 22 | # Define the shape of a single input sample, in the following format: 23 | # (# of channels, height, width) 24 | # For example, the above alexnet will get a 224x224 RGB image: 25 | input_shape = (3, 224, 224) 26 | 27 | # Define the number of batches that will be used for the inference 28 | batch_size = 1 29 | 30 | # Define the directory names where the timeloop workload yaml files will be stored. 31 | # The yaml files will be stored in ./workloads/alexnet/ in this example. 32 | top_dir = 'workloads' 33 | sub_dir = 'alexnet' 34 | 35 | # By default, nn.Conv2d modules will be automatically converted, but nn.Linear modules will be ignored. 36 | # If you want to convert nn.Linear, set the option to be true. 37 | # The converter will change the description of nn.Linear into Convolution-like layer. 38 | # (e.g., in_channel=in_features, out_channel=out_features, input_height=1, input_width=1, filter size = 1x1, stride = 1x1, padding = 0x0) 39 | # If you want to ignore nn.Linear layers, set this option to be false. 40 | convert_fc = True 41 | 42 | # Finally, in case there exists a layer that is only used during the training phase, define an identifier for a such layer. 43 | # For example, in torchvision.models.inception_v3, auxiliary classification layers are not used during the inference (e.g., InceptionAux). 44 | # In this case, include a string that can serve as an identifier for such layers (e.g., 'Aux') in exception_module_names. 45 | # But for the above alexnet, there is no necessity to define this. 46 | exception_module_names = [] 47 | 48 | # Now, convert! 49 | pytorch2timeloop.convert_model(net, input_shape, batch_size, sub_dir, top_dir, convert_fc, exception_module_names) 50 | ``` 51 | 52 | --- 53 | 54 | This code is licensed with MIT License. This code has been modified from the works of Anurag Golla and Alex Moser. 55 | 56 | 57 | -------------------------------------------------------------------------------- /pytorch2timeloop/__init__.py: -------------------------------------------------------------------------------- 1 | from .converter_pytorch import ( 2 | convert_model, 3 | convert_model_with_sample_input, 4 | ) 5 | -------------------------------------------------------------------------------- /pytorch2timeloop/converter_pytorch.py: -------------------------------------------------------------------------------- 1 | """ Convert Trained PyTorch Models to Workloads """ 2 | 3 | import logging 4 | import os 5 | from pathlib import Path 6 | from typing import Any 7 | 8 | import torch 9 | from torch import nn 10 | import torch.fx as fx 11 | 12 | import yaml 13 | 14 | from pytorch2timeloop.utils.interpreter import Converter 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def convert_model_with_sample_input(model: nn.Module, 20 | sample_input: Any, 21 | batch_size: int, 22 | model_name: str, 23 | save_dir: Path, 24 | exception_module_names=[]): 25 | """ 26 | Convert a general PyTorch model to Timeloop problem files. 27 | 28 | Currently, only common CNNs and the BERT transformer from 29 | `transformers` are supported, but it is easy to add support for new 30 | DNNs. See documentation in `utils/hooks.py` for more on supporting 31 | new PyTorch module types. This interface is more general than 32 | `convert_model()` and should be preferred for new code. 33 | 34 | :param model: the PyTorch CNN model 35 | :param sample_input: 36 | :param batch_size: the batch size 37 | :param model_name: the name of the model, which will become the name 38 | of the subdirectory of `save_dir` with the problem files 39 | :param save_dir: the directory to save the output in 40 | :param exception_module_names: a list of fragments of module names 41 | to ignore (can be a prefix, suffix, or infix). 42 | """ 43 | logger.info("converting {} in {} model ...".format("all", model_name)) 44 | 45 | layer_data = _make_summary(model, sample_input) 46 | _convert_from_layer_data(layer_data, model_name, save_dir) 47 | 48 | 49 | def convert_model(model: nn.Module, input_size: tuple, batch_size: int, 50 | model_name: str, save_dir: Path, 51 | fuse=False, convert_fc=False, 52 | ignored_func=None, 53 | exception_module_names=()): 54 | """ 55 | Convert a PyTorch CNN model to Timeloop problem files. 56 | 57 | This is the original interface to this library from 0.1. 58 | The primary difference between it and `convert_model_with_sample_input` 59 | is that it accepts an extra parameter (`convert_fc`) and accepts an 60 | input size parameter as a tuple, rather than a sample input. 61 | 62 | :param model: the PyTorch CNN model 63 | :param input_size: a tuple representing the input size 64 | :param batch_size: the batch size 65 | :param model_name: the name of the model, which will become the name 66 | of the subdirectory of `save_dir` with the problem files 67 | :param save_dir: the directory to save the output in 68 | :param convert_fc: whether to convert fully connected layers 69 | :param exception_module_names: a list of fragments of module names 70 | to ignore (can be a prefix, suffix, or infix). 71 | """ 72 | logger.info( 73 | "converting {} in {} model ...".format( 74 | "nn.Conv2d" if not convert_fc else "nn.Conv2d and nn.Linear", 75 | model_name 76 | ) 77 | ) 78 | sample_input = torch.rand(2, *input_size).type(torch.FloatTensor) 79 | layer_data = _make_summary(model, sample_input, ignored_func=ignored_func) 80 | _convert_from_layer_data(layer_data, model_name, save_dir, exception_module_names, fuse=fuse) 81 | 82 | 83 | def _convert_from_layer_data(layer_data, model_name, save_dir, exception_module_names=(), fuse=False): 84 | outdir = os.path.join(save_dir, model_name) 85 | if not os.path.exists(outdir): 86 | os.makedirs(outdir) 87 | layer_data = [ 88 | p for p in layer_data if not any( 89 | e.lower() in p.name.lower() or 90 | e.lower() in p.__class__.__name__.lower() 91 | for e in exception_module_names)] 92 | if fuse: 93 | problems = [] 94 | for i in range(0, len(layer_data)): 95 | problem = layer_data[i] 96 | problems.append(problem.to_fused_yaml()) 97 | file_name = model_name + '.yaml' 98 | file_path = os.path.abspath(os.path.join(save_dir, model_name, file_name)) 99 | with open(file_path, 'w') as f: 100 | f.write(yaml.dump( 101 | { 102 | 'problem': problems 103 | } 104 | )) 105 | else: 106 | # make the problem file for each layer 107 | for i in range(0, len(layer_data)): 108 | problem = layer_data[i] 109 | file_name = '[layer' + str(i+1) + ']' + problem.name + '.yaml' 110 | file_path = os.path.abspath(os.path.join(save_dir, model_name, file_name)) 111 | with open(file_path, 'w') as f: 112 | f.write(yaml.dump(problem.to_yaml())) 113 | 114 | logger.info("conversion complete!\n") 115 | 116 | def _make_summary(model, sample_input, ignored_func): 117 | converter = Converter(fx.symbolic_trace(model), ignored_func=ignored_func) 118 | converter.run(sample_input) 119 | return converter.summary 120 | -------------------------------------------------------------------------------- /pytorch2timeloop/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Accelergy-Project/pytorch2timeloop-converter/5de0d6bb4e718619d64673b94a50a8640242a8fe/pytorch2timeloop/utils/__init__.py -------------------------------------------------------------------------------- /pytorch2timeloop/utils/binary_elementwise.yaml: -------------------------------------------------------------------------------- 1 | problem: 2 | shape: 3 | name: "add" 4 | dimensions: [] 5 | data-spaces: 6 | - name: Input1 7 | - name: Input2 8 | - name: Outputs 9 | read-write: True 10 | -------------------------------------------------------------------------------- /pytorch2timeloop/utils/converter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Definitions of forward hooks for various PyTorch layer types, to extract 3 | `LayerDescription`s from them during evaluation. 4 | 5 | For many layer types, such as 2D convolutions and self-attention 6 | mechanisms, the layer itself does not know all of the information needed 7 | to generate a Timeloop workload: for example, a convolutional layer does 8 | not explicitly define its input size. As a result, we need to extract 9 | this information while _evaluating_ the model. 10 | 11 | The easiest mechanism for doing this is a PyTorch forward hook. This 12 | file defines hooks for various layer types, with a primary interface 13 | consisting of the function `hook_for()`, which returns a hook for the 14 | given layer. 15 | 16 | To add support for a new layer type, add a new hook type and return it 17 | from hook_for() with the appropriate conditions. You may also need to 18 | add a new `LayerDescription` if the layer is very different from the 19 | ones that are already here. 20 | """ 21 | 22 | from functools import singledispatch 23 | import logging 24 | from typing import Optional, Callable, Any 25 | import operator 26 | 27 | import torch 28 | import torch.nn as nn 29 | import transformers.models.distilbert.modeling_distilbert 30 | 31 | from pytorch2timeloop.utils.layer_descriptions import ( 32 | ConvLayerDescription, 33 | MaxPoolLayerDescription, 34 | MatrixMatrixMultiplyLayerDescription, 35 | MatmulFuncDescription 36 | ) 37 | 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | @singledispatch 42 | def generate_description(module, 43 | input: torch.Tensor, 44 | output: torch.Tensor, 45 | name: str, 46 | ifmap_name: str): 47 | raise NotImplementedError(f'not implemented for {type(module)}') 48 | 49 | 50 | @generate_description.register(nn.Conv2d) 51 | def _(module, input, output, name, ifmap_name): 52 | description = ConvLayerDescription( 53 | name=name, 54 | g=module.groups, 55 | m=output.shape[1], 56 | w=input.shape[3], 57 | h=input.shape[2], 58 | c=input.shape[1], 59 | n=input.shape[0], 60 | s=module.kernel_size[1], 61 | r=module.kernel_size[0], 62 | w_pad=module.padding[1], 63 | h_pad=module.padding[0], 64 | w_stride=module.stride[1], 65 | h_stride=module.stride[0], 66 | ifmap_name=ifmap_name, 67 | filter_name=f'{name}_filter', 68 | ofmap_name=f'{name}_out' 69 | ) 70 | return description 71 | 72 | 73 | @generate_description.register(nn.MaxPool2d) 74 | def _(module, input, output, name, ifmap_name): 75 | if isinstance(module.kernel_size, int): 76 | kernel_size = (module.kernel_size, module.kernel_size) 77 | else: 78 | kernel_size = module.kernel_size 79 | if isinstance(module.stride, int): 80 | stride = (module.stride, module.stride) 81 | else: 82 | stride = module.stride 83 | if isinstance(module.padding, int): 84 | padding = (module.padding, module.padding) 85 | else: 86 | padding = module.padding 87 | 88 | description = MaxPoolLayerDescription( 89 | w=input.shape[3], 90 | h=input.shape[2], 91 | c=input.shape[1], 92 | s=kernel_size[1], 93 | r=kernel_size[0], 94 | w_stride=stride[1], 95 | h_stride=stride[0], 96 | w_pad=padding[1], 97 | h_pad=padding[0], 98 | n=input.shape[0], 99 | name=name, 100 | ifmap_name=ifmap_name, 101 | ofmap_name=f'{name}_out' 102 | ) 103 | 104 | return description 105 | 106 | 107 | @generate_description.register(nn.AdaptiveAvgPool2d) 108 | def _(module, input, output, name, ifmap_name): 109 | stride_w = input.shape[-1] // output.shape[-1] 110 | stride_h = input.shape[-2] // output.shape[-2] 111 | kernel_w = input.shape[-1] - (output.shape[-1]-1)*stride_w 112 | kernel_h = input.shape[-2] - (output.shape[-2]-1)*stride_h 113 | 114 | description = MaxPoolLayerDescription( 115 | w=input.shape[3], 116 | h=input.shape[2], 117 | c=input.shape[1], 118 | s=kernel_w, 119 | r=kernel_h, 120 | w_stride=stride_w, 121 | h_stride=stride_h, 122 | w_pad=0, 123 | h_pad=0, 124 | n=input.shape[0], 125 | name=name, 126 | ifmap_name=ifmap_name, 127 | ofmap_name=f'{name}_out' 128 | ) 129 | 130 | return description 131 | 132 | 133 | @generate_description.register(nn.Linear) 134 | def _(module, input, output, name, ifmap_name): 135 | description = ConvLayerDescription( 136 | g=1, 137 | w=1, 138 | h=1, 139 | c=module.in_features, 140 | m=module.out_features, 141 | s=1, 142 | r=1, 143 | w_stride=1, 144 | h_stride=1, 145 | w_pad=0, 146 | h_pad=0, 147 | n=input.shape[0], 148 | name=name, 149 | ifmap_name=ifmap_name, 150 | filter_name=f'{name}_filter', 151 | ofmap_name=f'{name}_out' 152 | ) 153 | return description 154 | 155 | 156 | def generate_matmul_func(input1, input2, output, 157 | name, input1_name, input2_name): 158 | if len(input1.shape) == 2 and len(input2.shape) == 2: 159 | description = MatmulFuncDescription( 160 | name = name, 161 | m = input1.shape[0], 162 | n = input2.shape[1], 163 | k = input1.shape[1], 164 | ifmap1_name = input1_name, 165 | ifmap2_name = input2_name, 166 | ofmap_name = f'{name}_out', 167 | extra_dims = tuple() 168 | ) 169 | elif len(input1.shape) > 2 and input1.shape[:-2] == input2.shape[:-2]: 170 | description = MatmulFuncDescription( 171 | name = name, 172 | m = input1.shape[0], 173 | n = input2.shape[1], 174 | k = input1.shape[1], 175 | ifmap1_name = input1_name, 176 | ifmap2_name = input2_name, 177 | ofmap_name = f'{name}_out', 178 | extra_dims = input1.shape[:-2] 179 | ) 180 | else: 181 | raise NotImplementedError( 182 | f'unimplemented for arg shapes {input1.shape}, {input2.shape}' 183 | ) 184 | 185 | return description 186 | -------------------------------------------------------------------------------- /pytorch2timeloop/utils/convolution.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | problem: 27 | shape: 28 | name: "CNN-Layer" 29 | dimensions: [ C, M, R, S, N, P, Q ] 30 | coefficients: 31 | - name: Wstride 32 | default: 1 33 | - name: Hstride 34 | default: 1 35 | - name: Wdilation 36 | default: 1 37 | - name: Hdilation 38 | default: 1 39 | data-spaces: 40 | - name: Weights 41 | projection: 42 | - [ [C] ] 43 | - [ [M] ] 44 | - [ [R] ] 45 | - [ [S] ] 46 | - name: Inputs 47 | projection: 48 | - [ [N] ] 49 | - [ [C] ] 50 | - [ [R, Wdilation], [P, Wstride] ] # SOP form: R*Wdilation + P*Wstride 51 | - [ [S, Hdilation], [Q, Hstride] ] # SOP form: S*Hdilation + Q*Hstride 52 | - name: Outputs 53 | projection: 54 | - [ [N] ] 55 | - [ [M] ] 56 | - [ [Q] ] 57 | - [ [P] ] 58 | read-write: True 59 | instance: 60 | C: 16 61 | M: 32 62 | N: 1 63 | P: 10 64 | Q: 10 65 | R: 5 66 | S: 5 67 | Wdilation: 1 68 | Wstride: 1 69 | Hdilation: 1 70 | Hstride: 1 71 | -------------------------------------------------------------------------------- /pytorch2timeloop/utils/depth_wise_convolution.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | problem: 27 | shape: 28 | name: "CNN-Layer" 29 | dimensions: [ C, R, S, N, P, Q ] 30 | coefficients: 31 | - name: Wstride 32 | default: 1 33 | - name: Hstride 34 | default: 1 35 | - name: Wdilation 36 | default: 1 37 | - name: Hdilation 38 | default: 1 39 | data-spaces: 40 | - name: Weights 41 | projection: 42 | - [ [C] ] 43 | - [ [R] ] 44 | - [ [S] ] 45 | - name: Inputs 46 | projection: 47 | - [ [N] ] 48 | - [ [C] ] 49 | - [ [R, Wdilation], [P, Wstride] ] # SOP form: R*Wdilation + P*Wstride 50 | - [ [S, Hdilation], [Q, Hstride] ] # SOP form: S*Hdilation + Q*Hstride 51 | - name: Outputs 52 | projection: 53 | - [ [N] ] 54 | - [ [C] ] 55 | - [ [Q] ] 56 | - [ [P] ] 57 | read-write: True 58 | instance: 59 | C: 16 60 | N: 1 61 | P: 10 62 | Q: 10 63 | R: 5 64 | S: 5 65 | Wdilation: 1 66 | Wstride: 1 67 | Hdilation: 1 68 | Hstride: 1 69 | -------------------------------------------------------------------------------- /pytorch2timeloop/utils/grouped_convolution.yaml: -------------------------------------------------------------------------------- 1 | problem: 2 | shape: 3 | name: "CNN-Layer" 4 | dimensions: [ C, M, R, S, N, P, Q ] 5 | coefficients: 6 | - name: Wstride 7 | default: 1 8 | - name: Hstride 9 | default: 1 10 | - name: Wdilation 11 | default: 1 12 | - name: Hdilation 13 | default: 1 14 | data-spaces: 15 | - name: Weights 16 | projection: 17 | - [ [G] ] 18 | - [ [C] ] 19 | - [ [M] ] 20 | - [ [R] ] 21 | - [ [S] ] 22 | - name: Inputs 23 | projection: 24 | - [ [N] ] 25 | - [ [G] ] 26 | - [ [C] ] 27 | - [ [R, Wdilation], [P, Wstride] ] # SOP form: R*Wdilation + P*Wstride 28 | - [ [S, Hdilation], [Q, Hstride] ] # SOP form: S*Hdilation + Q*Hstride 29 | - name: Outputs 30 | projection: 31 | - [ [N] ] 32 | - [ [G] ] 33 | - [ [M] ] 34 | - [ [Q] ] 35 | - [ [P] ] 36 | read-write: True 37 | instance: 38 | G: 2 39 | C: 16 40 | M: 32 41 | N: 1 42 | P: 10 43 | Q: 10 44 | R: 5 45 | S: 5 46 | Wdilation: 1 47 | Wstride: 1 48 | Hdilation: 1 49 | Hstride: 1 50 | -------------------------------------------------------------------------------- /pytorch2timeloop/utils/hooks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Definitions of forward hooks for various PyTorch layer types, to extract 3 | `LayerDescription`s from them during evaluation. 4 | 5 | For many layer types, such as 2D convolutions and self-attention 6 | mechanisms, the layer itself does not know all of the information needed 7 | to generate a Timeloop workload: for example, a convolutional layer does 8 | not explicitly define its input size. As a result, we need to extract 9 | this information while _evaluating_ the model. 10 | 11 | The easiest mechanism for doing this is a PyTorch forward hook. This 12 | file defines hooks for various layer types, with a primary interface 13 | consisting of the function `hook_for()`, which returns a hook for the 14 | given layer. 15 | 16 | To add support for a new layer type, add a new hook type and return it 17 | from hook_for() with the appropriate conditions. You may also need to 18 | add a new `LayerDescription` if the layer is very different from the 19 | ones that are already here. 20 | """ 21 | 22 | import logging 23 | from typing import Optional, Callable, Any 24 | import operator 25 | 26 | import torch 27 | import torch.nn as nn 28 | import transformers.models.distilbert.modeling_distilbert 29 | from torch.fx import symbolic_trace 30 | 31 | from pytorch2timeloop.utils.layer_descriptions import ( 32 | DepthWiseConvLayerDescription, 33 | ConvLayerDescription, 34 | MatrixMatrixMultiplyLayerDescription 35 | ) 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | def _null_hook(summary, batch_size): 41 | """ 42 | An empty hook, for layers that we want to ignore without error (like ReLU) 43 | """ 44 | def hook(module, input, output): 45 | return 46 | return hook 47 | 48 | 49 | def _conv_hook(summary, batch_size, 50 | name: str=None, ifmap_name: str=None): 51 | """ 52 | A hook for convolutional (including depth-wise convolutional) layers, based on nn.Conv2d. 53 | 54 | :param summary: the summary list we are adding to 55 | :param batch_size: the input batch size 56 | :return: a PyTorch module forward hook to collect a `LayerDescription` about this convolutional layer 57 | """ 58 | if name is None: 59 | name = 'conv_layer' 60 | def hook(module, input, output): 61 | input_shape = input[0].size() 62 | if module.groups > 1 and module.groups == module.in_channels: 63 | # Depth-wise convolution 64 | description = DepthWiseConvLayerDescription( 65 | w=input_shape[2], 66 | h=input_shape[3], 67 | c=module.in_channels, 68 | s=module.kernel_size[0], 69 | r=module.kernel_size[1], 70 | w_stride=module.stride[0], 71 | h_stride=module.stride[1], 72 | w_pad=module.padding[0], 73 | h_pad=module.padding[1], 74 | n=batch_size, 75 | name=name, 76 | ifmap_name=ifmap_name, 77 | filter_name=f'{name}.filter', 78 | ofmap_name=f'{name}.out' 79 | ) 80 | else: 81 | description = ConvLayerDescription( 82 | w=input_shape[2], 83 | h=input_shape[3], 84 | c=module.in_channels, 85 | m=module.out_channels, 86 | s=module.kernel_size[0], 87 | r=module.kernel_size[1], 88 | w_stride=module.stride[0], 89 | h_stride=module.stride[1], 90 | w_pad=module.padding[0], 91 | h_pad=module.padding[1], 92 | n=batch_size, 93 | name=name, 94 | ifmap_name=ifmap_name, 95 | filter_name=f'{name}.filter', 96 | ofmap_name=f'{name}.out' 97 | ) 98 | summary.append(description) 99 | 100 | return hook 101 | 102 | 103 | def _linear_hook(summary, batch_size, name: str=None, ifmap_name: str=None): 104 | """ 105 | A hook for linear (i.e., fully connected) layers, based on nn.Linear. 106 | 107 | :param summary: the summary list we are adding to 108 | :param batch_size: the input batch size 109 | :return: a PyTorch module forward hook to collect a `LayerDescription` about this fully connected layer 110 | """ 111 | if name is None: 112 | name = 'linear' 113 | def hook(module, input, output): 114 | input_size = input[0].size() 115 | assert input_size[1] >= 0 116 | description = ConvLayerDescription( 117 | w=1, 118 | h=1, 119 | c=module.in_features, 120 | m=module.out_features, 121 | s=1, 122 | r=1, 123 | w_stride=1, 124 | h_stride=1, 125 | w_pad=0, 126 | h_pad=0, 127 | n=batch_size, 128 | name=name, 129 | ifmap_name=ifmap_name, 130 | filter_name=f'{name}.filter', 131 | ofmap_name=f'{name}.out' 132 | ) 133 | summary.append(description) 134 | 135 | return hook 136 | 137 | 138 | def _layer_norm_hook(summary, batch_size, 139 | name: str=None, ifmap_name: str=None): 140 | """ 141 | A hook for layer norm layers, based on nn.LayerNorm. 142 | 143 | :param summary: the summary list we are adding to 144 | :param batch_size: the input batch size 145 | :return: a PyTorch module forward hook to collect a `LayerDescription` about this layer norm layer. 146 | """ 147 | if name is None: 148 | name = 'layer_norm' 149 | def hook(module, input, output): 150 | if module.elementwise_affine: 151 | input_shape = input[0].size() 152 | assert input_shape[1] >= 0 153 | description = ConvLayerDescription( 154 | w=input_shape[2], 155 | h=1, 156 | c=1, 157 | m=1, 158 | s=input_shape[2], 159 | r=1, 160 | w_stride=1, 161 | h_stride=1, 162 | w_pad=0, 163 | h_pad=0, 164 | n=batch_size * input_shape[1], 165 | name=name 166 | ) 167 | summary.append(description) 168 | 169 | return hook 170 | 171 | 172 | def _multihead_self_attention(summary, batch_size, 173 | name: str=None, ifmap_name: str=None): 174 | """ 175 | A hook for multi-head self-attention layers. 176 | 177 | Currently, this is designed only to extract data from the self-attention layer defined in 178 | `transformers.models.bert.modeling_bert.BertSelfAttention`. It should be quite simple to adapt to other 179 | transformers with similar self-attention mechanisms, though. 180 | 181 | :param summary: the summary list we are adding to 182 | :param batch_size: the input batch size 183 | :return: a PyTorch module forward hook to collect a `LayerDescription` about this multi-head self-attention layer 184 | """ 185 | if name is None: 186 | name = 'attention' 187 | def hook(module, input, output): 188 | assert input != () 189 | x = input[0] 190 | head_size = module.attention_head_size 191 | sequence_length = x.shape[1] 192 | scores = MatrixMatrixMultiplyLayerDescription( 193 | m=sequence_length, 194 | k=head_size, 195 | n=sequence_length, 196 | batch_size=batch_size * module.num_attention_heads, 197 | name=f'{name}_scores' 198 | ) 199 | context = MatrixMatrixMultiplyLayerDescription( 200 | m=sequence_length, 201 | k=sequence_length, 202 | n=head_size, 203 | batch_size=batch_size * module.num_attention_heads, 204 | name=f'{name}_context' 205 | ) 206 | summary.append(scores) 207 | summary.append(context) 208 | 209 | return hook 210 | 211 | 212 | """ 213 | Layer types that should be considered "null ops" (i.e., that should not 214 | produce a layer file). 215 | 216 | This can be safely extended to reduce the amount of noise printed to the 217 | terminal when generating layer files. 218 | """ 219 | null_ops = ( 220 | nn.Dropout, 221 | nn.Embedding, 222 | nn.MaxPool2d, 223 | nn.AdaptiveAvgPool2d, 224 | nn.Sequential, 225 | nn.ModuleList, 226 | transformers.models.bert.modeling_bert.BertSelfOutput, 227 | transformers.models.bert.modeling_bert.BertEmbeddings, 228 | transformers.models.bert.modeling_bert.BertIntermediate, 229 | transformers.models.bert.modeling_bert.BertOutput, 230 | transformers.models.bert.modeling_bert.BertAttention, 231 | transformers.models.bert.modeling_bert.BertLayer, 232 | transformers.models.bert.modeling_bert.BertEncoder, 233 | transformers.models.bert.modeling_bert.BertPooler, 234 | transformers.models.bert.modeling_bert.BertModel, 235 | transformers.models.bert.modeling_bert.BertForSequenceClassification, 236 | ) 237 | 238 | 239 | def hook_for(module: nn.Module, summary: list, batch_size: int, 240 | convert_fc=False, name: str=None, module_args: tuple=None) \ 241 | -> Optional[Callable[[nn.Module, Any, Any], None]]: 242 | """ 243 | Return the hook, if any, for the given layer type. 244 | 245 | The hook will append a `LayerDescription` to the given summary list 246 | when the model containing `module` is executed. 247 | 248 | :param module: a nn.Module to generate a hook for 249 | :param summary: the summary list we are adding to 250 | :param batch_size: the input batch size 251 | :param convert_fc: whether to convert the layer if it is fully 252 | connected 253 | :return: a hook function that can be used with `register_forward_hook()`, 254 | or `None` if it does not exist 255 | """ 256 | if isinstance(module, nn.Linear) and convert_fc: 257 | return _linear_hook(summary, batch_size, name) 258 | elif isinstance(module, nn.Conv2d): 259 | return _conv_hook(summary, 260 | batch_size, 261 | name, 262 | f'{module_args[0].name}.out') 263 | elif isinstance(module, null_ops): 264 | return _null_hook(summary, batch_size) 265 | elif isinstance(module, nn.LayerNorm): 266 | if module.elementwise_affine: 267 | return _layer_norm_hook(summary, batch_size, name) 268 | elif isinstance(module, transformers.models.bert.modeling_bert.BertSelfAttention): 269 | return _multihead_self_attention(summary, batch_size, name) 270 | 271 | logger.warning("unknown module type %s", module.__class__) 272 | -------------------------------------------------------------------------------- /pytorch2timeloop/utils/interpreter.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import operator 4 | from typing import Dict, Tuple, Union 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | import torch.fx as fx 10 | 11 | from .converter import generate_description, generate_matmul_func 12 | from pytorch2timeloop.utils.layer_descriptions import ( 13 | BinaryElementwiseFuncDescription, 14 | SoftmaxFuncDescription, 15 | MaxPoolLayerDescription, 16 | ViewFuncDescription 17 | ) 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class Converter(fx.Interpreter): 23 | DEFAULT_BYPASSED_MODULES = ( 24 | nn.BatchNorm2d, 25 | nn.Dropout, 26 | # Elementwise activations 27 | nn.Hardsigmoid, 28 | nn.Hardswish, 29 | nn.ReLU, 30 | nn.ReLU6 31 | ) 32 | 33 | DEFAULT_IGNORED_MODULES = tuple() 34 | 35 | UNARY_ELEMENTWISE_FUNC = [ 36 | math.sqrt, 37 | F.relu, 38 | F.relu6 39 | ] 40 | 41 | BINARY_ELEMENTWISE_FUNC = [ 42 | operator.add, 43 | torch.add, 44 | operator.sub, 45 | torch.sub, 46 | operator.mul, 47 | torch.mul, 48 | operator.truediv, 49 | torch.div 50 | ] 51 | 52 | DEFAULT_IGNORED_FUNC = [] 53 | 54 | SOFTMAX = [ 55 | torch.softmax, 56 | F.softmax 57 | ] 58 | 59 | def __init__(self, module, garbage_collect_values=True, 60 | bypassed_modules=None, ignored_modules=None, 61 | ignored_func=None): 62 | super().__init__(module, garbage_collect_values) 63 | self.name_to_module = dict(module.named_modules()) 64 | self.tensor_sizes = {} 65 | self.summary = [] 66 | 67 | if bypassed_modules is None: 68 | bypassed_modules = Converter.DEFAULT_BYPASSED_MODULES 69 | self.bypassed_modules = bypassed_modules 70 | 71 | if ignored_modules is None: 72 | ignored_modules = Converter.DEFAULT_IGNORED_MODULES 73 | self.ignored_modules = ignored_modules 74 | 75 | if ignored_func is None: 76 | ignored_func = Converter.DEFAULT_IGNORED_FUNC 77 | self.ignored_func = ignored_func 78 | 79 | self.bypassed_arg_remap = {} 80 | 81 | def run_node(self, n): 82 | name = n.name 83 | original_args = n.args 84 | with self._set_current_node(n): 85 | args, kwargs = self.fetch_args_kwargs_from_env(n) 86 | if n.op == 'call_module' or n.op == 'call_function': 87 | return getattr(self, n.op)(n.target, args, kwargs, name, 88 | original_args) 89 | return getattr(self, n.op)(n.target, args, kwargs) 90 | 91 | def call_module(self, target, args: Tuple, kwargs: Dict, name: str, 92 | original_args: tuple): 93 | result = super().call_module(target, args, kwargs) 94 | module = self.name_to_module[target] 95 | 96 | if isinstance(module, self.ignored_modules): 97 | logger.warning('ignoring module %s[type=%s]', name, module) 98 | return result 99 | 100 | if isinstance(module, self.bypassed_modules): 101 | self.bypassed_arg_remap[f'{name}_out'] = \ 102 | f'{original_args[0].name}_out' 103 | return result 104 | 105 | arg_name = f'{original_args[0].name}_out' 106 | while arg_name in self.bypassed_arg_remap: 107 | arg_name = self.bypassed_arg_remap[arg_name] 108 | 109 | description = generate_description(module, args[0], result, name, 110 | arg_name) 111 | 112 | self.summary.append(description) 113 | 114 | return result 115 | 116 | def call_function(self, target, args, kwargs, name: str, 117 | original_args: tuple): 118 | result = super().call_function(target, args, kwargs) 119 | 120 | arg_names = [] 121 | for arg in original_args: 122 | try: 123 | arg_names.append(f'{arg.name}_out') 124 | except: 125 | arg_names.append(None) 126 | 127 | for i, n in enumerate(arg_names): 128 | if n is not None: 129 | while n in self.bypassed_arg_remap: 130 | n = self.bypassed_arg_remap[n] 131 | arg_names[i] = n 132 | 133 | if target in self.ignored_func: 134 | logger.warning('ignoring func %s[type=%s]', name, target) 135 | pass 136 | elif target in Converter.BINARY_ELEMENTWISE_FUNC: 137 | if isinstance(args[1], torch.Tensor): 138 | description = BinaryElementwiseFuncDescription( 139 | ifmap1_shape = args[0].shape, 140 | ifmap2_shape = args[1].shape, 141 | ofmap_shape = result.shape, 142 | ifmap1_name = arg_names[0], 143 | ifmap2_name = arg_names[1], 144 | ofmap_name = f'{name}_out', 145 | name = name 146 | ) 147 | self.summary.append(description) 148 | elif target == F.adaptive_avg_pool2d: 149 | stride_w = args[0].shape[-1] // result.shape[-1] 150 | stride_h = args[0].shape[-2] // result.shape[-2] 151 | kernel_w = args[0].shape[-1] - (result.shape[-1]-1)*stride_w 152 | kernel_h = args[0].shape[-2] - (result.shape[-2]-1)*stride_h 153 | 154 | description = MaxPoolLayerDescription( 155 | w=args[0].shape[3], 156 | h=args[0].shape[2], 157 | c=args[0].shape[1], 158 | s=kernel_w, 159 | r=kernel_h, 160 | w_stride=stride_w, 161 | h_stride=stride_h, 162 | w_pad=0, 163 | h_pad=0, 164 | n=args[0].shape[0], 165 | name=name, 166 | ifmap_name=arg_names[0], 167 | ofmap_name=f'{name}_out' 168 | ) 169 | self.summary.append(description) 170 | elif target == torch.matmul: 171 | description = generate_matmul_func( 172 | input1 = args[0], 173 | input2 = args[1], 174 | output = result, 175 | name = name, 176 | input1_name = arg_names[0], 177 | input2_name = arg_names[1] 178 | ) 179 | self.summary.append(description) 180 | elif target in Converter.SOFTMAX: 181 | description = SoftmaxFuncDescription( 182 | ifmap_shape = args[0].shape, 183 | ofmap_shape = result.shape, 184 | ifmap_name = arg_names[0], 185 | ofmap_name = f'{name}_out', 186 | name = name, 187 | softmax_dim = kwargs['dim'] 188 | ) 189 | self.summary.append(description) 190 | elif target == torch.flatten: 191 | description = ViewFuncDescription( 192 | name=name, 193 | ifmap_shape=args[0].shape, 194 | ofmap_shape=result.shape, 195 | ifmap_name=arg_names[0], 196 | ofmap_name=f'{name}_out' 197 | ) 198 | self.summary.append(description) 199 | elif target in Converter.UNARY_ELEMENTWISE_FUNC: 200 | self.bypassed_arg_remap[f'{name}.out'] = \ 201 | f'{original_args[0].name}.out' 202 | pass 203 | else: 204 | logger.error('unknwown function %s[type=%s]', name, target) 205 | raise NotImplementedError() 206 | 207 | return result 208 | -------------------------------------------------------------------------------- /pytorch2timeloop/utils/layer_descriptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | A layer description is a representation of a Timeloop workload as a Python dataclass. 3 | 4 | The dataclass representation is a lot more convenient than YAML when working in Python: for example, we can easily 5 | define helper properties (like p and q for convolutional layers). 6 | 7 | A layer description may use any YAML template (the name of the file that will be used is given by the `problem_template` 8 | attribute). Furthermore, any number of descriptions may use the same template but map the parameters differently. 9 | """ 10 | from functools import reduce 11 | import string 12 | from typing import Optional, Sequence 13 | import pkgutil 14 | 15 | import yaml 16 | from dataclasses import dataclass 17 | 18 | 19 | @dataclass 20 | class LayerDescription: 21 | name: str 22 | 23 | def get_workload(self): 24 | f = pkgutil.get_data("pytorch2timeloop", 25 | f"utils/{self.problem_template}.yaml") 26 | return yaml.load(f, Loader=yaml.SafeLoader) 27 | 28 | def to_yaml(self): 29 | config = self.get_workload() 30 | config['problem']['shape']['name'] = self.name 31 | return config 32 | 33 | 34 | @dataclass 35 | class ConvLayerDescription(LayerDescription): 36 | g: int 37 | m: int 38 | w: int 39 | h: int 40 | c: int 41 | n: int 42 | s: int 43 | r: int 44 | w_pad: int 45 | h_pad: int 46 | w_stride: int 47 | h_stride: int 48 | ifmap_name: str 49 | filter_name: str 50 | ofmap_name: str 51 | 52 | @property 53 | def q(self): 54 | return int((self.w - self.s + 2 * self.w_pad) / self.w_stride) + 1 55 | 56 | @property 57 | def p(self): 58 | return int((self.h - self.r + 2 * self.h_pad) / self.h_stride) + 1 59 | 60 | def to_yaml(self): 61 | # dims = list(map(lambda n: self.name + '_' + n, 'GCMRSNPQ')) 62 | dims = list(map(lambda n: n, 'GCMRSNPQ')) 63 | (dim_G, dim_C, dim_M, dim_R, dim_S, dim_N, dim_P, dim_Q) = dims 64 | 65 | in_channels_per_group = self.c // self.g 66 | out_channels_per_group = self.m // self.g 67 | 68 | config = { 69 | 'shape': { 70 | 'name': self.name, 71 | 'dimensions': dims, 72 | 'coefficients': [ 73 | { 74 | 'name': 'Cgroup', 75 | 'default': in_channels_per_group 76 | }, 77 | { 78 | 'name': 'Mgroup', 79 | 'default': out_channels_per_group 80 | }, 81 | { 82 | 'name': 'Hstride', 83 | 'default': self.h_stride 84 | }, 85 | { 86 | 'name': 'Wstride', 87 | 'default': self.w_stride 88 | } 89 | ], 90 | 'data-spaces': [ 91 | { 92 | 'name': 'Weights', 93 | 'projection': [ 94 | [[dim_G]], 95 | [[dim_C]], 96 | [[dim_M]], 97 | [[dim_R]], 98 | [[dim_S]] 99 | ] 100 | }, 101 | { 102 | 'name': 'Inputs', 103 | 'projection': [ 104 | [[dim_N]], 105 | [[dim_G, 'Cgroup'], [f'{dim_C}']], 106 | [[dim_R], [dim_P, 'Hstride']], 107 | [[dim_S], [dim_Q, 'Wstride']] 108 | ] 109 | }, 110 | { 111 | 'name': 'Outputs', 112 | 'projection': [ 113 | [[dim_N]], 114 | [[dim_G, 'Mgroup'], [dim_M]], 115 | [[dim_P]], 116 | [[dim_Q]] 117 | ], 118 | 'read-write': True 119 | } 120 | ] 121 | }, 122 | 'instance': { 123 | 'G': self.g, 124 | 'C': in_channels_per_group, 125 | 'M': out_channels_per_group, 126 | 'N': self.n, 127 | 'R': self.r, 128 | 'S': self.s, 129 | 'P': self.p, 130 | 'Q': self.q 131 | } 132 | } 133 | 134 | return {'problem': config} 135 | 136 | def to_fused_yaml(self): 137 | dims = list(map(lambda n: self.name + '_' + n, 'GCMRSNPQ')) 138 | (dim_G, dim_C, dim_M, dim_R, dim_S, dim_N, dim_P, dim_Q) = dims 139 | 140 | in_channels_per_group = self.c // self.g 141 | out_channels_per_group = self.m // self.g 142 | 143 | config = { 144 | 'shape': { 145 | 'name': self.name, 146 | 'dimensions': dims, 147 | 'data-spaces': [ 148 | { 149 | 'name': self.filter_name, 150 | 'projection': \ 151 | f'[ {dim_G}, {dim_C}, {dim_M}, {dim_R}, {dim_S} ]' 152 | }, 153 | { 154 | 'name': self.ifmap_name, 155 | 'projection': ( 156 | f'[ {dim_N}, ' 157 | f'{dim_G}*{in_channels_per_group} + {dim_C}, ' 158 | f'{dim_R} + {dim_P}*{self.h_stride}, ' 159 | f'{dim_S} + {dim_Q}*{self.w_stride} ]' 160 | ) 161 | }, 162 | { 163 | 'name': self.ofmap_name, 164 | 'projection': ( 165 | f'[ {dim_N}, ' 166 | f'{dim_G}*{out_channels_per_group} + {dim_M}, ' 167 | f'{dim_P}, ' 168 | f'{dim_Q} ]' 169 | ), 170 | 'read-write': True 171 | } 172 | ] 173 | }, 174 | 'instance': ( 175 | f'0 <= {dim_G} < {self.g} and ' 176 | f'0 <= {dim_C} < {in_channels_per_group} and ' 177 | f'0 <= {dim_M} < {out_channels_per_group} and ' 178 | f'0 <= {dim_N} < {self.n} and ' 179 | f'0 <= {dim_P} < {self.p} and ' 180 | f'0 <= {dim_Q} < {self.q} and ' 181 | f'0 <= {dim_R} < {self.r} and ' 182 | f'0 <= {dim_S} < {self.s}' 183 | ) 184 | } 185 | 186 | return config 187 | 188 | 189 | @dataclass 190 | class MaxPoolLayerDescription(LayerDescription): 191 | @property 192 | def q(self): 193 | return int((self.w - self.s + 2 * self.w_pad) / self.w_stride) + 1 194 | 195 | @property 196 | def p(self): 197 | return int((self.h - self.r + 2 * self.h_pad) / self.h_stride) + 1 198 | 199 | problem_template = 'pool' 200 | 201 | w: int 202 | h: int 203 | c: int 204 | n: int 205 | s: int 206 | r: int 207 | w_pad: int 208 | h_pad: int 209 | w_stride: int 210 | h_stride: int 211 | ifmap_name: str 212 | ofmap_name: str 213 | 214 | def to_yaml(self): 215 | config = super().to_yaml() 216 | config['problem']['instance']['R'] = self.r 217 | config['problem']['instance']['S'] = self.s 218 | config['problem']['instance']['P'] = self.p 219 | config['problem']['instance']['Q'] = self.q 220 | config['problem']['instance']['C'] = self.c 221 | config['problem']['instance']['N'] = self.n 222 | config['problem']['instance']['Wstride'] = self.w_stride 223 | config['problem']['instance']['Hstride'] = self.h_stride 224 | 225 | # for dspace in config['problem']['shape']['data-spaces']: 226 | # if dspace['name'] == 'Inputs': 227 | # dspace['name'] = self.ifmap_name 228 | # elif dspace['name'] == 'Outputs': 229 | # dspace['name'] = self.ofmap_name 230 | return config 231 | 232 | def to_fused_yaml(self): 233 | dims = list(map(lambda n: self.name + '_' + n, 'CRSNPQ')) 234 | (dim_C, dim_R, dim_S, dim_N, dim_P, dim_Q) = dims 235 | 236 | config = { 237 | 'shape': { 238 | 'name': self.name, 239 | 'dimensions': dims, 240 | 'data-spaces': [ 241 | { 242 | 'name': self.ifmap_name, 243 | 'projection': ( 244 | f'[ {dim_N}, ' 245 | f'{dim_C}, ' 246 | f'{dim_R} + {dim_P}*{self.h_stride}, ' 247 | f'{dim_S} + {dim_Q}*{self.w_stride} ]' 248 | ) 249 | }, 250 | { 251 | 'name': self.ofmap_name, 252 | 'projection': ( 253 | f'[ {dim_N}, ' 254 | f'{dim_C}, ' 255 | f'{dim_P}, ' 256 | f'{dim_Q} ]' 257 | ), 258 | 'read-write': True 259 | } 260 | ] 261 | }, 262 | 'instance': ( 263 | f'0 <= {dim_C} < {self.c} and ' 264 | f'0 <= {dim_N} < {self.n} and ' 265 | f'0 <= {dim_P} < {self.p} and ' 266 | f'0 <= {dim_Q} < {self.q} and ' 267 | f'0 <= {dim_R} < {self.r} and ' 268 | f'0 <= {dim_S} < {self.s}' 269 | ) 270 | } 271 | return config 272 | 273 | 274 | @dataclass 275 | class MatrixMatrixMultiplyLayerDescription(LayerDescription): 276 | name: Optional[str] 277 | problem_template = "convolution" 278 | m: int 279 | n: int 280 | k: int 281 | batch_size: int 282 | 283 | def to_yaml(self): 284 | config = super().to_yaml() 285 | config['problem']['instance']['R'] = 1 286 | config['problem']['instance']['S'] = self.k 287 | config['problem']['instance']['P'] = self.m 288 | config['problem']['instance']['Q'] = 1 289 | config['problem']['instance']['C'] = 1 290 | config['problem']['instance']['M'] = self.n 291 | config['problem']['instance']['N'] = self.batch_size 292 | config['problem']['instance']['Wstride'] = 1 293 | config['problem']['instance']['Hstride'] = 1 294 | config['problem']['shape']['name'] = self.name 295 | return config 296 | 297 | 298 | @dataclass 299 | class BinaryElementwiseFuncDescription(LayerDescription): 300 | problem_template='binary_elementwise' 301 | ifmap1_shape: Sequence 302 | ifmap2_shape: Sequence 303 | ofmap_shape: Sequence 304 | ifmap1_name: str 305 | ifmap2_name: str 306 | ofmap_name: str 307 | 308 | def to_yaml(self): 309 | if len(self.ifmap1_shape) < len(self.ofmap_shape): 310 | n_missing_dims = len(self.ofmap_shape) - len(self.ifmap1_shape) 311 | self.ifmap1_shape = tuple( 312 | [1]*n_missing_dims + list(self.ifmap1_shape) 313 | ) 314 | if len(self.ifmap2_shape) < len(self.ofmap_shape): 315 | n_missing_dims = len(self.ofmap_shape) - len(self.ifmap2_shape) 316 | self.ifmap2_shape = tuple( 317 | [1]*n_missing_dims + list(self.ifmap2_shape) 318 | ) 319 | assert(len(self.ifmap1_shape) == len(self.ofmap_shape)) 320 | assert(len(self.ifmap2_shape) == len(self.ofmap_shape)) 321 | 322 | config = super().to_yaml() 323 | 324 | dims = list(string.ascii_uppercase[:len(self.ofmap_shape)]) 325 | 326 | for dspace in config['problem']['shape']['data-spaces']: 327 | if dspace['name'] == 'Input1': 328 | # dspace['name'] = self.ifmap1_name 329 | dspace['projection'] = [] 330 | for d, size in zip(dims, self.ifmap1_shape): 331 | if size > 1: 332 | dspace['projection'].append([[d]]) 333 | elif dspace['name'] == 'Input2': 334 | # dspace['name'] = self.ifmap2_name 335 | dspace['projection'] = [] 336 | for d, size in zip(dims, self.ifmap2_shape): 337 | if size > 1: 338 | dspace['projection'].append([[d]]) 339 | elif dspace['name'] == 'Outputs': 340 | # dspace['name'] = self.ofmap_name 341 | dspace['projection'] = list(map( 342 | lambda d: [[d]], 343 | dims 344 | )) 345 | 346 | config['problem']['shape']['dimensions'] = dims 347 | 348 | config['problem']['instance'] = {} 349 | for dim, size in zip(dims, self.ifmap1_shape): 350 | config['problem']['instance'][dim] = size 351 | 352 | return config 353 | 354 | def to_fused_yaml(self): 355 | if len(self.ifmap1_shape) < len(self.ofmap_shape): 356 | n_missing_dims = len(self.ofmap_shape) - len(self.ifmap1_shape) 357 | self.ifmap1_shape = tuple( 358 | [1]*n_missing_dims + list(self.ifmap1_shape) 359 | ) 360 | if len(self.ifmap2_shape) < len(self.ofmap_shape): 361 | n_missing_dims = len(self.ofmap_shape) - len(self.ifmap2_shape) 362 | self.ifmap2_shape = tuple( 363 | [1]*n_missing_dims + list(self.ifmap2_shape) 364 | ) 365 | assert(len(self.ifmap1_shape) == len(self.ofmap_shape)) 366 | assert(len(self.ifmap2_shape) == len(self.ofmap_shape)) 367 | 368 | dims = list(string.ascii_uppercase[:len(self.ofmap_shape)]) 369 | bounds = [] 370 | for dim_name, dim_size in zip(dims, self.ifmap1_shape): 371 | bounds.append(f'0 <= {dim_name} < {dim_size}') 372 | 373 | config = { 374 | 'shape': { 375 | 'name': self.name, 376 | 'dimensions': dims, 377 | 'data-spaces': [ 378 | { 379 | 'name': self.ifmap1_name, 380 | 'projection': '[ ' + ', '.join(dims) + ' ]' 381 | }, 382 | { 383 | 'name': self.ifmap2_name, 384 | 'projection': '[ ' + ', '.join(dims) + ' ]' 385 | }, 386 | { 387 | 'name': self.ofmap_name, 388 | 'projection': '[ ' + ', '.join(dims) + ' ]', 389 | 'read-write': True 390 | } 391 | ] 392 | }, 393 | 'instance': ' and '.join(bounds) 394 | } 395 | 396 | return config 397 | 398 | 399 | @dataclass 400 | class MatmulFuncDescription(LayerDescription): 401 | problem_template = "matmul" 402 | m: int 403 | n: int 404 | k: int 405 | ifmap1_name: str 406 | ifmap2_name: str 407 | ofmap_name: str 408 | extra_dims: Optional[tuple] = None 409 | 410 | def to_yaml(self): 411 | config = super().to_yaml() 412 | 413 | if self.extra_dims is not None: 414 | dims = tuple(string.ascii_uppercase[:len(self.extra_dims)]) 415 | else: 416 | dims = tuple() 417 | self.extra_dims = tuple() 418 | 419 | for dspace in config['problem']['shape']['data-spaces']: 420 | # if dspace['name'] == 'Input1': 421 | # dspace['name'] = self.ifmap1_name 422 | # elif dspace['name'] == 'Input2': 423 | # dspace['name'] = self.ifmap2_name 424 | # elif dspace['name'] == 'Outputs': 425 | # dspace['name'] = self.ofmap_name 426 | proj_dims = list(map(lambda d: [[d]], dims)) 427 | dspace['projection'] = proj_dims + dspace['projection'] 428 | 429 | config['problem']['instance']['K'] = self.k 430 | config['problem']['instance']['M'] = self.m 431 | config['problem']['instance']['N'] = self.n 432 | config['problem']['shape']['name'] = self.name 433 | 434 | for dim, size in zip(dims, self.extra_dims): 435 | config['problem']['instance'][dim] = size 436 | 437 | return config 438 | 439 | 440 | @dataclass 441 | class SoftmaxFuncDescription(LayerDescription): 442 | problem_template = 'softmax' 443 | ifmap_shape: tuple 444 | ofmap_shape: tuple 445 | ifmap_name: str 446 | ofmap_name: str 447 | softmax_dim: int 448 | 449 | def to_yaml(self): 450 | config = super().to_yaml() 451 | 452 | dims = tuple(string.ascii_uppercase[:len(self.ifmap_shape)+1]) 453 | 454 | for dspace in config['problem']['shape']['data-spaces']: 455 | if dspace['name'] == 'Input': 456 | # dspace['name'] = self.ifmap_name 457 | dspace['projection'] = list(map( 458 | lambda d: [[d]], 459 | dims[:-1] 460 | )) 461 | elif dspace['name'] == 'Output': 462 | # dspace['name'] = self.ofmap_name 463 | dspace['projection'] = list(map( 464 | lambda d: [[d]], 465 | dims[:-1] 466 | )) 467 | dspace['projection'][self.softmax_dim] = [[dims[-1]]] 468 | 469 | instance = {} 470 | for dim, size in zip(dims[:-1], self.ifmap_shape): 471 | instance[dim] = size 472 | instance[dims[-1]] = self.ofmap_shape[self.softmax_dim] 473 | config['problem']['instance'] = instance 474 | 475 | return config 476 | 477 | 478 | @dataclass 479 | class ViewFuncDescription(LayerDescription): 480 | problem_template = 'view' 481 | ifmap_shape: tuple 482 | ofmap_shape: tuple 483 | ifmap_name: str 484 | ofmap_name: str 485 | 486 | def to_yaml(self): 487 | raise NotImplementedError('cannot be implemented in old Timeloop spec') 488 | 489 | def to_fused_yaml(self): 490 | product = lambda l: reduce(lambda x, y: x*y, l) 491 | assert(product(self.ifmap_shape) == product(self.ofmap_shape)) 492 | 493 | n_ofmap_dims = len(self.ofmap_shape) 494 | ofmap_dims = list(string.ascii_uppercase[:n_ofmap_dims]) 495 | 496 | bounds = [] 497 | for dim_name, dim_size in zip(ofmap_dims, self.ofmap_shape): 498 | bounds.append(f'0 <= {dim_name} < {dim_size}') 499 | 500 | terms = [] 501 | cur_size = 1 502 | for dim, dim_size in reversed(list(zip(ofmap_dims, self.ofmap_shape))): 503 | terms.append(f'{dim}*{cur_size}') 504 | cur_size *= dim_size 505 | linearized_ofmaps = ' + '.join(terms) 506 | 507 | ifmap_terms = [] 508 | cur_size = 1 509 | for dim_size in reversed(self.ifmap_shape): 510 | ifmap_terms.append( 511 | f'floor({linearized_ofmaps}/{cur_size})%{dim_size}' 512 | ) 513 | cur_size *= dim_size 514 | ifmap_terms.reverse() 515 | 516 | config = { 517 | 'shape': { 518 | 'name': self.name, 519 | 'dimensions': ofmap_dims, 520 | 'data-spaces': [ 521 | { 522 | 'name': self.ifmap_name, 523 | 'projection': '[ ' + ', '.join(ifmap_terms) + ' ]' 524 | }, 525 | { 526 | 'name': self.ofmap_name, 527 | 'projection': '[ ' + ', '.join(ofmap_dims) + ' ]', 528 | 'read-write': True 529 | } 530 | ] 531 | }, 532 | 'instance': ' and '.join(bounds) 533 | } 534 | 535 | return config 536 | -------------------------------------------------------------------------------- /pytorch2timeloop/utils/matmul.yaml: -------------------------------------------------------------------------------- 1 | problem: 2 | shape: 3 | dimensions: [ M, N, K ] 4 | data-spaces: 5 | - name: Input1 6 | projection: 7 | - [ [M] ] 8 | - [ [K] ] 9 | - name: Input2 10 | projection: 11 | - [ [K] ] 12 | - [ [N] ] 13 | - name: Outputs 14 | projection: 15 | - [ [M] ] 16 | - [ [N] ] 17 | read-write: True 18 | instance: 19 | M: 1 20 | N: 1 21 | K: 1 -------------------------------------------------------------------------------- /pytorch2timeloop/utils/pool.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | problem: 27 | shape: 28 | name: "Pool" 29 | dimensions: [ C, M, R, S, N, P, Q ] 30 | coefficients: 31 | - name: Wstride 32 | default: 1 33 | - name: Hstride 34 | default: 1 35 | - name: Wdilation 36 | default: 1 37 | - name: Hdilation 38 | default: 1 39 | data-spaces: 40 | - name: Inputs 41 | projection: 42 | - [ [N] ] 43 | - [ [C] ] 44 | - [ [R, Wdilation], [P, Wstride] ] # SOP form: R*Wdilation + P*Wstride 45 | - [ [S, Hdilation], [Q, Hstride] ] # SOP form: S*Hdilation + Q*Hstride 46 | - name: Outputs 47 | projection: 48 | - [ [N] ] 49 | - [ [M] ] 50 | - [ [Q] ] 51 | - [ [P] ] 52 | read-write: True 53 | instance: 54 | C: 16 55 | M: 32 56 | N: 1 57 | P: 10 58 | Q: 10 59 | R: 5 60 | S: 5 61 | Wdilation: 1 62 | Wstride: 1 63 | Hdilation: 1 64 | Hstride: 1 65 | -------------------------------------------------------------------------------- /pytorch2timeloop/utils/softmax.yaml: -------------------------------------------------------------------------------- 1 | problem: 2 | shape: 3 | name: "softmax" 4 | dimensions: [] 5 | data-spaces: 6 | - name: Input 7 | - name: Output 8 | read-write: True -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4 2 | torchvision==0.19 3 | numpy==2.1 4 | pyyaml==5.3 5 | transformers==4.26.0 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='pytorch2timeloop', 4 | version='0.2', 5 | url='https://github.com/Accelergy-Project/pytorch2timeloop-converter', 6 | license='MIT', 7 | install_requires=[ 8 | "torch==2.4", 9 | "torchvision==0.19", 10 | "numpy==2.1", 11 | "pyyaml==5.3", 12 | "transformers==4.26.0" 13 | ], 14 | dependency_links=[ 15 | "https://download.pytorch.org/whl/cpu/" 16 | ], 17 | python_requires='>=3.6', 18 | include_package_data=True, 19 | packages=find_packages()) 20 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Accelergy-Project/pytorch2timeloop-converter/5de0d6bb4e718619d64673b94a50a8640242a8fe/test/__init__.py -------------------------------------------------------------------------------- /test/test_configs.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | TMP_TEST_DIR = pathlib.Path(__file__).parent / '.test.tmp' 4 | -------------------------------------------------------------------------------- /test/test_mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from torchvision.models import mobilenet_v2 5 | import pytorch2timeloop 6 | 7 | from .test_configs import TMP_TEST_DIR 8 | 9 | class TestMobileNetv2(unittest.TestCase): 10 | def setUp(self): 11 | self.net = mobilenet_v2() 12 | self.input_size = (3, 224, 224) 13 | self.batch_size = 1 14 | 15 | def test_mobilenet_v2(self): 16 | pytorch2timeloop.convert_model( 17 | model=self.net, 18 | input_size=self.input_size, 19 | batch_size=self.batch_size, 20 | convert_fc=False, 21 | model_name='mobilenet_v2', 22 | save_dir=TMP_TEST_DIR, 23 | ignored_func=[torch.flatten], 24 | exception_module_names=[] 25 | ) 26 | 27 | def test_mobilenet_v2_fused(self): 28 | pytorch2timeloop.convert_model( 29 | model=self.net, 30 | input_size=self.input_size, 31 | batch_size=self.batch_size, 32 | convert_fc=False, 33 | model_name='mobilenet_v2', 34 | save_dir=TMP_TEST_DIR, 35 | fuse=True, 36 | exception_module_names=[] 37 | ) 38 | -------------------------------------------------------------------------------- /test/test_mobilenet_v3.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from torchvision.models import mobilenet_v3_small 5 | import pytorch2timeloop 6 | 7 | from .test_configs import TMP_TEST_DIR 8 | 9 | class TestMobileNetv3(unittest.TestCase): 10 | def setUp(self): 11 | self.net = mobilenet_v3_small() 12 | self.input_size = (3, 224, 224) 13 | self.batch_size = 1 14 | 15 | def test_mobilenet_v3_small(self): 16 | pytorch2timeloop.convert_model( 17 | model=self.net, 18 | input_size=self.input_size, 19 | batch_size=self.batch_size, 20 | convert_fc=False, 21 | model_name='mobilenet_v3_small', 22 | save_dir=TMP_TEST_DIR, 23 | ignored_func=[torch.flatten], 24 | exception_module_names=[] 25 | ) 26 | 27 | def test_mobilenet_v3_small_fused(self): 28 | pytorch2timeloop.convert_model( 29 | model=self.net, 30 | input_size=self.input_size, 31 | batch_size=self.batch_size, 32 | convert_fc=False, 33 | model_name='mobilenet_v3_small', 34 | save_dir=TMP_TEST_DIR, 35 | fuse=True, 36 | exception_module_names=[] 37 | ) 38 | -------------------------------------------------------------------------------- /test/test_resnet.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from torchvision.models import resnet18 5 | import pytorch2timeloop 6 | 7 | from .test_configs import TMP_TEST_DIR 8 | 9 | class TestResnet(unittest.TestCase): 10 | def setUp(self): 11 | self.net = resnet18() 12 | self.input_size = (3, 224, 224) 13 | self.batch_size = 1 14 | 15 | def test_resnet18(self): 16 | pytorch2timeloop.convert_model( 17 | model=self.net, 18 | input_size=self.input_size, 19 | batch_size=self.batch_size, 20 | convert_fc=False, 21 | model_name='resnet18', 22 | save_dir=TMP_TEST_DIR, 23 | ignored_func=[torch.flatten], 24 | exception_module_names=[] 25 | ) 26 | 27 | def test_resnet18_fused(self): 28 | pytorch2timeloop.convert_model( 29 | model=self.net, 30 | input_size=self.input_size, 31 | batch_size=self.batch_size, 32 | convert_fc=False, 33 | model_name='resnet18', 34 | save_dir=TMP_TEST_DIR, 35 | fuse=True, 36 | exception_module_names=[] 37 | ) 38 | -------------------------------------------------------------------------------- /test/test_simple_cnn.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import pytorch2timeloop 6 | 7 | from .test_configs import TMP_TEST_DIR 8 | 9 | class Net(nn.Module): 10 | def __init__(self): 11 | super(Net, self).__init__() 12 | self.conv1 = nn.Conv2d(1, 4, 5, padding = 2) 13 | self.pool = nn.MaxPool2d(2, 2) 14 | self.conv2 = nn.Conv2d(4, 8, 5, padding = 2) 15 | self.fc1 = nn.Linear(8 * 7 * 7, 256) 16 | self.fc2 = nn.Linear(256, 10) 17 | 18 | def forward(self, x): 19 | x = self.pool(F.relu(self.conv1(x))) 20 | x = self.pool(F.relu(self.conv2(x))) 21 | x = x.view(-1, 8 * 7 * 7) 22 | x = F.relu(self.fc1(x)) 23 | x = self.fc2(x) 24 | return x 25 | 26 | class TestSimpleCNN(unittest.TestCase): 27 | def setUp(self): 28 | self.net = Net() 29 | self.input_size = (1, 28, 28) 30 | self.batch_size = 1 31 | 32 | def test_simple_cnn_without_fc(self): 33 | pytorch2timeloop.convert_model( 34 | model=self.net, 35 | input_size=self.input_size, 36 | batch_size=self.batch_size, 37 | convert_fc=False, 38 | model_name='simple_cnn_without_fc', 39 | save_dir=TMP_TEST_DIR, 40 | exception_module_names=[] 41 | ) 42 | 43 | def test_simple_cnn_with_fc(self): 44 | pytorch2timeloop.convert_model( 45 | model=self.net, 46 | input_size=self.input_size, 47 | batch_size=self.batch_size, 48 | convert_fc=True, 49 | model_name='simple_cnn_with_fc', 50 | save_dir=TMP_TEST_DIR, 51 | exception_module_names=[] 52 | ) 53 | 54 | class GroupedCNN(nn.Module): 55 | def __init__(self): 56 | super(GroupedCNN, self).__init__() 57 | self.conv1 = nn.Conv2d(4, 8, 5, groups=2, padding=2) 58 | 59 | def forward(self, x): 60 | return self.conv1(x) 61 | 62 | class TestGroupedConv(unittest.TestCase): 63 | def setUp(self): 64 | self.net = GroupedCNN() 65 | self.input_size = (4, 28, 28) 66 | self.batch_size = 1 67 | 68 | def test_grouped_conv(self): 69 | pytorch2timeloop.convert_model( 70 | model=self.net, 71 | input_size=self.input_size, 72 | batch_size=self.batch_size, 73 | convert_fc=False, 74 | model_name='grouped_conv', 75 | save_dir=TMP_TEST_DIR, 76 | exception_module_names=[] 77 | ) --------------------------------------------------------------------------------