├── .gitignore ├── README.md ├── example.py ├── example_images ├── Charles_Bronson │ ├── Charles_Bronson_0002.jpg │ └── Charles_Bronson_0003.jpg ├── Colin_Prescot │ └── Colin_Prescot_0001.jpg ├── Dino_de_Laurentis │ └── Dino_de_Laurentis_0002.jpg ├── Emma_Thompson │ └── Emma_Thompson_0003.jpg ├── Francis_Ricciardone │ └── Francis_Ricciardone_0001.jpg ├── Fujio_Cho │ └── Fujio_Cho_0006.jpg ├── Gloria_Macapagal_Arroyo │ ├── Gloria_Macapagal_Arroyo_0043.jpg │ └── Gloria_Macapagal_Arroyo_0044.jpg └── Gwyneth_Paltrow │ └── Gwyneth_Paltrow_0006.jpg ├── hmax.py └── universal_patch_set.mat /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | ~* 3 | __pycache__ 4 | *.pyc 5 | C2_output.mat 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/143711096.svg)](https://zenodo.org/badge/latestdoi/143711096) 2 | 3 | PyTorch implementation of HMAX 4 | ============================== 5 | 6 | PyTorch implementation of the HMAX model that closely follows that of the 7 | MATLAB implementation of The Laboratory for Computational Cognitive 8 | Neuroscience: 9 | 10 | http://maxlab.neuro.georgetown.edu/hmax.html 11 | 12 | The S and C units of the HMAX model can almost be mapped directly onto 13 | TorchVision's Conv2d and MaxPool2d layers, where channels are used to store the 14 | filters for different orientations. However, HMAX also implements multiple 15 | scales, which doesn't map nicely onto the existing TorchVision functionality. 16 | Therefore, each scale has its own Conv2d layer, which are executed in parallel. 17 | 18 | Here is a schematic overview of the network architecture: 19 | 20 | layers consisting of units with increasing scale 21 | S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 22 | \ / \ / \ / \ / \ / \ / \ / \ / 23 | C1 C1 C1 C1 C1 C1 C1 C1 24 | \ \ \ | / / / / 25 | ALL-TO-ALL CONNECTIVITY 26 | / / / | \ \ \ \ 27 | S2 S2 S2 S2 S2 S2 S2 S2 28 | | | | | | | | | 29 | C2 C2 C2 C2 C2 C2 C2 C2 30 | 31 | 32 | Installation 33 | ============ 34 | 35 | This script depends on the [NumPy, SciPy](https://www.scipy.org), [PyTorch and 36 | TorchVision](https://pytorch.org) packages. 37 | 38 | 39 | Clone the repository somewhere and run the `example.py` script: 40 | 41 | git clone https://github.com/wmvanvliet/pytorch_hmax 42 | python example.py 43 | 44 | 45 | Usage 46 | ===== 47 | 48 | See the `example.py` script on how to run the model on 10 example images. 49 | 50 | 51 | Explanation of the output 52 | ========================= 53 | 54 | The `hmax.get_all_layers` method returns a 4-tuple: `s1`, `c1`, `s2`, `c2`. 55 | Here is a detailed explanation of the dimensions of each of these variables: 56 | 57 | `s1` 58 | ---- 59 | These are the first simple units in the model, that perform a 2D convolution with Gabor filters. There are 4 Gabor filters, oriented at 90, -45, 0 and 45 degrees. Each filter is defined at 16 different scales. The `s1` variable is a list of length 16, containing the output at each scale. Each element is a NumPy array of shape `#images x #rotations x image_height x image_width` that is the result of the convolution operation. 60 | 61 | `c1` 62 | ---- 63 | The output of the `s1` units is processed by the `c1` units, which perform a maxpool operation. This is done in 8 scales (pooling across a different number of pixels). The `c1` variable is alist of lengh 8, containing the output at each `s1` scale. Each element is a NumPy array of shape `#images x #rotations x height x width`. 64 | 65 | `s2` 66 | ---- 67 | The output of the `c1` units is processed by the `s2` units, which perform 2D convolution again (not with Gabor filters this time, but pre-trained filters loaded from the `universal_patch_set.mat` file). This is done in 8 scales, operating on each of the 8 scales of the c1 output. The `s2` variable is a list of lengh 8, containing the output at each scale. Each element is again a list of length 8, matching the 8 scales of the `c1` units. The elements of this list are NumPy arrays of shape `#images x #filters x height x width` containing the convolution output. 68 | 69 | `c2` 70 | ---- 71 | The output of the `s2` units is processed by the `c2` units, which perform a maxpool operation for each `s2` filter. The `c2` variable is a list of length 8, containing the output at each `s2` scale. Each element is a NumPy array of shape `#images x #filters` containing the result of the maxpool operation. 72 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run the HMAX model on the example images. 3 | 4 | Authors: Marijn van Vliet 5 | """ 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from torchvision import datasets, transforms 9 | import pickle 10 | 11 | import hmax 12 | 13 | # Initialize the model with the universal patch set 14 | print('Constructing model') 15 | model = hmax.HMAX('./universal_patch_set.mat') 16 | 17 | # A folder with example images 18 | example_images = datasets.ImageFolder( 19 | './example_images/', 20 | transform=transforms.Compose([ 21 | transforms.Grayscale(), 22 | transforms.ToTensor(), 23 | transforms.Lambda(lambda x: x * 255), 24 | ]) 25 | ) 26 | 27 | # A dataloader that will run through all example images in one batch 28 | dataloader = DataLoader(example_images, batch_size=10) 29 | 30 | # Determine whether there is a compatible GPU available 31 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 32 | 33 | # Run the model on the example images 34 | print('Running model on', device) 35 | model = model.to(device) 36 | for X, y in dataloader: 37 | s1, c1, s2, c2 = model.get_all_layers(X.to(device)) 38 | 39 | print('Saving output of all layers to: output.pkl') 40 | with open('output.pkl', 'wb') as f: 41 | pickle.dump(dict(s1=s1, c1=c1, s2=s2, c2=c2), f) 42 | print('[done]') 43 | -------------------------------------------------------------------------------- /example_images/Charles_Bronson/Charles_Bronson_0002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmvanvliet/pytorch_hmax/bf7eda90d232b57ad23c44f2d5acf4b91eaa23b2/example_images/Charles_Bronson/Charles_Bronson_0002.jpg -------------------------------------------------------------------------------- /example_images/Charles_Bronson/Charles_Bronson_0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmvanvliet/pytorch_hmax/bf7eda90d232b57ad23c44f2d5acf4b91eaa23b2/example_images/Charles_Bronson/Charles_Bronson_0003.jpg -------------------------------------------------------------------------------- /example_images/Colin_Prescot/Colin_Prescot_0001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmvanvliet/pytorch_hmax/bf7eda90d232b57ad23c44f2d5acf4b91eaa23b2/example_images/Colin_Prescot/Colin_Prescot_0001.jpg -------------------------------------------------------------------------------- /example_images/Dino_de_Laurentis/Dino_de_Laurentis_0002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmvanvliet/pytorch_hmax/bf7eda90d232b57ad23c44f2d5acf4b91eaa23b2/example_images/Dino_de_Laurentis/Dino_de_Laurentis_0002.jpg -------------------------------------------------------------------------------- /example_images/Emma_Thompson/Emma_Thompson_0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmvanvliet/pytorch_hmax/bf7eda90d232b57ad23c44f2d5acf4b91eaa23b2/example_images/Emma_Thompson/Emma_Thompson_0003.jpg -------------------------------------------------------------------------------- /example_images/Francis_Ricciardone/Francis_Ricciardone_0001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmvanvliet/pytorch_hmax/bf7eda90d232b57ad23c44f2d5acf4b91eaa23b2/example_images/Francis_Ricciardone/Francis_Ricciardone_0001.jpg -------------------------------------------------------------------------------- /example_images/Fujio_Cho/Fujio_Cho_0006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmvanvliet/pytorch_hmax/bf7eda90d232b57ad23c44f2d5acf4b91eaa23b2/example_images/Fujio_Cho/Fujio_Cho_0006.jpg -------------------------------------------------------------------------------- /example_images/Gloria_Macapagal_Arroyo/Gloria_Macapagal_Arroyo_0043.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmvanvliet/pytorch_hmax/bf7eda90d232b57ad23c44f2d5acf4b91eaa23b2/example_images/Gloria_Macapagal_Arroyo/Gloria_Macapagal_Arroyo_0043.jpg -------------------------------------------------------------------------------- /example_images/Gloria_Macapagal_Arroyo/Gloria_Macapagal_Arroyo_0044.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmvanvliet/pytorch_hmax/bf7eda90d232b57ad23c44f2d5acf4b91eaa23b2/example_images/Gloria_Macapagal_Arroyo/Gloria_Macapagal_Arroyo_0044.jpg -------------------------------------------------------------------------------- /example_images/Gwyneth_Paltrow/Gwyneth_Paltrow_0006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmvanvliet/pytorch_hmax/bf7eda90d232b57ad23c44f2d5acf4b91eaa23b2/example_images/Gwyneth_Paltrow/Gwyneth_Paltrow_0006.jpg -------------------------------------------------------------------------------- /hmax.py: -------------------------------------------------------------------------------- 1 | # encoding: utf8 2 | r""" 3 | PyTorch implementation of the HMAX model of human vision. For more information 4 | about HMAX, check: 5 | 6 | http://maxlab.neuro.georgetown.edu/hmax.html 7 | 8 | The S and C units of the HMAX model can almost be mapped directly onto 9 | TorchVision's Conv2d and MaxPool2d layers, where channels are used to store the 10 | filters for different orientations. However, HMAX also implements multiple 11 | scales, which doesn't map nicely onto the existing TorchVision functionality. 12 | Therefore, each scale has its own Conv2d layer, which are executed in parallel. 13 | 14 | Here is a schematic overview of the network architecture: 15 | 16 | layers consisting of units with increasing scale 17 | S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 S1 18 | \ / \ / \ / \ / \ / \ / \ / \ / 19 | C1 C1 C1 C1 C1 C1 C1 C1 20 | \ \ \ | / / / / 21 | ALL-TO-ALL CONNECTIVITY 22 | / / / | \ \ \ \ 23 | S2 S2 S2 S2 S2 S2 S2 S2 24 | | | | | | | | | 25 | C2 C2 C2 C2 C2 C2 C2 C2 26 | 27 | This implementation tries to follow the original MATLAB implementation by 28 | Maximilian Riesenhuber as closely as possible: 29 | https://maxlab.neuro.georgetown.edu/hmax.html 30 | 31 | Author: Marijn van Vliet 32 | 33 | References 34 | ---------- 35 | 36 | .. [1] Riesenhuber, Maximilian, and Tomaso Poggio. “Hierarchical Models of 37 | Object Recognition in Cortex.” Nature Neuroscience 2, no. 11 (1999): 38 | 1019–25. https://doi.org/10.1038/14819. 39 | .. [2] Serre, T, M Kouh, C Cadieu, U Knoblich, Gabriel Kreiman, and T Poggio. 40 | “A Theory of Object Recognition: Computations and Circuits in the 41 | Feedforward Path of the Ventral Stream in Primate Visual Cortex.” 42 | Artificial Intelligence, no. December (2005): 1–130. 43 | https://doi.org/10.1.1.207.9279. 44 | .. [3] Serre, Thomas, Aude Oliva, and Tomaso Poggio. “A Feedforward 45 | Architecture Accounts for Rapid Categorization.” Proceedings of the 46 | National Academy of Sciences 104, no. 15 (April 10, 2007): 6424–29. 47 | https://doi.org/10.1073/pnas.0700622104. 48 | .. [4] Serre, Thomas, and Maximilian Riesenhuber. “Realistic Modeling of 49 | Simple and Complex Cell Tuning in the HMAXModel, and Implications for 50 | Invariant Object Recognition in Cortex.” CBCL Memo, no. 239 (2004). 51 | .. [5] Serre, Thomas, Lior Wolf, Stanley Bileschi, Maximilian Riesenhuber, 52 | and Tomaso Poggio. “Robust Object Recognition with Cortex-like 53 | Mechanisms.” IEEE Trans Pattern Anal Mach Intell 29, no. 3 (2007): 54 | 411–26. https://doi.org/10.1109/TPAMI.2007.56. 55 | """ 56 | import numpy as np 57 | from scipy.io import loadmat 58 | import torch 59 | from torch import nn 60 | 61 | 62 | def gabor_filter(size, wavelength, orientation): 63 | """Create a single gabor filter. 64 | 65 | Parameters 66 | ---------- 67 | size : int 68 | The size of the filter, measured in pixels. The filter is square, hence 69 | only a single number (either width or height) needs to be specified. 70 | wavelength : float 71 | The wavelength of the grating in the filter, relative to the half the 72 | size of the filter. For example, a wavelength of 2 will generate a 73 | Gabor filter with a grating that contains exactly one wave. This 74 | determines the "tightness" of the filter. 75 | orientation : float 76 | The orientation of the grating in the filter, in degrees. 77 | 78 | Returns 79 | ------- 80 | filt : ndarray, shape (size, size) 81 | The filter weights. 82 | """ 83 | lambda_ = size * 2. / wavelength 84 | sigma = lambda_ * 0.8 85 | gamma = 0.3 # spatial aspect ratio: 0.23 < gamma < 0.92 86 | theta = np.deg2rad(orientation + 90) 87 | 88 | # Generate Gabor filter 89 | x, y = np.mgrid[:size, :size] - (size // 2) 90 | rotx = x * np.cos(theta) + y * np.sin(theta) 91 | roty = -x * np.sin(theta) + y * np.cos(theta) 92 | filt = np.exp(-(rotx**2 + gamma**2 * roty**2) / (2 * sigma ** 2)) 93 | filt *= np.cos(2 * np.pi * rotx / lambda_) 94 | filt[np.sqrt(x**2 + y**2) > (size / 2)] = 0 95 | 96 | # Normalize the filter 97 | filt = filt - np.mean(filt) 98 | filt = filt / np.sqrt(np.sum(filt ** 2)) 99 | 100 | return filt 101 | 102 | 103 | class S1(nn.Module): 104 | """A layer of S1 units with different orientations but the same scale. 105 | 106 | The S1 units are at the bottom of the network. They are exposed to the raw 107 | pixel data of the image. Each S1 unit is a Gabor filter, which detects 108 | edges in a certain orientation. They are implemented as PyTorch Conv2d 109 | modules, where each channel is loaded with a Gabor filter in a specific 110 | orientation. 111 | 112 | Parameters 113 | ---------- 114 | size : int 115 | The size of the filters, measured in pixels. The filters are square, 116 | hence only a single number (either width or height) needs to be 117 | specified. 118 | wavelength : float 119 | The wavelength of the grating in the filter, relative to the half the 120 | size of the filter. For example, a wavelength of 2 will generate a 121 | Gabor filter with a grating that contains exactly one wave. This 122 | determines the "tightness" of the filter. 123 | orientations : list of float 124 | The orientations of the Gabor filters, in degrees. 125 | """ 126 | def __init__(self, size, wavelength, orientations=[90, -45, 0, 45]): 127 | super().__init__() 128 | self.num_orientations = len(orientations) 129 | self.size = size 130 | 131 | # Use PyTorch's Conv2d as a base object. Each "channel" will be an 132 | # orientation. 133 | self.gabor = nn.Conv2d(1, self.num_orientations, size, 134 | padding='same', bias=False) 135 | 136 | # The original HMAX code has a rather unique approach to padding during 137 | # convolution. First, the convolution is performed with padding='same', 138 | # and then the borders of the result are replaced with zeros. The 139 | # computation of the border width is as follows: 140 | self.padding_left = (size + 1) // 2 141 | self.padding_right = (size - 1) // 2 142 | self.padding_top = (size + 1) // 2 143 | self.padding_bottom = (size - 1) // 2 144 | 145 | # Fill the Conv2d filter weights with Gabor kernels: one for each 146 | # orientation 147 | for channel, orientation in enumerate(orientations): 148 | self.gabor.weight.data[channel, 0] = torch.Tensor( 149 | gabor_filter(size, wavelength, orientation)) 150 | 151 | # A convolution layer filled with ones. This is used to normalize the 152 | # result in the forward method. 153 | self.uniform = nn.Conv2d(1, 4, size, padding=size // 2, bias=False) 154 | nn.init.constant_(self.uniform.weight, 1) 155 | 156 | # Since everything is pre-computed, no gradient is required 157 | for p in self.parameters(): 158 | p.requires_grad = False 159 | 160 | def forward(self, img): 161 | """Apply Gabor filters, take absolute value, and normalize.""" 162 | s1_output = torch.abs(self.gabor(img)) 163 | s1_output[:, :, :, :self.padding_left] = 0 164 | s1_output[:, :, :, -self.padding_right:] = 0 165 | s1_output[:, :, :self.padding_top, :] = 0 166 | s1_output[:, :, -self.padding_bottom:, :] = 0 167 | norm = torch.sqrt(self.uniform(img ** 2)) 168 | norm.data[norm == 0] = 1 # To avoid divide by zero 169 | s1_output /= norm 170 | return s1_output 171 | 172 | 173 | class C1(nn.Module): 174 | """A layer of C1 units with different orientations but the same scale. 175 | 176 | Each C1 unit pools over the S1 units that are assigned to it. 177 | 178 | Parameters 179 | ---------- 180 | size : int 181 | Size of the MaxPool2d operation being performed by this C1 layer. 182 | """ 183 | def __init__(self, size): 184 | super().__init__() 185 | self.size = size 186 | self.local_pool = nn.MaxPool2d(size, stride=size // 2, 187 | padding=size // 2) 188 | 189 | # Since everything is pre-computed, no gradient is required 190 | for p in self.parameters(): 191 | p.requires_grad = False 192 | 193 | def forward(self, s1_outputs): 194 | """Max over scales, followed by a MaxPool2d operation.""" 195 | s1_outputs = torch.cat([out.unsqueeze(0) for out in s1_outputs], 0) 196 | 197 | # Pool over all scales 198 | s1_output, _ = torch.max(s1_outputs, dim=0) 199 | 200 | # Pool over local (c1_space x c1_space) neighbourhood 201 | c1_output = self.local_pool(s1_output) 202 | 203 | # We need to shift the output after the convolution by 1 pixel to 204 | # exactly match the wonky MATLAB implementation. 205 | c1_output = torch.roll(c1_output, (-1, -1), dims=(2, 3)) 206 | c1_output[:, :, -1, :] = 0 207 | c1_output[:, :, :, -1] = 0 208 | 209 | return c1_output 210 | 211 | 212 | class S2(nn.Module): 213 | """A layer of S2 units with different orientations but the same scale. 214 | 215 | The activation of these units is computed by taking the distance between 216 | the output of the C layer below and a set of predefined patches. This 217 | distance is computed as: 218 | 219 | d = sqrt( (w - p)^2 ) 220 | = sqrt( w^2 - 2pw + p^2 ) 221 | 222 | Parameters 223 | ---------- 224 | patches : ndarray, shape (n_patches, n_orientations, size, size) 225 | The precomputed patches to lead into the weights of this layer. 226 | activation : 'gaussian' | 'euclidean' 227 | Which activation function to use for the units. In the PNAS paper, a 228 | gaussian curve is used ('guassian', the default), whereas the MATLAB 229 | implementation of The Laboratory for Computational Cognitive 230 | Neuroscience uses the euclidean distance ('euclidean'). 231 | sigma : float 232 | The sharpness of the tuning (sigma in eqn 1 of [1]_). Defaults to 1. 233 | Only used when using gaussian activation. 234 | 235 | References: 236 | ----------- 237 | 238 | .. [1] Serre, Thomas, Aude Oliva, and Tomaso Poggio. “A Feedforward 239 | Architecture Accounts for Rapid Categorization.” Proceedings of the 240 | National Academy of Sciences 104, no. 15 (April 10, 2007): 6424–29. 241 | https://doi.org/10.1073/pnas.0700622104. 242 | """ 243 | def __init__(self, patches, activation='euclidean', sigma=1): 244 | super().__init__() 245 | self.activation = activation 246 | self.sigma = sigma 247 | 248 | num_patches, num_orientations, size, _ = patches.shape 249 | 250 | # Main convolution layer 251 | self.conv = nn.Conv2d(in_channels=num_orientations, 252 | out_channels=num_orientations * num_patches, 253 | kernel_size=size, 254 | padding=size // 2, 255 | groups=num_orientations, 256 | bias=False) 257 | self.conv.weight.data = torch.Tensor( 258 | patches.transpose(1, 0, 2, 3).reshape(num_orientations * num_patches, 259 | 1, size, size)) 260 | 261 | # A convolution layer filled with ones. This is used for the distance 262 | # computation 263 | self.uniform = nn.Conv2d(1, 1, size, padding=size // 2, bias=False) 264 | nn.init.constant_(self.uniform.weight, 1) 265 | 266 | # This is also used for the distance computation 267 | self.patches_sum_sq = nn.Parameter( 268 | torch.Tensor((patches ** 2).sum(axis=(1, 2, 3)))) 269 | 270 | self.num_patches = num_patches 271 | self.num_orientations = num_orientations 272 | self.size = size 273 | 274 | # No gradient required for this layer 275 | for p in self.parameters(): 276 | p.requires_grad = False 277 | 278 | def forward(self, c1_outputs): 279 | s2_outputs = [] 280 | for c1_output in c1_outputs: 281 | conv_output = self.conv(c1_output) 282 | conv_output = conv_output[:, :, 1:, 1:] 283 | 284 | # Unstack the orientations 285 | conv_output_size = conv_output.shape[3] 286 | conv_output = conv_output.view( 287 | -1, self.num_orientations, self.num_patches, conv_output_size, 288 | conv_output_size) 289 | 290 | # Pool over orientations 291 | conv_output = conv_output.sum(dim=1) 292 | 293 | # Compute distance 294 | c1_sq = self.uniform( 295 | torch.sum(c1_output ** 2, dim=1, keepdim=True)) 296 | c1_sq = c1_sq[:, :, 1:, 1:] 297 | dist = c1_sq - 2 * conv_output 298 | dist += self.patches_sum_sq[None, :, None, None] 299 | 300 | # Apply activation function 301 | if self.activation == 'gaussian': 302 | dist = torch.exp(- 1 / (2 * self.sigma ** 2) * dist) 303 | elif self.activation == 'euclidean': 304 | dist[dist < 0] = 0 # Negative values should never occur 305 | torch.sqrt_(dist) 306 | else: 307 | raise ValueError("activation parameter should be either " 308 | "'gaussian' or 'euclidean'.") 309 | 310 | s2_outputs.append(dist) 311 | return s2_outputs 312 | 313 | 314 | class C2(nn.Module): 315 | """A layer of C2 units operating on a layer of S2 units.""" 316 | def forward(self, s2_outputs): 317 | """Take the minimum value of the underlying S2 units.""" 318 | mins = [s2.min(dim=3)[0] for s2 in s2_outputs] 319 | mins = [m.min(dim=2)[0] for m in mins] 320 | mins = torch.cat([m[:, None, :] for m in mins], 1) 321 | return mins.min(dim=1)[0] 322 | 323 | 324 | class HMAX(nn.Module): 325 | """The full HMAX model. 326 | 327 | Use the `get_all_layers` method to obtain the activations for all layers. 328 | 329 | If you are only interested in the final output (=C2 layer), use the model 330 | as any other PyTorch module: 331 | 332 | model = HMAX(universal_patch_set) 333 | output = model(img) 334 | 335 | Parameters 336 | ---------- 337 | universal_patch_set : str 338 | Filename of the .mat file containing the universal patch set. 339 | s2_act : 'gaussian' | 'euclidean' 340 | The activation function for the S2 units. Defaults to 'euclidean'. 341 | 342 | Returns 343 | ------- 344 | c2_output : list of Tensors, shape (batch_size, num_patches) 345 | For each scale, the output of the C2 units. 346 | """ 347 | def __init__(self, universal_patch_set, s2_act='euclidean'): 348 | super().__init__() 349 | 350 | # S1 layers, consisting of units with increasing size 351 | self.s1_units = [ 352 | S1(size=7, wavelength=4), 353 | S1(size=9, wavelength=3.95), 354 | S1(size=11, wavelength=3.9), 355 | S1(size=13, wavelength=3.85), 356 | S1(size=15, wavelength=3.8), 357 | S1(size=17, wavelength=3.75), 358 | S1(size=19, wavelength=3.7), 359 | S1(size=21, wavelength=3.65), 360 | S1(size=23, wavelength=3.6), 361 | S1(size=25, wavelength=3.55), 362 | S1(size=27, wavelength=3.5), 363 | S1(size=29, wavelength=3.45), 364 | S1(size=31, wavelength=3.4), 365 | S1(size=33, wavelength=3.35), 366 | S1(size=35, wavelength=3.3), 367 | S1(size=37, wavelength=3.25), 368 | S1(size=39, wavelength=3.20), # Unused as far as I can tell 369 | ] 370 | 371 | # Explicitly add the S1 units as submodules of the model 372 | for s1 in self.s1_units: 373 | self.add_module('s1_%02d' % s1.size, s1) 374 | 375 | # Each C1 layer pools across two S1 layers 376 | self.c1_units = [ 377 | C1(size=8), 378 | C1(size=10), 379 | C1(size=12), 380 | C1(size=14), 381 | C1(size=16), 382 | C1(size=18), 383 | C1(size=20), 384 | C1(size=22), 385 | ] 386 | 387 | # Explicitly add the C1 units as submodules of the model 388 | for c1 in self.c1_units: 389 | self.add_module('c1_%02d' % c1.size, c1) 390 | 391 | # Read the universal patch set for the S2 layer 392 | m = loadmat(universal_patch_set) 393 | patches = [patch.reshape(shape[[2, 1, 0, 3]]).transpose(3, 0, 2, 1) 394 | for patch, shape in zip(m['patches'][0], m['patchSizes'].T)] 395 | 396 | # One S2 layer for each patch scale, operating on all C1 layers 397 | self.s2_units = [S2(patches=scale_patches, activation=s2_act) 398 | for scale_patches in patches] 399 | 400 | # Explicitly add the S2 units as submodules of the model 401 | for i, s2 in enumerate(self.s2_units): 402 | self.add_module('s2_%d' % i, s2) 403 | 404 | # One C2 layer operating on each scale 405 | self.c2_units = [C2() for s2 in self.s2_units] 406 | 407 | # Explicitly add the C2 units as submodules of the model 408 | for i, c2 in enumerate(self.c2_units): 409 | self.add_module('c2_%d' % i, c2) 410 | 411 | def run_all_layers(self, img): 412 | """Compute the activation for each layer. 413 | 414 | Parameters 415 | ---------- 416 | img : Tensor, shape (batch_size, 1, height, width) 417 | A batch of images to run through the model 418 | 419 | Returns 420 | ------- 421 | s1_outputs : List of Tensors, shape (batch_size, num_orientations, height, width) 422 | For each scale, the output of the layer of S1 units. 423 | c1_outputs : List of Tensors, shape (batch_size, num_orientations, height, width) 424 | For each scale, the output of the layer of C1 units. 425 | s2_outputs : List of lists of Tensors, shape (batch_size, num_patches, height, width) 426 | For each C1 scale and each patch scale, the output of the layer of 427 | S2 units. 428 | c2_outputs : List of Tensors, shape (batch_size, num_patches) 429 | For each patch scale, the output of the layer of C2 units. 430 | """ # noqa 431 | s1_outputs = [s1(img) for s1 in self.s1_units] 432 | 433 | # Each C1 layer pools across two S1 layers 434 | c1_outputs = [] 435 | for c1, i in zip(self.c1_units, range(0, len(self.s1_units), 2)): 436 | c1_outputs.append(c1(s1_outputs[i:i+2])) 437 | 438 | s2_outputs = [s2(c1_outputs) for s2 in self.s2_units] 439 | c2_outputs = [c2(s2) for c2, s2 in zip(self.c2_units, s2_outputs)] 440 | 441 | return s1_outputs, c1_outputs, s2_outputs, c2_outputs 442 | 443 | def forward(self, img): 444 | """Run through everything and concatenate the output of the C2s.""" 445 | c2_outputs = self.run_all_layers(img)[-1] 446 | c2_outputs = torch.cat( 447 | [c2_out[:, None, :] for c2_out in c2_outputs], 1) 448 | return c2_outputs 449 | 450 | def get_all_layers(self, img): 451 | """Get the activation for all layers as NumPy arrays. 452 | 453 | Parameters 454 | ---------- 455 | img : Tensor, shape (batch_size, 1, height, width) 456 | A batch of images to run through the model 457 | 458 | Returns 459 | ------- 460 | s1_outputs : List of arrays, shape (batch_size, num_orientations, height, width) 461 | For each scale, the output of the layer of S1 units. 462 | c1_outputs : List of arrays, shape (batch_size, num_orientations, height, width) 463 | For each scale, the output of the layer of C1 units. 464 | s2_outputs : List of lists of arrays, shape (batch_size, num_patches, height, width) 465 | For each C1 scale and each patch scale, the output of the layer of 466 | S2 units. 467 | c2_outputs : List of arrays, shape (batch_size, num_patches) 468 | For each patch scale, the output of the layer of C2 units. 469 | """ # noqa 470 | s1_out, c1_out, s2_out, c2_out = self.run_all_layers(img) 471 | return ( 472 | [s1.cpu().detach().numpy() for s1 in s1_out], 473 | [c1.cpu().detach().numpy() for c1 in c1_out], 474 | [[s2_.cpu().detach().numpy() for s2_ in s2] for s2 in s2_out], 475 | [c2.cpu().detach().numpy() for c2 in c2_out], 476 | ) 477 | -------------------------------------------------------------------------------- /universal_patch_set.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wmvanvliet/pytorch_hmax/bf7eda90d232b57ad23c44f2d5acf4b91eaa23b2/universal_patch_set.mat --------------------------------------------------------------------------------