├── LICENSE ├── README.md ├── ats ├── core │ ├── __pycache__ │ │ ├── ats_layer.cpython-36.pyc │ │ ├── expectation.cpython-36.pyc │ │ └── sampling.cpython-36.pyc │ ├── ats_layer.py │ ├── expectation.py │ └── sampling.py ├── data │ └── from_tensors.py └── utils │ ├── __init__.py │ ├── layers.py │ ├── logging.py │ └── regularizers.py ├── colon_cancer.py ├── dataset ├── colon_cancer │ └── .gitkeep ├── colon_cancer_dataset.py ├── mega_mnist │ └── .gitkeep ├── mega_mnist_dataset.py ├── speed_limits_dataset.py └── traffic_data │ └── .gitkeep ├── mega_mnist.py ├── models ├── attention_model.py ├── classifier.py └── feature_model.py ├── requirements.txt ├── speed_limits.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 SURFsara 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Attention Sampling - Pytorch 3 | This is a PyTorch implementation of the the paper: ["Processing Megapixel Images with Deep Attention-Sampling Models"](https://arxiv.org/abs/1905.03711). This repository is based on the [original repository](https://github.com/idiap/attention-sampling) belonging to this paper which is written in TensorFlow. 4 | 5 | ## Porting to PyTorch 6 | The code from the original repository has been rewritten to to a PyTorch 1.4.0 implementation. The most difficult part was rewriting the functions that extract the patches from the high resolution image. The original version uses special C/C++ files for this, I have done this in native Python. This is probably more inefficient and slower because it requires a nested for-loop. I tested with performing the patch extraction in parallel but this adds so much overhead that it is actually slower. 7 | 8 | Furthermore, I hope I implemented the part where the expectation is calculated correctly. This uses a custom `backward()` function and I hope there are no bugs in it. 9 | 10 | ## Performance 11 | This code repository has been tested on two of the tasks mentioned in the original paper: the Mega-MNIST and the traffic sign detection task. A qualitative analysis of the results show they are comparable to the original work, however a qualitative analysis shows the errors are higher in this code base. A couple of users have alerted me that they can not reproduce the results from the original paper using this code base. I suspect there may still be a couple of bugs in this work. If you intend to use it beware, and any help finding these will be greatly appreciated. Experiments can be run by running `mega_mnist.py` and `speed_limits.py`. 12 | 13 | ## Installation 14 | Dependencies can be found inside the `requirements.txt` file. To install, run `pip3 install -r requirements.txt`. This code repository defaults to running on a GPU if it is available. It has been tested on both CPU and GPU. 15 | 16 | ## Questions and contributions 17 | If you have any question about the code or methods used in this repository you can reach out to joris.mollinga@surf.nl. If you find bugs in this code (which could be possible) please also contact me or file an issue. If you want to contribute to this code my making it more efficient (for example, the patch extraction procedure is quite inefficient) please contact me or submit a pull request. 18 | 19 | ## Research 20 | If this repository has helped you in your research we would value to be acknowledged in your publication. 21 | 22 | # Acknowledgement 23 | This project has received funding from the European Union’s Horizon 2020 research and innovation programme under grant agreement No 825292. This project is better known as the ExaMode project. The objectives of the ExaMode project are: 24 | 1. Weakly-supervised knowledge discovery for exascale medical data. 25 | 2. Develop extreme scale analytic tools for heterogeneous exascale multimodal and multimedia data. 26 | 3. Healthcare & industry decision-making adoption of extreme-scale analysis and prediction tools. 27 | 28 | For more information on the ExaMode project, please visit www.examode.eu. 29 | 30 | ![enter image description here](https://www.examode.eu/wp-content/uploads/2018/11/horizon.jpg) ![enter image description here](https://www.examode.eu/wp-content/uploads/2018/11/flag_yellow.png) 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /ats/core/__pycache__/ats_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sara-nl/attention-sampling-pytorch/47b7722e180e6179327c571f3a846f72f06ee182/ats/core/__pycache__/ats_layer.cpython-36.pyc -------------------------------------------------------------------------------- /ats/core/__pycache__/expectation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sara-nl/attention-sampling-pytorch/47b7722e180e6179327c571f3a846f72f06ee182/ats/core/__pycache__/expectation.cpython-36.pyc -------------------------------------------------------------------------------- /ats/core/__pycache__/sampling.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sara-nl/attention-sampling-pytorch/47b7722e180e6179327c571f3a846f72f06ee182/ats/core/__pycache__/sampling.cpython-36.pyc -------------------------------------------------------------------------------- /ats/core/ats_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pdb 4 | 5 | from ..data.from_tensors import FromTensors 6 | from .sampling import sample 7 | from .expectation import Expectation 8 | 9 | 10 | class SamplePatches(nn.Module): 11 | """SamplePatches samples from a high resolution image using an attention 12 | map. The layer expects the following inputs when called `x_low`, `x_high`, 13 | `attention`. `x_low` corresponds to the low resolution view of the image 14 | which is used to derive the mapping from low resolution to high. `x_high` 15 | is the tensor from which we extract patches. `attention` is an attention 16 | map that is computed from `x_low`. 17 | Arguments 18 | --------- 19 | n_patches: int, how many patches should be sampled 20 | patch_size: int, the size of the patches to be sampled (squared) 21 | receptive_field: int, how large is the receptive field of the attention 22 | network. It is used to map the attention to high 23 | resolution patches. 24 | replace: bool, whether we should sample with replacement or without 25 | use_logits: bool, whether of not logits are used in the attention map 26 | """ 27 | 28 | def __init__(self, n_patches, patch_size, receptive_field=0, replace=False, 29 | use_logits=False, **kwargs): 30 | self._n_patches = n_patches 31 | self._patch_size = (patch_size, patch_size) 32 | self._receptive_field = receptive_field 33 | self._replace = replace 34 | self._use_logits = use_logits 35 | 36 | super(SamplePatches, self).__init__(**kwargs) 37 | 38 | def compute_output_shape(self, input_shape): 39 | """ Legacy function of the pytorch implementation """ 40 | shape_low, shape_high, shape_att = input_shape 41 | 42 | # Figure out the shape of the patches 43 | patch_shape = (shape_high[1], *self._patch_size) 44 | 45 | patches_shape = (shape_high[0], self._n_patches, *patch_shape) 46 | 47 | # Sampled attention 48 | att_shape = (shape_high[0], self._n_patches) 49 | 50 | return [patches_shape, att_shape] 51 | 52 | def forward(self, x_low, x_high, attention): 53 | sample_space = attention.shape[1:] 54 | samples, sampled_attention = sample( 55 | self._n_patches, 56 | attention, 57 | sample_space, 58 | replace=self._replace, 59 | use_logits=self._use_logits 60 | ) 61 | 62 | offsets = torch.zeros_like(samples).float() 63 | if self._receptive_field > 0: 64 | offsets = offsets + self._receptive_field / 2 65 | 66 | # Get the patches from the high resolution data 67 | # Make sure that below works 68 | x_low = x_low.permute(0, 2, 3, 1) 69 | x_high = x_high.permute(0, 2, 3, 1) 70 | assert x_low.shape[-1] == x_high.shape[-1], "Channels should be last for now" 71 | patches, _ = FromTensors([x_low, x_high], None).patches( 72 | samples, 73 | offsets, 74 | sample_space, 75 | torch.Tensor([x_low.shape[1:-1]]).view(-1) - self._receptive_field, 76 | self._patch_size, 77 | 0, 78 | 1 79 | ) 80 | 81 | return [patches, sampled_attention] 82 | 83 | 84 | class ATSModel(nn.Module): 85 | """ Attention sampling model that perform the entire process of calculating the 86 | attention map, sampling the patches, calculating the features of the patches, 87 | the expectation and classifices the features. 88 | Arguments 89 | --------- 90 | attention_model: pytorch model, that calculated the attention map given a low 91 | resolution input image 92 | feature_model: pytorch model, that takes the patches and calculated features 93 | of the patches 94 | classifier: pytorch model, that can do a classification into the number of 95 | classes for the specific problem 96 | n_patches: int, the number of patches to sample 97 | patch_size: int, the patch size (squared) 98 | receptive_field: int, how large is the receptive field of the attention network. 99 | It is used to map the attention to high resolution patches. 100 | replace: bool, if to sample with our without replacment 101 | use_logts: bool, if to use logits when sampling 102 | """ 103 | 104 | def __init__(self, attention_model, feature_model, classifier, n_patches, patch_size, receptive_field=0, 105 | replace=False, use_logits=False): 106 | super(ATSModel, self).__init__() 107 | 108 | self.attention_model = attention_model 109 | self.feature_model = feature_model 110 | self.classifier = classifier 111 | 112 | self.sampler = SamplePatches(n_patches, patch_size, receptive_field, replace, use_logits) 113 | self.expectation = Expectation(replace=replace) 114 | 115 | self.patch_size = patch_size 116 | self.n_patches = n_patches 117 | 118 | def forward(self, x_low, x_high): 119 | # First we compute our attention map 120 | attention_map = self.attention_model(x_low) 121 | 122 | # Then we sample patches based on the attention 123 | patches, sampled_attention = self.sampler(x_low, x_high, attention_map) 124 | 125 | # We compute the features of the sampled patches 126 | channels = patches.shape[2] 127 | patches_flat = patches.view(-1, channels, self.patch_size, self.patch_size) 128 | patch_features = self.feature_model(patches_flat) 129 | dims = patch_features.shape[-1] 130 | patch_features = patch_features.view(-1, self.n_patches, dims) 131 | 132 | sample_features = self.expectation(patch_features, sampled_attention) 133 | 134 | y = self.classifier(sample_features) 135 | 136 | return y, attention_map, patches, x_low 137 | -------------------------------------------------------------------------------- /ats/core/expectation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..utils import expand_many, to_float32 5 | 6 | 7 | class ExpectWithReplacement(torch.autograd.Function): 8 | """ Custom pytorch layer for calculating the expectation of the sampled patches 9 | with replacement. 10 | """ 11 | @staticmethod 12 | def forward(ctx, weights, attention, features): 13 | 14 | axes = [-1] * (len(features.shape) - 2) 15 | wf = expand_many(weights, axes) 16 | 17 | F = torch.sum(wf * features, dim=1) 18 | 19 | ctx.save_for_backward(weights, attention, features, F) 20 | return F 21 | 22 | @staticmethod 23 | def backward(ctx, grad_output): 24 | weights, attention, features, F = ctx.saved_tensors 25 | axes = [-1] * (len(features.shape) - 2) 26 | wf = expand_many(weights, axes) 27 | 28 | grad = torch.unsqueeze(grad_output, 1) 29 | 30 | # Gradient wrt to the attention 31 | ga = grad * features 32 | ga = torch.sum(ga, axis=list(range(2, len(ga.shape)))) 33 | ga = ga * weights / attention 34 | 35 | # Gradient wrt to the features 36 | gf = wf * grad 37 | 38 | return None, ga, gf 39 | 40 | 41 | class ExpectWithoutReplacement(torch.autograd.Function): 42 | """ Custom pytorch layer for calculating the expectation of the sampled patches 43 | without replacement. 44 | """ 45 | 46 | @staticmethod 47 | def forward(ctx, weights, attention, features): 48 | # Reshape the passed weights and attention in feature compatible shapes 49 | axes = [-1] * (len(features.shape) - 2) 50 | wf = expand_many(weights, axes) 51 | af = expand_many(attention, axes) 52 | 53 | # Compute how much of the probablity mass was available for each sample 54 | pm = 1 - torch.cumsum(attention, axis=1) 55 | pmf = expand_many(pm, axes) 56 | 57 | # Compute the features 58 | Fa = af * features 59 | Fpm = pmf * features 60 | Fa_cumsum = torch.cumsum(Fa, axis=1) 61 | F_estimator = Fa_cumsum + Fpm 62 | 63 | F = torch.sum(wf * F_estimator, axis=1) 64 | 65 | ctx.save_for_backward(weights, attention, features, pm, pmf, Fa, Fpm, Fa_cumsum, F_estimator) 66 | 67 | return F 68 | 69 | @staticmethod 70 | def backward(ctx, grad_output): 71 | weights, attention, features, pm, pmf, Fa, Fpm, Fa_cumsum, F_estimator = ctx.saved_tensors 72 | device = weights.device 73 | 74 | axes = [-1] * (len(features.shape) - 2) 75 | wf = expand_many(weights, axes) 76 | af = expand_many(attention, axes) 77 | 78 | N = attention.shape[1] 79 | probs = attention / pm 80 | probsf = expand_many(probs, axes) 81 | grad = torch.unsqueeze(grad_output, 1) 82 | 83 | # Gradient wrt to the attention 84 | ga1 = F_estimator / probsf 85 | ga2 = ( 86 | torch.cumsum(features, axis=1) - 87 | expand_many(to_float32(torch.arange(0, N, device=device)), [0] + axes) * features 88 | ) 89 | ga = grad * (ga1 + ga2) 90 | ga = torch.sum(ga, axis=list(range(2, len(ga.shape)))) 91 | ga = ga * weights 92 | 93 | # Gradient wrt to the features 94 | gf = expand_many(to_float32(torch.arange(N-1, -1, -1, device=device)), [0] + axes) 95 | gf = pmf + gf * af 96 | gf = wf * gf 97 | gf = gf * grad 98 | 99 | return None, ga, gf 100 | 101 | 102 | class Expectation(nn.Module): 103 | """ Approximate the expectation of all the features under the attention 104 | distribution (and its gradient) given a sampled set. 105 | 106 | Arguments 107 | --------- 108 | attention: Tensor of shape (B, N) containing the attention values that 109 | correspond to the sampled features 110 | features: Tensor of shape (B, N, ...) containing the sampled features 111 | replace: bool describing if we sampled with or without replacement 112 | weights: Tensor of shape (B, N) or None to weigh the samples in case of 113 | multiple samplings of the same position. If None it defaults 114 | o torch.ones(B, N) 115 | """ 116 | 117 | def __init__(self, replace=False): 118 | super(Expectation, self).__init__() 119 | self._replace = replace 120 | 121 | self.E = ExpectWithReplacement() if replace else ExpectWithoutReplacement() 122 | 123 | def forward(self, features, attention, weights=None): 124 | if weights is None: 125 | weights = torch.ones_like(attention) / float(attention.shape[1]) 126 | 127 | return self.E.apply(weights, attention, features) 128 | -------------------------------------------------------------------------------- /ats/core/sampling.py: -------------------------------------------------------------------------------- 1 | """Implement sampling from a multinomial distribution on a n-dimensional 2 | tensor.""" 3 | import torch 4 | import torch.distributions as dist 5 | 6 | 7 | def _sample_with_replacement(logits, n_samples): 8 | """Sample with replacement using the pytorch categorical distribution op.""" 9 | distribution = dist.categorical.Categorical(logits=logits) 10 | return distribution.sample(sample_shape=torch.Size([n_samples])).transpose(0, 1) 11 | 12 | 13 | def _sample_without_replacement(logits, n_samples): 14 | """Sample without replacement using the Gumbel-max trick. 15 | See lips.cs.princeton.edu/the-gumbel-max-trick-for-discrete-distributions/ 16 | """ 17 | z = -torch.log(-torch.log(torch.rand_like(logits))) 18 | return torch.topk(logits+z, k=n_samples)[1] 19 | 20 | 21 | def unravel_index(index, shape): 22 | out = [] 23 | for dim in reversed(shape): 24 | out.append(index % dim) 25 | index = index // dim 26 | return torch.stack(tuple(reversed(out))) 27 | 28 | 29 | def sample(n_samples, attention, sample_space, replace=False, 30 | use_logits=False): 31 | """Sample from the passed in attention distribution. 32 | Arguments 33 | --------- 34 | n_samples: int, the number of samples per datapoint 35 | attention: tensor, the attention distribution per datapoint (could be logits 36 | or normalized) 37 | sample_space: This should always equal K.shape(attention)[1:] 38 | replace: bool, sample with replacement if set to True (defaults to False) 39 | use_logits: bool, assume the input is logits if set to True (defaults to False) 40 | """ 41 | # Make sure we have logits and choose replacement or not 42 | logits = attention if use_logits else torch.log(attention) 43 | sampling_function = ( 44 | _sample_with_replacement if replace 45 | else _sample_without_replacement 46 | ) 47 | 48 | # Flatten the attention distribution and sample from it 49 | logits = logits.reshape(-1, sample_space[0]*sample_space[1]) 50 | samples = sampling_function(logits, n_samples) 51 | 52 | # Unravel the indices into sample_space 53 | batch_size = attention.shape[0] 54 | n_dims = len(sample_space) 55 | 56 | # Gather the attention 57 | attention = attention.view(batch_size, 1, -1).expand(batch_size, n_samples, -1) 58 | sampled_attention = torch.gather(attention, -1, samples[:, :, None])[:, :, 0] 59 | 60 | samples = unravel_index(samples.reshape(-1, ), sample_space) 61 | samples = torch.reshape(samples.transpose(1, 0), (batch_size, n_samples, n_dims)) 62 | 63 | return samples, sampled_attention 64 | -------------------------------------------------------------------------------- /ats/data/from_tensors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from joblib import Parallel, delayed 4 | 5 | from ..utils import to_tensor, to_float32, to_int32, expand_many 6 | 7 | 8 | def _extract_patch(img_b, coord, patch_size): 9 | """ Extract a single patch """ 10 | x_start = int(coord[0]) 11 | x_end = x_start + int(patch_size[0]) 12 | y_start = int(coord[1]) 13 | y_end = y_start + int(patch_size[1]) 14 | 15 | patch = img_b[:, x_start:x_end, y_start:y_end] 16 | return patch 17 | 18 | 19 | def _extract_patches_batch(b, img, offsets, patch_size, num_patches, extract_patch_parallel=False): 20 | """ Extract patches for a single batch. This function can be called in a for loop or in parallel. 21 | This functions returns a tensor of patches of size [num_patches, channels, width, height] """ 22 | patches = [] 23 | 24 | # Extracting in parallel is more expensive than doing it sequentially. This I left it in here 25 | if extract_patch_parallel: 26 | num_jobs = min(os.cpu_count(), num_patches) 27 | patches = Parallel(n_jobs=num_jobs)( 28 | delayed(_extract_patch)(img[b], offsets[b, p], patch_size) for p in range(num_patches)) 29 | 30 | else: 31 | # Run extraction sequentially 32 | for p in range(num_patches): 33 | patch = _extract_patch(img[b], offsets[b, p], patch_size) 34 | patches.append(patch) 35 | 36 | return torch.stack(patches) 37 | 38 | 39 | def extract_patches(img, offsets, patch_size, extract_batch_parallel=False): 40 | img = img.permute(0, 3, 1, 2) 41 | 42 | num_patches = offsets.shape[1] 43 | batch_size = img.shape[0] 44 | 45 | # I pad the images with zeros for the cases that a part of the patch falls outside the image 46 | pad_const = int(patch_size[0].item() / 2) 47 | pad_func = torch.nn.ConstantPad2d(pad_const, 0.0) 48 | img = pad_func(img) 49 | 50 | # Add the pad_const to the offsets, because everything is now shifted by pad_const 51 | offsets = offsets + pad_const 52 | 53 | all_patches = [] 54 | 55 | # Extracting in parallel is more expensive than doing it sequentially. This I left it in here 56 | if extract_batch_parallel: 57 | num_jobs = min(os.cpu_count(), batch_size) 58 | all_patches = Parallel(n_jobs=num_jobs)( 59 | delayed(_extract_patches_batch)(b, img, offsets, patch_size, num_patches) for b in range(batch_size)) 60 | 61 | else: 62 | # Run sequentially over the elements in the batch 63 | for b in range(batch_size): 64 | patches = _extract_patches_batch(b, img, offsets, patch_size, num_patches) 65 | all_patches.append(patches) 66 | 67 | return torch.stack(all_patches) 68 | 69 | 70 | class FromTensors: 71 | def __init__(self, xs, y): 72 | """Given input tensors for each level of resolution provide the patches. 73 | Arguments 74 | --------- 75 | xs: list of tensors, one tensor per resolution in ascending 76 | resolutions, namely the lowest resolution is 0 and the highest 77 | is len(xs)-1 78 | y: tensor or list of tensors or None, the targets can be anything 79 | since it is simply returned as is 80 | """ 81 | self._xs = xs 82 | self._y = y 83 | 84 | def targets(self): 85 | # Since the xs were also given to us the y is also given to us 86 | return self._y 87 | 88 | def inputs(self): 89 | # We leave it to the caller to add xs and y to the input list if they 90 | # are placeholders 91 | return [] 92 | 93 | def patches(self, samples, offsets, sample_space, previous_patch_size, 94 | patch_size, fromlevel, tolevel): 95 | device = samples.device 96 | 97 | # Make sure everything is a tensor 98 | sample_space = to_tensor(sample_space, device=device) 99 | previous_patch_size = to_tensor(previous_patch_size, device=device) 100 | patch_size = to_tensor(patch_size, device=device) 101 | shape_from = self._shape(fromlevel) 102 | shape_to = self._shape(tolevel) 103 | 104 | # Compute the scales 105 | scale_samples = self._scale(sample_space, tolevel).to(device) 106 | scale_offsets = self._scale(shape_from, shape_to).to(device) 107 | 108 | # Steps is the offset per pixel of the sample space. Pixel zero should 109 | # be at position steps/2 and the last pixel should be at 110 | # space_available - steps/2. 111 | space_available = to_float32(previous_patch_size) * scale_offsets 112 | steps = space_available / to_float32(sample_space) 113 | 114 | # Compute the patch start which are also the offsets to be returned 115 | offsets = to_int32(torch.round( 116 | to_float32(offsets) * expand_many(scale_offsets, [0, 0]) + 117 | to_float32(samples) * expand_many(steps, [0, 0]) + 118 | expand_many(steps / 2, [0, 0]) - 119 | expand_many(to_float32(patch_size) / 2, [0, 0]) 120 | )) 121 | 122 | # Extract the patches 123 | patches = extract_patches( 124 | self._xs[tolevel], 125 | offsets, 126 | patch_size 127 | ) 128 | 129 | return patches, offsets 130 | 131 | def data(self, level): 132 | return self._xs[level] 133 | 134 | def _scale(self, shape_from, shape_to): 135 | # Compute the tensor that needs to be multiplied with `shape_from` to 136 | # get `shape_to` 137 | shape_from = to_float32(to_tensor(shape_from)) 138 | shape_to = to_float32(to_tensor(shape_to)) 139 | 140 | return shape_to / shape_from 141 | 142 | def _shape(self, level): 143 | x = self._xs[level] 144 | int_shape = x.shape[1:-1] 145 | if not any(s is None for s in int_shape): 146 | return int_shape 147 | 148 | return x.shape[1:-1] 149 | -------------------------------------------------------------------------------- /ats/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Provide utility functions to the rest of the modules.""" 2 | from functools import partial 3 | 4 | import torch 5 | 6 | 7 | def to_tensor(x, dtype=torch.int32, device=None): 8 | """If x is a Tensor return it as is otherwise return a constant tensor of 9 | type dtype.""" 10 | device = torch.device('cpu') if device is None else device 11 | if torch.is_tensor(x): 12 | return x.to(device) 13 | 14 | return torch.tensor(x, dtype=dtype, device=device) 15 | 16 | 17 | def to_dtype(x, dtype): 18 | """Cast Tensor x to the dtype """ 19 | return x.type(dtype) 20 | 21 | 22 | to_float16 = partial(to_dtype, dtype=torch.float16) 23 | to_float32 = partial(to_dtype, dtype=torch.float32) 24 | to_float64 = partial(to_dtype, dtype=torch.float64) 25 | to_double = to_float64 26 | to_int8 = partial(to_dtype, dtype=torch.int8) 27 | to_int16 = partial(to_dtype, dtype=torch.int16) 28 | to_int32 = partial(to_dtype, dtype=torch.int32) 29 | to_int64 = partial(to_dtype, dtype=torch.int64) 30 | 31 | 32 | def expand_many(x, axes): 33 | """Call expand_dims many times on x once for each item in axes.""" 34 | for ax in axes: 35 | x = torch.unsqueeze(x, ax) 36 | return x 37 | -------------------------------------------------------------------------------- /ats/utils/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SampleSoftmax(nn.Module): 7 | """ Apply softmax to the whole sample not just the last dimension. 8 | Arguments 9 | --------- 10 | squeeze_channels: bool, if True then squeeze the channel dimension of the input 11 | """ 12 | 13 | def __init__(self, squeeze_channels=False, smooth=0): 14 | self.squeeze_channels = squeeze_channels 15 | self.smooth = smooth 16 | super(SampleSoftmax, self).__init__() 17 | 18 | def forward(self, x): 19 | # Apply softmax to the whole x (per sample) 20 | s = x.shape 21 | x = F.softmax(x.reshape(s[0], -1), dim=-1) 22 | 23 | # Smooth the distribution 24 | if 0 < self.smooth < 1: 25 | x = x * (1 - self.smooth) 26 | x = x + self.smooth / float(x.shape[1]) 27 | 28 | # Finally reshape to the original shape 29 | x = x.reshape(s) 30 | 31 | # Squeeze the channels dimension if set 32 | if self.squeeze_channels: 33 | x = torch.squeeze(x, 1) 34 | 35 | return x 36 | -------------------------------------------------------------------------------- /ats/utils/logging.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.nn.functional as F 6 | import torchvision 7 | import matplotlib.pyplot as plt 8 | from tensorboardX import SummaryWriter 9 | 10 | 11 | class AttentionSaverTrafficSigns: 12 | """Save the attention maps to monitor model evolution.""" 13 | 14 | def __init__(self, output_directory, ats_model, training_set, opts): 15 | self.dir = output_directory 16 | os.makedirs(self.dir, exist_ok=True) 17 | self.ats_model = ats_model 18 | self.opts = opts 19 | 20 | idxs = training_set.strided(9) 21 | data = [training_set[i] for i in idxs] 22 | self.x_low = torch.stack([d[0] for d in data]).cpu() 23 | self.x_high = torch.stack([d[1] for d in data]).cpu() 24 | self.labels = torch.LongTensor([d[2] for d in data]).numpy() 25 | 26 | self.writer = SummaryWriter(os.path.join(self.dir, opts.run_name), flush_secs=5) 27 | self.on_train_begin() 28 | 29 | def on_train_begin(self): 30 | opts = self.opts 31 | with torch.no_grad(): 32 | _, _, _, x_low = self.ats_model(self.x_low.to(opts.device), self.x_high.to(opts.device)) 33 | x_low = x_low.cpu() 34 | image_list = [x for x in x_low] 35 | 36 | grid = torchvision.utils.make_grid(image_list, nrow=3, normalize=True, scale_each=True) 37 | 38 | self.writer.add_image('original_images', grid, global_step=0, dataformats='CHW') 39 | self.__call__(-1) 40 | 41 | def __call__(self, epoch, losses=None, metrics=None): 42 | opts = self.opts 43 | with torch.no_grad(): 44 | _, att, _, x_low = self.ats_model(self.x_low.to(opts.device), self.x_high.to(opts.device)) 45 | att = att.unsqueeze(1) 46 | att = F.interpolate(att, size=(x_low.shape[-2], x_low.shape[-1])) 47 | att = att.cpu() 48 | 49 | grid = torchvision.utils.make_grid(att, nrow=3, normalize=True, scale_each=True, pad_value=1.) 50 | self.writer.add_image('attention_map', grid, epoch, dataformats='CHW') 51 | 52 | if metrics is not None: 53 | train_metrics, test_metrics = metrics 54 | self.writer.add_scalar('Accuracy/Train', train_metrics['accuracy'], epoch) 55 | self.writer.add_scalar('Accuracy/Test', test_metrics['accuracy'], epoch) 56 | 57 | if losses is not None: 58 | train_loss, test_loss = losses 59 | self.writer.add_scalar('Loss/Train', train_loss, epoch) 60 | self.writer.add_scalar('Loss/Test', test_loss, epoch) 61 | 62 | @staticmethod 63 | def imsave(filepath, x): 64 | if x.shape[-1] == 3: 65 | plt.imshow(x) 66 | plt.savefig(filepath) 67 | else: 68 | plt.imshow(x, cmap='viridis') 69 | plt.savefig(filepath) 70 | 71 | @staticmethod 72 | def reverse_transform(inp): 73 | """ Do a reverse transformation. inp should be a torch tensor of shape [3, H, W] """ 74 | inp = inp.numpy().transpose((1, 2, 0)) 75 | mean = np.array([0.485, 0.456, 0.406]) 76 | std = np.array([0.229, 0.224, 0.225]) 77 | inp = std * inp + mean 78 | inp = np.clip(inp, 0, 1) 79 | inp = (inp * 255).astype(np.uint8) 80 | 81 | return inp 82 | 83 | @staticmethod 84 | def reverse_transform_torch(inp): 85 | """ Do a reverse transformation. inp should be a torch tensor of shape [3, H, W] """ 86 | inp = inp.numpy().transpose((1, 2, 0)) 87 | mean = np.array([0.485, 0.456, 0.406]) 88 | std = np.array([0.229, 0.224, 0.225]) 89 | inp = std * inp + mean 90 | inp = np.clip(inp, 0, 1) 91 | inp = torch.from_numpy(inp).permute(2, 0, 1) 92 | 93 | return inp 94 | 95 | 96 | class AttentionSaverMNIST: 97 | def __init__(self, output_directory, ats_model, dataset, opts): 98 | self.dir = output_directory 99 | os.makedirs(self.dir, exist_ok=True) 100 | self.ats_model = ats_model 101 | self.opts = opts 102 | 103 | idxs = [random.randrange(0, len(dataset)-1) for _ in range(9)] 104 | data = [dataset[i] for i in idxs] 105 | self.x_low = torch.stack([d[0] for d in data]).cpu() 106 | self.x_high = torch.stack([d[1] for d in data]).cpu() 107 | self.label = torch.LongTensor([d[2] for d in data]).numpy() 108 | 109 | self.writer = SummaryWriter(os.path.join(self.dir, opts.run_name), flush_secs=2) 110 | self.__call__(-1) 111 | 112 | def __call__(self, epoch, losses=None, metrics=None): 113 | opts = self.opts 114 | with torch.no_grad(): 115 | _, att, patches, x_low = self.ats_model(self.x_low.to(opts.device), self.x_high.to(opts.device)) 116 | att = att.unsqueeze(1) 117 | att = F.interpolate(att, size=(x_low.shape[-2], x_low.shape[-1])) 118 | att = att.cpu() 119 | 120 | grid = torchvision.utils.make_grid(att, nrow=3, normalize=True, scale_each=True, pad_value=1.) 121 | self.writer.add_image('attention_map', grid, epoch, dataformats='CHW') 122 | 123 | if metrics is not None: 124 | train_metrics, test_metrics = metrics 125 | self.writer.add_scalar('Accuracy/Train', train_metrics['accuracy'], epoch) 126 | self.writer.add_scalar('Accuracy/Test', test_metrics['accuracy'], epoch) 127 | 128 | if losses is not None: 129 | train_loss, test_loss = losses 130 | self.writer.add_scalar('Loss/Train', train_loss, epoch) 131 | self.writer.add_scalar('Loss/Test', test_loss, epoch) 132 | -------------------------------------------------------------------------------- /ats/utils/regularizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MultinomialEntropy(nn.Module): 6 | """Increase or decrease the entropy of a multinomial distribution. 7 | Arguments 8 | --------- 9 | strength: A float that defines the strength and direction of the 10 | regularizer. A positive number increases the entropy, a 11 | negative number decreases the entropy. 12 | eps: A small float to avoid numerical errors when computing the entropy 13 | """ 14 | 15 | def __init__(self, strength=1, eps=1e-6): 16 | super(MultinomialEntropy, self).__init__() 17 | if strength is None: 18 | self.strength = float(0) 19 | else: 20 | self.strength = float(strength) 21 | self.eps = float(eps) 22 | 23 | def forward(self, x): 24 | logx = torch.log(x + self.eps) 25 | # Formally the minus sign should be here 26 | return - self.strength * torch.sum(x * logx) / float(x.shape[0]) 27 | -------------------------------------------------------------------------------- /colon_cancer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | import argparse 6 | import time 7 | 8 | from models.attention_model import AttentionModelColonCancer 9 | from models.feature_model import FeatureModelColonCancer 10 | from models.classifier import ClassificationHead 11 | 12 | from ats.core.ats_layer import ATSModel 13 | from ats.utils.regularizers import MultinomialEntropy 14 | from ats.utils.logging import AttentionSaverTrafficSigns 15 | 16 | from dataset.colon_cancer_dataset import ColonCancerDataset 17 | from train import train, evaluate 18 | 19 | 20 | def main(opts): 21 | train_dataset = ColonCancerDataset('dataset/colon_cancer', train=True) 22 | train_loader = DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, num_workers=opts.num_workers, drop_last=False) 23 | 24 | attention_model = AttentionModelColonCancer(squeeze_channels=True, softmax_smoothing=0) 25 | feature_model = FeatureModelColonCancer(in_channels=3, out_channels=500) 26 | classification_head = ClassificationHead(in_channels=500, num_classes=len(train_dataset.CLASSES)) 27 | 28 | ats_model = ATSModel(attention_model, feature_model, classification_head, n_patches=opts.n_patches, 29 | patch_size=opts.patch_size) 30 | ats_model = ats_model.to(opts.device) 31 | optimizer = optim.Adam(ats_model.parameters(), lr=opts.lr) 32 | 33 | logger = AttentionSaverTrafficSigns(opts.output_dir, ats_model, train_dataset, opts) 34 | 35 | criterion = nn.CrossEntropyLoss() 36 | entropy_loss_func = MultinomialEntropy(opts.regularizer_strength) 37 | 38 | for epoch in range(opts.epochs): 39 | train_loss, train_metrics = train(ats_model, optimizer, train_loader, 40 | criterion, entropy_loss_func, opts) 41 | 42 | with torch.no_grad(): 43 | test_loss, test_metrics = evaluate(ats_model, train_loader, criterion, 44 | entropy_loss_func, opts) 45 | 46 | logger(epoch, (train_loss, test_loss), (train_metrics, test_metrics)) 47 | 48 | 49 | if __name__ == '__main__': 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument("--regularizer_strength", type=float, default=0.01, 52 | help="How strong should the regularization be for the attention") 53 | parser.add_argument("--softmax_smoothing", type=float, default=1e-4, 54 | help="Smoothing for calculating the attention map") 55 | parser.add_argument("--lr", type=float, default=1e-3, help="Set the optimizer's learning rate") 56 | parser.add_argument("--n_patches", type=int, default=10, help="How many patches to sample") 57 | parser.add_argument("--patch_size", type=int, default=27, help="Patch size of a square patch") 58 | parser.add_argument("--batch_size", type=int, default=8, help="Choose the batch size for SGD") 59 | parser.add_argument("--epochs", type=int, default=500, help="How many epochs to train for") 60 | parser.add_argument("--decrease_lr_at", type=float, default=250, help="Decrease the learning rate in this epoch") 61 | parser.add_argument("--clipnorm", type=float, default=1, help="Clip the norm of the gradients") 62 | parser.add_argument("--output_dir", type=str, help="An output directory", default='output/colon_cancer') 63 | parser.add_argument('--run_name', type=str, default='run') 64 | parser.add_argument('--num_workers', type=int, default=20, help='Number of workers to use for data loading') 65 | 66 | opts = parser.parse_args() 67 | opts.run_name = f"{opts.run_name}_{time.strftime('%Y%m%dT%H%M%S')}" 68 | opts.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 69 | 70 | main(opts) 71 | -------------------------------------------------------------------------------- /dataset/colon_cancer/.gitkeep: -------------------------------------------------------------------------------- 1 | keep 2 | -------------------------------------------------------------------------------- /dataset/colon_cancer_dataset.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from functools import partial 3 | import hashlib 4 | import os 5 | from PIL import Image 6 | import torch 7 | import urllib.request 8 | from os import path 9 | import sys 10 | import zipfile 11 | import torch.nn.functional as F 12 | from torch.utils.data import Dataset, DataLoader 13 | from torchvision import transforms 14 | import numpy as np 15 | import scipy.io 16 | import pdb 17 | import matplotlib.pyplot as plt 18 | import imageio 19 | 20 | 21 | class ColonCancerDataset(Dataset): 22 | 23 | CLASSES = [0, 1] 24 | 25 | def __init__(self, directory, train=True): 26 | cwd = os.getcwd().replace('dataset', '') 27 | directory = path.join(cwd, directory) 28 | 29 | self.data = [os.path.join(directory, x) for x in os.listdir(directory)] 30 | 31 | if train: 32 | self.image_transform = transforms.Compose([transforms.ToPILImage(), 33 | transforms.ColorJitter(0.1, 0.1, 0.1, 0.1), 34 | transforms.ToTensor() 35 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 36 | ]) 37 | 38 | def __len__(self): 39 | return len(self.data) 40 | 41 | def __getitem__(self, i): 42 | 43 | folder_path = self.data[i] 44 | img_id = int(folder_path.split('/')[-1].replace('img', '')) 45 | 46 | mat = scipy.io.loadmat(path.join(folder_path, f'img{img_id}_epithelial.mat'))['detection'] 47 | x_high = imageio.imread(path.join(folder_path, f'img{img_id}.bmp')) 48 | 49 | x_high = self.image_transform(x_high) 50 | x_low = F.interpolate(x_high[None, ...], scale_factor=0.2, mode='bilinear')[0] 51 | 52 | category = int(mat.shape[0] > 0) 53 | return x_low, x_high, category 54 | 55 | def strided(self, N): 56 | """Extract N images almost in equal proportions from each category.""" 57 | order = np.arange(len(self.data)) 58 | np.random.shuffle(order) 59 | idxs = [] 60 | cat = 0 61 | while len(idxs) < N: 62 | for i in order: 63 | _, _, category = self[i] 64 | if cat == category: 65 | idxs.append(i) 66 | cat = (cat + 1) % len(self.CLASSES) 67 | if len(idxs) >= N: 68 | break 69 | return idxs 70 | 71 | 72 | if __name__ == '__main__': 73 | colon_cancer_dataset = ColonCancerDataset('colon_cancer', train=True) 74 | print() -------------------------------------------------------------------------------- /dataset/mega_mnist/.gitkeep: -------------------------------------------------------------------------------- 1 | keep 2 | -------------------------------------------------------------------------------- /dataset/mega_mnist_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os import path 3 | import numpy as np 4 | from torch.utils.data import Dataset, DataLoader 5 | import torch 6 | import torch.nn.functional as F 7 | from torchvision import transforms 8 | 9 | 10 | class MNIST(Dataset): 11 | """Load a Megapixel MNIST dataset. See make_mnist.py.""" 12 | 13 | CLASSES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 14 | 15 | def __init__(self, dataset_dir, train=True): 16 | with open(path.join(dataset_dir, "parameters.json")) as f: 17 | self.parameters = json.load(f) 18 | 19 | filename = "train.npy" if train else "test.npy" 20 | N = self.parameters["n_train" if train else "n_test"] 21 | W = self.parameters["width"] 22 | H = self.parameters["height"] 23 | self.scale = self.parameters["scale"] 24 | 25 | self._high_shape = (H, W, 1) 26 | self._low_shape = (int(self.scale*H), int(self.scale*W), 1) 27 | self._data = np.load(path.join(dataset_dir, filename), allow_pickle=True) 28 | self.image_transform = transforms.Normalize([0.5], [0.5]) 29 | 30 | def __len__(self): 31 | return len(self._data) 32 | 33 | def __getitem__(self, i): 34 | if i >= len(self): 35 | raise IndexError() 36 | 37 | # Placeholders 38 | x_high = np.zeros(self._high_shape, dtype=np.float32).ravel() 39 | 40 | # Fill the sparse representations 41 | data = self._data[i] 42 | x_high[data[1][0]] = data[1][1] 43 | 44 | # Reshape to their final shape 45 | x_high = x_high.reshape(self._high_shape) 46 | 47 | x_high = torch.from_numpy(x_high) 48 | x_high = x_high.permute(2, 0, 1) 49 | x_high = self.image_transform(x_high) 50 | x_low = F.interpolate(x_high[None, ...], scale_factor=self.scale)[0] 51 | 52 | label = np.argmax(data[2]) 53 | 54 | return x_low, x_high, label 55 | 56 | 57 | if __name__ == '__main__': 58 | mnist_dataset = MNIST('mega_mnist', train=True) 59 | mnist_dataset[0] -------------------------------------------------------------------------------- /dataset/speed_limits_dataset.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from functools import partial 3 | import hashlib 4 | import os 5 | from PIL import Image 6 | import torch 7 | import urllib.request 8 | from os import path 9 | import sys 10 | import zipfile 11 | import torch.nn.functional as F 12 | from torch.utils.data import Dataset, DataLoader 13 | from torchvision import transforms 14 | import numpy as np 15 | 16 | 17 | def check_file(filepath, md5sum): 18 | """Check a file against an md5 hash value. 19 | Returns 20 | ------- 21 | True if the file exists and has the given md5 sum False otherwise 22 | """ 23 | try: 24 | md5 = hashlib.md5() 25 | with open(filepath, "rb") as f: 26 | for chunk in iter(partial(f.read, 4096), b""): 27 | md5.update(chunk) 28 | return md5.hexdigest() == md5sum 29 | except FileNotFoundError: 30 | return False 31 | 32 | 33 | def ensure_dataset_exists(directory, tries=1, progress_file=sys.stderr): 34 | """Ensure that the dataset is downloaded and is correct. 35 | Correctness is checked only against the annotations files. 36 | """ 37 | set1_url = ("http://www.isy.liu.se/cvl/research/trafficSigns" 38 | "/swedishSignsSummer/Set1/Set1Part0.zip") 39 | set1_annotations_url = ("http://www.isy.liu.se/cvl/research/trafficSigns" 40 | "/swedishSignsSummer/Set1/annotations.txt") 41 | set1_annotations_md5 = "9106a905a86209c95dc9b51d12f520d6" 42 | set2_url = ("http://www.isy.liu.se/cvl/research/trafficSigns" 43 | "/swedishSignsSummer/Set2/Set2Part0.zip") 44 | set2_annotations_url = ("http://www.isy.liu.se/cvl/research/trafficSigns" 45 | "/swedishSignsSummer/Set2/annotations.txt") 46 | set2_annotations_md5 = "09debbc67f6cd89c1e2a2688ad1d03ca" 47 | 48 | integrity = ( 49 | check_file( 50 | path.join(directory, "Set1", "annotations.txt"), 51 | set1_annotations_md5 52 | ) and check_file( 53 | path.join(directory, "Set2", "annotations.txt"), 54 | set2_annotations_md5 55 | ) 56 | ) 57 | 58 | if integrity: 59 | return 60 | 61 | if tries <= 0: 62 | raise RuntimeError(("Cannot download dataset or dataset download " 63 | "is corrupted")) 64 | 65 | print("Downloading Set1", file=progress_file) 66 | download_file(set1_url, path.join(directory, "Set1.zip"), 67 | progress_file=progress_file) 68 | print("Extracting...", file=progress_file) 69 | with zipfile.ZipFile(path.join(directory, "Set1.zip")) as archive: 70 | archive.extractall(path.join(directory, "Set1")) 71 | print("Getting annotation file", file=progress_file) 72 | download_file( 73 | set1_annotations_url, 74 | path.join(directory, "Set1", "annotations.txt"), 75 | progress_file=progress_file 76 | ) 77 | print("Downloading Set2", file=progress_file) 78 | download_file(set2_url, path.join(directory, "Set2.zip"), 79 | progress_file=progress_file) 80 | print("Extracting...", file=progress_file) 81 | with zipfile.ZipFile(path.join(directory, "Set2.zip")) as archive: 82 | archive.extractall(path.join(directory, "Set2")) 83 | print("Getting annotation file", file=progress_file) 84 | download_file( 85 | set2_annotations_url, 86 | path.join(directory, "Set2", "annotations.txt"), 87 | progress_file=progress_file 88 | ) 89 | 90 | return ensure_dataset_exists( 91 | directory, 92 | tries=tries - 1, 93 | progress_file=progress_file 94 | ) 95 | 96 | 97 | def download_file(url, destination, progress_file=sys.stderr): 98 | """Download a file with progress.""" 99 | response = urllib.request.urlopen(url) 100 | n_bytes = response.headers.get("Content-Length") 101 | if n_bytes == "": 102 | n_bytes = 0 103 | else: 104 | n_bytes = int(n_bytes) 105 | 106 | message = "\rReceived {} / {}" 107 | cnt = 0 108 | with open(destination, "wb") as dst: 109 | while True: 110 | print(message.format(cnt, n_bytes), file=progress_file, 111 | end="", flush=True) 112 | data = response.read(65535) 113 | if len(data) == 0: 114 | break 115 | dst.write(data) 116 | cnt += len(data) 117 | print(file=progress_file) 118 | 119 | 120 | class Sign(namedtuple("Sign", ["visibility", "bbox", "type", "name"])): 121 | """A sign object. Useful for making ground truth images as well as making 122 | the dataset.""" 123 | 124 | @property 125 | def x_min(self): 126 | return self.bbox[2] 127 | 128 | @property 129 | def x_max(self): 130 | return self.bbox[0] 131 | 132 | @property 133 | def y_min(self): 134 | return self.bbox[3] 135 | 136 | @property 137 | def y_max(self): 138 | return self.bbox[1] 139 | 140 | @property 141 | def area(self): 142 | return (self.x_max - self.x_min) * (self.y_max - self.y_min) 143 | 144 | @property 145 | def center(self): 146 | return [ 147 | (self.y_max - self.y_min) / 2 + self.y_min, 148 | (self.x_max - self.x_min) / 2 + self.x_min 149 | ] 150 | 151 | @property 152 | def visibility_index(self): 153 | visibilities = ["VISIBLE", "BLURRED", "SIDE_ROAD", "OCCLUDED"] 154 | return visibilities.index(self.visibility) 155 | 156 | def pixels(self, scale, size): 157 | return zip(*( 158 | (i, j) 159 | for i in range(round(self.y_min * scale), round(self.y_max * scale) + 1) 160 | for j in range(round(self.x_min * scale), round(self.x_max * scale) + 1) 161 | if i < round(size[0] * scale) and j < round(size[1] * scale) 162 | )) 163 | 164 | def __lt__(self, other): 165 | if not isinstance(other, Sign): 166 | raise ValueError("Signs can only be compared to signs") 167 | 168 | if self.visibility_index != other.visibility_index: 169 | return self.visibility_index < other.visibility_index 170 | 171 | return self.area > other.area 172 | 173 | 174 | class STS: 175 | """The STS class reads the annotations and creates the corresponding 176 | Sign objects.""" 177 | 178 | def __init__(self, directory, train=True, seed=0): 179 | cwd = os.getcwd().replace('dataset', '') 180 | directory = path.join(cwd, directory) 181 | ensure_dataset_exists(directory) 182 | 183 | self._directory = directory 184 | self._inner = "Set{}".format(1 + ((seed + 1 + int(train)) % 2)) 185 | self._data = self._load_signs(self._directory, self._inner) 186 | 187 | def _load_files(self, directory, inner): 188 | files = set() 189 | with open(path.join(directory, inner, "annotations.txt")) as f: 190 | for l in f: 191 | files.add(l.split(":", 1)[0]) 192 | return sorted(files) 193 | 194 | def _read_bbox(self, parts): 195 | def _float(x): 196 | try: 197 | return float(x) 198 | except ValueError: 199 | if len(x) > 0: 200 | return _float(x[:-1]) 201 | raise 202 | 203 | return [_float(x) for x in parts] 204 | 205 | def _load_signs(self, directory, inner): 206 | with open(path.join(directory, inner, "annotations.txt")) as f: 207 | lines = [l.strip() for l in f] 208 | keys, values = zip(*(l.split(":", 1) for l in lines)) 209 | all_signs = [] 210 | for v in values: 211 | signs = [] 212 | for sign in v.split(";"): 213 | if sign == [""] or sign == "": 214 | continue 215 | parts = [s.strip() for s in sign.split(",")] 216 | if parts[0] == "MISC_SIGNS": 217 | continue 218 | signs.append(Sign( 219 | visibility=parts[0], 220 | bbox=self._read_bbox(parts[1:5]), 221 | type=parts[5], 222 | name=parts[6] 223 | )) 224 | all_signs.append(signs) 225 | images = [path.join(directory, inner, f) for f in keys] 226 | 227 | return list(zip(images, all_signs)) 228 | 229 | def __len__(self): 230 | return len(self._data) 231 | 232 | def __getitem__(self, i): 233 | return self._data[i] 234 | 235 | 236 | class SpeedLimits(Dataset): 237 | """Provide a Keras Sequence for the SpeedLimits dataset which is basically 238 | a filtered version of the STS dataset. 239 | Arguments 240 | --------- 241 | directory: str, The directory that the dataset already is or is going 242 | to be downloaded in 243 | train: bool, Select the training or testing sets 244 | seed: int, The prng seed for the dataset 245 | """ 246 | LIMITS = ["50_SIGN", "70_SIGN", "80_SIGN"] 247 | CLASSES = ["EMPTY", *LIMITS] 248 | 249 | def __init__(self, directory, train=True, seed=0): 250 | self._data = self._filter(STS(directory, train, seed)) 251 | if train: 252 | self.image_transform = transforms.Compose([transforms.ColorJitter(0.1, 0.1, 0.1, 0.1), 253 | transforms.RandomAffine(degrees=0, 254 | translate=(100 / 1280, 100 / 960)), 255 | transforms.ToTensor() 256 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 257 | ]) 258 | else: 259 | self.image_transform = transforms.Compose([transforms.ToTensor() 260 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 261 | ]) 262 | 263 | weights = make_weights_for_balanced_classes(self._data, len(self.CLASSES)) 264 | weights = torch.DoubleTensor(weights) 265 | self.sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights)) 266 | 267 | def _filter(self, data): 268 | filtered = [] 269 | for image, signs in data: 270 | signs, acceptable = self._acceptable(signs) 271 | if acceptable: 272 | if not signs: 273 | filtered.append((image, 0)) 274 | else: 275 | filtered.append((image, self.CLASSES.index(signs[0].name))) 276 | return filtered 277 | 278 | def _acceptable(self, signs): 279 | # Keep it as empty 280 | if not signs: 281 | return signs, True 282 | 283 | # Filter just the speed limits and sort them wrt visibility 284 | signs = sorted(s for s in signs if s.name in self.LIMITS) 285 | 286 | # No speed limit but many other signs 287 | if not signs: 288 | return None, False 289 | 290 | # Not visible sign so skip 291 | if signs[0].visibility != "VISIBLE": 292 | return None, False 293 | 294 | return signs, True 295 | 296 | def __len__(self): 297 | return len(self._data) 298 | 299 | def __getitem__(self, i): 300 | image, category = self._data[i] 301 | 302 | x_high = Image.open(image) 303 | x_high = self.image_transform(x_high) 304 | 305 | x_low = F.interpolate(x_high[None, ...], scale_factor=0.3, mode='bilinear')[0] 306 | return x_low, x_high, category 307 | 308 | @property 309 | def image_size(self): 310 | return self[0][0].shape[1:] 311 | 312 | @property 313 | def class_frequencies(self): 314 | """Compute and return the class specific frequencies.""" 315 | freqs = np.zeros(len(self.CLASSES), dtype=np.float32) 316 | for image, category in self._data: 317 | freqs[category] += 1 318 | return freqs / len(self._data) 319 | 320 | def strided(self, N): 321 | """Extract N images almost in equal proportions from each category.""" 322 | order = np.arange(len(self._data)) 323 | np.random.shuffle(order) 324 | idxs = [] 325 | cat = 0 326 | while len(idxs) < N: 327 | for i in order: 328 | image, category = self._data[i] 329 | if cat == category: 330 | idxs.append(i) 331 | cat = (cat + 1) % len(self.CLASSES) 332 | if len(idxs) >= N: 333 | break 334 | return idxs 335 | 336 | 337 | def make_weights_for_balanced_classes(images, num_classes): 338 | count = [0] * num_classes 339 | for item in images: 340 | count[item[1]] += 1 341 | weight_per_class = [0.] * num_classes 342 | N = float(sum(count)) 343 | for i in range(num_classes): 344 | weight_per_class[i] = N / float(count[i]) 345 | weight = [0] * len(images) 346 | for idx, val in enumerate(images): 347 | weight[idx] = weight_per_class[val[1]] 348 | return weight 349 | 350 | 351 | def reverse_transform(inp): 352 | """ Do a reverse transformation. inp should be of shape [3, H, W] """ 353 | inp = inp.numpy().transpose((1, 2, 0)) 354 | mean = np.array([0.485, 0.456, 0.406]) 355 | std = np.array([0.229, 0.224, 0.225]) 356 | inp = std * inp + mean 357 | inp = np.clip(inp, 0, 1) 358 | inp = (inp * 255).astype(np.uint8) 359 | 360 | return inp 361 | 362 | 363 | if __name__ == '__main__': 364 | speedlimit_dataset = SpeedLimits('traffic_data') 365 | 366 | speedlimit_dataloader = DataLoader(speedlimit_dataset, shuffle=True, batch_size=4) 367 | 368 | for i, (x_low, x_high, label) in enumerate(speedlimit_dataloader): 369 | print(x_low) 370 | -------------------------------------------------------------------------------- /dataset/traffic_data/.gitkeep: -------------------------------------------------------------------------------- 1 | keep 2 | -------------------------------------------------------------------------------- /mega_mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | import argparse 6 | import time 7 | 8 | from models.attention_model import AttentionModelMNIST 9 | from models.feature_model import FeatureModelMNIST 10 | from models.classifier import ClassificationHead 11 | 12 | from ats.core.ats_layer import ATSModel 13 | from ats.utils.regularizers import MultinomialEntropy 14 | from ats.utils.logging import AttentionSaverMNIST 15 | 16 | from dataset.mega_mnist_dataset import MNIST 17 | from train import train, evaluate 18 | 19 | 20 | def main(opts): 21 | train_dataset = MNIST('dataset/mega_mnist', train=True) 22 | train_loader = DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, num_workers=1) 23 | 24 | test_dataset = MNIST('dataset/mega_mnist', train=False) 25 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=opts.batch_size, num_workers=1) 26 | 27 | attention_model = AttentionModelMNIST(squeeze_channels=True, softmax_smoothing=1e-4) 28 | feature_model = FeatureModelMNIST(in_channels=1) 29 | classification_head = ClassificationHead(in_channels=32, num_classes=10) 30 | 31 | ats_model = ATSModel(attention_model, feature_model, classification_head, n_patches=opts.n_patches, patch_size=opts.patch_size) 32 | ats_model = ats_model.to(opts.device) 33 | optimizer = optim.Adam(ats_model.parameters(), lr=opts.lr) 34 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.decrease_lr_at, gamma=0.1) 35 | 36 | logger = AttentionSaverMNIST(opts.output_dir, ats_model, test_dataset, opts) 37 | 38 | criterion = nn.CrossEntropyLoss() 39 | entropy_loss_func = MultinomialEntropy(opts.regularizer_strength) 40 | 41 | for epoch in range(opts.epochs): 42 | train_loss, train_metrics = train(ats_model, optimizer, train_loader, 43 | criterion, entropy_loss_func, opts) 44 | 45 | with torch.no_grad(): 46 | test_loss, test_metrics = evaluate(ats_model, test_loader, criterion, 47 | entropy_loss_func, opts) 48 | 49 | logger(epoch, (train_loss, test_loss), (train_metrics, test_metrics)) 50 | scheduler.step() 51 | 52 | 53 | if __name__ == '__main__': 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--regularizer_strength", type=float, default=0.0001, 56 | help="How strong should the regularization be for the attention") 57 | parser.add_argument("--softmax_smoothing", type=float, default=1e-5, 58 | help="Smoothing for calculating the attention map") 59 | parser.add_argument("--lr", type=float, default=0.001, help="Set the optimizer's learning rate") 60 | parser.add_argument("--n_patches", type=int, default=10, help="How many patches to sample") 61 | parser.add_argument("--patch_size", type=int, default=50, help="Patch size of a square patch") 62 | parser.add_argument("--batch_size", type=int, default=128, help="Choose the batch size for SGD") 63 | parser.add_argument("--epochs", type=int, default=500, help="How many epochs to train for") 64 | parser.add_argument("--decrease_lr_at", type=float, default=1000, help="Decrease the learning rate in this epoch") 65 | parser.add_argument("--clipnorm", type=float, default=1, help="Clip the norm of the gradients") 66 | parser.add_argument("--output_dir", type=str, help="An output directory", default='output/mnist') 67 | parser.add_argument('--run_name', type=str, default='run') 68 | parser.add_argument('--num_workers', type=int, default=20, help='Number of workers to use for data loading') 69 | 70 | opts = parser.parse_args() 71 | opts.run_name = f"{opts.run_name}_{time.strftime('%Y%m%dT%H%M%S')}" 72 | opts.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 73 | 74 | main(opts) 75 | -------------------------------------------------------------------------------- /models/attention_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ats.utils.layers import SampleSoftmax 4 | 5 | 6 | class AttentionModelTrafficSigns(nn.Module): 7 | """ Base class for calculating the attention map of a low resolution image """ 8 | 9 | def __init__(self, 10 | squeeze_channels=False, 11 | softmax_smoothing=0.0): 12 | super(AttentionModelTrafficSigns, self).__init__() 13 | 14 | conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, padding_mode='valid') 15 | relu1 = nn.ReLU() 16 | 17 | conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding_mode='valid') 18 | relu2 = nn.ReLU() 19 | 20 | conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding_mode='valid') 21 | relu3 = nn.ReLU() 22 | 23 | conv4 = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, padding_mode='valid') 24 | 25 | pool = nn.MaxPool2d(kernel_size=8) 26 | sample_softmax = SampleSoftmax(squeeze_channels, softmax_smoothing) 27 | 28 | self.part1 = nn.Sequential(conv1, relu1, conv2, relu2, conv3, relu3) 29 | self.part2 = nn.Sequential(conv4, pool, sample_softmax) 30 | 31 | def forward(self, x_low): 32 | out = self.part1(x_low) 33 | 34 | out = self.part2(out) 35 | 36 | return out 37 | 38 | 39 | class AttentionModelMNIST(nn.Module): 40 | """ Base class for calculating the attention map of a low resolution image """ 41 | 42 | def __init__(self, 43 | squeeze_channels=False, 44 | softmax_smoothing=0.0): 45 | super(AttentionModelMNIST, self).__init__() 46 | 47 | self.squeeze_channels = squeeze_channels 48 | self.softmax_smoothing = softmax_smoothing 49 | 50 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1, padding_mode='reflect') 51 | self.tanh1 = nn.Tanh() 52 | 53 | self.conv2 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=3, padding=1, padding_mode='reflect') 54 | self.tanh2 = nn.Tanh() 55 | 56 | self.conv3 = nn.Conv2d(in_channels=8, out_channels=1, kernel_size=3, padding=1, padding_mode='reflect') 57 | 58 | self.sample_softmax = SampleSoftmax(squeeze_channels, softmax_smoothing) 59 | 60 | def forward(self, x_low): 61 | out = self.conv1(x_low) 62 | out = self.tanh1(out) 63 | 64 | out = self.conv2(out) 65 | out = self.tanh2(out) 66 | 67 | out = self.conv3(out) 68 | out = self.sample_softmax(out) 69 | 70 | return out 71 | 72 | 73 | class AttentionModelColonCancer(nn.Module): 74 | """ Base class for calculating the attention map of a low resolution image """ 75 | 76 | def __init__(self, 77 | squeeze_channels=False, 78 | softmax_smoothing=0.0): 79 | super(AttentionModelColonCancer, self).__init__() 80 | 81 | conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, padding_mode='zeros', padding=1) 82 | relu1 = nn.ReLU() 83 | 84 | conv2 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=3, padding_mode='zeros', padding=1) 85 | relu2 = nn.ReLU() 86 | 87 | conv3 = nn.Conv2d(in_channels=8, out_channels=1, kernel_size=3, padding_mode='zeros', padding=1) 88 | 89 | sample_softmax = SampleSoftmax(squeeze_channels, softmax_smoothing) 90 | 91 | self.forward_pass = nn.Sequential(conv1, relu1, conv2, relu2, conv3, sample_softmax) 92 | 93 | def forward(self, x_low): 94 | out = self.forward_pass(x_low) 95 | return out 96 | -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ClassificationHead(nn.Module): 5 | 6 | def __init__(self, in_channels, num_classes): 7 | super(ClassificationHead, self).__init__() 8 | 9 | self.classifier = nn.Linear(in_channels, num_classes) 10 | 11 | def forward(self, x): 12 | return self.classifier(x) 13 | -------------------------------------------------------------------------------- /models/feature_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import pdb 4 | 5 | 6 | def conv_layer(in_channels, out_channels, kernel, strides, padding=1): 7 | return nn.Conv2d(in_channels, out_channels, kernel_size=kernel, stride=strides, padding_mode="zeros", bias=False, 8 | padding=padding) 9 | 10 | 11 | def batch_norm(filters): 12 | return nn.BatchNorm2d(filters) 13 | 14 | 15 | def relu(): 16 | return nn.ReLU() 17 | 18 | 19 | class Block(nn.Module): 20 | 21 | def __init__(self, in_channels, out_channels, stride, kernel_size, short): 22 | super(Block, self).__init__() 23 | 24 | self.short = short 25 | self.bn1 = batch_norm(in_channels) 26 | self.relu1 = relu() 27 | self.conv1 = conv_layer(in_channels, out_channels, 1, stride, padding=0) 28 | 29 | self.conv2 = conv_layer(in_channels, out_channels, kernel_size, stride) 30 | self.bn2 = batch_norm(out_channels) 31 | self.relu2 = relu() 32 | self.conv3 = conv_layer(out_channels, out_channels, kernel_size, 1) 33 | 34 | def forward(self, x): 35 | x = self.bn1(x) 36 | x = self.relu1(x) 37 | 38 | x_short = x 39 | if self.short: 40 | x_short = self.conv1(x) 41 | 42 | x = self.conv2(x) 43 | x = self.bn2(x) 44 | x = self.relu2(x) 45 | x = self.conv3(x) 46 | 47 | out = x + x_short 48 | return out 49 | 50 | 51 | class FeatureModelTrafficSigns(nn.Module): 52 | 53 | def __init__(self, in_channels, strides=[1, 2, 2, 2], filters=[32, 32, 32, 32]): 54 | super(FeatureModelTrafficSigns, self).__init__() 55 | 56 | stride_prev = strides.pop(0) 57 | filters_prev = filters.pop(0) 58 | 59 | self.conv1 = conv_layer(in_channels, filters_prev, 3, stride_prev) 60 | 61 | module_list = nn.ModuleList() 62 | for s, f in zip(strides, filters): 63 | module_list.append(Block(filters_prev, f, s, 3, s != 1 or f != filters_prev)) 64 | 65 | stride_prev = s 66 | filters_prev = f 67 | 68 | self.module_list = nn.Sequential(*module_list) 69 | 70 | self.bn1 = batch_norm(filters_prev) 71 | self.relu1 = relu() 72 | self.pool = nn.AvgPool2d(kernel_size=(13, 13)) 73 | 74 | def forward(self, x): 75 | out = self.conv1(x) 76 | out = self.module_list(out) 77 | out = self.bn1(out) 78 | out = self.relu1(out) 79 | out = self.pool(out) 80 | out = out.view(out.shape[0], out.shape[1]) 81 | out = F.normalize(out, p=2, dim=-1) 82 | return out 83 | 84 | 85 | class FeatureModelMNIST(nn.Module): 86 | 87 | def __init__(self, in_channels): 88 | super(FeatureModelMNIST, self).__init__() 89 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=7) 90 | self.relu1 = relu() 91 | 92 | self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3) 93 | self.relu2 = relu() 94 | 95 | self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3) 96 | self.relu3 = relu() 97 | 98 | self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3) 99 | self.relu4 = relu() 100 | 101 | self.pool = nn.AvgPool2d(kernel_size=(38, 38)) 102 | 103 | def forward(self, x): 104 | out = self.conv1(x) 105 | out = self.relu1(out) 106 | 107 | out = self.conv2(out) 108 | out = self.relu2(out) 109 | 110 | out = self.conv3(out) 111 | out = self.relu3(out) 112 | 113 | out = self.conv4(out) 114 | out = self.relu4(out) 115 | 116 | out = self.pool(out) 117 | out = out.view(out.shape[0], out.shape[1]) 118 | out = F.normalize(out, p=2, dim=-1) 119 | 120 | return out 121 | 122 | 123 | class FeatureModelColonCancer(nn.Module): 124 | 125 | def __init__(self, in_channels, out_channels): 126 | super(FeatureModelColonCancer, self).__init__() 127 | 128 | self.feature_extractor_part1 = nn.Sequential( 129 | nn.Conv2d(in_channels, 20, kernel_size=5), 130 | nn.ReLU(), 131 | nn.MaxPool2d(2, stride=2), 132 | nn.Conv2d(20, 50, kernel_size=5), 133 | nn.ReLU(), 134 | nn.MaxPool2d(2, stride=2) 135 | ) 136 | 137 | self.feature_extractor_part2 = nn.Sequential( 138 | nn.Linear(50 * 3 * 3, out_channels), 139 | nn.ReLU(), 140 | ) 141 | 142 | def forward(self, x): 143 | out = self.feature_extractor_part1(x) 144 | out = out.view(out.shape[0], -1) 145 | 146 | out = self.feature_extractor_part2(out) 147 | 148 | out = F.normalize(out, p=2, dim=-1) 149 | return out 150 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | imageio==2.8.0 3 | joblib==0.14.1 4 | kiwisolver==1.2.0 5 | matplotlib==3.2.1 6 | numpy==1.18.2 7 | opencv-python==4.2.0.34 8 | Pillow==7.1.1 9 | protobuf==3.11.3 10 | pyparsing==2.4.7 11 | python-dateutil==2.8.1 12 | scikit-learn==0.22.2.post1 13 | scipy==1.4.1 14 | six==1.14.0 15 | sklearn==0.0 16 | tensorboardX==2.0 17 | torch==1.4.0 18 | torchvision==0.5.0 19 | tqdm==4.45.0 20 | -------------------------------------------------------------------------------- /speed_limits.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | import argparse 6 | import time 7 | 8 | from models.attention_model import AttentionModelTrafficSigns 9 | from models.feature_model import FeatureModelTrafficSigns 10 | from models.classifier import ClassificationHead 11 | 12 | from ats.core.ats_layer import ATSModel 13 | from ats.utils.regularizers import MultinomialEntropy 14 | from ats.utils.logging import AttentionSaverTrafficSigns 15 | 16 | from dataset.speed_limits_dataset import SpeedLimits 17 | from train import train, evaluate 18 | 19 | 20 | def main(opts): 21 | train_dataset = SpeedLimits('dataset/traffic_data', train=True) 22 | train_loader = DataLoader(train_dataset, batch_size=opts.batch_size, shuffle=True, num_workers=opts.num_workers) 23 | 24 | test_dataset = SpeedLimits('dataset/traffic_data', train=False) 25 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=opts.batch_size, num_workers=opts.num_workers) 26 | 27 | attention_model = AttentionModelTrafficSigns(squeeze_channels=True, softmax_smoothing=1e-4) 28 | feature_model = FeatureModelTrafficSigns(in_channels=3, strides=[1, 2, 2, 2], filters=[32, 32, 32, 32]) 29 | classification_head = ClassificationHead(in_channels=32, num_classes=len(train_dataset.CLASSES)) 30 | 31 | ats_model = ATSModel(attention_model, feature_model, classification_head, n_patches=opts.n_patches, patch_size=opts.patch_size) 32 | ats_model = ats_model.to(opts.device) 33 | optimizer = optim.Adam([{'params': ats_model.attention_model.part1.parameters(), 'weight_decay': 1e-5}, 34 | {'params': ats_model.attention_model.part2.parameters()}, 35 | {'params': ats_model.feature_model.parameters()}, 36 | {'params': ats_model.classifier.parameters()}, 37 | {'params': ats_model.sampler.parameters()}, 38 | {'params': ats_model.expectation.parameters()} 39 | ], lr=opts.lr) 40 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.decrease_lr_at, gamma=0.1) 41 | 42 | logger = AttentionSaverTrafficSigns(opts.output_dir, ats_model, test_dataset, opts) 43 | class_weights = train_dataset.class_frequencies 44 | class_weights = torch.from_numpy((1. / len(class_weights)) / class_weights).to(opts.device) 45 | 46 | criterion = nn.CrossEntropyLoss(weight=class_weights) 47 | entropy_loss_func = MultinomialEntropy(opts.regularizer_strength) 48 | 49 | for epoch in range(opts.epochs): 50 | train_loss, train_metrics = train(ats_model, optimizer, train_loader, 51 | criterion, entropy_loss_func, opts) 52 | 53 | with torch.no_grad(): 54 | test_loss, test_metrics = evaluate(ats_model, test_loader, criterion, 55 | entropy_loss_func, opts) 56 | 57 | logger(epoch, (train_loss, test_loss), (train_metrics, test_metrics)) 58 | scheduler.step() 59 | 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument("--regularizer_strength", type=float, default=0.05, 64 | help="How strong should the regularization be for the attention") 65 | parser.add_argument("--softmax_smoothing", type=float, default=1e-4, 66 | help="Smoothing for calculating the attention map") 67 | parser.add_argument("--lr", type=float, default=0.001, help="Set the optimizer's learning rate") 68 | parser.add_argument("--n_patches", type=int, default=5, help="How many patches to sample") 69 | parser.add_argument("--patch_size", type=int, default=100, help="Patch size of a square patch") 70 | parser.add_argument("--batch_size", type=int, default=32, help="Choose the batch size for SGD") 71 | parser.add_argument("--epochs", type=int, default=500, help="How many epochs to train for") 72 | parser.add_argument("--decrease_lr_at", type=float, default=250, help="Decrease the learning rate in this epoch") 73 | parser.add_argument("--clipnorm", type=float, default=1, help="Clip the norm of the gradients") 74 | parser.add_argument("--output_dir", type=str, help="An output directory", default='output/traffic') 75 | parser.add_argument('--run_name', type=str, default='run') 76 | parser.add_argument('--num_workers', type=int, default=20, help='Number of workers to use for data loading') 77 | 78 | opts = parser.parse_args() 79 | opts.run_name = f"{opts.run_name}_{time.strftime('%Y%m%dT%H%M%S')}" 80 | opts.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 81 | 82 | main(opts) 83 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from tqdm import tqdm 5 | import pdb 6 | 7 | from utils import calc_cls_measures, move_to 8 | 9 | 10 | def train(model, optimizer, train_loader, criterion, entropy_loss_func, opts): 11 | """ Train for a single epoch """ 12 | 13 | y_probs = np.zeros((0, len(train_loader.dataset.CLASSES)), np.float) 14 | y_trues = np.zeros((0), np.int) 15 | losses = [] 16 | 17 | # Put model in training mode 18 | model.train() 19 | 20 | for i, (x_low, x_high, label) in enumerate(tqdm(train_loader)): 21 | x_low, x_high, label = move_to([x_low, x_high, label], opts.device) 22 | 23 | optimizer.zero_grad() 24 | y, attention_map, patches, x_low = model(x_low, x_high) 25 | 26 | entropy_loss = entropy_loss_func(attention_map) 27 | 28 | loss = criterion(y, label) - entropy_loss 29 | loss.backward() 30 | torch.nn.utils.clip_grad_norm_(model.parameters(), opts.clipnorm) 31 | optimizer.step() 32 | 33 | loss_value = loss.item() 34 | losses.append(loss_value) 35 | 36 | y_prob = F.softmax(y, dim=1) 37 | y_probs = np.concatenate([y_probs, y_prob.detach().cpu().numpy()]) 38 | y_trues = np.concatenate([y_trues, label.cpu().numpy()]) 39 | 40 | train_loss_epoch = np.round(np.mean(losses), 4) 41 | metrics = calc_cls_measures(y_probs, y_trues) 42 | return train_loss_epoch, metrics 43 | 44 | 45 | def evaluate(model, test_loader, criterion, entropy_loss_func, opts): 46 | """ Evaluate a single epoch """ 47 | 48 | y_probs = np.zeros((0, len(test_loader.dataset.CLASSES)), np.float) 49 | y_trues = np.zeros((0), np.int) 50 | losses = [] 51 | 52 | # Put model in eval mode 53 | model.eval() 54 | 55 | for i, (x_low, x_high, label) in enumerate(tqdm(test_loader)): 56 | 57 | x_low, x_high, label = move_to([x_low, x_high, label], opts.device) 58 | 59 | y, attention_map, patches, x_low = model(x_low, x_high) 60 | 61 | entropy_loss = entropy_loss_func(attention_map) 62 | loss = criterion(y, label) - entropy_loss 63 | 64 | loss_value = loss.item() 65 | losses.append(loss_value) 66 | 67 | y_prob = F.softmax(y, dim=1) 68 | y_probs = np.concatenate([y_probs, y_prob.detach().cpu().numpy()]) 69 | y_trues = np.concatenate([y_trues, label.cpu().numpy()]) 70 | 71 | test_loss_epoch = np.round(np.mean(losses), 4) 72 | metrics = calc_cls_measures(y_probs, y_trues) 73 | return test_loss_epoch, metrics 74 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import accuracy_score 4 | 5 | 6 | def move_to(var, device): 7 | if isinstance(var, dict): 8 | return {k: move_to(v, device) for k, v in var.items()} 9 | elif isinstance(var, list): 10 | return [move_to(v, device) for v in var] 11 | elif isinstance(var, tuple): 12 | return (move_to(v, device) for v in var) 13 | return var.to(device) 14 | 15 | 16 | def calc_cls_measures(probs, label): 17 | """Calculate multi-class classification measures (Accuracy) 18 | :probs: NxC numpy array storing probabilities for each case 19 | :label: ground truth label 20 | :returns: a dictionary of accuracy 21 | """ 22 | label = label.reshape(-1, 1) 23 | n_classes = probs.shape[1] 24 | preds = np.argmax(probs, axis=1) 25 | accuracy = accuracy_score(label, preds) 26 | 27 | metric_collects = {'accuracy': accuracy} 28 | return metric_collects 29 | --------------------------------------------------------------------------------