├── bagnets ├── __init__.py ├── kerasnet.py ├── utils.py └── pytorchnet.py ├── LICENSE ├── setup.py ├── README.rst └── .gitignore /bagnets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | T License 2 | 3 | Copyright (c) 2019 Wieland Brendel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bagnets/kerasnet.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras.models import load_model 3 | 4 | __all__ = ['bagnet9', 'bagnet17', 'bagnet33'] 5 | 6 | model_urls = { 7 | 'bagnet9': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/d413271344758455ac086992beb579e256447839/bagnet8.h5', 8 | 'bagnet17': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/d413271344758455ac086992beb579e256447839/bagnet16.h5', 9 | 'bagnet33': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/d413271344758455ac086992beb579e256447839/bagnet32.h5', 10 | } 11 | 12 | def bagnet9(): 13 | model_path = keras.utils.get_file( 14 | 'bagnet8.h5', 15 | model_urls['bagnet9'], 16 | cache_subdir='models', 17 | file_hash='5b70adc7c4ff77d932dbba485a5ea1d333a65e777a45511010f22e304a2fdd69') 18 | 19 | return load_model(model_path) 20 | 21 | def bagnet17(): 22 | model_path = keras.utils.get_file( 23 | 'bagnet16.h5', 24 | model_urls['bagnet17'], 25 | cache_subdir='models', 26 | file_hash='b262dfee15a86c91e6aa21bfd86505ecd20a539f7f7c72439d5b1d352dd98a1d') 27 | 28 | return load_model(model_path) 29 | 30 | def bagnet33(): 31 | model_path = keras.utils.get_file( 32 | 'bagnet32.h5', 33 | model_urls['bagnet33'], 34 | cache_subdir='models', 35 | file_hash='96d8842eec8b8ce5b3bc6a5f4ff3c8c0278df3722c12bc84408e1487811f8f0f') 36 | 37 | return load_model(model_path) 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from setuptools import find_packages 3 | from os.path import join, dirname 4 | # We need io.open() (Python 3's default open) to specify file encodings 5 | import io 6 | 7 | try: 8 | # obtain long description from README 9 | # Specify encoding to get a unicode type in Python 2 and a str in Python 3 10 | readme_path = join(dirname(__file__), 'README.rst') 11 | with io.open(readme_path, encoding='utf-8') as fr: 12 | README = fr.read() 13 | except IOError: 14 | README = '' 15 | 16 | 17 | install_requires = [ 18 | 'numpy', 19 | 'scipy', 20 | 'setuptools', 21 | ] 22 | 23 | tests_require = [ 24 | 'pytest', 25 | 'pytest-cov', 26 | ] 27 | 28 | setup( 29 | name="bagnets", 30 | version=0.1, 31 | description="Models and pretrained weights for bag-of-local-features models", # noqa: E501 32 | long_description=README, 33 | classifiers=[ 34 | "Development Status :: 3 - Alpha", 35 | "Intended Audience :: Developers", 36 | "Intended Audience :: Science/Research", 37 | "License :: OSI Approved :: MIT License", 38 | "Programming Language :: Python :: 2.7", 39 | "Programming Language :: Python :: 3", 40 | "Programming Language :: Python :: 3.5", 41 | "Programming Language :: Python :: 3.6", 42 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 43 | ], 44 | keywords="", 45 | author="Wieland Brendel", 46 | author_email="wieland.brendel@bethgelab.org", 47 | url="https://github.com/wielandbrendel/bag-of-local-features-models", 48 | license="MIT", 49 | packages=find_packages(), 50 | include_package_data=True, 51 | zip_safe=False, 52 | install_requires=install_requires, 53 | extras_require={ 54 | ':python_version == "2.7"': ['future', 'futures'], 55 | }, 56 | ) 57 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | BagNets 3 | ======= 4 | 5 | In this repository you find the model specification and pretrained weights for the bag-of-local-feature models published in 6 | 7 | | `Approximating CNNs with Bag-of-local-Features models works surprisingly well on ImageNet `__. 8 | | Wieland Brendel and Matthias Bethge, ICLR 2019 9 | 10 | Installation 11 | ------------ 12 | 13 | .. code-block:: bash 14 | 15 | pip install git+https://github.com/wielandbrendel/bag-of-local-features-models.git 16 | 17 | 18 | Usage 19 | ----- 20 | 21 | The code provides simple means to initialize the models in either Pytorch or Keras. After installation please use the following 22 | code snippets to load the models: 23 | 24 | .. code-block:: python 25 | 26 | import bagnets.pytorchnet 27 | pytorch_model = bagnets.pytorchnet.bagnet17(pretrained=True) 28 | 29 | .. code-block:: python 30 | 31 | import bagnets.kerasnet 32 | keras_model = bagnets.kerasnet.bagnet17() 33 | 34 | and replace bagnet17 with whatever size you want (available are bagnet9, bagnet17 and bagnet33). The last number refers to the 35 | maximum local patch size that the network can integrate over. 36 | 37 | FAQ 38 | ---- 39 | 40 | * **Do I need to manually split the image into patches?** 41 | 42 | No. You use BagNets just like any other DNN and apply it to the whole image. The BagNets are really similar to ResNets. In a nutshell we simply replaced most 3x3 convolutions by 1x1 convolutions. This effectively means that the largest receptive fields in the BagNets are of size qxq (where q is smaller than the image size), which is equivalent to splitting the image into individual patches. 43 | 44 | Image Preprocessing 45 | ------------------- 46 | 47 | The models expect inputs with the standard torchvision preprocessing, i.e. 48 | 49 | * with RGB channels 50 | * in the format [channel, x, y] 51 | * loaded with pixel values between 0 and 1 which are then... 52 | * ...normalized by mean and standard deviation, i.e. for given mean: (M1,...,Mn) and std: (S1,..,Sn) for n channels, the normalization should transform each channel of the input as input[channel] = (input[channel] - mean[channel]) / std[channel] 53 | 54 | The mean and standard deviation are: 55 | 56 | * mean = [0.485, 0.456, 0.406] 57 | * std = [0.229, 0.224, 0.225] 58 | 59 | Citation 60 | -------- 61 | 62 | If you find BagNets useful for your scientific work, please consider citing it 63 | in resulting publications: 64 | 65 | .. code-block:: 66 | 67 | @article{brendel2018bagnets, 68 | title={Approximating CNNs with Bag-of-local-Features models works surprisingly well on ImageNet}, 69 | author={Brendel, Wieland and Bethge, Matthias}, 70 | journal={International Conference on Learning Representations}, 71 | year={2019}, 72 | url={https://openreview.net/pdf?id=SkfMWhAqYQ}, 73 | } 74 | 75 | You can find the paper on OpenReview: https://openreview.net/pdf?id=SkfMWhAqYQ 76 | 77 | Authors 78 | ------- 79 | 80 | * `Wieland Brendel `_ 81 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | worktrees/ 2 | .mypy_cache 3 | Makefile 4 | wip/ 5 | 6 | # Created by https://www.gitignore.io/api/osx,vim,linux,python 7 | 8 | ### Linux ### 9 | *~ 10 | 11 | # temporary files which can be created if a process still has a handle open of a deleted file 12 | .fuse_hidden* 13 | 14 | # KDE directory preferences 15 | .directory 16 | 17 | # Linux trash folder which might appear on any partition or disk 18 | .Trash-* 19 | 20 | # .nfs files are created when an open file is removed but is still being accessed 21 | .nfs* 22 | 23 | ### OSX ### 24 | *.DS_Store 25 | .AppleDouble 26 | .LSOverride 27 | 28 | # Icon must end with two \r 29 | Icon 30 | 31 | # Thumbnails 32 | ._* 33 | 34 | # Files that might appear in the root of a volume 35 | .DocumentRevisions-V100 36 | .fseventsd 37 | .Spotlight-V100 38 | .TemporaryItems 39 | .Trashes 40 | .VolumeIcon.icns 41 | .com.apple.timemachine.donotpresent 42 | 43 | # Directories potentially created on remote AFP share 44 | .AppleDB 45 | .AppleDesktop 46 | Network Trash Folder 47 | Temporary Items 48 | .apdisk 49 | 50 | ### Python ### 51 | # Byte-compiled / optimized / DLL files 52 | __pycache__/ 53 | *.py[cod] 54 | *$py.class 55 | 56 | # C extensions 57 | *.so 58 | 59 | # Distribution / packaging 60 | .Python 61 | env/ 62 | build/ 63 | develop-eggs/ 64 | dist/ 65 | downloads/ 66 | eggs/ 67 | .eggs/ 68 | lib/ 69 | lib64/ 70 | parts/ 71 | sdist/ 72 | var/ 73 | wheels/ 74 | *.egg-info/ 75 | .installed.cfg 76 | *.egg 77 | 78 | # PyInstaller 79 | # Usually these files are written by a python script from a template 80 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 81 | *.manifest 82 | *.spec 83 | 84 | # Installer logs 85 | pip-log.txt 86 | pip-delete-this-directory.txt 87 | 88 | # Unit test / coverage reports 89 | htmlcov/ 90 | .tox/ 91 | .coverage 92 | .coverage.* 93 | .cache 94 | nosetests.xml 95 | coverage.xml 96 | *,cover 97 | .hypothesis/ 98 | 99 | # Translations 100 | *.mo 101 | *.pot 102 | 103 | # Django stuff: 104 | *.log 105 | local_settings.py 106 | 107 | # Flask stuff: 108 | instance/ 109 | .webassets-cache 110 | 111 | # Scrapy stuff: 112 | .scrapy 113 | 114 | # Sphinx documentation 115 | docs/_build/ 116 | 117 | # PyBuilder 118 | target/ 119 | 120 | # Jupyter Notebook 121 | .ipynb_checkpoints 122 | 123 | # pyenv 124 | .python-version 125 | 126 | # celery beat schedule file 127 | celerybeat-schedule 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # dotenv 133 | .env 134 | 135 | # virtualenv 136 | .venv 137 | venv/ 138 | ENV/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | ### Vim ### 151 | # swap 152 | [._]*.s[a-v][a-z] 153 | [._]*.sw[a-p] 154 | [._]s[a-v][a-z] 155 | [._]sw[a-p] 156 | # session 157 | Session.vim 158 | # temporary 159 | .netrwhist 160 | # auto-generated tag files 161 | tags 162 | 163 | # End of https://www.gitignore.io/api/osx,vim,linux,python 164 | -------------------------------------------------------------------------------- /bagnets/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from skimage import feature, transform 4 | 5 | def plot_heatmap(heatmap, original, ax, cmap='RdBu_r', 6 | percentile=99, dilation=0.5, alpha=0.25): 7 | """ 8 | Plots the heatmap on top of the original image 9 | (which is shown by most important edges). 10 | 11 | Parameters 12 | ---------- 13 | heatmap : Numpy Array of shape [X, X] 14 | Heatmap to visualise. 15 | original : Numpy array of shape [X, X, 3] 16 | Original image for which the heatmap was computed. 17 | ax : Matplotlib axis 18 | Axis onto which the heatmap should be plotted. 19 | cmap : Matplotlib color map 20 | Color map for the visualisation of the heatmaps (default: RdBu_r) 21 | percentile : float between 0 and 100 (default: 99) 22 | Extreme values outside of the percentile range are clipped. 23 | This avoids that a single outlier dominates the whole heatmap. 24 | dilation : float 25 | Resizing of the original image. Influences the edge detector and 26 | thus the image overlay. 27 | alpha : float in [0, 1] 28 | Opacity of the overlay image. 29 | 30 | """ 31 | if len(heatmap.shape) == 3: 32 | heatmap = np.mean(heatmap, 0) 33 | 34 | dx, dy = 0.05, 0.05 35 | xx = np.arange(0.0, heatmap.shape[1], dx) 36 | yy = np.arange(0.0, heatmap.shape[0], dy) 37 | xmin, xmax, ymin, ymax = np.amin(xx), np.amax(xx), np.amin(yy), np.amax(yy) 38 | extent = xmin, xmax, ymin, ymax 39 | cmap_original = plt.get_cmap('Greys_r') 40 | cmap_original.set_bad(alpha=0) 41 | overlay = None 42 | if original is not None: 43 | # Compute edges (to overlay to heatmaps later) 44 | original_greyscale = original if len(original.shape) == 2 else np.mean(original, axis=-1) 45 | in_image_upscaled = transform.rescale(original_greyscale, dilation, mode='constant', 46 | multichannel=False, anti_aliasing=True) 47 | edges = feature.canny(in_image_upscaled).astype(float) 48 | edges[edges < 0.5] = np.nan 49 | edges[:5, :] = np.nan 50 | edges[-5:, :] = np.nan 51 | edges[:, :5] = np.nan 52 | edges[:, -5:] = np.nan 53 | overlay = edges 54 | 55 | abs_max = np.percentile(np.abs(heatmap), percentile) 56 | abs_min = abs_max 57 | 58 | ax.imshow(heatmap, extent=extent, interpolation='none', cmap=cmap, vmin=-abs_min, vmax=abs_max) 59 | if overlay is not None: 60 | ax.imshow(overlay, extent=extent, interpolation='none', cmap=cmap_original, alpha=alpha) 61 | 62 | 63 | def generate_heatmap_pytorch(model, image, target, patchsize): 64 | """ 65 | Generates high-resolution heatmap for a BagNet by decomposing the 66 | image into all possible patches and by computing the logits for 67 | each patch. 68 | 69 | Parameters 70 | ---------- 71 | model : Pytorch Model 72 | This should be one of the BagNets. 73 | image : Numpy array of shape [1, 3, X, X] 74 | The image for which we want to compute the heatmap. 75 | target : int 76 | Class for which the heatmap is computed. 77 | patchsize : int 78 | The size of the receptive field of the given BagNet. 79 | 80 | """ 81 | import torch 82 | 83 | with torch.no_grad(): 84 | # pad with zeros 85 | _, c, x, y = image.shape 86 | padded_image = np.zeros((c, x + patchsize - 1, y + patchsize - 1)) 87 | padded_image[:, (patchsize-1)//2:(patchsize-1)//2 + x, (patchsize-1)//2:(patchsize-1)//2 + y] = image[0] 88 | image = padded_image[None].astype(np.float32) 89 | 90 | # turn to torch tensor 91 | input = torch.from_numpy(image).cuda() 92 | 93 | # extract patches 94 | patches = input.permute(0, 2, 3, 1) 95 | patches = patches.unfold(1, patchsize, 1).unfold(2, patchsize, 1) 96 | num_rows = patches.shape[1] 97 | num_cols = patches.shape[2] 98 | patches = patches.contiguous().view((-1, 3, patchsize, patchsize)) 99 | 100 | # compute logits for each patch 101 | logits_list = [] 102 | 103 | for batch_patches in torch.split(patches, 1000): 104 | logits = model(batch_patches) 105 | logits = logits[:, target][:, 0] 106 | logits_list.append(logits.data.cpu().numpy().copy()) 107 | 108 | logits = np.hstack(logits_list) 109 | return logits.reshape((224, 224)) -------------------------------------------------------------------------------- /bagnets/pytorchnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | from collections import OrderedDict 5 | from torch.utils import model_zoo 6 | 7 | import os 8 | dir_path = os.path.dirname(os.path.realpath(__file__)) 9 | 10 | __all__ = ['bagnet9', 'bagnet17', 'bagnet33'] 11 | 12 | model_urls = { 13 | 'bagnet9': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet8-34f4ccd2.pth.tar', 14 | 'bagnet17': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet16-105524de.pth.tar', 15 | 'bagnet33': 'https://bitbucket.org/wielandbrendel/bag-of-feature-pretrained-models/raw/249e8fa82c0913623a807d9d35eeab9da7dcc2a8/bagnet32-2ddd53ed.pth.tar', 16 | } 17 | 18 | 19 | class Bottleneck(nn.Module): 20 | expansion = 4 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None, kernel_size=1): 23 | super(Bottleneck, self).__init__() 24 | # print('Creating bottleneck with kernel size {} and stride {} with padding {}'.format(kernel_size, stride, (kernel_size - 1) // 2)) 25 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, stride=stride, 28 | padding=0, bias=False) # changed padding from (kernel_size - 1) // 2 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 31 | self.bn3 = nn.BatchNorm2d(planes * 4) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x, **kwargs): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv3(out) 48 | out = self.bn3(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | if residual.size(-1) != out.size(-1): 54 | diff = residual.size(-1) - out.size(-1) 55 | residual = residual[:,:,:-diff,:-diff] 56 | 57 | out += residual 58 | out = self.relu(out) 59 | 60 | return out 61 | 62 | 63 | class BagNet(nn.Module): 64 | 65 | def __init__(self, block, layers, strides=[1, 2, 2, 2], kernel3=[0, 0, 0, 0], num_classes=1000, avg_pool=True): 66 | self.inplanes = 64 67 | super(BagNet, self).__init__() 68 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, stride=1, padding=0, 69 | bias=False) 70 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0, 71 | bias=False) 72 | self.bn1 = nn.BatchNorm2d(64, momentum=0.001) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], kernel3=kernel3[0], prefix='layer1') 75 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], kernel3=kernel3[1], prefix='layer2') 76 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], kernel3=kernel3[2], prefix='layer3') 77 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], kernel3=kernel3[3], prefix='layer4') 78 | self.avgpool = nn.AvgPool2d(1, stride=1) 79 | self.fc = nn.Linear(512 * block.expansion, num_classes) 80 | self.avg_pool = avg_pool 81 | self.block = block 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 86 | m.weight.data.normal_(0, math.sqrt(2. / n)) 87 | elif isinstance(m, nn.BatchNorm2d): 88 | m.weight.data.fill_(1) 89 | m.bias.data.zero_() 90 | 91 | def _make_layer(self, block, planes, blocks, stride=1, kernel3=0, prefix=''): 92 | downsample = None 93 | if stride != 1 or self.inplanes != planes * block.expansion: 94 | downsample = nn.Sequential( 95 | nn.Conv2d(self.inplanes, planes * block.expansion, 96 | kernel_size=1, stride=stride, bias=False), 97 | nn.BatchNorm2d(planes * block.expansion), 98 | ) 99 | 100 | layers = [] 101 | kernel = 1 if kernel3 == 0 else 3 102 | layers.append(block(self.inplanes, planes, stride, downsample, kernel_size=kernel)) 103 | self.inplanes = planes * block.expansion 104 | for i in range(1, blocks): 105 | kernel = 1 if kernel3 <= i else 3 106 | layers.append(block(self.inplanes, planes, kernel_size=kernel)) 107 | 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | x = self.conv1(x) 112 | x = self.conv2(x) 113 | x = self.bn1(x) 114 | x = self.relu(x) 115 | 116 | x = self.layer1(x) 117 | x = self.layer2(x) 118 | x = self.layer3(x) 119 | x = self.layer4(x) 120 | 121 | if self.avg_pool: 122 | x = nn.AvgPool2d(x.size()[2], stride=1)(x) 123 | x = x.view(x.size(0), -1) 124 | x = self.fc(x) 125 | else: 126 | x = x.permute(0,2,3,1) 127 | x = self.fc(x) 128 | 129 | return x 130 | 131 | def bagnet33(pretrained=False, strides=[2, 2, 2, 1], **kwargs): 132 | """Constructs a Bagnet-33 model. 133 | 134 | Args: 135 | pretrained (bool): If True, returns a model pre-trained on ImageNet 136 | """ 137 | model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,1,1], **kwargs) 138 | if pretrained: 139 | model.load_state_dict(model_zoo.load_url(model_urls['bagnet33'])) 140 | return model 141 | 142 | def bagnet17(pretrained=False, strides=[2, 2, 2, 1], **kwargs): 143 | """Constructs a Bagnet-17 model. 144 | 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | """ 148 | model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,1,0], **kwargs) 149 | if pretrained: 150 | model.load_state_dict(model_zoo.load_url(model_urls['bagnet17'])) 151 | return model 152 | 153 | def bagnet9(pretrained=False, strides=[2, 2, 2, 1], **kwargs): 154 | """Constructs a Bagnet-9 model. 155 | 156 | Args: 157 | pretrained (bool): If True, returns a model pre-trained on ImageNet 158 | """ 159 | model = BagNet(Bottleneck, [3, 4, 6, 3], strides=strides, kernel3=[1,1,0,0], **kwargs) 160 | if pretrained: 161 | model.load_state_dict(model_zoo.load_url(model_urls['bagnet9'])) 162 | return model 163 | --------------------------------------------------------------------------------