├── test ├── __init__.py └── test_resnet_configuration.py ├── pytorch ├── __init__.py ├── models │ ├── __init__.py │ ├── convs.py │ └── resnet.py └── example.py ├── README.md ├── LICENSE ├── .github └── workflows │ └── test_workflow.yml ├── .gitignore └── environment.yml /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch/models/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "Conv", 3 | "Seq1", 4 | "Seq2", 5 | "Seq3", 6 | "ResNet18", 7 | "ResNet34", 8 | "ResNet50", 9 | "ResNet101", 10 | "ResNet152", 11 | ] 12 | 13 | from .convs import * 14 | from .resnet import * 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nas-as-program-transformation-exploration 2 | ![GitHub Workflow Status](https://img.shields.io/github/workflow/status/jack-willturner/nas-as-program-transformation-exploration/CI) 3 | ![LGTM Grade](https://img.shields.io/lgtm/grade/python/github/jack-willturner/nas-as-program-transformation-exploration) 4 | 5 | The code for our paper "[Neural Architecture Search as Program Transformation Exploration](https://arxiv.org/abs/2102.06599)". 6 | -------------------------------------------------------------------------------- /test/test_resnet_configuration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch.models import * 3 | 4 | 5 | def test_default_resnet(): 6 | configs = [ 7 | [{"conv": Conv, "stride": 1}, {"conv": Conv, "stride": 1}], 8 | [{"conv": Conv, "stride": 2}, {"conv": Conv, "stride": 1}], 9 | [{"conv": Conv, "stride": 2}, {"conv": Conv, "stride": 1}], 10 | [{"conv": Conv, "stride": 2}, {"conv": Conv, "stride": 1}], 11 | ] 12 | 13 | net = ResNet18(configs) 14 | y = net(torch.randn(1, 3, 32, 32)) 15 | assert y is not None 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jack Turner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/test_workflow.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | # Controls when the action will run. 4 | on: 5 | # Triggers the workflow on push or pull request events but only for the main branch 6 | push: 7 | branches: [ main ] 8 | pull_request: 9 | branches: [ main ] 10 | 11 | # Allows you to run this workflow manually from the Actions tab 12 | workflow_dispatch: 13 | 14 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 15 | jobs: 16 | # This workflow contains a single job called "build" 17 | build: 18 | # The type of runner that the job will run on 19 | runs-on: ubuntu-latest 20 | 21 | # Steps represent a sequence of tasks that will be executed as part of the job 22 | steps: 23 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 24 | - uses: actions/checkout@v2 25 | 26 | - name: Install conda environment 27 | uses: conda-incubator/setup-miniconda@v2.1.1 28 | with: 29 | activate-environment: loops 30 | environment-file: environment.yml 31 | python-version: 3.8 32 | 33 | - name: GitHub Action for pytest 34 | uses: cclauss/GitHub-Action-for-pytest@0.5.0 35 | -------------------------------------------------------------------------------- /pytorch/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from models import * 4 | 5 | seed = 0 6 | 7 | np.random.seed(seed) 8 | torch.manual_seed(seed) 9 | 10 | if torch.cuda.is_available(): 11 | torch.backends.cudnn.deterministic = True 12 | torch.backends.cudnn.benchmark = False 13 | 14 | 15 | def gen_random_net_config(): 16 | res34, strides = [3, 4, 6, 3], [1, 2, 2, 2] 17 | configs = [] 18 | 19 | for block, stride in zip(res34, strides): 20 | configs_ = [] 21 | for i, layer in enumerate(range(block)): 22 | subconfig = {} 23 | conv = np.random.choice([Conv, Seq1, Seq2, Seq3]) 24 | subconfig["conv"] = conv 25 | subconfig["stride"] = stride if i == 0 else 1 26 | if conv == Seq1: 27 | sf = np.random.choice([1, 2, 4, 8]) 28 | subconfig["split_factor"] = sf 29 | subconfig["groups"] = np.random.choice([1, 2, 4, 8], sf) 30 | elif conv == Seq2: 31 | subconfig["unroll_factor"] = np.random.choice([1, 2, 4, 8, 16]) 32 | subconfig["unrollconv_groups"] = np.random.choice([1, 2, 3, 4]) 33 | elif conv == Seq3: 34 | sf = np.random.choice([1, 2, 4, 8]) 35 | subconfig["split_factor"] = sf 36 | 37 | configs_.append(subconfig) 38 | configs.append(configs_) 39 | return configs 40 | 41 | 42 | invalid = 0 43 | for i in range(100): 44 | try: 45 | model = ResNet34(gen_random_net_config()) 46 | 47 | test = torch.randn((1, 3, 32, 32)) 48 | model(test) 49 | except Exception as e: 50 | print(e) 51 | invalid += 1 52 | 53 | print(f"{100-invalid}/100 configs were valid") 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Mac OS 132 | .DS_Store -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: loops 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=1_llvm 9 | - blas=1.0=mkl 10 | - bzip2=1.0.8=h7f98852_4 11 | - ca-certificates=2021.1.19=h06a4308_1 12 | - certifi=2020.12.5=py38h06a4308_0 13 | - cudatoolkit=11.1.1=h6406543_8 14 | - cycler=0.10.0=py38_0 15 | - dbus=1.13.18=hb2f20db_0 16 | - expat=2.3.0=h2531618_2 17 | - ffmpeg=4.3=hf484d3e_0 18 | - fontconfig=2.13.1=h6c09931_0 19 | - freetype=2.10.4=h0708190_1 20 | - glib=2.68.0=h36276a3_0 21 | - gmp=6.2.1=h58526e2_0 22 | - gnutls=3.6.13=h85f3911_1 23 | - gst-plugins-base=1.14.0=h8213a91_2 24 | - gstreamer=1.14.0=h28cd5cc_2 25 | - icu=58.2=he6710b0_3 26 | - jpeg=9b=h024ee3a_2 27 | - kiwisolver=1.3.1=py38h2531618_0 28 | - lame=3.100=h7f98852_1001 29 | - lcms2=2.12=h3be6417_0 30 | - ld_impl_linux-64=2.33.1=h53a641e_7 31 | - libffi=3.3=he6710b0_2 32 | - libgcc-ng=9.3.0=h2828fa1_18 33 | - libgfortran-ng=7.3.0=hdf63c60_0 34 | - libiconv=1.16=h516909a_0 35 | - libpng=1.6.37=h21135ba_2 36 | - libstdcxx-ng=9.3.0=h6de172a_18 37 | - libtiff=4.1.0=h2733197_1 38 | - libuuid=1.0.3=h1bed415_2 39 | - libuv=1.41.0=h7f98852_0 40 | - libxcb=1.14=h7b6447c_0 41 | - libxml2=2.9.10=hb55368b_3 42 | - llvm-openmp=11.1.0=h4bd325d_1 43 | - lz4-c=1.9.3=h9c3ff4c_0 44 | - matplotlib=3.3.4=py38h06a4308_0 45 | - matplotlib-base=3.3.4=py38h62a2d02_0 46 | - mkl=2020.4=h726a3e6_304 47 | - mkl-service=2.3.0=py38h1e0a361_2 48 | - mkl_fft=1.3.0=py38h5c078b8_1 49 | - mkl_random=1.2.0=py38hc5bc63f_1 50 | - ncurses=6.2=he6710b0_1 51 | - nettle=3.6=he412f7d_0 52 | - ninja=1.10.2=h4bd325d_0 53 | - numpy=1.19.2=py38h54aff64_0 54 | - numpy-base=1.19.2=py38hfa32c7d_0 55 | - olefile=0.46=pyh9f0ad1d_1 56 | - openh264=2.1.1=h780b84a_0 57 | - openssl=1.1.1k=h27cfd23_0 58 | - pandas=1.2.3=py38ha9443f7_0 59 | - pcre=8.44=he6710b0_0 60 | - pillow=8.2.0=py38he98fc37_0 61 | - pip=21.0.1=py38h06a4308_0 62 | - pyparsing=2.4.7=pyhd3eb1b0_0 63 | - pyqt=5.9.2=py38h05f1152_4 64 | - pytest=6.2.3=py38h578d9bd_0 65 | - python=3.8.8=hdb3f193_4 66 | - python-dateutil=2.8.1=pyhd3eb1b0_0 67 | - python_abi=3.8=1_cp38 68 | - pytorch=1.8.1=py3.8_cuda11.1_cudnn8.0.5_0 69 | - pytz=2021.1=pyhd3eb1b0_0 70 | - qt=5.9.7=h5867ecd_1 71 | - readline=8.1=h27cfd23_0 72 | - scipy=1.6.2=py38h91f5cce_0 73 | - seaborn=0.11.1=pyhd3eb1b0_0 74 | - setuptools=52.0.0=py38h06a4308_0 75 | - sip=4.19.13=py38he6710b0_0 76 | - six=1.15.0=pyh9f0ad1d_0 77 | - sqlite=3.35.4=hdfb4753_0 78 | - tk=8.6.10=hbc83047_0 79 | - torchaudio=0.8.1=py38 80 | - torchvision=0.9.1=py38_cu111 81 | - tornado=6.1=py38h27cfd23_0 82 | - tqdm=4.59.0=pyhd3eb1b0_1 83 | - typing_extensions=3.7.4.3=py_0 84 | - wheel=0.36.2=pyhd3eb1b0_0 85 | - xz=5.2.5=h7b6447c_0 86 | - zlib=1.2.11=h7b6447c_3 87 | - zstd=1.4.9=ha95c52a_0 88 | - pip: 89 | - colorama==0.4.4 90 | - commonmark==0.9.1 91 | - gputil==1.4.0 92 | - pygments==2.8.1 93 | - rich==10.1.0 94 | - pytest==5.4.0 95 | -------------------------------------------------------------------------------- /pytorch/models/convs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ["Conv", "Seq1", "Seq2", "Seq3"] 5 | 6 | """ 7 | Each of the convolutions in this file correspond to sequences in 8 | section 7.3 of the paper. 9 | """ 10 | 11 | 12 | class ConvModule(nn.Module): 13 | def _cache_sizes(self, x, convs): 14 | self._sizecache = [] 15 | 16 | for conv in convs: 17 | N, CI, H, W = x.size() 18 | CO, KH, KW, stride, pad, G = ( 19 | conv.out_channels, 20 | conv.kernel_size[0], 21 | conv.kernel_size[1], 22 | conv.stride[0], 23 | conv.padding[0], 24 | conv.groups, 25 | ) 26 | self._sizecache.append([N, H, W, CO, CI, KH, KW, stride, pad, G]) 27 | x = conv(x) 28 | 29 | 30 | class Conv(ConvModule): 31 | def __init__( 32 | self, in_channels, out_channels, kernel_size, stride, bias, padding=1, args=None 33 | ): 34 | super(Conv, self).__init__() 35 | self.conv = nn.Conv2d( 36 | in_channels, 37 | out_channels, 38 | kernel_size=kernel_size, 39 | stride=stride, 40 | bias=bias, 41 | padding=padding, 42 | ) 43 | 44 | def forward(self, x): 45 | return self.conv(x) 46 | 47 | 48 | class Seq1(ConvModule): 49 | def __init__( 50 | self, in_channels, out_channels, kernel_size, stride, bias, padding=1, args=None 51 | ): 52 | super(Seq1, self).__init__() 53 | convs = [] 54 | sf = args["split_factor"] 55 | for i, layer in enumerate(range(sf)): 56 | g = args["groups"][i] 57 | convs.append( 58 | nn.Conv2d( 59 | in_channels, 60 | out_channels // sf, 61 | kernel_size=kernel_size, 62 | stride=stride, 63 | bias=bias, 64 | padding=padding, 65 | groups=g, 66 | ) 67 | ) 68 | self.convs = nn.ModuleList(convs) 69 | 70 | def forward(self, x): 71 | outs = [conv(x) for conv in self.convs] 72 | return torch.cat(outs, dim=1) 73 | 74 | 75 | class Seq2(ConvModule): 76 | def __init__( 77 | self, in_channels, out_channels, kernel_size, stride, bias, padding=1, args=None 78 | ): 79 | super(Seq2, self).__init__() 80 | self.unroll_factor = args["unroll_factor"] 81 | g = args["unrollconv_groups"] 82 | self.conv1 = nn.Conv2d( 83 | in_channels, 84 | self.unroll_factor, 85 | kernel_size=3, 86 | stride=stride, 87 | padding=1, 88 | bias=False, 89 | ) 90 | self.convg1 = nn.Conv2d( 91 | (in_channels - self.unroll_factor), 92 | (out_channels - self.unroll_factor), 93 | kernel_size=3, 94 | stride=stride, 95 | padding=1, 96 | bias=False, 97 | groups=g, 98 | ) 99 | 100 | def forward(self, x): 101 | l_slice = x 102 | r_slice = x[:, self.unroll_factor :, :, :] 103 | 104 | l_out = self.conv1(l_slice) 105 | r_out = self.convg1(r_slice) 106 | 107 | return torch.cat((l_out, r_out), 1) 108 | 109 | 110 | class Seq3(ConvModule): 111 | def __init__( 112 | self, in_channels, out_channels, kernel_size, stride, bias, padding=1, args=None 113 | ): 114 | super(Seq3, self).__init__() 115 | self.split_factor = args["split_factor"] 116 | self.convs = nn.ModuleList( 117 | [ 118 | nn.Conv2d( 119 | in_channels, 120 | out_channels, 121 | kernel_size=kernel_size, 122 | stride=stride, 123 | bias=bias, 124 | padding=padding, 125 | ) 126 | for i in range(args["split_factor"]) 127 | ] 128 | ) 129 | 130 | def forward(self, x): 131 | H = x.shape[2] 132 | Hg = H // self.split_factor 133 | 134 | outs = [] 135 | for i, conv in enumerate(self.convs): 136 | x_ = x[:, :, i * Hg : (i + 1) * Hg, :] 137 | outs.append(conv(x_)) 138 | 139 | return torch.cat(outs, 2) 140 | -------------------------------------------------------------------------------- /pytorch/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | __all__ = ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"] 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | expansion = 1 9 | 10 | def __init__(self, in_planes, planes, layer_config): 11 | super(BasicBlock, self).__init__() 12 | conv = layer_config["conv"] 13 | stride = layer_config["stride"] 14 | self.conv1 = conv( 15 | in_planes, 16 | planes, 17 | kernel_size=3, 18 | stride=stride, 19 | padding=1, 20 | bias=False, 21 | args=layer_config, 22 | ) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = conv( 25 | planes, 26 | planes, 27 | kernel_size=3, 28 | stride=1, 29 | padding=1, 30 | bias=False, 31 | args=layer_config, 32 | ) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | 35 | self.shortcut = nn.Sequential() 36 | if stride != 1 or in_planes != self.expansion * planes: 37 | self.shortcut = nn.Sequential( 38 | nn.Conv2d( 39 | in_planes, 40 | self.expansion * planes, 41 | kernel_size=1, 42 | stride=stride, 43 | bias=False, 44 | ), 45 | nn.BatchNorm2d(self.expansion * planes), 46 | ) 47 | 48 | def forward(self, x): 49 | out = F.relu(self.bn1(self.conv1(x))) 50 | out = self.bn2(self.conv2(out)) 51 | out += self.shortcut(x) 52 | out = F.relu(out) 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, in_planes, planes, stride=1): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d( 64 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 65 | ) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d( 68 | planes, self.expansion * planes, kernel_size=1, bias=False 69 | ) 70 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 71 | 72 | self.shortcut = nn.Sequential() 73 | if stride != 1 or in_planes != self.expansion * planes: 74 | self.shortcut = nn.Sequential( 75 | nn.Conv2d( 76 | in_planes, 77 | self.expansion * planes, 78 | kernel_size=1, 79 | stride=stride, 80 | bias=False, 81 | ), 82 | nn.BatchNorm2d(self.expansion * planes), 83 | ) 84 | 85 | def forward(self, x): 86 | out = F.relu(self.bn1(self.conv1(x))) 87 | out = F.relu(self.bn2(self.conv2(out))) 88 | out = self.bn3(self.conv3(out)) 89 | out += self.shortcut(x) 90 | out = F.relu(out) 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | def __init__(self, block, num_blocks, configs=None, num_classes=10): 96 | super(ResNet, self).__init__() 97 | self.configs = configs 98 | self.in_planes = 64 99 | 100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 101 | self.bn1 = nn.BatchNorm2d(64) 102 | self.layer1 = self._make_layer( 103 | block, 64, num_blocks[0], configs[0] 104 | ) # , stride=1) 105 | self.layer2 = self._make_layer( 106 | block, 128, num_blocks[1], configs[1] 107 | ) # , stride=2) 108 | self.layer3 = self._make_layer( 109 | block, 256, num_blocks[2], configs[2] 110 | ) # , stride=2) 111 | self.layer4 = self._make_layer( 112 | block, 512, num_blocks[3], configs[3] 113 | ) # , stride=2) 114 | self.linear = nn.Linear(512 * block.expansion, num_classes) 115 | 116 | def _make_layer(self, block, planes, num_blocks, configs): 117 | # strides = [stride] + [1]*(num_blocks-1) 118 | layers = [] 119 | for layer_config in configs: 120 | layers.append(block(self.in_planes, planes, layer_config)) 121 | self.in_planes = planes * block.expansion 122 | return nn.Sequential(*layers) 123 | 124 | def forward(self, x): 125 | out = F.relu(self.bn1(self.conv1(x))) 126 | out = self.layer1(out) 127 | out = self.layer2(out) 128 | out = self.layer3(out) 129 | out = self.layer4(out) 130 | out = F.avg_pool2d(out, 4) 131 | out = out.view(out.size(0), -1) 132 | out = self.linear(out) 133 | return out 134 | 135 | 136 | def ResNet18(configs): 137 | return ResNet(BasicBlock, [2, 2, 2, 2], configs) 138 | 139 | 140 | def ResNet34(configs): 141 | return ResNet(BasicBlock, [3, 4, 6, 3], configs) 142 | 143 | 144 | def ResNet50(configs): 145 | return ResNet(Bottleneck, [3, 4, 6, 3], configs) 146 | 147 | 148 | def ResNet101(configs): 149 | return ResNet(Bottleneck, [3, 4, 23, 3], configs) 150 | 151 | 152 | def ResNet152(configs): 153 | return ResNet(Bottleneck, [3, 8, 36, 3], configs) 154 | --------------------------------------------------------------------------------