├── .gitignore ├── LICENSE ├── PyTorchSteerablePyramid ├── README.md ├── __init__.py ├── assets │ ├── coeff.png │ ├── lena.jpg │ ├── patagonia.jpg │ ├── runtime_benchmark.pdf │ └── runtime_benchmark.png ├── examples │ ├── benchmark_runtime.py │ ├── compare_both.py │ ├── example_lena.py │ ├── example_numpy.py │ └── example_pytorch.py ├── license.txt ├── requirements.txt ├── setup.py ├── steerable │ ├── SCFpyr_NumPy.py │ ├── SCFpyr_PyTorch.py │ ├── __init__.py │ ├── math_utils.py │ └── utils.py └── tests │ ├── test_ifft.py │ ├── test_torch_fft.py │ ├── test_torch_fftshift.py │ └── test_torch_fftshift2d.py ├── README.md ├── conf ├── base.yaml ├── dataset │ ├── mvtec.yaml │ └── wft.yaml ├── fe │ ├── color.yaml │ ├── ltem.yaml │ ├── random.yaml │ ├── steerable.yaml │ ├── vgg.yaml │ └── wide.yaml └── sc │ ├── aota.yaml │ ├── fca.yaml │ ├── hist.yaml │ ├── moments.yaml │ └── sww.yaml ├── data_loaders.py ├── datasets └── .gitkeep ├── demo.py ├── environment.yml ├── evaluation ├── additional_util.py ├── evaluate_experiment.py ├── evaluate_multiple_experiments.py ├── generic_util.py ├── print_metrics.py ├── pro_curve_util.py └── roc_curve_util.py ├── example ├── 000.png └── 000_out.png ├── feature_extractors.py ├── main.py ├── op_utils.py ├── outputs └── .gitkeep ├── sc_methods.py └── static └── teaser.png /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/** 2 | !outputs/.gitkeep 3 | datasets/** 4 | !datasets/.gitkeep 5 | .idea 6 | __pycache__ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Andrei-Timotei Ardelean, Tim Weyrich 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/README.md: -------------------------------------------------------------------------------- 1 | # Complex Steerable Pyramid in PyTorch 2 | 3 | This is a PyTorch implementation of the Complex Steerable Pyramid described in [Portilla and Simoncelli (IJCV, 2000)](http://www.cns.nyu.edu/~lcv/pubs/makeAbs.php?loc=Portilla99). 4 | 5 | It uses PyTorch's efficient spectral decomposition layers `torch.fft` and `torch.ifft`. Just like a normal convolution layer, the complex steerable pyramid expects a batch of images of shape `[N,C,H,W]` with current support only for grayscale images (`C=1`). It returns a `list` structure containing the low-pass, high-pass and intermediate levels of the pyramid for each image in the batch (as `torch.Tensor`). Computing the steerable pyramid is significantly faster on the GPU as can be observed from the runtime benchmark below. 6 | 7 | 8 | 9 | ## Usage 10 | 11 | In addition to the PyTorch implementation defined in `SCFpyr_PyTorch` the original SciPy version is also included in `SCFpyr` for completeness and comparison. As the GPU implementation highly benefits from parallelization, the `cwt` and `power` methods expect signal batches of shape `[N,H,W]` containing a batch of `N` images of shape `HxW`. 12 | 13 | ```python 14 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch 15 | import steerable.utils as utils 16 | 17 | # Load batch of images [N,1,H,W] 18 | im_batch_numpy = utils.load_image_batch(...) 19 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device) 20 | 21 | # Requires PyTorch with MKL when setting to 'cpu' 22 | device = torch.device('cuda:0') 23 | 24 | # Initialize Complex Steerbale Pyramid 25 | pyr = SCFpyr_PyTorch(height=5, nbands=4, scale_factor=2, device=device) 26 | 27 | # Decompose entire batch of images 28 | coeff = pyr.build(im_batch_torch) 29 | 30 | # Reconstruct batch of images again 31 | im_batch_reconstructed = pyr.reconstruct(coeff) 32 | 33 | # Visualization 34 | coeff_single = utils.extract_from_batch(coeff, 0) 35 | coeff_grid = utils.make_grid_coeff(coeff, normalize=True) 36 | cv2.imshow('Complex Steerable Pyramid', coeff_grid) 37 | cv2.waitKey(0) 38 | ``` 39 | 40 | ## Benchmark 41 | 42 | Performing parallel the CSP decomposition on the GPU using PyTorch results in a significant speed-up. Increasing the batch size will give faster runtimes. The plot below shows a comprison between the `scipy` versus `torch` implementation as function of the batch size `N` and input signal length. These results were obtained on a powerful Linux desktop with NVIDIA Titan X GPU. 43 | 44 | 45 | 46 | ## Installation 47 | 48 | Clone and install: 49 | 50 | ```sh 51 | git clone https://github.com/tomrunia/PyTorchSteerablePyramid.git 52 | cd PyTorchSteerablePyramid 53 | pip install -r requirements.txt 54 | python setup.py install 55 | ``` 56 | 57 | ## Requirements 58 | 59 | - Python 2.7 or 3.6 (other versions might also work) 60 | - Numpy (developed with 1.15.4) 61 | - Scipy (developed with 1.1.0) 62 | - PyTorch >= 0.4.0 (developed with 1.0.0; see note below) 63 | 64 | The steerable pyramid depends utilizes `torch.fft` and `torch.ifft` to perform operations in the spectral domain. At the moment, PyTorch only implements these operations for the GPU or with the MKL back-end on the CPU. Therefore, if you want to run the code on the CPU you might need to compile PyTorch from source with MKL enabled. Use `torch.backends.mkl.is_available()` to check if MKL is installed. 65 | 66 | ## References 67 | 68 | - [J. Portilla and E.P. Simoncelli, Complex Steerable Pyramid (IJCV, 2000)](http://www.cns.nyu.edu/pub/eero/portilla99-reprint.pdf) 69 | - [The Steerable Pyramid](http://www.cns.nyu.edu/~eero/steerpyr/) 70 | - [Official implementation: matPyrTools](http://www.cns.nyu.edu/~lcv/software.php) 71 | - [perceptual repository by Dzung Nguyen](https://github.com/andreydung/Steerable-filter) 72 | 73 | ## License 74 | 75 | MIT License 76 | 77 | Copyright (c) 2018 Tom Runia (tomrunia@gmail.com) 78 | 79 | Permission is hereby granted, free of charge, to any person obtaining a copy 80 | of this software and associated documentation files (the "Software"), to deal 81 | in the Software without restriction, including without limitation the rights 82 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 83 | copies of the Software, and to permit persons to whom the Software is 84 | furnished to do so, subject to the following conditions: 85 | 86 | The above copyright notice and this permission notice shall be included in all 87 | copies or substantial portions of the Software. 88 | 89 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 90 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 91 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 92 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 93 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 94 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 95 | SOFTWARE. -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | 5 | sys.path.append(os.path.dirname(__file__)) 6 | 7 | from PyTorchSteerablePyramid.steerable.SCFpyr_PyTorch import SCFpyr_PyTorch 8 | 9 | 10 | def extract_steerable_features(image, o=4, m=1, scale_factor=2): 11 | # image of shape 1 x 3 x H x W 12 | assert image.shape[0] == 1 13 | im_batch_torch = image.permute(1, 0, 2, 3) 14 | pyr = SCFpyr_PyTorch(height=m + 2, nbands=o, scale_factor=scale_factor, device=image.device) 15 | coeff = pyr.build(im_batch_torch) 16 | 17 | h, w = im_batch_torch.shape[-2:] 18 | real_features = [torch.stack(c)[..., 0] for c in coeff[1:-1]] 19 | real_features = [torch.nn.functional.interpolate(f, size=(h, w), mode='bilinear') for f in real_features] 20 | real_features = torch.cat(real_features, dim=0).reshape(image.shape[0], -1, h, w) 21 | 22 | return real_features 23 | -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/assets/coeff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/PyTorchSteerablePyramid/assets/coeff.png -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/assets/lena.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/PyTorchSteerablePyramid/assets/lena.jpg -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/assets/patagonia.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/PyTorchSteerablePyramid/assets/patagonia.jpg -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/assets/runtime_benchmark.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/PyTorchSteerablePyramid/assets/runtime_benchmark.pdf -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/assets/runtime_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/PyTorchSteerablePyramid/assets/runtime_benchmark.png -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/examples/benchmark_runtime.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-10 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import time 20 | import argparse 21 | 22 | import numpy as np 23 | import torch 24 | import matplotlib.pyplot as plt 25 | 26 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch 27 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy 28 | import steerable.utils as utils 29 | 30 | import cortex.plot 31 | cortex.plot.init_plotting() 32 | colors = cortex.plot.nature_colors() 33 | 34 | 35 | if __name__ == "__main__": 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg') 39 | parser.add_argument('--batch_sizes', type=str, default='1,8,16,32,64,128,256') 40 | parser.add_argument('--image_sizes', type=str, default='128,256,512') 41 | parser.add_argument('--num_runs', type=int, default='5') 42 | parser.add_argument('--pyr_nlevels', type=int, default='5') 43 | parser.add_argument('--pyr_nbands', type=int, default='4') 44 | parser.add_argument('--pyr_scale_factor', type=int, default='2') 45 | parser.add_argument('--device', type=str, default='cuda:0') 46 | config = parser.parse_args() 47 | 48 | config.batch_sizes = list(map(int, config.batch_sizes.split(','))) 49 | config.image_sizes = list(map(int, config.image_sizes.split(','))) 50 | 51 | device = utils.get_device(config.device) 52 | 53 | ################################################################################ 54 | 55 | pyr_numpy = SCFpyr_NumPy( 56 | height=config.pyr_nlevels, 57 | nbands=config.pyr_nbands, 58 | scale_factor=config.pyr_scale_factor, 59 | ) 60 | 61 | pyr_torch = SCFpyr_PyTorch( 62 | height=config.pyr_nlevels, 63 | nbands=config.pyr_nbands, 64 | scale_factor=config.pyr_scale_factor, 65 | device=device 66 | ) 67 | 68 | ############################################################################ 69 | # Run Benchmark 70 | 71 | durations_numpy = np.zeros((len(config.batch_sizes), len(config.image_sizes), config.num_runs)) 72 | durations_torch = np.zeros((len(config.batch_sizes), len(config.image_sizes), config.num_runs)) 73 | 74 | for batch_idx, batch_size in enumerate(config.batch_sizes): 75 | 76 | for size_idx, im_size in enumerate(config.image_sizes): 77 | 78 | for run_idx in range(config.num_runs): 79 | 80 | im_batch_numpy = utils.load_image_batch(config.image_file, batch_size, im_size) 81 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device) 82 | 83 | # NumPy implementation 84 | start_time = time.time() 85 | 86 | for image in im_batch_numpy: 87 | coeff = pyr_numpy.build(image[0,]) 88 | 89 | duration = time.time()-start_time 90 | durations_numpy[batch_idx,size_idx,run_idx] = duration 91 | print('BatchSize: {batch_size} | ImSize: {im_size} | NumPy Run {curr_run}/{num_run} | Duration: {duration:.3f} seconds.'.format( 92 | batch_size=batch_size, 93 | im_size=im_size, 94 | curr_run=run_idx+1, 95 | num_run=config.num_runs, 96 | duration=duration 97 | )) 98 | 99 | # PyTorch Implementation 100 | start_time = time.time() 101 | 102 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device) 103 | coeff = pyr_torch.build(im_batch_torch) 104 | 105 | duration = time.time()-start_time 106 | durations_torch[batch_idx,size_idx,run_idx] = duration 107 | print('BatchSize: {batch_size} | ImSize: {im_size} | Torch Run {curr_run}/{num_run} | Duration: {duration:.3f} seconds.'.format( 108 | batch_size=batch_size, 109 | im_size=im_size, 110 | curr_run=run_idx+1, 111 | num_run=config.num_runs, 112 | duration=duration 113 | )) 114 | 115 | np.save('./assets/durations_numpy.npy', durations_numpy) 116 | np.save('./assets/durations_torch.npy', durations_torch) 117 | 118 | ################################################################################ 119 | # Plotting 120 | 121 | durations_numpy = np.load('./assets/durations_numpy.npy') 122 | durations_torch = np.load('./assets/durations_torch.npy') 123 | 124 | for i, num_examples in enumerate(config.batch_sizes): 125 | if num_examples == 8: continue 126 | avg_durations_numpy = np.mean(durations_numpy[i,:], -1) / num_examples 127 | avg_durations_torch = np.mean(durations_torch[i,:], -1) / num_examples 128 | plt.plot(config.image_sizes, avg_durations_numpy, marker='o', linestyle='-', lw=1.4, color=colors[i], label='numpy (N = {})'.format(num_examples)) 129 | plt.plot(config.image_sizes, avg_durations_torch, marker='d', linestyle='--', lw=1.4, color=colors[i], label='pytorch (N = {})'.format(num_examples)) 130 | 131 | plt.title('Runtime Benchmark ({} levels, {} bands, average {numruns} runs)'.format(config.pyr_nlevels, config.pyr_nbands, numruns=config.num_runs)) 132 | plt.xlabel('Image Size (px)') 133 | plt.ylabel('Time per Example (s)') 134 | plt.xlim((100, 550)) 135 | plt.ylim((-0.01, 0.2)) 136 | plt.xticks(config.image_sizes) 137 | plt.legend(ncol=2, loc='top left') 138 | plt.tight_layout() 139 | 140 | plt.savefig('./assets/runtime_benchmark.png', dpi=600) 141 | plt.savefig('./assets/runtime_benchmark.pdf') 142 | plt.show() 143 | -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/examples/compare_both.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import torch 21 | import cv2 22 | 23 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy 24 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch 25 | import steerable.utils as utils 26 | 27 | ################################################################################ 28 | ################################################################################ 29 | # Common 30 | 31 | image_file = './assets/lena.jpg' 32 | image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE) 33 | image = cv2.resize(image, (200,200)) 34 | 35 | # Number of pyramid levels 36 | pyr_height = 5 37 | 38 | # Number of orientation bands 39 | pyr_nbands = 4 40 | 41 | # Tolerance for error checking 42 | tolerance = 1e-3 43 | 44 | ################################################################################ 45 | # NumPy 46 | 47 | pyr_numpy = SCFpyr_NumPy(pyr_height, pyr_nbands, scale_factor=2) 48 | coeff_numpy = pyr_numpy.build(image) 49 | reconstruction_numpy = pyr_numpy.reconstruct(coeff_numpy) 50 | reconstruction_numpy = reconstruction_numpy.astype(np.uint8) 51 | 52 | print('#'*60) 53 | 54 | ################################################################################ 55 | # PyTorch 56 | 57 | device = torch.device('cuda:0') 58 | 59 | im_batch = torch.from_numpy(image[None,None,:,:]) 60 | im_batch = im_batch.to(device).float() 61 | 62 | pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device) 63 | coeff_torch = pyr_torch.build(im_batch) 64 | reconstruction_torch = pyr_torch.reconstruct(coeff_torch) 65 | reconstruction_torch = reconstruction_torch.cpu().numpy()[0,] 66 | 67 | # Extract first example from the batch and move to CPU 68 | coeff_torch = utils.extract_from_batch(coeff_torch, 0) 69 | 70 | ################################################################################ 71 | # Check correctness 72 | 73 | print('#'*60) 74 | assert len(coeff_numpy) == len(coeff_torch) 75 | 76 | for level, _ in enumerate(coeff_numpy): 77 | 78 | print('Pyramid Level {level}'.format(level=level)) 79 | coeff_level_numpy = coeff_numpy[level] 80 | coeff_level_torch = coeff_torch[level] 81 | 82 | assert isinstance(coeff_level_torch, type(coeff_level_numpy)) 83 | 84 | if isinstance(coeff_level_numpy, np.ndarray): 85 | 86 | # Low- or High-Pass 87 | print(' NumPy. min = {min:.3f}, max = {max:.3f},' 88 | ' mean = {mean:.3f}, std = {std:.3f}'.format( 89 | min=np.min(coeff_level_numpy), max=np.max(coeff_level_numpy), 90 | mean=np.mean(coeff_level_numpy), std=np.std(coeff_level_numpy))) 91 | 92 | print(' PyTorch. min = {min:.3f}, max = {max:.3f},' 93 | ' mean = {mean:.3f}, std = {std:.3f}'.format( 94 | min=np.min(coeff_level_torch), max=np.max(coeff_level_torch), 95 | mean=np.mean(coeff_level_torch), std=np.std(coeff_level_torch))) 96 | 97 | # Check numerical correctness 98 | assert np.allclose(coeff_level_numpy, coeff_level_torch, atol=tolerance) 99 | 100 | elif isinstance(coeff_level_numpy, list): 101 | 102 | # Intermediate bands 103 | for band, _ in enumerate(coeff_level_numpy): 104 | 105 | band_numpy = coeff_level_numpy[band] 106 | band_torch = coeff_level_torch[band] 107 | 108 | print(' Orientation Band {}'.format(band)) 109 | print(' NumPy. min = {min:.3f}, max = {max:.3f},' 110 | ' mean = {mean:.3f}, std = {std:.3f}'.format( 111 | min=np.min(band_numpy), max=np.max(band_numpy), 112 | mean=np.mean(band_numpy), std=np.std(band_numpy))) 113 | 114 | print(' PyTorch. min = {min:.3f}, max = {max:.3f},' 115 | ' mean = {mean:.3f}, std = {std:.3f}'.format( 116 | min=np.min(band_torch), max=np.max(band_torch), 117 | mean=np.mean(band_torch), std=np.std(band_torch))) 118 | 119 | # Check numerical correctness 120 | assert np.allclose(band_numpy, band_torch, atol=tolerance) 121 | 122 | ################################################################################ 123 | # Visualize 124 | 125 | coeff_grid_numpy = utils.make_grid_coeff(coeff_numpy, normalize=False) 126 | coeff_grid_torch = utils.make_grid_coeff(coeff_torch, normalize=False) 127 | 128 | cv2.imshow('image', image) 129 | cv2.imshow('coeff numpy', np.ascontiguousarray(coeff_grid_numpy)) 130 | cv2.imshow('coeff torch', np.ascontiguousarray(coeff_grid_torch)) 131 | cv2.imshow('reconstruction numpy', reconstruction_numpy.astype(np.uint8)) 132 | cv2.imshow('reconstruction torch', reconstruction_torch.astype(np.uint8)) 133 | 134 | cv2.waitKey(0) 135 | -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/examples/example_lena.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import argparse 20 | import cv2 21 | 22 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy 23 | import steerable.utils as utils 24 | 25 | ################################################################################ 26 | ################################################################################ 27 | 28 | if __name__ == "__main__": 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg') 32 | parser.add_argument('--batch_size', type=int, default='1') 33 | parser.add_argument('--image_size', type=int, default='200') 34 | parser.add_argument('--pyr_nlevels', type=int, default='5') 35 | parser.add_argument('--pyr_nbands', type=int, default='4') 36 | parser.add_argument('--pyr_scale_factor', type=int, default='2') 37 | parser.add_argument('--visualize', type=bool, default=True) 38 | config = parser.parse_args() 39 | 40 | ############################################################################ 41 | # Build the complex steerable pyramid 42 | 43 | pyr = SCFpyr_NumPy( 44 | height=config.pyr_nlevels, 45 | nbands=config.pyr_nbands, 46 | scale_factor=config.pyr_scale_factor, 47 | ) 48 | 49 | ############################################################################ 50 | # Create a batch and feed-forward 51 | 52 | image = cv2.imread('./assets/lena.jpg', cv2.IMREAD_GRAYSCALE) 53 | image = cv2.resize(image, dsize=(200,200)) 54 | 55 | coeff = pyr.build(image) 56 | 57 | grid = utils.make_grid_coeff(coeff, normalize=True) 58 | 59 | cv2.imwrite('./assets/coeff.png', grid) 60 | cv2.imshow('image', image) 61 | cv2.imshow('coeff', grid) 62 | cv2.waitKey(0) 63 | 64 | -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/examples/example_numpy.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import time 20 | import argparse 21 | import numpy as np 22 | 23 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy 24 | import steerable.utils as utils 25 | 26 | ################################################################################ 27 | ################################################################################ 28 | 29 | if __name__ == "__main__": 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg') 33 | parser.add_argument('--batch_size', type=int, default='1') 34 | parser.add_argument('--image_size', type=int, default='200') 35 | parser.add_argument('--pyr_nlevels', type=int, default='5') 36 | parser.add_argument('--pyr_nbands', type=int, default='4') 37 | parser.add_argument('--pyr_scale_factor', type=int, default='2') 38 | parser.add_argument('--visualize', type=bool, default=True) 39 | config = parser.parse_args() 40 | 41 | ############################################################################ 42 | # Build the complex steerable pyramid 43 | 44 | pyr = SCFpyr_NumPy( 45 | height=config.pyr_nlevels, 46 | nbands=config.pyr_nbands, 47 | scale_factor=config.pyr_scale_factor, 48 | ) 49 | 50 | ############################################################################ 51 | # Create a batch and feed-forward 52 | 53 | start_time = time.time() 54 | 55 | im_batch_numpy = utils.load_image_batch(config.image_file, config.batch_size, config.image_size) 56 | im_batch_numpy = im_batch_numpy.squeeze(1) # no channel dim for NumPy 57 | 58 | # Compute Steerable Pyramid 59 | start_time = time.time() 60 | for image in im_batch_numpy: 61 | coeff = pyr.build(image) 62 | 63 | duration = time.time()-start_time 64 | print('Finishing decomposing {batch_size} images in {duration:.1f} seconds.'.format( 65 | batch_size=config.batch_size, 66 | duration=duration 67 | )) 68 | 69 | ############################################################################ 70 | # Visualization 71 | 72 | if config.visualize: 73 | import cv2 74 | coeff_grid = utils.make_grid_coeff(coeff, normalize=True) 75 | cv2.imshow('image', (im_batch_numpy[0,]*255.).astype(np.uint8)) 76 | cv2.imshow('coeff', coeff_grid) 77 | cv2.waitKey(0) 78 | 79 | -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/examples/example_pytorch.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import argparse 20 | import time 21 | import numpy as np 22 | import torch 23 | 24 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch 25 | import steerable.utils as utils 26 | 27 | ################################################################################ 28 | ################################################################################ 29 | 30 | if __name__ == "__main__": 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg') 34 | parser.add_argument('--batch_size', type=int, default='32') 35 | parser.add_argument('--image_size', type=int, default='200') 36 | parser.add_argument('--pyr_nlevels', type=int, default='5') 37 | parser.add_argument('--pyr_nbands', type=int, default='4') 38 | parser.add_argument('--pyr_scale_factor', type=int, default='2') 39 | parser.add_argument('--device', type=str, default='cuda:0') 40 | parser.add_argument('--visualize', type=bool, default=True) 41 | config = parser.parse_args() 42 | 43 | device = utils.get_device(config.device) 44 | 45 | ############################################################################ 46 | # Build the complex steerable pyramid 47 | 48 | pyr = SCFpyr_PyTorch( 49 | height=config.pyr_nlevels, 50 | nbands=config.pyr_nbands, 51 | scale_factor=config.pyr_scale_factor, 52 | device=device 53 | ) 54 | 55 | ############################################################################ 56 | # Create a batch and feed-forward 57 | 58 | start_time = time.time() 59 | 60 | # Load Batch 61 | im_batch_numpy = utils.load_image_batch(config.image_file, config.batch_size, config.image_size) 62 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device) 63 | 64 | # Compute Steerable Pyramid 65 | coeff = pyr.build(im_batch_torch) 66 | 67 | duration = time.time()-start_time 68 | print('Finishing decomposing {batch_size} images in {duration:.1f} seconds.'.format( 69 | batch_size=config.batch_size, 70 | duration=duration 71 | )) 72 | 73 | ############################################################################ 74 | # Visualization 75 | 76 | # Just extract a single example from the batch 77 | # Also moves the example to CPU and NumPy 78 | coeff = utils.extract_from_batch(coeff, 0) 79 | 80 | if config.visualize: 81 | import cv2 82 | coeff_grid = utils.make_grid_coeff(coeff, normalize=True) 83 | cv2.imshow('image', (im_batch_numpy[0,0,]*255.).astype(np.uint8)) 84 | cv2.imshow('coeff', coeff_grid) 85 | cv2.waitKey(0) 86 | -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/license.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tom Runia (tomrunia@gmail.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/requirements.txt: -------------------------------------------------------------------------------- 1 | six 2 | numpy 3 | scipy 4 | torch >= 0.4.0 5 | matplotlib 6 | pillow -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from setuptools import setup 3 | 4 | setup( 5 | name='steerable_pytorch', 6 | version='0.1', 7 | author='Tom Runia', 8 | author_email='tomrunia@gmail.com', 9 | url='https://github.com/tomrunia/PyTorchSteerablePyramid', 10 | description='Complex Steerable Pyramids in PyTorch', 11 | long_description='Fast CPU/CUDA implementation of the Complex Steerable Pyramid in PyTorch.', 12 | license='MIT', 13 | packages=['steerable_pytorch'], 14 | scripts=[] 15 | ) -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/steerable/SCFpyr_NumPy.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | from scipy.misc import factorial 21 | 22 | import steerable.math_utils as math_utils 23 | pointOp = math_utils.pointOp 24 | 25 | ################################################################################ 26 | 27 | class SCFpyr_NumPy(): 28 | ''' 29 | This is a modified version of buildSFpyr, that constructs a 30 | complex-valued steerable pyramid using Hilbert-transform pairs 31 | of filters. Note that the imaginary parts will *not* be steerable. 32 | 33 | Description of this transform appears in: Portilla & Simoncelli, 34 | International Journal of Computer Vision, 40(1):49-71, Oct 2000. 35 | Further information: http://www.cns.nyu.edu/~eero/STEERPYR/ 36 | 37 | Modified code from the perceptual repository: 38 | https://github.com/andreydung/Steerable-filter 39 | 40 | This code looks very similar to the original Matlab code: 41 | https://github.com/LabForComputationalVision/matlabPyrTools/blob/master/buildSCFpyr.m 42 | 43 | Also looks very similar to the original Python code presented here: 44 | https://github.com/LabForComputationalVision/pyPyrTools/blob/master/pyPyrTools/SCFpyr.py 45 | 46 | ''' 47 | 48 | def __init__(self, height=5, nbands=4, scale_factor=2): 49 | self.nbands = nbands # number of orientation bands 50 | self.height = height # including low-pass and high-pass 51 | self.scale_factor = scale_factor 52 | 53 | # Cache constants 54 | self.lutsize = 1024 55 | self.Xcosn = np.pi * np.array(range(-(2*self.lutsize+1), (self.lutsize+2)))/self.lutsize 56 | self.alpha = (self.Xcosn + np.pi) % (2*np.pi) - np.pi 57 | 58 | 59 | ################################################################################ 60 | # Construction of Steerable Pyramid 61 | 62 | def build(self, im): 63 | ''' Decomposes an image into it's complex steerable pyramid. 64 | The pyramid typically has ~4 levels and 4-8 orientations. 65 | 66 | Args: 67 | im_batch (np.ndarray): single image [H,W] 68 | 69 | Returns: 70 | pyramid: list containing np.ndarray objects storing the pyramid 71 | ''' 72 | 73 | assert len(im.shape) == 2, 'Input im must be grayscale' 74 | height, width = im.shape 75 | 76 | # Check whether image size is sufficient for number of levels 77 | if self.height > int(np.floor(np.log2(min(width, height))) - 2): 78 | raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height)) 79 | 80 | # Prepare a grid 81 | log_rad, angle = math_utils.prepare_grid(height, width) 82 | 83 | # Radial transition function (a raised cosine in log-frequency): 84 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5) 85 | Yrcos = np.sqrt(Yrcos) 86 | 87 | YIrcos = np.sqrt(1 - Yrcos**2) 88 | lo0mask = pointOp(log_rad, YIrcos, Xrcos) 89 | hi0mask = pointOp(log_rad, Yrcos, Xrcos) 90 | 91 | # Shift the zero-frequency component to the center of the spectrum. 92 | imdft = np.fft.fftshift(np.fft.fft2(im)) 93 | 94 | # Low-pass 95 | lo0dft = imdft * lo0mask 96 | 97 | # Recursive build the steerable pyramid 98 | coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height-1) 99 | 100 | # High-pass 101 | hi0dft = imdft * hi0mask 102 | hi0 = np.fft.ifft2(np.fft.ifftshift(hi0dft)) 103 | coeff.insert(0, hi0.real) 104 | return coeff 105 | 106 | 107 | def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): 108 | 109 | if height <= 1: 110 | 111 | # Low-pass 112 | lo0 = np.fft.ifftshift(lodft) 113 | lo0 = np.fft.ifft2(lo0) 114 | coeff = [lo0.real] 115 | 116 | else: 117 | 118 | Xrcos = Xrcos - np.log2(self.scale_factor) 119 | 120 | #################################################################### 121 | ####################### Orientation bandpass ####################### 122 | #################################################################### 123 | 124 | himask = pointOp(log_rad, Yrcos, Xrcos) 125 | 126 | order = self.nbands - 1 127 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order)) 128 | Ycosn = 2*np.sqrt(const) * np.power(np.cos(self.Xcosn), order) * (np.abs(self.alpha) < np.pi/2) 129 | 130 | # Loop through all orientation bands 131 | orientations = [] 132 | for b in range(self.nbands): 133 | anglemask = pointOp(angle, Ycosn, self.Xcosn + np.pi*b/self.nbands) 134 | banddft = np.power(np.complex(0, -1), self.nbands - 1) * lodft * anglemask * himask 135 | band = np.fft.ifft2(np.fft.ifftshift(banddft)) 136 | orientations.append(band) 137 | 138 | #################################################################### 139 | ######################## Subsample lowpass ######################### 140 | #################################################################### 141 | 142 | dims = np.array(lodft.shape) 143 | 144 | # Both are tuples of size 2 145 | low_ind_start = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(int) 146 | low_ind_end = (low_ind_start + np.ceil((dims-0.5)/2)).astype(int) 147 | 148 | # Selection 149 | log_rad = log_rad[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]] 150 | angle = angle[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]] 151 | lodft = lodft[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]] 152 | 153 | # Subsampling in frequency domain 154 | YIrcos = np.abs(np.sqrt(1 - Yrcos**2)) 155 | lomask = pointOp(log_rad, YIrcos, Xrcos) 156 | lodft = lomask * lodft 157 | 158 | #################################################################### 159 | ####################### Recursion next level ####################### 160 | #################################################################### 161 | 162 | coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos, height-1) 163 | coeff.insert(0, orientations) 164 | 165 | return coeff 166 | 167 | ############################################################################ 168 | ########################### RECONSTRUCTION ################################# 169 | ############################################################################ 170 | 171 | def reconstruct(self, coeff): 172 | 173 | if self.nbands != len(coeff[1]): 174 | raise Exception("Unmatched number of orientations") 175 | 176 | height, width = coeff[0].shape 177 | log_rad, angle = math_utils.prepare_grid(height, width) 178 | 179 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5) 180 | Yrcos = np.sqrt(Yrcos) 181 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2)) 182 | 183 | lo0mask = pointOp(log_rad, YIrcos, Xrcos) 184 | hi0mask = pointOp(log_rad, Yrcos, Xrcos) 185 | 186 | tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle) 187 | 188 | hidft = np.fft.fftshift(np.fft.fft2(coeff[0])) 189 | outdft = tempdft * lo0mask + hidft * hi0mask 190 | 191 | reconstruction = np.fft.ifftshift(outdft) 192 | reconstruction = np.fft.ifft2(reconstruction) 193 | reconstruction = reconstruction.real 194 | 195 | return reconstruction 196 | 197 | def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle): 198 | 199 | if len(coeff) == 1: 200 | dft = np.fft.fft2(coeff[0]) 201 | dft = np.fft.fftshift(dft) 202 | return dft 203 | 204 | Xrcos = Xrcos - np.log2(self.scale_factor) 205 | 206 | #################################################################### 207 | ####################### Orientation Residue ######################## 208 | #################################################################### 209 | 210 | himask = pointOp(log_rad, Yrcos, Xrcos) 211 | 212 | lutsize = 1024 213 | Xcosn = np.pi * np.array(range(-(2*lutsize+1), (lutsize+2)))/lutsize 214 | order = self.nbands - 1 215 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order)) 216 | Ycosn = np.sqrt(const) * np.power(np.cos(Xcosn), order) 217 | 218 | orientdft = np.zeros(coeff[0][0].shape) 219 | 220 | for b in range(self.nbands): 221 | anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands) 222 | banddft = np.fft.fft2(coeff[0][b]) 223 | banddft = np.fft.fftshift(banddft) 224 | orientdft = orientdft + np.power(np.complex(0, 1), order) * banddft * anglemask * himask 225 | 226 | #################################################################### 227 | ########## Lowpass component are upsampled and convoluted ########## 228 | #################################################################### 229 | 230 | dims = np.array(coeff[0][0].shape) 231 | 232 | lostart = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(np.int32) 233 | loend = lostart + np.ceil((dims-0.5)/2).astype(np.int32) 234 | 235 | nlog_rad = log_rad[lostart[0]:loend[0], lostart[1]:loend[1]] 236 | nangle = angle[lostart[0]:loend[0], lostart[1]:loend[1]] 237 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2)) 238 | lomask = pointOp(nlog_rad, YIrcos, Xrcos) 239 | 240 | ################################################################################ 241 | 242 | # Recursive call for image reconstruction 243 | nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle) 244 | 245 | resdft = np.zeros(dims, 'complex') 246 | resdft[lostart[0]:loend[0], lostart[1]:loend[1]] = nresdft * lomask 247 | 248 | return resdft + orientdft 249 | -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/steerable/SCFpyr_PyTorch.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import torch 21 | from scipy.special import factorial 22 | 23 | import steerable.math_utils as math_utils 24 | pointOp = math_utils.pointOp 25 | 26 | ################################################################################ 27 | ################################################################################ 28 | 29 | 30 | class SCFpyr_PyTorch(object): 31 | ''' 32 | This is a modified version of buildSFpyr, that constructs a 33 | complex-valued steerable pyramid using Hilbert-transform pairs 34 | of filters. Note that the imaginary parts will *not* be steerable. 35 | 36 | Description of this transform appears in: Portilla & Simoncelli, 37 | International Journal of Computer Vision, 40(1):49-71, Oct 2000. 38 | Further information: http://www.cns.nyu.edu/~eero/STEERPYR/ 39 | 40 | Modified code from the perceptual repository: 41 | https://github.com/andreydung/Steerable-filter 42 | 43 | This code looks very similar to the original Matlab code: 44 | https://github.com/LabForComputationalVision/matlabPyrTools/blob/master/buildSCFpyr.m 45 | 46 | Also looks very similar to the original Python code presented here: 47 | https://github.com/LabForComputationalVision/pyPyrTools/blob/master/pyPyrTools/SCFpyr.py 48 | 49 | ''' 50 | 51 | def __init__(self, height=5, nbands=4, scale_factor=2, device=None): 52 | self.height = height # including low-pass and high-pass 53 | self.nbands = nbands # number of orientation bands 54 | self.scale_factor = scale_factor 55 | self.device = torch.device('cpu') if device is None else device 56 | 57 | # Cache constants 58 | self.lutsize = 1024 59 | self.Xcosn = np.pi * np.array(range(-(2*self.lutsize+1), (self.lutsize+2)))/self.lutsize 60 | self.alpha = (self.Xcosn + np.pi) % (2*np.pi) - np.pi 61 | self.complex_fact_construct = np.power(np.complex(0, -1), self.nbands-1) 62 | self.complex_fact_reconstruct = np.power(np.complex(0, 1), self.nbands-1) 63 | 64 | ################################################################################ 65 | # Construction of Steerable Pyramid 66 | 67 | def build(self, im_batch): 68 | ''' Decomposes a batch of images into a complex steerable pyramid. 69 | The pyramid typically has ~4 levels and 4-8 orientations. 70 | 71 | Args: 72 | im_batch (torch.Tensor): Batch of images of shape [N,C,H,W] 73 | 74 | Returns: 75 | pyramid: list containing torch.Tensor objects storing the pyramid 76 | ''' 77 | 78 | assert im_batch.device == self.device, 'Devices invalid (pyr = {}, batch = {})'.format(self.device, im_batch.device) 79 | assert im_batch.dtype == torch.float32, 'Image batch must be torch.float32' 80 | assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]' 81 | assert im_batch.shape[1] == 1, 'Second dimension must be 1 encoding grayscale image' 82 | 83 | im_batch = im_batch.squeeze(1) # flatten channels dim 84 | height, width = im_batch.shape[1], im_batch.shape[2] 85 | 86 | # Check whether image size is sufficient for number of levels 87 | if self.height > int(np.floor(np.log2(min(width, height))) - 2): 88 | raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height)) 89 | 90 | # Prepare a grid 91 | log_rad, angle = math_utils.prepare_grid(height, width) 92 | 93 | # Radial transition function (a raised cosine in log-frequency): 94 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5) 95 | Yrcos = np.sqrt(Yrcos) 96 | 97 | YIrcos = np.sqrt(1 - Yrcos**2) 98 | 99 | lo0mask = pointOp(log_rad, YIrcos, Xrcos) 100 | hi0mask = pointOp(log_rad, Yrcos, Xrcos) 101 | 102 | # Note that we expand dims to support broadcasting later 103 | lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device) 104 | hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device) 105 | 106 | # Fourier transform (2D) and shifting 107 | batch_dft = torch.view_as_real(torch.fft.fft2(im_batch)) 108 | batch_dft = math_utils.batch_fftshift2d(batch_dft) 109 | 110 | # Low-pass 111 | lo0dft = batch_dft * lo0mask 112 | 113 | # Start recursively building the pyramids 114 | coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height-1) 115 | 116 | # High-pass 117 | hi0dft = batch_dft * hi0mask 118 | hi0 = math_utils.batch_ifftshift2d(hi0dft) 119 | # hi0 = torch.ifft(hi0, signal_ndim=2) 120 | hi0 = torch.view_as_real(torch.fft.ifft2(torch.view_as_complex(hi0))) 121 | hi0_real = torch.unbind(hi0, -1)[0] 122 | coeff.insert(0, hi0_real) 123 | return coeff 124 | 125 | def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): 126 | 127 | if height <= 1: 128 | 129 | # Low-pass 130 | lo0 = math_utils.batch_ifftshift2d(lodft) 131 | # lo0 = torch.ifft(lo0, signal_ndim=2) 132 | lo0 = torch.view_as_real(torch.fft.ifft2(torch.view_as_complex(lo0))) 133 | lo0_real = torch.unbind(lo0, -1)[0] 134 | coeff = [lo0_real] 135 | 136 | else: 137 | 138 | Xrcos = Xrcos - np.log2(self.scale_factor) 139 | 140 | #################################################################### 141 | ####################### Orientation bandpass ####################### 142 | #################################################################### 143 | 144 | himask = pointOp(log_rad, Yrcos, Xrcos) 145 | himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device) 146 | 147 | order = self.nbands - 1 148 | const = np.power(2.0, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order)) 149 | Ycosn = 2*np.sqrt(const) * np.power(np.cos(self.Xcosn), order) * (np.abs(self.alpha) < np.pi/2) # [n,] 150 | 151 | # Loop through all orientation bands 152 | orientations = [] 153 | for b in range(self.nbands): 154 | 155 | anglemask = pointOp(angle, Ycosn, self.Xcosn + np.pi*b/self.nbands) 156 | anglemask = anglemask[None,:,:,None] # for broadcasting 157 | anglemask = torch.from_numpy(anglemask).float().to(self.device) 158 | 159 | # Bandpass filtering 160 | banddft = lodft * anglemask * himask 161 | 162 | # Now multiply with complex number 163 | # (x+yi)(u+vi) = (xu-yv) + (xv+yu)i 164 | banddft = torch.unbind(banddft, -1) 165 | banddft_real = self.complex_fact_construct.real*banddft[0] - self.complex_fact_construct.imag*banddft[1] 166 | banddft_imag = self.complex_fact_construct.real*banddft[1] + self.complex_fact_construct.imag*banddft[0] 167 | banddft = torch.stack((banddft_real, banddft_imag), -1) 168 | 169 | band = math_utils.batch_ifftshift2d(banddft) 170 | # band = torch.ifft(band, signal_ndim=2) 171 | band = torch.view_as_real(torch.fft.ifft2(torch.view_as_complex(band))) 172 | orientations.append(band) 173 | 174 | #################################################################### 175 | ######################## Subsample lowpass ######################### 176 | #################################################################### 177 | 178 | # Don't consider batch_size and imag/real dim 179 | dims = np.array(lodft.shape[1:3]) 180 | 181 | # Both are tuples of size 2 182 | low_ind_start = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(int) 183 | low_ind_end = (low_ind_start + np.ceil((dims-0.5)/2)).astype(int) 184 | 185 | # Subsampling indices 186 | log_rad = log_rad[low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]] 187 | angle = angle[low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]] 188 | 189 | # Actual subsampling 190 | lodft = lodft[:,low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1],:] 191 | 192 | # Filtering 193 | YIrcos = np.abs(np.sqrt(1 - Yrcos**2)) 194 | lomask = pointOp(log_rad, YIrcos, Xrcos) 195 | lomask = torch.from_numpy(lomask[None,:,:,None]).float() 196 | lomask = lomask.to(self.device) 197 | 198 | # Convolution in spatial domain 199 | lodft = lomask * lodft 200 | 201 | #################################################################### 202 | ####################### Recursion next level ####################### 203 | #################################################################### 204 | 205 | coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos, height-1) 206 | coeff.insert(0, orientations) 207 | 208 | return coeff 209 | 210 | ############################################################################ 211 | ########################### RECONSTRUCTION ################################# 212 | ############################################################################ 213 | 214 | def reconstruct(self, coeff): 215 | 216 | if self.nbands != len(coeff[1]): 217 | raise Exception("Unmatched number of orientations") 218 | 219 | height, width = coeff[0].shape[2], coeff[0].shape[1] 220 | log_rad, angle = math_utils.prepare_grid(height, width) 221 | 222 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5) 223 | Yrcos = np.sqrt(Yrcos) 224 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2)) 225 | 226 | lo0mask = pointOp(log_rad, YIrcos, Xrcos) 227 | hi0mask = pointOp(log_rad, Yrcos, Xrcos) 228 | 229 | # Note that we expand dims to support broadcasting later 230 | lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device) 231 | hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device) 232 | 233 | # Start recursive reconstruction 234 | tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle) 235 | 236 | hidft = torch.view_as_real(torch.fft.fft2(coeff[0].contiguous())) 237 | hidft = math_utils.batch_fftshift2d(hidft) 238 | 239 | outdft = tempdft * lo0mask + hidft * hi0mask 240 | 241 | reconstruction = math_utils.batch_ifftshift2d(outdft) 242 | # reconstruction = torch.ifft(reconstruction, signal_ndim=2) 243 | reconstruction = torch.view_as_real(torch.fft.ifft2(torch.view_as_complex(reconstruction))) 244 | reconstruction = torch.unbind(reconstruction, -1)[0] # real 245 | 246 | return reconstruction 247 | 248 | def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle): 249 | 250 | if len(coeff) == 1: 251 | # dft = torch.rfft(coeff[0], signal_ndim=2, onesided=False) 252 | dft = torch.view_as_real(torch.fft.fft2(coeff[0].contiguous())) 253 | dft = math_utils.batch_fftshift2d(dft) 254 | return dft 255 | 256 | Xrcos = Xrcos - np.log2(self.scale_factor) 257 | 258 | #################################################################### 259 | ####################### Orientation Residue ######################## 260 | #################################################################### 261 | 262 | himask = pointOp(log_rad, Yrcos, Xrcos) 263 | himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device) 264 | 265 | lutsize = 1024 266 | Xcosn = np.pi * np.array(range(-(2*lutsize+1), (lutsize+2)))/lutsize 267 | order = self.nbands - 1 268 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order)) 269 | Ycosn = np.sqrt(const) * np.power(np.cos(Xcosn), order) 270 | 271 | orientdft = torch.zeros_like(coeff[0][0]) 272 | for b in range(self.nbands): 273 | 274 | anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands) 275 | anglemask = anglemask[None,:,:,None] # for broadcasting 276 | anglemask = torch.from_numpy(anglemask).float().to(self.device) 277 | 278 | # banddft = torch.fft(coeff[0][b], signal_ndim=2) 279 | banddft = torch.view_as_real(torch.fft.fft2(torch.view_as_complex(coeff[0][b]), norm="backward")) 280 | banddft = math_utils.batch_fftshift2d(banddft) 281 | banddft = banddft * anglemask * himask 282 | banddft = torch.unbind(banddft, -1) 283 | banddft_real = self.complex_fact_reconstruct.real*banddft[0] - self.complex_fact_reconstruct.imag*banddft[1] 284 | banddft_imag = self.complex_fact_reconstruct.real*banddft[1] + self.complex_fact_reconstruct.imag*banddft[0] 285 | banddft = torch.stack((banddft_real, banddft_imag), -1) 286 | 287 | orientdft = orientdft + banddft 288 | 289 | #################################################################### 290 | ########## Lowpass component are upsampled and convoluted ########## 291 | #################################################################### 292 | 293 | dims = np.array(coeff[0][0].shape[1:3]) 294 | 295 | lostart = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(np.int32) 296 | loend = lostart + np.ceil((dims-0.5)/2).astype(np.int32) 297 | 298 | nlog_rad = log_rad[lostart[0]:loend[0], lostart[1]:loend[1]] 299 | nangle = angle[lostart[0]:loend[0], lostart[1]:loend[1]] 300 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2)) 301 | lomask = pointOp(nlog_rad, YIrcos, Xrcos) 302 | 303 | # Filtering 304 | lomask = pointOp(nlog_rad, YIrcos, Xrcos) 305 | lomask = torch.from_numpy(lomask[None,:,:,None]) 306 | lomask = lomask.float().to(self.device) 307 | 308 | ################################################################################ 309 | 310 | # Recursive call for image reconstruction 311 | nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle) 312 | 313 | resdft = torch.zeros_like(coeff[0][0]).to(self.device) 314 | resdft[:,lostart[0]:loend[0], lostart[1]:loend[1],:] = nresdft * lomask 315 | 316 | return resdft + orientdft 317 | -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/steerable/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/PyTorchSteerablePyramid/steerable/__init__.py -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/steerable/math_utils.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import torch 21 | 22 | ################################################################################ 23 | ################################################################################ 24 | 25 | def roll_n(X, axis, n): 26 | f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim())) 27 | b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim())) 28 | front = X[f_idx] 29 | back = X[b_idx] 30 | return torch.cat([back, front], axis) 31 | 32 | def batch_fftshift2d(x): 33 | real, imag = torch.unbind(x, -1) 34 | for dim in range(1, len(real.size())): 35 | n_shift = real.size(dim)//2 36 | if real.size(dim) % 2 != 0: 37 | n_shift += 1 # for odd-sized images 38 | real = roll_n(real, axis=dim, n=n_shift) 39 | imag = roll_n(imag, axis=dim, n=n_shift) 40 | return torch.stack((real, imag), -1) # last dim=2 (real&imag) 41 | 42 | def batch_ifftshift2d(x): 43 | real, imag = torch.unbind(x, -1) 44 | for dim in range(len(real.size()) - 1, 0, -1): 45 | real = roll_n(real, axis=dim, n=real.size(dim)//2) 46 | imag = roll_n(imag, axis=dim, n=imag.size(dim)//2) 47 | return torch.stack((real, imag), -1) # last dim=2 (real&imag) 48 | 49 | ################################################################################ 50 | ################################################################################ 51 | 52 | def prepare_grid(m, n): 53 | x = np.linspace(-(m // 2)/(m / 2), (m // 2)/(m / 2) - (1 - m % 2)*2/m, num=m) 54 | y = np.linspace(-(n // 2)/(n / 2), (n // 2)/(n / 2) - (1 - n % 2)*2/n, num=n) 55 | xv, yv = np.meshgrid(y, x) 56 | angle = np.arctan2(yv, xv) 57 | rad = np.sqrt(xv**2 + yv**2) 58 | rad[m//2][n//2] = rad[m//2][n//2 - 1] 59 | log_rad = np.log2(rad) 60 | return log_rad, angle 61 | 62 | def rcosFn(width, position): 63 | N = 256 # abritrary 64 | X = np.pi * np.array(range(-N-1, 2))/2/N 65 | Y = np.cos(X)**2 66 | Y[0] = Y[1] 67 | Y[N+2] = Y[N+1] 68 | X = position + 2*width/np.pi*(X + np.pi/4) 69 | return X, Y 70 | 71 | def pointOp(im, Y, X): 72 | out = np.interp(im.flatten(), X, Y) 73 | return np.reshape(out, im.shape) 74 | 75 | def getlist(coeff): 76 | straight = [bands for scale in coeff[1:-1] for bands in scale] 77 | straight = [coeff[0]] + straight + [coeff[-1]] 78 | return straight 79 | 80 | ################################################################################ 81 | # NumPy reference implementation (fftshift and ifftshift) 82 | 83 | # def fftshift(x, axes=None): 84 | # """ 85 | # Shift the zero-frequency component to the center of the spectrum. 86 | # This function swaps half-spaces for all axes listed (defaults to all). 87 | # Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even. 88 | # Parameters 89 | # """ 90 | # x = np.asarray(x) 91 | # if axes is None: 92 | # axes = tuple(range(x.ndim)) 93 | # shift = [dim // 2 for dim in x.shape] 94 | # shift = [x.shape[ax] // 2 for ax in axes] 95 | # return np.roll(x, shift, axes) 96 | # 97 | # def ifftshift(x, axes=None): 98 | # """ 99 | # The inverse of `fftshift`. Although identical for even-length `x`, the 100 | # functions differ by one sample for odd-length `x`. 101 | # """ 102 | # x = np.asarray(x) 103 | # if axes is None: 104 | # axes = tuple(range(x.ndim)) 105 | # shift = [-(dim // 2) for dim in x.shape] 106 | # shift = [-(x.shape[ax] // 2) for ax in axes] 107 | # return np.roll(x, shift, axes) 108 | -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/steerable/utils.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-10 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import numpy as np 21 | import skimage 22 | import matplotlib.pyplot as plt 23 | 24 | import torch 25 | import torchvision 26 | 27 | ################################################################################ 28 | 29 | ToPIL = torchvision.transforms.ToPILImage() 30 | from PIL import Image 31 | Grayscale = torchvision.transforms.Grayscale() 32 | RandomCrop = torchvision.transforms.RandomCrop 33 | 34 | def get_device(device='cuda:0'): 35 | assert isinstance(device, str) 36 | num_cuda = torch.cuda.device_count() 37 | 38 | if 'cuda' in device: 39 | if num_cuda > 0: 40 | # Found CUDA device, use the GPU 41 | return torch.device(device) 42 | # Fallback to CPU 43 | print('No CUDA devices found, falling back to CPU') 44 | device = 'cpu' 45 | 46 | if not torch.backends.mkl.is_available(): 47 | raise NotImplementedError( 48 | 'torch.fft on the CPU requires MKL back-end. ' + 49 | 'Please recompile your PyTorch distribution.') 50 | return torch.device('cpu') 51 | 52 | def load_image_batch(image_file, batch_size, image_size=200): 53 | if not os.path.isfile(image_file): 54 | raise FileNotFoundError('Image file not found on disk: {}'.format(image_file)) 55 | # im = ToPIL(skimage.io.imread(image_file)) 56 | im = Image.open(image_file) 57 | im = Grayscale(im) 58 | im_batch = np.zeros((batch_size, image_size, image_size), np.float32) 59 | for i in range(batch_size): 60 | im_batch[i] = RandomCrop(image_size)(im) 61 | # insert channels dim and rescale 62 | return im_batch[:,None,:,:]/225. 63 | 64 | def show_image_batch(im_batch): 65 | assert isinstance(im_batch, torch.Tensor) 66 | im_batch = torchvision.utils.make_grid(im_batch).numpy() 67 | im_batch = np.transpose(im_batch.squeeze(1), (1,2,0)) 68 | plt.imshow(im_batch) 69 | plt.axis('off') 70 | plt.tight_layout() 71 | plt.show() 72 | return im_batch 73 | 74 | def extract_from_batch(coeff_batch, example_idx=0): 75 | ''' 76 | Given the batched Complex Steerable Pyramid, extract the coefficients 77 | for a single example from the batch. Additionally, it converts all 78 | torch.Tensor's to np.ndarrays' and changes creates proper np.complex 79 | objects for all the orientation bands. 80 | 81 | Args: 82 | coeff_batch (list): list containing low-pass, high-pass and pyr levels 83 | example_idx (int, optional): Defaults to 0. index in batch to extract 84 | 85 | Returns: 86 | list: list containing low-pass, high-pass and pyr levels as np.ndarray 87 | ''' 88 | if not isinstance(coeff_batch, list): 89 | raise ValueError('Batch of coefficients must be a list') 90 | coeff = [] # coefficient for single example 91 | for coeff_level in coeff_batch: 92 | if isinstance(coeff_level, torch.Tensor): 93 | # Low- or High-Pass 94 | coeff_level_numpy = coeff_level[example_idx].cpu().numpy() 95 | coeff.append(coeff_level_numpy) 96 | elif isinstance(coeff_level, list): 97 | coeff_orientations_numpy = [] 98 | for coeff_orientation in coeff_level: 99 | coeff_orientation_numpy = coeff_orientation[example_idx].cpu().numpy() 100 | coeff_orientation_numpy = coeff_orientation_numpy[:,:,0] + 1j*coeff_orientation_numpy[:,:,1] 101 | coeff_orientations_numpy.append(coeff_orientation_numpy) 102 | coeff.append(coeff_orientations_numpy) 103 | else: 104 | raise ValueError('coeff leve must be of type (list, torch.Tensor)') 105 | return coeff 106 | 107 | ################################################################################ 108 | 109 | def make_grid_coeff(coeff, normalize=True): 110 | ''' 111 | Visualization function for building a large image that contains the 112 | low-pass, high-pass and all intermediate levels in the steerable pyramid. 113 | For the complex intermediate bands, the real part is visualized. 114 | 115 | Args: 116 | coeff (list): complex pyramid stored as list containing all levels 117 | normalize (bool, optional): Defaults to True. Whether to normalize each band 118 | 119 | Returns: 120 | np.ndarray: large image that contains grid of all bands and orientations 121 | ''' 122 | M, N = coeff[1][0].shape 123 | Norients = len(coeff[1]) 124 | out = np.zeros((M * 2 - coeff[-1].shape[0], Norients * N)) 125 | currentx, currenty = 0, 0 126 | 127 | for i in range(1, len(coeff[:-1])): 128 | for j in range(len(coeff[1])): 129 | tmp = coeff[i][j].real 130 | m, n = tmp.shape 131 | if normalize: 132 | tmp = 255 * tmp/tmp.max() 133 | tmp[m-1,:] = 255 134 | tmp[:,n-1] = 255 135 | out[currentx:currentx+m,currenty:currenty+n] = tmp 136 | currenty += n 137 | currentx += coeff[i][0].shape[0] 138 | currenty = 0 139 | 140 | m, n = coeff[-1].shape 141 | out[currentx: currentx+m, currenty: currenty+n] = 255 * coeff[-1]/coeff[-1].max() 142 | out[0,:] = 255 143 | out[:,0] = 255 144 | return out.astype(np.uint8) 145 | -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/tests/test_ifft.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-07 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import torch 21 | import cv2 22 | 23 | import steerable.fft as fft_utils 24 | import matplotlib.pyplot as plt 25 | 26 | ################################################################################ 27 | 28 | tolerance = 1e-6 29 | 30 | image_file = './assets/lena.jpg' 31 | im = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE) 32 | im = cv2.resize(im, dsize=(200, 200)) 33 | im = im.astype(np.float64)/255. # note the np.float64 34 | 35 | ################################################################################ 36 | # NumPy 37 | 38 | fft_numpy = np.fft.fft2(im) 39 | fft_numpy = np.fft.fftshift(fft_numpy) 40 | 41 | fft_numpy_mag_viz = np.log10(np.abs(fft_numpy)) 42 | fft_numpy_ang_viz = np.angle(fft_numpy) 43 | 44 | ifft_numpy1 = np.fft.ifftshift(fft_numpy) 45 | ifft_numpy = np.fft.ifft2(ifft_numpy1) 46 | 47 | ################################################################################ 48 | # Torch 49 | 50 | device = torch.device('cpu') 51 | 52 | im_torch = torch.from_numpy(im[None,:,:]) # add batch dim 53 | im_torch = im_torch.to(device) 54 | 55 | # fft = complex-to-complex, rfft = real-to-complex 56 | fft_torch = torch.rfft(im_torch, signal_ndim=2, onesided=False) 57 | fft_torch = fft_utils.batch_fftshift2d(fft_torch) 58 | 59 | ifft_torch = fft_utils.batch_ifftshift2d(fft_torch) 60 | ifft_torch = torch.ifft(ifft_torch, signal_ndim=2, normalized=False) 61 | 62 | ifft_torch_to_numpy = ifft_torch.numpy() 63 | ifft_torch_to_numpy = np.split(ifft_torch_to_numpy, 2, -1) # complex => real/imag 64 | ifft_torch_to_numpy = np.squeeze(ifft_torch_to_numpy, -1) 65 | ifft_torch_to_numpy = ifft_torch_to_numpy[0] + 1j*ifft_torch_to_numpy[1] 66 | all_close_ifft = np.allclose(ifft_numpy1, ifft_torch_to_numpy, atol=tolerance) 67 | print('ifft all close: ', all_close_ifft) 68 | 69 | fft_torch = fft_torch.cpu().numpy().squeeze() 70 | fft_torch = np.split(fft_torch, 2, -1) # complex => real/imag 71 | fft_torch = np.squeeze(fft_torch, -1) 72 | fft_torch = fft_torch[0] + 1j*fft_torch[1] 73 | 74 | ifft_torch = ifft_torch.cpu().numpy().squeeze() 75 | ifft_torch = np.split(ifft_torch, 2, -1) # complex => real/imag 76 | ifft_torch = np.squeeze(ifft_torch, -1) 77 | ifft_torch = ifft_torch[0] + 1j*ifft_torch[1] 78 | 79 | fft_torch_mag_viz = np.log10(np.abs(fft_torch)) 80 | fft_torch_ang_viz = np.angle(fft_torch) 81 | 82 | ################################################################################ 83 | # Tolerance checking 84 | 85 | all_close_real = np.allclose(np.real(fft_numpy), np.real(fft_torch), atol=tolerance) 86 | all_close_imag = np.allclose(np.imag(fft_numpy), np.imag(fft_torch), atol=tolerance) 87 | print('fft allclose real: {}'.format(all_close_real)) 88 | print('fft allclose imag: {}'.format(all_close_imag)) 89 | 90 | all_close_real = np.allclose(np.real(ifft_numpy), np.real(ifft_torch), atol=tolerance) 91 | all_close_imag = np.allclose(np.imag(ifft_numpy), np.imag(ifft_torch), atol=tolerance) 92 | print('ifft allclose real: {}'.format(all_close_real)) 93 | print('ifft allclose imag: {}'.format(all_close_imag)) 94 | 95 | ################################################################################ 96 | 97 | fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(12,6)) 98 | 99 | # Plotting NumPy results 100 | ax[0][0].imshow(im, cmap='gray') 101 | 102 | ax[0][1].imshow(fft_numpy_mag_viz, cmap='gray') 103 | ax[0][1].set_title('NumPy fft magnitude') 104 | ax[0][2].imshow(fft_numpy_ang_viz, cmap='gray') 105 | ax[0][2].set_title('NumPy fft spectrum') 106 | ax[0][3].imshow(ifft_numpy.real, cmap='gray') 107 | ax[0][3].set_title('NumPy ifft real') 108 | ax[0][4].imshow(ifft_numpy.imag, cmap='gray') 109 | ax[0][4].set_title('NumPy ifft imag') 110 | 111 | # Plotting PyTorch results 112 | ax[1][0].imshow(im, cmap='gray') 113 | ax[1][1].imshow(fft_torch_mag_viz, cmap='gray') 114 | ax[1][1].set_title('PyTorch fft magnitude') 115 | ax[1][2].imshow(fft_torch_ang_viz, cmap='gray') 116 | ax[1][2].set_title('PyTorch fft phase') 117 | ax[1][3].imshow(ifft_torch.real, cmap='gray') 118 | ax[1][3].set_title('PyTorch ifft real') 119 | ax[1][4].imshow(ifft_torch.imag, cmap='gray') 120 | ax[1][4].set_title('PyTorch ifft imag') 121 | 122 | for cur_ax in ax.flatten(): 123 | cur_ax.axis('off') 124 | plt.tight_layout() 125 | plt.show() 126 | -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/tests/test_torch_fft.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import torch 21 | 22 | ################################################################################ 23 | 24 | x = np.linspace(-10, 10, 100) 25 | y = np.cos(np.pi*x) 26 | 27 | ################################################################################ 28 | # NumPy 29 | 30 | y_fft_numpy = np.fft.fft(y) 31 | 32 | ################################################################################ 33 | # Torch 34 | 35 | device = torch.device('cuda:0') 36 | 37 | y_torch = torch.from_numpy(y[None, :]) 38 | y_torch = y_torch.to(device) 39 | 40 | # fft = complex-to-complex, rfft = real-to-complex 41 | y_fft_torch = torch.rfft(y_torch, signal_ndim=1, onesided=False) 42 | y_fft_torch = y_fft_torch.cpu().numpy().squeeze() 43 | y_fft_torch = y_fft_torch[:, 0] + 1j*y_fft_torch[:, 1] 44 | 45 | tolerance = 1e-6 46 | all_close = np.allclose(y_fft_numpy, y_fft_torch, atol=tolerance) 47 | print('numpy', y_fft_numpy.shape, y_fft_numpy.dtype) 48 | print('torch', y_fft_torch.shape, y_fft_torch.dtype) 49 | print('Succesful: {}'.format(all_close)) 50 | -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/tests/test_torch_fftshift.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | import torch 22 | import torch.nn as nn 23 | 24 | from steerable.fft import fftshift 25 | 26 | ################################################################################ 27 | 28 | x = np.linspace(-10, 10, 100) 29 | y = np.cos(np.pi*x) 30 | 31 | ################################################################################ 32 | # NumPy 33 | 34 | y_fft_numpy = np.fft.fft(y) 35 | y_fft_numpy_shift = np.fft.fftshift(y_fft_numpy) 36 | fft_numpy_real = np.real(y_fft_numpy_shift) 37 | fft_numpy_imag = np.imag(y_fft_numpy_shift) 38 | 39 | ################################################################################ 40 | # Torch 41 | 42 | device = torch.device('cuda:0') 43 | 44 | y_torch = torch.from_numpy(y[None, :]) 45 | y_torch = y_torch.to(device) 46 | 47 | # fft = complex-to-complex, rfft = real-to-complex 48 | y_fft_torch = torch.rfft(y_torch, signal_ndim=1, onesided=False) 49 | y_fft_torch = fftshift(y_fft_torch[:,:,0], y_fft_torch[:,:,1]) 50 | y_fft_torch = y_fft_torch.cpu().numpy().squeeze() 51 | fft_torch_real = y_fft_torch[:,0] 52 | fft_torch_imag = y_fft_torch[:,1] 53 | 54 | tolerance = 1e-6 55 | 56 | all_close_real = np.allclose(fft_numpy_real, fft_torch_real, atol=tolerance) 57 | all_close_imag = np.allclose(fft_numpy_imag, fft_torch_imag, atol=tolerance) 58 | 59 | print('fftshift allclose real: {}'.format(all_close_real)) 60 | print('fftshift allclose imag: {}'.format(all_close_imag)) 61 | 62 | ################################################################################ 63 | 64 | import cortex.plot 65 | import cortex.plot.colors 66 | colors = cortex.plot.nature_colors() 67 | 68 | fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(10,6)) 69 | 70 | # Plotting NumPy results 71 | ax[0][0].plot(x, y, color=colors[3]) 72 | ax[0][1].plot(np.real(y_fft_numpy), color=colors[0]) 73 | ax[0][1].plot(fft_numpy_real, color=colors[1]) 74 | ax[0][1].set_title('NumPy Real') 75 | ax[0][2].plot(np.imag(y_fft_numpy), color=colors[0]) 76 | ax[0][2].plot(fft_numpy_real, color=colors[1]) 77 | ax[0][2].set_title('NumPy Imag') 78 | 79 | # Plotting NumPy results 80 | ax[1][0].plot(x, y, color=colors[3]) 81 | ax[1][1].plot(fft_torch_real, color=colors[1]) 82 | ax[1][1].set_title('Torch Real') 83 | ax[1][2].plot(fft_torch_imag, color=colors[1]) 84 | ax[1][2].set_title('Torch Imag') 85 | 86 | plt.show() 87 | -------------------------------------------------------------------------------- /PyTorchSteerablePyramid/tests/test_torch_fftshift2d.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Tom Runia 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to conditions. 11 | # 12 | # Author: Tom Runia 13 | # Date Created: 2018-12-04 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | import torch 22 | import torch.nn as nn 23 | import cv2 24 | 25 | from steerable.fft import fftshift 26 | 27 | ################################################################################ 28 | 29 | image_file = './assets/lena.jpg' 30 | im = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE) 31 | im = cv2.resize(im, dsize=(200, 200)) 32 | im = im.astype(np.float64)/255. # note the np.float64 33 | 34 | ################################################################################ 35 | # NumPy 36 | 37 | fft_numpy = np.fft.fft2(im) 38 | fft_numpy = np.fft.fftshift(fft_numpy) 39 | 40 | fft_numpy_mag_viz = np.log10(np.abs(fft_numpy)) 41 | fft_numpy_ang_viz = np.angle(fft_numpy) 42 | 43 | ################################################################################ 44 | # Torch 45 | 46 | device = torch.device('cuda:0') 47 | 48 | im_torch = torch.from_numpy(im[None,:,:]) # add batch dim 49 | im_torch = im_torch.to(device) 50 | 51 | # fft = complex-to-complex, rfft = real-to-complex 52 | fft_torch = torch.rfft(im_torch, signal_ndim=2, onesided=False) 53 | fft_torch = fftshift(fft_torch[:,:,:,0], fft_torch[:,:,:,1]) 54 | fft_torch = fft_torch.cpu().numpy().squeeze() 55 | print(fft_torch.shape) 56 | fft_torch = np.split(fft_torch, 2, -1) # complex => real/imag 57 | fft_torch = np.squeeze(fft_torch, -1) 58 | fft_torch = fft_torch[0] + 1j*fft_torch[1] 59 | 60 | print('fft_torch', fft_torch.shape, fft_torch.dtype) 61 | fft_torch_mag_viz = np.log10(np.abs(fft_torch)) 62 | fft_torch_ang_viz = np.angle(fft_torch) 63 | 64 | tolerance = 1e-6 65 | 66 | print(fft_numpy.dtype, fft_torch.dtype) 67 | all_close_real = np.allclose(np.real(fft_numpy), np.real(fft_torch), atol=tolerance) 68 | all_close_imag = np.allclose(np.imag(fft_numpy), np.imag(fft_torch), atol=tolerance) 69 | 70 | print('fftshift allclose real: {}'.format(all_close_real)) 71 | print('fftshift allclose imag: {}'.format(all_close_imag)) 72 | 73 | ################################################################################ 74 | 75 | fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(10,6)) 76 | 77 | # Plotting NumPy results 78 | ax[0][0].imshow(im, cmap='gray') 79 | 80 | ax[0][1].imshow(fft_numpy_mag_viz, cmap='gray') 81 | ax[0][1].set_title('NumPy Magnitude Spectrum') 82 | ax[0][2].imshow(fft_numpy_ang_viz, cmap='gray') 83 | ax[0][2].set_title('NumPy Phase Spectrum') 84 | 85 | # Plotting PyTorch results 86 | ax[1][0].imshow(im, cmap='gray') 87 | ax[1][1].imshow(fft_torch_mag_viz, cmap='gray') 88 | ax[1][1].set_title('Torch Magnitude Spectrum') 89 | ax[1][2].imshow(fft_torch_ang_viz, cmap='gray') 90 | ax[1][2].set_title('Torch Phase Spectrum') 91 | 92 | for cur_ax in ax.flatten(): 93 | cur_ax.axis('off') 94 | plt.tight_layout() 95 | plt.show() 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # High-Fidelity Zero-Shot Texture Anomaly Localization Using Feature Correspondence Analysis (WACV 2024) 2 | 3 | ### [Project Page](https://reality.tf.fau.de/pub/ardelean2024highfidelity.html) | [Paper](https://arxiv.org/abs/2304.06433) 4 | 5 | The official implementation of *High-Fidelity Zero-Shot Texture Anomaly Localization Using Feature Correspondence Analysis*. 6 | 7 | ![Teaser](static/teaser.png) 8 | 9 | ## Installation 10 | We implemented our method using PyTorch. 11 | For an easy installation, we provide an environment file that contains all dependencies: 12 | 13 | ``` 14 | conda env create -f environment.yml 15 | conda activate fca 16 | ``` 17 | 18 | ## Demo 19 | We include a minimal script that allows running our code on sample images to easily try our method. 20 | By default, running the `demo.py` file will compute the anomaly score for the provided image example `example/000.png`. 21 | ``` 22 | python demo.py 23 | ``` 24 | 25 | ## Data 26 | To run the evaluation code, you have to first prepare the desired dataset. 27 | Our data loader assumes the data follows the file structure of the MVTec anomaly detection dataset. 28 | You can download the MVTec AD from [here](https://www.mvtec.com/company/research/datasets/mvtec-ad). 29 | The woven fabric textures dataset (WFT) can be downloaded [here](https://www.mydrive.ch/shares/46066/8338a11f32bb1b7b215c5381abe54ebf/download/420939225-1629955758/textures.zip). 30 | Please extract the data into the `datasets` directory. 31 | To use any other dataset that follows the same file structure you can simply place the data in the same folder and create a corresponding dataset config file under `conf/dataset`. To understand the required format, please see `conf/dataset/mvtec.yaml`. 32 | 33 | ## Run and evaluate on a dataset 34 | To evaluate the method, use the `main.py` script. We use Hydra to manage the command line interface, which makes it easy to specify the method and dataset configurations. 35 | For example, running our FCA statistics comparison (`sc`) with WideResnet features (`fe`) on the MVTec dataset is done using 36 | ``` 37 | python main.py dataset=mvtec fe=wide sc=fca image_size=original tile_size=[9,9] 38 | ``` 39 | To run the same on the WFT dataset 40 | ``` 41 | python main.py dataset=wft fe=wide sc=fca image_size=original tile_size=[9,9] 42 | ``` 43 | 44 | We automatically run the evaluation code after computing the anomaly scores for all images in the dataset. 45 | You can inspect the metrics in the json file `outputs/{experiment_name}/metrics_{padding}.json` and visualize the predicted anomaly maps under `outputs/{experiment_name}/{object_name}/visualize` 46 | 47 | 48 | ## Content 49 | This repository contains several options for feature extraction, as described in our paper: plain colors (`color`), random kernels (`random`), VGG-network (`vgg`), WideResnet-network (`wide`), steerable filters (`steerable`), and Laws' texture energy measure (`ltem`). 50 | 51 | For patch statistics comparison, you can opt for one of the following: moments-based (`moments`), histogram-based (`hist`), sample weighted wasserstein (`sww`), feature correspondence analysis (`fca`), and our reimplementation of the method of Aota et al. (`aota`). 52 | 53 | To reproduce one of our experiments for design space exploration (for example, histogram-based statistics comparison with VGG features) you can run: 54 | ``` 55 | python main.py dataset=mvtec fe=vgg sc=hist image_size=[256,256] tile_size=[25,25] 56 | ``` 57 | 58 | ## Citation 59 | Should you find our work useful in your research, please cite: 60 | ```BibTeX 61 | @inproceedings{ardelean2023highfidelity, 62 | title = {High-Fidelity Zero-Shot Texture Anomaly Localization Using Feature Correspondence Analysis}, 63 | author = {Ardelean, Andrei-Timotei and Weyrich, Tim}, 64 | booktitle = {IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 65 | numpages = 11, 66 | year = {2024}, 67 | month = jan, 68 | day = 4, 69 | authorurl = {https://reality.tf.fau.de/pub/ardelean2024highfidelity.html}, 70 | } 71 | ``` 72 | 73 | ## Acknowledgements 74 | This project has received funding from the European Union’s Horizon 2020 research and innovation programme under the Marie Skłodowska-Curie grant agreement No 956585. 75 | 76 | ## License 77 | Please see the [LICENSE](LICENSE). -------------------------------------------------------------------------------- /conf/base.yaml: -------------------------------------------------------------------------------- 1 | name: base 2 | device: cuda 3 | image_size: original 4 | tile_size: [9, 9] 5 | 6 | defaults: 7 | - _self_ 8 | - fe: ??? 9 | - sc: ??? 10 | - dataset: mvtec 11 | 12 | 13 | hydra: 14 | job: 15 | chdir: True 16 | run: 17 | dir: ./outputs/${dataset.name}_${fe.name}_${sc.name}_${now:%Y-%m-%d}_${now:%H-%M-%S} -------------------------------------------------------------------------------- /conf/dataset/mvtec.yaml: -------------------------------------------------------------------------------- 1 | name: mvtec 2 | data_manager: 3 | _target_: data_loaders.StandardFormatDataset 4 | data_dir: ${hydra:runtime.cwd}/datasets/mvtec_anomaly_detection 5 | out_dir: ${hydra:runtime.output_dir} 6 | paddings: [100] 7 | objects: ['carpet', 'grid', 'leather', 'tile', 'wood'] 8 | -------------------------------------------------------------------------------- /conf/dataset/wft.yaml: -------------------------------------------------------------------------------- 1 | name: wft 2 | data_manager: 3 | _target_: data_loaders.StandardFormatDataset 4 | data_dir: ${hydra:runtime.cwd}/datasets/wft 5 | out_dir: ${hydra:runtime.output_dir} 6 | paddings: [50] 7 | objects: ['texture_1', 'texture_2'] 8 | -------------------------------------------------------------------------------- /conf/fe/color.yaml: -------------------------------------------------------------------------------- 1 | name: color 2 | tile_size: ${tile_size} 3 | feature_extractor: 4 | _target_: feature_extractors.Colors 5 | post_scale: False 6 | -------------------------------------------------------------------------------- /conf/fe/ltem.yaml: -------------------------------------------------------------------------------- 1 | name: ltem 2 | tile_size: ${tile_size} 3 | feature_extractor: 4 | _target_: feature_extractors.LawTextureEnergyMeasure 5 | mean_patch_size: 15 6 | device: ${device} 7 | -------------------------------------------------------------------------------- /conf/fe/random.yaml: -------------------------------------------------------------------------------- 1 | name: random 2 | tile_size: ${tile_size} 3 | feature_extractor: 4 | _target_: feature_extractors.RandomKernels 5 | projections: [[128,1]] 6 | patch_size: 5 7 | device: ${device} 8 | -------------------------------------------------------------------------------- /conf/fe/steerable.yaml: -------------------------------------------------------------------------------- 1 | name: steerable 2 | tile_size: ${tile_size} 3 | feature_extractor: 4 | _target_: feature_extractors.SteerableExtractor 5 | o: 43 6 | m: 1 7 | scale_factor: 1 8 | post_scale: True 9 | -------------------------------------------------------------------------------- /conf/fe/vgg.yaml: -------------------------------------------------------------------------------- 1 | name: vgg 2 | tile_size: ${tile_size} 3 | feature_extractor: 4 | _target_: feature_extractors.VggExtractor 5 | post_scale: True 6 | layers: ['conv_1', 'conv_2'] 7 | device: ${device} 8 | -------------------------------------------------------------------------------- /conf/fe/wide.yaml: -------------------------------------------------------------------------------- 1 | name: wide 2 | tile_size: ${tile_size} 3 | feature_extractor: 4 | _target_: feature_extractors.WideResnetExtractor 5 | post_scale: True 6 | device: ${device} 7 | -------------------------------------------------------------------------------- /conf/sc/aota.yaml: -------------------------------------------------------------------------------- 1 | name: aota 2 | method: 3 | _partial_: true 4 | _target_: sc_methods.sc_aota 5 | tile_size: ${tile_size} 6 | -------------------------------------------------------------------------------- /conf/sc/fca.yaml: -------------------------------------------------------------------------------- 1 | name: FCA 2 | method: 3 | _target_: sc_methods.ScFCA 4 | tile_size: ${tile_size} 5 | chunk_size: 8 6 | sigma_p: 3.0 7 | reference_selection: median 8 | sigma_s: 1.0 9 | -------------------------------------------------------------------------------- /conf/sc/hist.yaml: -------------------------------------------------------------------------------- 1 | name: hist 2 | method: 3 | _partial_: true 4 | _target_: sc_methods.sc_hist 5 | tile_size: ${tile_size} 6 | bins: 10 7 | sigma: 6.0 8 | -------------------------------------------------------------------------------- /conf/sc/moments.yaml: -------------------------------------------------------------------------------- 1 | name: moments 2 | method: 3 | _partial_: true 4 | _target_: sc_methods.sc_moments 5 | tile_size: ${tile_size} 6 | sigma: 6.0 7 | powers: [1.0, 2.0, 3.0, 4.0] 8 | -------------------------------------------------------------------------------- /conf/sc/sww.yaml: -------------------------------------------------------------------------------- 1 | name: sww 2 | method: 3 | _target_: sc_methods.ScSWW 4 | tile_size: ${tile_size} 5 | chunk_size: 8 6 | sigma: 6.0 7 | reference_selection: median 8 | -------------------------------------------------------------------------------- /data_loaders.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import tifffile as tiff 5 | from PIL import Image 6 | import numpy as np 7 | 8 | from torch.utils.data import Dataset 9 | 10 | from evaluation import evaluate_experiment 11 | 12 | 13 | class StandardFormatDataset(Dataset): 14 | def __init__(self, data_dir, out_dir, objects, paddings=(100,), pro_integration_limit=0.3, resize=False): 15 | super(StandardFormatDataset, self).__init__() 16 | self.data_dir = Path(data_dir) 17 | self.out_dir = Path(out_dir) 18 | self.objects = objects 19 | 20 | self.paddings = paddings 21 | self.resize = resize 22 | self.pro_integration_limit = pro_integration_limit 23 | 24 | self.in_image_paths = self.get_input_paths() 25 | 26 | def __getitem__(self, index): 27 | in_image_path = self.in_image_paths[index] 28 | return in_image_path 29 | 30 | def __len__(self): 31 | return len(self.in_image_paths) 32 | 33 | def get_input_paths(self): 34 | in_image_paths = [] 35 | for obj in self.objects: 36 | in_obj_path = self.data_dir / obj / "test" 37 | for defect in sorted(in_obj_path.iterdir()): 38 | in_image_paths.extend(sorted(list(defect.iterdir()))) 39 | return in_image_paths 40 | 41 | def save_output(self, age: torch.Tensor, in_image_path): 42 | np_age = age.cpu().numpy() 43 | obj, _, defect, _ = in_image_path.parts[-4:] 44 | stem = in_image_path.stem 45 | out_def_path = self.out_dir / obj / "test" / defect 46 | vis_def_path = self.out_dir / obj / "visualize" / defect 47 | out_def_path.mkdir(parents=True, exist_ok=True) 48 | vis_def_path.mkdir(parents=True, exist_ok=True) 49 | 50 | tiff.imsave(out_def_path / f"{stem}.tiff", np_age) 51 | # Save for visualization 52 | vis = ((np_age - np_age.min()) / (np_age.max() - np_age.min()) * 255) 53 | Image.fromarray(vis.astype(np.uint8)).save(vis_def_path / f"{stem}.jpg") 54 | 55 | def run_evaluation(self): 56 | args = { 57 | 'dataset_base_dir': str(self.data_dir), 58 | 'anomaly_maps_dir': str(self.out_dir), 59 | 'output_dir': str(self.out_dir), 60 | 'evaluated_objects': self.objects, 61 | 'pro_integration_limit': self.pro_integration_limit, 62 | 'resize': self.resize if self.resize else 0, 63 | } 64 | args = type('dict_as_obj', (object,), args) # Object from dict 65 | for padding in self.paddings: 66 | args.padding = padding 67 | evaluate_experiment.main(args) 68 | -------------------------------------------------------------------------------- /datasets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/datasets/.gitkeep -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import matplotlib.cm as cm 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | 6 | from feature_extractors import WideResnetExtractor 7 | from sc_methods import ScFCA 8 | 9 | 10 | def main(): 11 | device = torch.device('cuda') 12 | feature_extractor = WideResnetExtractor(device=device) 13 | s = ScFCA((9, 9), sigma_p=3.0, sigma_s=1.0) 14 | 15 | image = torch.tensor(np.asarray(Image.open("example/000.png")), device=device).permute(2, 0, 1)[None] / 255.0 16 | features = feature_extractor(image) 17 | a = s(features) 18 | 19 | pad = 10 20 | np_im = a.cpu().numpy()[pad:-pad, pad:-pad] 21 | np_im = (np_im - np_im.min()) / (np_im.max() - np_im.min()) 22 | cma = cm.Reds(np_im, bytes=True) 23 | result = Image.fromarray(cma, 'RGBA').convert('RGB') 24 | result.save('example/000_out.png') 25 | result.show() 26 | 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: fca 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1 9 | - _openmp_mutex=5.1 10 | - aom=3.6.0 11 | - blas=1.0 12 | - blosc=1.21.3 13 | - brotli=1.0.9 14 | - brotli-python=1.0.9 15 | - brunsli=0.1 16 | - bzip2=1.0.8 17 | - c-ares=1.19.1 18 | - ca-certificates=2023.08.22 19 | - certifi=2023.11.17 20 | - cffi=1.16.0 21 | - cfitsio=3.470 22 | - charls=2.2.0 23 | - charset-normalizer=2.0.4 24 | - click=8.1.7 25 | - cloudpickle=2.2.1 26 | - colorama=0.4.6 27 | - cryptography=41.0.7 28 | - cuda-cudart=11.8.89 29 | - cuda-cupti=11.8.87 30 | - cuda-libraries=11.8.0 31 | - cuda-nvrtc=11.8.89 32 | - cuda-nvtx=11.8.86 33 | - cuda-runtime=11.8.0 34 | - cytoolz=0.12.2 35 | - dask-core=2023.11.0 36 | - dav1d=1.2.1 37 | - ffmpeg=4.3 38 | - filelock=3.13.1 39 | - freetype=2.12.1 40 | - fsspec=2023.10.0 41 | - giflib=5.2.1 42 | - gmp=6.2.1 43 | - gmpy2=2.1.2 44 | - gnutls=3.6.15 45 | - idna=3.4 46 | - imagecodecs=2023.1.23 47 | - imageio=2.31.4 48 | - importlib-metadata=6.0.0 49 | - intel-openmp=2023.1.0 50 | - jinja2=3.1.2 51 | - joblib=1.2.0 52 | - jpeg=9e 53 | - jxrlib=1.1 54 | - krb5=1.20.1 55 | - lame=3.100 56 | - lcms2=2.12 57 | - ld_impl_linux-64=2.38 58 | - lerc=3.0 59 | - libaec=1.0.4 60 | - libavif=0.11.1 61 | - libbrotlicommon=1.0.9 62 | - libbrotlidec=1.0.9 63 | - libbrotlienc=1.0.9 64 | - libcublas=11.11.3.6 65 | - libcufft=10.9.0.58 66 | - libcufile=1.8.1.2 67 | - libcurand=10.3.4.101 68 | - libcurl=8.4.0 69 | - libcusolver=11.4.1.48 70 | - libcusparse=11.7.5.86 71 | - libdeflate=1.17 72 | - libedit=3.1.20221030 73 | - libev=4.33 74 | - libffi=3.4.4 75 | - libgcc-ng=11.2.0 76 | - libgfortran-ng=13.2.0 77 | - libgfortran5=13.2.0 78 | - libgomp=11.2.0 79 | - libiconv=1.16 80 | - libidn2=2.3.4 81 | - libjpeg-turbo=2.0.0 82 | - libnghttp2=1.57.0 83 | - libnpp=11.8.0.86 84 | - libnvjpeg=11.9.0.86 85 | - libpng=1.6.39 86 | - libssh2=1.10.0 87 | - libstdcxx-ng=11.2.0 88 | - libtasn1=4.19.0 89 | - libtiff=4.5.1 90 | - libunistring=0.9.10 91 | - libwebp=1.3.2 92 | - libwebp-base=1.3.2 93 | - libzopfli=1.0.3 94 | - llvm-openmp=14.0.6 95 | - locket=1.0.0 96 | - lz4-c=1.9.4 97 | - markupsafe=2.1.1 98 | - mkl=2023.1.0 99 | - mkl-service=2.4.0 100 | - mkl_fft=1.3.8 101 | - mkl_random=1.2.4 102 | - mpc=1.1.0 103 | - mpfr=4.0.2 104 | - mpmath=1.3.0 105 | - ncurses=6.4 106 | - nettle=3.7.3 107 | - networkx=3.1 108 | - numpy=1.26.2 109 | - numpy-base=1.26.2 110 | - openh264=2.1.1 111 | - openjpeg=2.4.0 112 | - openssl=3.0.12 113 | - partd=1.4.1 114 | - pillow=10.0.1 115 | - pip=23.3.1 116 | - pycparser=2.21 117 | - pyopenssl=23.2.0 118 | - pysocks=1.7.1 119 | - python=3.9.18 120 | - pytorch=2.1.2 121 | - pytorch-cuda=11.8 122 | - pytorch-mutex=1.0 123 | - pywavelets=1.4.1 124 | - pyyaml=6.0.1 125 | - readline=8.2 126 | - requests=2.31.0 127 | - scikit-image=0.19.3 128 | - scikit-learn=1.3.0 129 | - scipy=1.11.4 130 | - setuptools=68.2.2 131 | - snappy=1.1.9 132 | - sqlite=3.41.2 133 | - sympy=1.12 134 | - tbb=2021.8.0 135 | - threadpoolctl=2.2.0 136 | - tifffile=2023.2.28 137 | - tk=8.6.12 138 | - toolz=0.12.0 139 | - torchaudio=2.1.2 140 | - torchtriton=2.1.0 141 | - torchvision=0.16.2 142 | - tqdm=4.66.1 143 | - typing_extensions=4.7.1 144 | - tzdata=2023c 145 | - urllib3=1.26.18 146 | - wheel=0.41.2 147 | - xz=5.4.5 148 | - yaml=0.2.5 149 | - zfp=1.0.0 150 | - zipp=3.11.0 151 | - zlib=1.2.13 152 | - zstd=1.5.5 153 | - pip: 154 | - antlr4-python3-runtime==4.9.3 155 | - hydra-core==1.3.2 156 | - omegaconf==2.3.0 157 | - opencv-python==4.8.1.78 158 | - packaging==23.2 159 | -------------------------------------------------------------------------------- /evaluation/additional_util.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | import numpy as np 3 | 4 | 5 | def compute_seg_au_roc(anomaly_maps, ground_truth_maps): 6 | anomaly_scores_flat = np.array(anomaly_maps).ravel() 7 | gt_flat = np.array(ground_truth_maps).ravel() > 0 8 | fpr, tpr, thresholds = metrics.roc_curve(gt_flat.astype(int), anomaly_scores_flat) 9 | # au_roc = metrics.roc_auc_score(gt_flat.astype(int), anomaly_scores_flat) 10 | return fpr, tpr 11 | 12 | 13 | def compute_precision_recall(anomaly_maps, ground_truth_maps): 14 | anomaly_scores_flat = np.array(anomaly_maps).ravel() 15 | gt_flat = np.array(ground_truth_maps).ravel() > 0 16 | precision, recall, thresholds = metrics.precision_recall_curve(gt_flat.astype(int), anomaly_scores_flat) 17 | return precision, recall, thresholds 18 | 19 | 20 | def compute_optimal_f1(anomaly_maps, ground_truth_maps): 21 | precision, recall, thresholds = compute_precision_recall(anomaly_maps, ground_truth_maps) 22 | f1_score = 2 * (precision * recall) / (precision + recall + 1e-8) 23 | ind = np.argmax(f1_score) 24 | return f1_score[ind].item(), thresholds[ind].item() 25 | -------------------------------------------------------------------------------- /evaluation/evaluate_experiment.py: -------------------------------------------------------------------------------- 1 | """Compute evaluation metrics for a single experiment.""" 2 | 3 | __author__ = "Paul Bergmann, David Sattlegger" 4 | __copyright__ = "2021, MVTec Software GmbH" 5 | 6 | import argparse 7 | import json 8 | from os import makedirs, path 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import skimage 13 | import skimage.transform 14 | from PIL import Image 15 | from tqdm import tqdm 16 | 17 | from evaluation import generic_util as util 18 | from evaluation.additional_util import compute_optimal_f1, compute_seg_au_roc 19 | from evaluation.pro_curve_util import compute_pro 20 | from evaluation.roc_curve_util import compute_classification_roc 21 | 22 | 23 | def parse_user_arguments(): 24 | """Parse user arguments for the evaluation of a method on the MVTec AD 25 | dataset. 26 | 27 | Returns: 28 | Parsed user arguments. 29 | """ 30 | parser = argparse.ArgumentParser(description="""Parse user arguments.""") 31 | 32 | parser.add_argument('--anomaly_maps_dir', 33 | required=True, 34 | help="""Path to the directory that contains the anomaly 35 | maps of the evaluated method.""") 36 | 37 | parser.add_argument('--dataset_base_dir', 38 | required=True, 39 | help="""Path to the directory that contains the dataset 40 | images of the MVTec AD dataset.""") 41 | 42 | parser.add_argument('--output_dir', 43 | help="""Path to the directory to store evaluation 44 | results. If no output directory is specified, 45 | the results are not written to drive.""") 46 | 47 | parser.add_argument('--padding', 48 | type=int, 49 | default=100, 50 | help="""How much to cut of the borders of the image before evaluating; 51 | the metrics are computed on m[padding:-padding, padding:-padding]""") 52 | 53 | parser.add_argument('--resize', 54 | type=int, 55 | default=0, 56 | help="""Resize the predictions and ground truth masks before evaluation""") 57 | 58 | parser.add_argument('--pro_integration_limit', 59 | type=float, 60 | default=0.3, 61 | help="""Integration limit to compute the area under 62 | the PRO curve. Must lie within the interval 63 | of (0.0, 1.0].""") 64 | 65 | parser.add_argument('--evaluated_objects', 66 | nargs='+', 67 | help="""List of objects to be evaluated. By default, 68 | all dataset objects will be evaluated.""", 69 | choices=util.OBJECT_NAMES, 70 | default=util.OBJECT_NAMES) 71 | 72 | args = parser.parse_args() 73 | 74 | # Check that the PRO integration limit is within the valid range. 75 | assert 0.0 < args.pro_integration_limit <= 1.0 76 | 77 | return args 78 | 79 | 80 | def parse_dataset_files(object_name, dataset_base_dir, anomaly_maps_dir): 81 | """Parse the filenames for one object of the MVTec AD dataset. 82 | 83 | Args: 84 | object_name: Name of the dataset object. 85 | dataset_base_dir: Base directory of the MVTec AD dataset. 86 | anomaly_maps_dir: Base directory where anomaly maps are located. 87 | """ 88 | 89 | # Store a list of all ground truth filenames. 90 | gt_filenames = [] 91 | 92 | # Store a list of all corresponding anomaly map filenames. 93 | prediction_filenames = [] 94 | 95 | # Test images are located here. 96 | test_dir = path.join(dataset_base_dir, object_name, 'test') 97 | gt_base_dir = path.join(dataset_base_dir, object_name, 'ground_truth') 98 | anomaly_maps_base_dir = path.join(anomaly_maps_dir, object_name, 'test') 99 | 100 | # List all ground truth and corresponding anomaly images. 101 | for subdir in Path(test_dir).iterdir(): 102 | 103 | # Get paths to all test images in the dataset for this subdir. 104 | test_images = list(subdir.glob('*.png')) + list(subdir.glob('*.jpg')) 105 | test_images = sorted([x.stem for x in test_images]) 106 | 107 | # If subdir is not 'good', derive corresponding GT names. 108 | if subdir.name != 'good': 109 | file_names = (Path(gt_base_dir) / subdir.name).glob('*.png') 110 | gt_filenames.extend(sorted(file_names)) 111 | else: 112 | # No ground truth maps exist for anomaly-free images. 113 | gt_filenames.extend([None] * len(test_images)) 114 | 115 | # Fetch corresponding anomaly maps. 116 | prediction_filenames.extend( 117 | [path.join(anomaly_maps_base_dir, subdir.name, file) 118 | for file in test_images]) 119 | 120 | print(f"Parsed {len(gt_filenames)} ground truth image files.") 121 | 122 | return gt_filenames, prediction_filenames 123 | 124 | 125 | def calculate_au_pro_au_roc(gt_filenames, 126 | prediction_filenames, 127 | integration_limit, 128 | padding=0, 129 | resize=False): 130 | """Compute the area under the PRO curve for a set of ground truth images 131 | and corresponding anomaly images. 132 | 133 | In addition, the function computes the area under the ROC curve for image 134 | level classification. 135 | 136 | Args: 137 | gt_filenames: List of filenames that contain the ground truth images 138 | for a single dataset object. 139 | prediction_filenames: List of filenames that contain the corresponding 140 | anomaly images for each ground truth image. 141 | integration_limit: Integration limit to use when computing the area 142 | under the PRO curve. 143 | padding: Remove borders before evaluation 144 | resize: Resize images before padding 145 | Returns: 146 | au_pro: Area under the PRO curve computed up to the given integration 147 | limit. 148 | au_roc: Area under the ROC curve. 149 | pro_curve: PRO curve values for localization (fpr,pro). 150 | roc_curve: ROC curve values for image level classifiction (fpr,tpr). 151 | """ 152 | # Read all ground truth and anomaly images. 153 | ground_truth = [] 154 | predictions = [] 155 | 156 | print("Read ground truth files and corresponding predictions...") 157 | original_size = np.asarray(Image.open(next(_ for _ in gt_filenames if _ is not None))).shape[-2:] 158 | for (gt_name, pred_name) in tqdm(zip(gt_filenames, prediction_filenames), 159 | total=len(gt_filenames)): 160 | prediction = util.read_tiff(pred_name) 161 | if resize > 0: 162 | prediction = skimage.transform.resize(prediction, (resize, resize)) 163 | elif prediction.shape != original_size: 164 | prediction = skimage.transform.resize(prediction, original_size) 165 | if padding != 0: 166 | prediction = prediction[padding:-padding, padding:-padding] 167 | predictions.append(prediction) 168 | 169 | if gt_name is not None: 170 | gt = np.asarray(Image.open(gt_name)) 171 | if resize > 0: 172 | gt = skimage.transform.resize(gt, (resize, resize)) 173 | if padding != 0: 174 | gt = gt[padding:-padding, padding:-padding] 175 | else: 176 | gt = np.zeros(prediction.shape) 177 | ground_truth.append(gt) 178 | 179 | # Compute the PRO curve. 180 | pro_curve = compute_pro( 181 | anomaly_maps=predictions, 182 | ground_truth_maps=ground_truth) 183 | 184 | # Compute the area under the PRO curve. 185 | au_pro = util.trapezoid( 186 | pro_curve[0], pro_curve[1], x_max=integration_limit) 187 | au_pro /= integration_limit 188 | print(f"AU-PRO (FPR limit: {integration_limit}): {au_pro}") 189 | 190 | # Compute the segmentation ROC 191 | seg_roc_curve = compute_seg_au_roc(anomaly_maps=predictions, ground_truth_maps=ground_truth) 192 | au_seg_roc = util.trapezoid(seg_roc_curve[0], seg_roc_curve[1], x_max=1.0) 193 | print(f"AU-ROC segmentation: {au_seg_roc}") 194 | 195 | f1_optimal = compute_optimal_f1(anomaly_maps=predictions, ground_truth_maps=ground_truth) 196 | 197 | # Derive binary labels for each input image: 198 | # (0 = anomaly free, 1 = anomalous). 199 | binary_labels = [int(np.any(x > 0)) for x in ground_truth] 200 | del ground_truth 201 | 202 | if all([x == 1 for x in binary_labels]): 203 | roc_curve, au_roc = [], 1.0 204 | else: 205 | # Compute the classification ROC curve. 206 | roc_curve = compute_classification_roc( 207 | anomaly_maps=predictions, 208 | scoring_function=np.max, 209 | ground_truth_labels=binary_labels) 210 | 211 | # Compute the area under the classification ROC curve. 212 | au_roc = util.trapezoid(roc_curve[0], roc_curve[1]) 213 | print(f"Image-level classification AU-ROC: {au_roc}") 214 | 215 | # Return the evaluation metrics. 216 | return au_pro, au_roc, pro_curve, roc_curve, seg_roc_curve, au_seg_roc, f1_optimal 217 | 218 | 219 | def main(args): 220 | """Calculate the performance metrics for a single experiment on the 221 | MVTec AD dataset. 222 | """ 223 | # Parse user arguments. 224 | padding = args.padding 225 | resize = args.resize 226 | 227 | # Store evaluation results in this dictionary. 228 | evaluation_dict = dict() 229 | 230 | # Keep track of the mean performance measures. 231 | au_pros = [] 232 | au_rocs = [] 233 | au_segs = [] 234 | f1s = [] 235 | 236 | # Evaluate each dataset object separately. 237 | for obj in args.evaluated_objects: 238 | print(f"=== Evaluate {obj} ===") 239 | evaluation_dict[obj] = dict() 240 | 241 | # Parse the filenames of all ground truth and corresponding anomaly 242 | # images for this object. 243 | gt_filenames, prediction_filenames = \ 244 | parse_dataset_files( 245 | object_name=obj, 246 | dataset_base_dir=args.dataset_base_dir, 247 | anomaly_maps_dir=args.anomaly_maps_dir) 248 | 249 | # Calculate the PRO and ROC curves. 250 | au_pro, au_roc, pro_curve, roc_curve, seg_curve, au_seg, f1_optimal = \ 251 | calculate_au_pro_au_roc( 252 | gt_filenames, 253 | prediction_filenames, 254 | args.pro_integration_limit, padding=padding, resize=resize) 255 | 256 | evaluation_dict[obj]['au_pro'] = au_pro 257 | evaluation_dict[obj]['classification_au_roc'] = au_roc 258 | evaluation_dict[obj]['au_segroc'] = au_seg 259 | evaluation_dict[obj]['f1_optimal_value'] = f1_optimal[0] 260 | evaluation_dict[obj]['f1_optimal_threshold'] = f1_optimal[1] 261 | 262 | # evaluation_dict[obj]['classification_roc_curve_fpr'] = roc_curve[0] 263 | # evaluation_dict[obj]['classification_roc_curve_tpr'] = roc_curve[1] 264 | 265 | # Keep track of the mean performance measures. 266 | au_pros.append(au_pro) 267 | au_rocs.append(au_roc) 268 | au_segs.append(au_seg) 269 | f1s.append(f1_optimal[0]) 270 | 271 | print('\n') 272 | 273 | # Compute the mean of the performance measures. 274 | evaluation_dict['mean_au_pro'] = np.mean(au_pros).item() 275 | evaluation_dict['mean_au_segroc'] = np.mean(au_segs).item() 276 | evaluation_dict['mean_classification_au_roc'] = np.mean(au_rocs).item() 277 | evaluation_dict['mean_f1'] = np.mean(f1s).item() 278 | 279 | # If required, write evaluation metrics to drive. 280 | if args.output_dir is not None: 281 | output_dir = args.output_dir 282 | makedirs(output_dir, exist_ok=True) 283 | 284 | with open(path.join(output_dir, f'metrics_{padding}.json'), 'w') as file: 285 | json.dump(evaluation_dict, file, indent=4) 286 | 287 | print(f"Wrote metrics to {path.join(output_dir, f'metrics_{padding}.json')}") 288 | 289 | 290 | if __name__ == "__main__": 291 | main(parse_user_arguments()) 292 | -------------------------------------------------------------------------------- /evaluation/evaluate_multiple_experiments.py: -------------------------------------------------------------------------------- 1 | """Run the evaluation script for multiple experiments. 2 | 3 | This is a wrapper around evaluate_experiments.py which is called once for each 4 | experiment specified in the config file passed to this script. 5 | """ 6 | 7 | __author__ = "Paul Bergmann, David Sattlegger" 8 | __copyright__ = "2021, MVTec Software GmbH" 9 | 10 | import argparse 11 | import json 12 | import subprocess 13 | from os import path 14 | 15 | 16 | def parse_user_arguments(): 17 | """Parse user arguments. 18 | 19 | Returns: Parsed arguments. 20 | """ 21 | parser = argparse.ArgumentParser(description="""Parse user arguments.""") 22 | 23 | parser.add_argument('--experiment_configs', 24 | default='experiment_configs.json', 25 | help="""Path to the config file that contains the 26 | locations of all experiments that should be 27 | evaluated.""") 28 | 29 | parser.add_argument('--dataset_base_dir', 30 | required=True, 31 | help="""Path to the directory that contains the dataset 32 | images of the MVTec AD dataset.""") 33 | 34 | parser.add_argument('--output_dir', 35 | default='metrics/', 36 | help="""Path to write evaluation results to.""") 37 | 38 | parser.add_argument('--dry_run', 39 | choices=['True', 'False'], 40 | default='False', 41 | help="""If set to 'True', the script is run without 42 | perfoming the actual evalutions. Instead, the 43 | experiments to be evaluated are simply printed 44 | to the standard output.""") 45 | 46 | parser.add_argument('--pro_integration_limit', 47 | type=float, 48 | default=0.3, 49 | help="""Integration limit to compute the area under 50 | the PRO curve. Must lie within the interval 51 | of (0.0, 1.0].""") 52 | 53 | return parser.parse_args() 54 | 55 | 56 | def main(): 57 | """Run the evaluation script for multiple experiments.""" 58 | # Parse user arguments. 59 | args = parse_user_arguments() 60 | 61 | # Read the experiment configurations to be evaluated. 62 | with open(args.experiment_configs) as file: 63 | experiment_configs = json.load(file) 64 | 65 | # Call the evaluation script for each experiment separately. 66 | for experiment_id in experiment_configs['anomaly_maps_dirs']: 67 | print(f"=== Evaluate experiment: {experiment_id} ===\n") 68 | 69 | # Anomaly maps for this experiment are located in this directory. 70 | anomaly_maps_dir = path.join( 71 | experiment_configs['exp_base_dir'], 72 | experiment_configs['anomaly_maps_dirs'][experiment_id]) 73 | 74 | # Set up python call for the evaluation script. 75 | call = ['python', 'evaluate_experiment.py', 76 | '--anomaly_maps_dir', anomaly_maps_dir, 77 | '--dataset_base_dir', args.dataset_base_dir, 78 | '--output_dir', path.join(args.output_dir, experiment_id), 79 | '--pro_integration_limit', str(args.pro_integration_limit)] 80 | 81 | # Run evaluation script. 82 | if args.dry_run == 'False': 83 | subprocess.run(call, check=True) 84 | else: 85 | print(f"Would call: {' '.join(call)}\n") 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /evaluation/generic_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utility functions for: 3 | - parsing user arguments. 4 | - computing the area under a curve. 5 | - generating a toy dataset to test the evaluation script. 6 | """ 7 | from bisect import bisect 8 | import os 9 | 10 | import numpy as np 11 | import tifffile as tiff 12 | 13 | OBJECT_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid', 14 | 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 15 | 'tile', 'toothbrush', 'transistor', 'wood', 'zipper'] 16 | 17 | 18 | def trapezoid(x, y, x_max=None): 19 | """ 20 | This function calculates the definit integral of a curve given by 21 | x- and corresponding y-values. In contrast to, e.g., 'numpy.trapz()', 22 | this function allows to define an upper bound to the integration range by 23 | setting a value x_max. 24 | 25 | Points that do not have a finite x or y value will be ignored with a 26 | warning. 27 | 28 | Args: 29 | x: Samples from the domain of the function to integrate 30 | Need to be sorted in ascending order. May contain the same value 31 | multiple times. In that case, the order of the corresponding 32 | y values will affect the integration with the trapezoidal rule. 33 | y: Values of the function corresponding to x values. 34 | x_max: Upper limit of the integration. The y value at max_x will be 35 | determined by interpolating between its neighbors. Must not lie 36 | outside of the range of x. 37 | 38 | Returns: 39 | Area under the curve. 40 | """ 41 | 42 | x = np.asarray(x) 43 | y = np.asarray(y) 44 | finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y)) 45 | if not finite_mask.all(): 46 | print("WARNING: Not all x and y values passed to trapezoid(...)" 47 | " are finite. Will continue with only the finite values.") 48 | x = x[finite_mask] 49 | y = y[finite_mask] 50 | 51 | # Introduce a correction term if max_x is not an element of x. 52 | correction = 0. 53 | if x_max is not None: 54 | if x_max not in x: 55 | # Get the insertion index that would keep x sorted after 56 | # np.insert(x, ins, x_max). 57 | ins = bisect(x, x_max) 58 | # x_max must be between the minimum and the maximum, so the 59 | # insertion_point cannot be zero or len(x). 60 | assert 0 < ins < len(x) 61 | 62 | # Calculate the correction term which is the integral between 63 | # the last x[ins-1] and x_max. Since we do not know the exact value 64 | # of y at x_max, we interpolate between y[ins] and y[ins-1]. 65 | y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) * 66 | (x_max - x[ins - 1]) / 67 | (x[ins] - x[ins - 1])) 68 | correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1]) 69 | 70 | # Cut off at x_max. 71 | mask = x <= x_max 72 | x = x[mask] 73 | y = y[mask] 74 | 75 | # Return area under the curve using the trapezoidal rule. 76 | return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction 77 | 78 | 79 | def read_tiff(file_path_no_ext, exts=('.tif', '.tiff', '.TIF', '.TIFF')): 80 | """Read a TIFF file from a given path without the TIFF extension. 81 | 82 | Args: 83 | file_path_no_ext: Path to the TIFF file without a file extension. 84 | exts: TIFF extensions to consider when searching for the file. 85 | 86 | Raises: 87 | FileNotFoundError: The given file path does not exist with any of the 88 | given extensions. 89 | IOError: The given file path exists with multiple of the given 90 | extensions. 91 | """ 92 | # Get all file paths that exist 93 | file_paths = [] 94 | for ext in exts: 95 | # Make sure the file path does not already end with a tiff extension. 96 | assert not file_path_no_ext.endswith(ext) 97 | file_path = file_path_no_ext + ext 98 | if os.path.exists(file_path): 99 | file_paths.append(file_path) 100 | 101 | if len(file_paths) == 0: 102 | raise FileNotFoundError('Could not find a file with a TIFF extension' 103 | f' at {file_path_no_ext}') 104 | elif len(file_paths) > 1: 105 | raise IOError('Found multiple files with a TIFF extension at' 106 | f' {file_path_no_ext}' 107 | '\nPlease specify which TIFF extension to use via the' 108 | ' `exts` parameter of this function.') 109 | 110 | return tiff.imread(file_paths[0]) 111 | 112 | 113 | def generate_toy_dataset(num_images, image_width, image_height, gt_size): 114 | """Generate a toy dataset to test the evaluation script. 115 | 116 | Args: 117 | num_images: Number of images that the toy dataset contains. 118 | image_width: Width of the dataset images in pixels. 119 | image_height: Height of the dataset images in pixels. 120 | gt_size: Size of rectangular ground truth regions that are 121 | artificially generated on the dataset images. 122 | 123 | Returns: 124 | anomaly_maps: List of numpy arrays that contain random anomaly maps. 125 | ground_truth_map: Corresponding list of numpy arrays that specify a 126 | rectangular ground truth region. 127 | """ 128 | # Fix a random seed for reproducibility. 129 | np.random.seed(1338) 130 | 131 | # Create synthetic evaluation data with random anomaly scores and 132 | # simple ground truth maps. 133 | anomaly_maps = [] 134 | ground_truth_maps = [] 135 | for _ in range(num_images): 136 | # Sample a random anomaly map. 137 | anomaly_map = np.random.random((image_height, image_width)) 138 | 139 | # Construct a fixed ground truth map. 140 | ground_truth_map = np.zeros((image_height, image_width)) 141 | ground_truth_map[0:gt_size, 0:gt_size] = 1 142 | 143 | anomaly_maps.append(anomaly_map) 144 | ground_truth_maps.append(ground_truth_map) 145 | 146 | return anomaly_maps, ground_truth_maps 147 | -------------------------------------------------------------------------------- /evaluation/print_metrics.py: -------------------------------------------------------------------------------- 1 | """Print the key metrics of multiple experiments to the standard output. 2 | """ 3 | 4 | __author__ = "Paul Bergmann, David Sattlegger" 5 | __copyright__ = "2021, MVTec Software GmbH" 6 | 7 | import argparse 8 | import json 9 | import os 10 | from os.path import join 11 | 12 | import numpy as np 13 | from tabulate import tabulate 14 | 15 | from generic_util import OBJECT_NAMES 16 | 17 | 18 | def parse_user_arguments(): 19 | """Parse user arguments. 20 | 21 | Returns: Parsed arguments. 22 | """ 23 | parser = argparse.ArgumentParser(description="""Parse user arguments.""") 24 | 25 | parser.add_argument('--metrics_folder', 26 | default="./metrics/", 27 | help="""Path to the folder that contains the evaluation 28 | results.""") 29 | 30 | return parser.parse_args() 31 | 32 | 33 | def extract_table_rows(metrics_folder, metric): 34 | """Extract all rows to create a table that displays a given metric for each 35 | evaluated experiment. 36 | 37 | Args: 38 | metrics_folder: Base folder that contains evaluation results. 39 | metric: Name of the metric to be extracted. Choose between 40 | 'au_pro' for localization and 41 | 'classification_au_roc' for classification. 42 | 43 | Returns: 44 | List of table rows. Each row contains the experiment name and the 45 | extracted metrics for each evaluated object as well as the mean 46 | performance. 47 | """ 48 | assert metric in ['au_pro', 'classification_au_roc'] 49 | 50 | # Iterate each experiment. 51 | exp_ids = os.listdir(metrics_folder) 52 | exp_id_to_json_path = { 53 | exp_id: join(metrics_folder, exp_id, 'metrics.json') 54 | for exp_id in exp_ids 55 | if os.path.exists(join(metrics_folder, exp_id, 'metrics.json')) 56 | } 57 | 58 | # If there is a metrics.json file in the metrics_folder, also print that. 59 | # This is the case when evaluate_experiment.py has been called. 60 | root_metrics_json_path = join(metrics_folder, 'metrics.json') 61 | if os.path.exists(root_metrics_json_path): 62 | exp_id = join(os.path.split(metrics_folder)[-1], 'metrics.json') 63 | exp_id_to_json_path[exp_id] = root_metrics_json_path 64 | 65 | rows = [] 66 | for exp_id, json_path in exp_id_to_json_path.items(): 67 | 68 | # Each row starts with the name of the experiment. 69 | row = [exp_id] 70 | 71 | # Open the metrics file. 72 | with open(json_path) as file: 73 | metrics = json.load(file) 74 | 75 | # Parse performance metrics for each evaluated object if available. 76 | for obj in OBJECT_NAMES: 77 | if obj in metrics: 78 | row.append(np.round(metrics[obj][metric], decimals=3)) 79 | else: 80 | row.append('-') 81 | 82 | # Parse mean performance. 83 | row.append(np.round(metrics['mean_' + metric], decimals=3)) 84 | rows.append(row) 85 | 86 | return rows 87 | 88 | 89 | def main(): 90 | """Print the key metrics of multiple experiments to the standard output. 91 | """ 92 | # Parse user arguments. 93 | args = parse_user_arguments() 94 | 95 | # Create the table rows. One row for each experiment. 96 | rows_pro = extract_table_rows(args.metrics_folder, 'au_pro') 97 | rows_roc = extract_table_rows(args.metrics_folder, 'classification_au_roc') 98 | 99 | # Print localization result table. 100 | print("\nAU PRO (localization)") 101 | print( 102 | tabulate( 103 | rows_pro, headers=['Experiment'] + OBJECT_NAMES + ['Mean'], 104 | tablefmt='fancy_grid')) 105 | 106 | # Print classification result table. 107 | print("\nAU ROC (classification)") 108 | print( 109 | tabulate( 110 | rows_roc, headers=['Experiment'] + OBJECT_NAMES + ['Mean'], 111 | tablefmt='fancy_grid')) 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /evaluation/pro_curve_util.py: -------------------------------------------------------------------------------- 1 | """Utility function that computes a PRO curve, given pairs of anomaly and ground 2 | truth maps. 3 | 4 | The PRO curve can also be integrated up to a constant integration limit. 5 | """ 6 | import numpy as np 7 | from scipy.ndimage.measurements import label 8 | 9 | 10 | def compute_pro(anomaly_maps, ground_truth_maps): 11 | """Compute the PRO curve for a set of anomaly maps with corresponding ground 12 | truth maps. 13 | 14 | Args: 15 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a 16 | real-valued anomaly score at each pixel. 17 | 18 | ground_truth_maps: List of ground truth maps (2D numpy arrays) that 19 | contain binary-valued ground truth labels for each pixel. 20 | 0 indicates that a pixel is anomaly-free. 21 | 1 indicates that a pixel contains an anomaly. 22 | 23 | Returns: 24 | fprs: numpy array of false positive rates. 25 | pros: numpy array of corresponding PRO values. 26 | """ 27 | 28 | print("Compute PRO curve...") 29 | 30 | # Structuring element for computing connected components. 31 | structure = np.ones((3, 3), dtype=int) 32 | 33 | num_ok_pixels = 0 34 | num_gt_regions = 0 35 | 36 | shape = (len(anomaly_maps), 37 | anomaly_maps[0].shape[0], 38 | anomaly_maps[0].shape[1]) 39 | fp_changes = np.zeros(shape, dtype=np.uint32) 40 | assert shape[0] * shape[1] * shape[2] < np.iinfo(fp_changes.dtype).max, \ 41 | 'Potential overflow when using np.cumsum(), consider using np.uint64.' 42 | 43 | pro_changes = np.zeros(shape, dtype=np.float64) 44 | 45 | for gt_ind, gt_map in enumerate(ground_truth_maps): 46 | 47 | # Compute the connected components in the ground truth map. 48 | labeled, n_components = label(gt_map, structure) 49 | num_gt_regions += n_components 50 | 51 | # Compute the mask that gives us all ok pixels. 52 | ok_mask = labeled == 0 53 | num_ok_pixels_in_map = np.sum(ok_mask) 54 | num_ok_pixels += num_ok_pixels_in_map 55 | 56 | # Compute by how much the FPR changes when each anomaly score is 57 | # added to the set of positives. 58 | # fp_change needs to be normalized later when we know the final value 59 | # of num_ok_pixels -> right now it is only the change in the number of 60 | # false positives 61 | fp_change = np.zeros_like(gt_map, dtype=fp_changes.dtype) 62 | fp_change[ok_mask] = 1 63 | 64 | # Compute by how much the PRO changes when each anomaly score is 65 | # added to the set of positives. 66 | # pro_change needs to be normalized later when we know the final value 67 | # of num_gt_regions. 68 | pro_change = np.zeros_like(gt_map, dtype=np.float64) 69 | for k in range(n_components): 70 | region_mask = labeled == (k + 1) 71 | region_size = np.sum(region_mask) 72 | pro_change[region_mask] = 1. / region_size 73 | 74 | fp_changes[gt_ind, :, :] = fp_change 75 | pro_changes[gt_ind, :, :] = pro_change 76 | 77 | # Flatten the numpy arrays before sorting. 78 | anomaly_scores_flat = np.array(anomaly_maps).ravel() 79 | fp_changes_flat = fp_changes.ravel() 80 | pro_changes_flat = pro_changes.ravel() 81 | 82 | # Sort all anomaly scores. 83 | print(f"Sort {len(anomaly_scores_flat)} anomaly scores...") 84 | sort_idxs = np.argsort(anomaly_scores_flat).astype(np.uint32)[::-1] 85 | 86 | # Info: np.take(a, ind, out=a) followed by b=a instead of 87 | # b=a[ind] showed to be more memory efficient. 88 | np.take(anomaly_scores_flat, sort_idxs, out=anomaly_scores_flat) 89 | anomaly_scores_sorted = anomaly_scores_flat 90 | np.take(fp_changes_flat, sort_idxs, out=fp_changes_flat) 91 | fp_changes_sorted = fp_changes_flat 92 | np.take(pro_changes_flat, sort_idxs, out=pro_changes_flat) 93 | pro_changes_sorted = pro_changes_flat 94 | 95 | del sort_idxs 96 | 97 | # Get the (FPR, PRO) curve values. 98 | np.cumsum(fp_changes_sorted, out=fp_changes_sorted) 99 | fp_changes_sorted = fp_changes_sorted.astype(np.float32, copy=False) 100 | np.divide(fp_changes_sorted, num_ok_pixels, out=fp_changes_sorted) 101 | fprs = fp_changes_sorted 102 | 103 | np.cumsum(pro_changes_sorted, out=pro_changes_sorted) 104 | np.divide(pro_changes_sorted, num_gt_regions, out=pro_changes_sorted) 105 | pros = pro_changes_sorted 106 | 107 | # Merge (FPR, PRO) points that occur together at the same threshold. 108 | # For those points, only the final (FPR, PRO) point should be kept. 109 | # That is because that point is the one that takes all changes 110 | # to the FPR and the PRO at the respective threshold into account. 111 | # -> keep_mask is True if the subsequent score is different from the 112 | # score at the respective position. 113 | # anomaly_scores_sorted = [7, 4, 4, 4, 3, 1, 1] 114 | # -> keep_mask = [T, F, F, T, T, F] 115 | keep_mask = np.append(np.diff(anomaly_scores_sorted) != 0, np.True_) 116 | del anomaly_scores_sorted 117 | 118 | fprs = fprs[keep_mask] 119 | pros = pros[keep_mask] 120 | del keep_mask 121 | 122 | # To mitigate the adding up of numerical errors during the np.cumsum calls, 123 | # make sure that the curve ends at (1, 1) and does not contain values > 1. 124 | np.clip(fprs, a_min=None, a_max=1., out=fprs) 125 | np.clip(pros, a_min=None, a_max=1., out=pros) 126 | 127 | # Make the fprs and pros start at 0 and end at 1. 128 | zero = np.array([0.]) 129 | one = np.array([1.]) 130 | 131 | return np.concatenate((zero, fprs, one)), np.concatenate((zero, pros, one)) 132 | 133 | 134 | def main(): 135 | """ 136 | Compute the area under the PRO curve for a toy dataset and an algorithm 137 | that randomly assigns anomaly scores to each pixel. The integration 138 | limit can be specified. 139 | """ 140 | 141 | from generic_util import trapezoid, generate_toy_dataset 142 | 143 | integration_limit = 0.3 144 | 145 | # Generate a toy dataset. 146 | anomaly_maps, ground_truth_maps = generate_toy_dataset( 147 | num_images=200, image_width=500, image_height=300, gt_size=10) 148 | 149 | # Compute the PRO curve for this dataset. 150 | all_fprs, all_pros = compute_pro( 151 | anomaly_maps=anomaly_maps, 152 | ground_truth_maps=ground_truth_maps) 153 | 154 | au_pro = trapezoid(all_fprs, all_pros, x_max=integration_limit) 155 | au_pro /= integration_limit 156 | print(f"AU-PRO (FPR limit: {integration_limit}): {au_pro}") 157 | 158 | 159 | if __name__ == "__main__": 160 | main() 161 | -------------------------------------------------------------------------------- /evaluation/roc_curve_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions that compute a ROC curve and integrate its area from a set 3 | of anomaly maps and corresponding ground truth classification labels. 4 | """ 5 | import numpy as np 6 | 7 | 8 | def compute_classification_roc( 9 | anomaly_maps, 10 | scoring_function, 11 | ground_truth_labels): 12 | """Compute the ROC curve for anomaly classification on the image level. 13 | 14 | Args: 15 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain 16 | a real-valued anomaly score at each pixel. 17 | scoring_function: Function that turns anomaly maps into a single 18 | real valued anomaly score. 19 | 20 | ground_truth_labels: List of integers that indicate the ground truth 21 | class for each input image. 0 corresponds to an anomaly-free sample 22 | while a value != 0 indicates an anomalous sample. 23 | 24 | Returns: 25 | fprs: List of false positive rates. 26 | tprs: List of correspoding true positive rates. 27 | """ 28 | assert len(anomaly_maps) == len(ground_truth_labels) 29 | 30 | # Compute the anomaly score for each anomaly map. 31 | anomaly_scores = map(scoring_function, anomaly_maps) 32 | num_scores = len(anomaly_maps) 33 | 34 | # Sort samples by anomaly score. Keep track of ground truth label. 35 | sorted_samples = \ 36 | sorted(zip(anomaly_scores, ground_truth_labels), key=lambda x: x[0]) 37 | 38 | # Compute the number of OK and NOK samples from the ground truth. 39 | ground_truth_labels_np = np.array(ground_truth_labels) 40 | num_nok = ground_truth_labels_np[ground_truth_labels_np != 0].size 41 | num_ok = ground_truth_labels_np[ground_truth_labels_np == 0].size 42 | 43 | # Initially, every NOK sample is correctly classified as anomalous 44 | # (tpr = 1.0), and every OK sample is incorrectly classified as anomalous 45 | # (fpr = 1.0). 46 | fprs = [1.0] 47 | tprs = [1.0] 48 | 49 | # Keep track of the current number of false and true positive predictions. 50 | num_fp = num_ok 51 | num_tp = num_nok 52 | 53 | # Compute new true and false positive rates when successively increasing 54 | # the threshold. 55 | next_score = None 56 | 57 | for i, (current_score, label) in enumerate(sorted_samples): 58 | 59 | if label == 0: 60 | num_fp -= 1 61 | else: 62 | num_tp -= 1 63 | 64 | if i < num_scores - 1: 65 | next_score = sorted_samples[i + 1][0] 66 | else: 67 | next_score = None # end of list 68 | 69 | if (next_score != current_score) or (next_score is None): 70 | fprs.append(num_fp / num_ok) 71 | tprs.append(num_tp / num_nok) 72 | 73 | # Return (FPR, TPR) pairs in increasing order. 74 | fprs = fprs[::-1] 75 | tprs = tprs[::-1] 76 | 77 | return fprs, tprs 78 | 79 | 80 | def main(): 81 | """ 82 | Compute the area under the ROC curve for a toy dataset and an algorithm 83 | that randomly assigns anomaly scores to each image pixel. 84 | """ 85 | 86 | from generic_util import trapezoid, generate_toy_dataset 87 | 88 | # Generate a toy dataset. 89 | anomaly_maps, _ = generate_toy_dataset( 90 | num_images=10000, image_width=30, image_height=30, gt_size=0) 91 | 92 | # Assign a random classification label to each image. 93 | np.random.seed(42) 94 | labels = np.random.randint(2, size=len(anomaly_maps)) 95 | 96 | # Compute the ROC curve. 97 | all_fprs, all_tprs = compute_classification_roc(anomaly_maps=anomaly_maps, 98 | scoring_function=np.max, 99 | ground_truth_labels=labels) 100 | 101 | # Compute the area under the ROC curve. 102 | au_roc = trapezoid(all_fprs, all_tprs) 103 | print(f"AU-ROC: {au_roc}") 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /example/000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/example/000.png -------------------------------------------------------------------------------- /example/000_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/example/000_out.png -------------------------------------------------------------------------------- /feature_extractors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | from torchvision.models import wide_resnet50_2, Wide_ResNet50_2_Weights 5 | 6 | import op_utils 7 | from op_utils import scale_features, reflect_pad 8 | from PyTorchSteerablePyramid import extract_steerable_features 9 | 10 | 11 | class Colors: 12 | def __init__(self, post_scale=False): 13 | self.post_scale = post_scale 14 | 15 | def __call__(self, image: torch.tensor) -> torch.tensor: 16 | if self.post_scale: 17 | return scale_features(image) 18 | else: 19 | return image 20 | 21 | 22 | class RandomKernels: 23 | def __init__(self, input_c=3, projections=((256, 1),), patch_size=7, device=None): 24 | self.ms_kernels = [self.build_kernels(input_c, num_proj, patch_size, device) for num_proj, _ in projections] 25 | self.scales = [scale for _, scale in projections] 26 | 27 | @staticmethod 28 | def build_kernels(c, num_proj, patch_size, device): 29 | kernels = torch.randn(num_proj, c * patch_size ** 2, device=device) 30 | kernels = kernels / torch.norm(kernels, dim=1, keepdim=True) 31 | kernels = kernels.reshape(num_proj, c, patch_size, patch_size) 32 | return kernels 33 | 34 | def __call__(self, image: torch.tensor) -> torch.tensor: 35 | parts = [] 36 | for kernels, scale in zip(self.ms_kernels, self.scales): 37 | scaled_im = F.interpolate(image, scale_factor=scale, mode='bilinear') 38 | scaled_f = F.conv2d(reflect_pad(scaled_im, patch_size=kernels.shape[-1]), kernels) 39 | parts.append(F.interpolate(scaled_f, size=image.shape[-2:], mode='bilinear')) 40 | return torch.cat(parts, dim=1) 41 | 42 | 43 | class NeuralExtractor: 44 | def __init__(self, post_scale=True, device=None): 45 | self.normalization_mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(-1, 1, 1) 46 | self.normalization_std = torch.tensor([0.229, 0.224, 0.225], device=device).view(-1, 1, 1) 47 | self.post_scale = post_scale 48 | self.device = device 49 | 50 | def normalize(self, image: torch.tensor): 51 | return (image - self.normalization_mean) / self.normalization_std 52 | 53 | def post(self, features): 54 | if self.post_scale: 55 | return scale_features(features) 56 | return features 57 | 58 | 59 | class VggExtractor(NeuralExtractor): 60 | def __init__(self, layers=('conv_1', 'conv_2', 'conv_3'), post_scale=True, device=None): 61 | super(VggExtractor, self).__init__(post_scale, device) 62 | self.cnn = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.to(device).eval() 63 | self.layers = layers 64 | 65 | def get_vgg_features(self, image: torch.tensor) -> torch.tensor: 66 | out = self.normalize(image) 67 | vgg_features = [] 68 | i = 1 69 | for layer in self.cnn.children(): 70 | out = layer(out) 71 | if isinstance(layer, torch.nn.Conv2d) and f'conv_{i}' in self.layers: 72 | vgg_features.append(out.clone().detach()) 73 | i += 1 74 | 75 | return vgg_features 76 | 77 | @staticmethod 78 | def union_features(vgg_features): 79 | hw = vgg_features[0].shape[-2:] 80 | same_size = [F.interpolate(f, size=hw, mode='bilinear') for f in vgg_features] 81 | return torch.cat(same_size, dim=1) 82 | 83 | @torch.no_grad() 84 | def __call__(self, image: torch.tensor) -> torch.tensor: 85 | vgg_features = self.get_vgg_features(image) 86 | union = self.union_features(vgg_features) 87 | return self.post(union) 88 | 89 | 90 | class WideResnetExtractor(NeuralExtractor): 91 | def __init__(self, post_scale=True, device=None): 92 | super(WideResnetExtractor, self).__init__(post_scale, device) 93 | cnn = wide_resnet50_2(weights=Wide_ResNet50_2_Weights.IMAGENET1K_V1).eval().to(device) 94 | self.model = torch.nn.Sequential(*list(cnn.children())[:6]) 95 | 96 | @torch.no_grad() 97 | def __call__(self, image: torch.tensor) -> torch.tensor: 98 | features = self.model(self.normalize(image)) 99 | return self.post(features) 100 | 101 | 102 | class SteerableExtractor: 103 | def __init__(self, post_scale=True, o=4, m=1, scale_factor=2): 104 | self.post_scale = post_scale 105 | self.o = o 106 | self.m = m 107 | self.scale_factor = scale_factor 108 | 109 | def __call__(self, image: torch.tensor) -> torch.tensor: 110 | features = extract_steerable_features(image, self.o, self.m, self.scale_factor) 111 | if self.post_scale: 112 | features = scale_features(features) 113 | return features 114 | 115 | 116 | class LawTextureEnergyMeasure: 117 | def __init__(self, mean_patch_size=15, device=None): 118 | self.mean_patch_size = mean_patch_size 119 | # noinspection SpellCheckingInspection 120 | lesr = torch.tensor([[1, 4, 6, 4, 1], 121 | [-1, -2, 0, 2, 1], 122 | [-1, 0, 2, 0, -1], 123 | [1, -4, 6, -4, 1]], dtype=torch.float32, device=device) 124 | outers = torch.einsum('ni,mj->nmij', lesr, lesr) 125 | self.kernels = outers.reshape(outers.shape[0] * outers.shape[1], 1, *outers.shape[-2:]) # 16 x 1 x 5 x 5 126 | 127 | def __call__(self, image: torch.tensor) -> torch.tensor: 128 | image = torch.mean(image, dim=1, keepdim=True) # Grayscale B x 1 x H x W 129 | if self.mean_patch_size != 0: 130 | image = image - op_utils.blur(image, self.mean_patch_size, sigma=20) 131 | energy_maps = F.conv2d(reflect_pad(image, patch_size=5), self.kernels) / 10 132 | features = torch.stack([ 133 | energy_maps[:, 5], # E5E5 134 | energy_maps[:, 10], # S5S5 135 | energy_maps[:, 15], # R5R5 136 | (energy_maps[:, 1] + energy_maps[:, 4]) / 2.0, # L5E5 + E5L5 137 | (energy_maps[:, 2] + energy_maps[:, 8]) / 2.0, # L5S5 + S5L5 138 | (energy_maps[:, 3] + energy_maps[:, 12]) / 2.0, # L5R5 + R5L5 139 | (energy_maps[:, 6] + energy_maps[:, 9]) / 2.0, # E5S5 + S5E5 140 | (energy_maps[:, 7] + energy_maps[:, 13]) / 2.0, # E5R5 + R5E5 141 | (energy_maps[:, 11] + energy_maps[:, 14]) / 2.0, # S5R5 + R5S5 142 | ], dim=1) 143 | return features 144 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import hydra 4 | import torch 5 | from hydra.utils import instantiate 6 | from omegaconf import DictConfig, OmegaConf 7 | from tqdm import tqdm 8 | 9 | from op_utils import load_image_tensor 10 | 11 | 12 | def compute_anomaly_map(image: torch.tensor, feature_extractor: Callable, stat_method: Callable): 13 | features = feature_extractor(image) 14 | age = stat_method(features) 15 | return age 16 | 17 | 18 | def run_anomaly_detection(dataset, cfg: DictConfig): 19 | fe = instantiate(cfg.fe.feature_extractor) 20 | sc = instantiate(cfg.sc.method) 21 | for in_image_path in tqdm(dataset): 22 | image = load_image_tensor(in_image_path, cfg.image_size, device=cfg.device) 23 | anomaly_map = compute_anomaly_map(image, fe, sc) 24 | 25 | dataset.save_output(anomaly_map, in_image_path) 26 | 27 | 28 | @hydra.main(version_base=None, config_path="conf", config_name="base") 29 | def my_app(cfg: DictConfig) -> None: 30 | print(OmegaConf.to_yaml(cfg)) 31 | data_manager = instantiate(cfg.dataset.data_manager) 32 | run_anomaly_detection(data_manager, cfg) 33 | data_manager.run_evaluation() 34 | 35 | 36 | if __name__ == "__main__": 37 | my_app() 38 | -------------------------------------------------------------------------------- /op_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchvision 4 | import cv2 5 | import numpy as np 6 | 7 | def scale_features(tensor: torch.tensor) -> torch.tensor: 8 | # tensor shape B x C x H x W 9 | tf = tensor.flatten(start_dim=2) 10 | mini = tf.min(dim=-1).values[..., None, None] 11 | maxi = tf.max(dim=-1).values[..., None, None] 12 | div = (maxi - mini + 1e-8) 13 | return (tensor - mini) / div 14 | 15 | # noinspection PyProtectedMember,PyUnresolvedReferences 16 | def get_gaussian_kernel(device, tile_size, s: float): 17 | return torchvision.transforms._functional_tensor._get_gaussian_kernel2d(tile_size, [s, s], torch.float32, device) 18 | 19 | def blur(image, kernel_size=7, sigma=None): 20 | if sigma is None: 21 | sigma = kernel_size / 4 22 | shape = image.shape 23 | im_b = image[(None,) * (4 - len(shape))] 24 | # noinspection PyUnresolvedReferences 25 | return torchvision.transforms.functional.gaussian_blur(im_b, kernel_size, sigma=sigma).view(shape) 26 | 27 | 28 | def load_image_tensor(path, image_size, device=None): 29 | input_image = cv2.imread(str(path)) 30 | if image_size == "original": 31 | image_size = input_image.shape[:2] 32 | else: 33 | image_size = tuple(image_size) 34 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) 35 | return F.interpolate(image_to_tensor(input_image, device=device), size=image_size, mode='bilinear') 36 | 37 | 38 | def image_to_tensor(image, device): 39 | image = torch.tensor(image, device=device, dtype=torch.float32) 40 | image = image.permute(2, 0, 1)[None] / 255 41 | return image 42 | 43 | 44 | def tensor_to_image(image: torch.tensor): 45 | out = image.detach().cpu() 46 | out = (out.squeeze(0).permute(1, 2, 0) * 255) 47 | return out.numpy().astype(np.uint8) 48 | 49 | 50 | def reflect_pad(image, patch_size=7): 51 | p = patch_size // 2 52 | return torch.nn.functional.pad(image, (p, p, p, p), mode='reflect') 53 | -------------------------------------------------------------------------------- /outputs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/outputs/.gitkeep -------------------------------------------------------------------------------- /sc_methods.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchvision 4 | 5 | from op_utils import blur, scale_features, reflect_pad, get_gaussian_kernel 6 | 7 | 8 | def sc_moments(features, tile_size, powers=(1, 2, 3, 4), sigma=6.0): 9 | assert features.shape[0] == 1 # Only implemented for one image, but easy to extend... 10 | ps = features.new_tensor(list(powers)) 11 | b, c, h, w = features.shape 12 | moments = torch.pow(features[:, :, None, :, :], ps[None, None, :, None, None]).reshape(b, c * len(ps), h, w) 13 | local_moments = blur(moments, kernel_size=tile_size[0], sigma=sigma)[0] # C*powers x H x W 14 | 15 | def dist(one, many): 16 | vec_arr = one[:, None, None].expand_as(many) 17 | loss = torch.mean((vec_arr - many) ** 2, dim=0) 18 | del vec_arr 19 | return loss 20 | 21 | representative = torch.mean(local_moments.flatten(start_dim=-2), dim=-1) # C*powers 22 | return dist(representative, local_moments) 23 | 24 | 25 | def sc_hist(features, tile_size, bins=10, sigma=6.0): 26 | assert features.shape[0] == 1 # Only implemented for one image, but easy to extend... 27 | scaled = scale_features(features) 28 | indices = torch.floor(0.999 * bins * scaled).type(torch.int64) 29 | one_hot = F.one_hot(indices) 30 | hist = one_hot.permute(0, 1, 4, 2, 3).flatten(start_dim=1, end_dim=2).type(torch.float32) 31 | aggregated_hist = blur(hist, tile_size[0], sigma=sigma)[0] # (C*bins) x H x W 32 | aggregated_hist = aggregated_hist.view(features.shape[1], bins, *features.shape[-2:]) # C x bins x H x W 33 | 34 | def dist(one, many): 35 | # Computes EMD distance between two 1D histograms 36 | vec_arr = one[:, None, None].expand_as(many) 37 | c1 = torch.cumsum(vec_arr, dim=0) 38 | c2 = torch.cumsum(many, dim=0) 39 | loss = torch.mean(torch.abs(c1 - c2), dim=0) 40 | del vec_arr 41 | return loss 42 | 43 | representative = torch.mean(aggregated_hist.flatten(start_dim=-2), dim=-1) # C x bins 44 | return torch.stack([dist(c_rep, c_hist) for (c_rep, c_hist) in zip(representative, aggregated_hist)]).mean(dim=0) 45 | 46 | 47 | def get_reference(reference_type, tile_size, num_ref=1): 48 | def reference_median(features): 49 | unf = F.unfold(features, tile_size, stride=tile_size)[0].T.reshape(-1, features.shape[1], *tile_size) 50 | val, _ = torch.sort(unf.view(-1, features.shape[1], tile_size[0] * tile_size[1]), dim=-1) 51 | return torch.median(val, dim=0).values # C x tile**2 52 | 53 | def reference_random(features): 54 | h, w = features.shape[-2:] 55 | h0 = torch.randint(0, h - tile_size[0], (num_ref,), device=features.device) 56 | w0 = torch.randint(0, w - tile_size[1], (num_ref,), device=features.device) 57 | grid = torch.meshgrid(torch.arange(tile_size[0], device=features.device), 58 | torch.arange(tile_size[1], device=features.device), indexing='ij') 59 | 60 | x_ind = (grid[0].reshape(-1)[None] + h0[:, None]).view(-1) 61 | y_ind = (grid[1].reshape(-1)[None] + w0[:, None]).view(-1) 62 | refs = features[..., x_ind, y_ind].reshape(*features.shape[:-2], num_ref, -1) 63 | refs = torch.sort(refs[0].permute(1, 0, 2), dim=-1).values # num_ref x C x tile**2 64 | assert num_ref == 1 # Current public code only allows 1 (random) reference 65 | return refs[0] 66 | 67 | if reference_type == 'median': 68 | return reference_median 69 | if reference_type == 'random': 70 | return reference_random 71 | else: 72 | raise ValueError() 73 | 74 | 75 | class ScSWW: 76 | def __init__(self, tile_size, chunk_size=8, sigma=6.0, reference_selection='median'): 77 | self.tile_size = tuple(tile_size) 78 | self.chunk_size = chunk_size 79 | self.sigma = sigma 80 | self.ref_selection = get_reference(reference_selection, self.tile_size) 81 | self.gaussian_kernel = None 82 | 83 | def generate_all_sets(self, features): 84 | b, c, h, w = features.shape 85 | padded = reflect_pad(features, self.tile_size[0]) 86 | unf = torch.nn.functional.unfold(padded, (self.tile_size[0], padded.shape[-1]), stride=(1, 1)) 87 | unf = unf[0].T.reshape(h, c, self.tile_size[0], padded.shape[-1]) 88 | for i in range(0, len(unf), self.chunk_size): 89 | chunk = unf[i:i + self.chunk_size] 90 | unf_2 = torch.nn.functional.unfold(chunk, self.tile_size, stride=(1, 1)) 91 | unf_2 = unf_2.transpose(1, 2).reshape(chunk.shape[0], w, c, -1) 92 | yield unf_2 93 | 94 | def __call__(self, features): 95 | r_set = self.ref_selection(features) 96 | generator = self.generate_all_sets(features) 97 | if self.gaussian_kernel is None: 98 | self.gaussian_kernel = get_gaussian_kernel(features.device, self.tile_size, self.sigma).reshape(-1) 99 | parts = [] 100 | for f_set in generator: 101 | fvalues, ind = torch.sort(f_set, dim=-1) # h x W x C x tile**2 102 | vec_arr = r_set[None, None].expand_as(fvalues) 103 | weights = torch.gather(self.gaussian_kernel.expand_as(fvalues), dim=-1, index=ind) 104 | loss = (F.l1_loss(fvalues, vec_arr, reduction='none') * weights).sum(dim=-1).mean(dim=-1) 105 | del vec_arr 106 | parts.append(loss) 107 | p = torch.cat(parts, dim=0) 108 | return p 109 | 110 | 111 | class ScFCA(ScSWW): 112 | def __init__(self, tile_size, chunk_size=8, sigma_p=3.0, reference_selection='median', k_s=5, sigma_s=1.0): 113 | super(ScFCA, self).__init__(tile_size, chunk_size, sigma_p, reference_selection) 114 | assert tile_size[0] == tile_size[1] # Only implemented for square patches 115 | self.p_size = (tile_size[0] // 2) 116 | if sigma_s is not None: 117 | self.local_blur = torchvision.transforms.GaussianBlur(k_s, sigma=sigma_s) 118 | else: 119 | self.local_blur = None 120 | 121 | def __call__(self, features): 122 | r_set = self.ref_selection(features) 123 | wp = features.shape[-1] + 2 * self.p_size 124 | generator = self.generate_all_sets(features) 125 | if self.gaussian_kernel is None: 126 | self.gaussian_kernel = get_gaussian_kernel(features.device, self.tile_size, self.sigma).reshape(-1) 127 | parts = [] 128 | for f_set in generator: 129 | fvalues, ind = torch.sort(f_set, dim=-1) # h x W x C x tile**2 130 | vec_arr = r_set[None, None].expand_as(fvalues) 131 | diff = F.l1_loss(fvalues, vec_arr, reduction='none') 132 | diff_re = torch.gather(diff, dim=-1, index=torch.argsort(ind)).mean(dim=2, keepdim=True) # h x W x 1 x t**2 133 | if self.local_blur is not None: 134 | diff_re = self.local_blur(diff_re.view(-1, 1, *self.tile_size)).reshape(diff_re.shape) 135 | diff_re = diff_re * self.gaussian_kernel # h x W x 1 x tile**2 136 | diff_re = diff_re.permute(0, 2, 3, 1).reshape(f_set.shape[0], -1, features.shape[-1]) # h x 1*tile**2 x W 137 | c_fold = F.fold(diff_re, (self.tile_size[0], wp), kernel_size=self.tile_size) # h x C x tile x WP 138 | parts.append(c_fold) 139 | combined = torch.cat(parts, dim=0) # H x 1 x tile x WP 140 | folded = F.fold(combined.permute(1, 2, 3, 0).reshape(1, -1, features.shape[-2]), 141 | output_size=(wp, wp), kernel_size=(self.tile_size[0], wp)) 142 | folded = folded[0, 0, self.p_size:-self.p_size, self.p_size:-self.p_size] # Remove extra pad -> 1 x 1 x H x W 143 | 144 | return folded 145 | 146 | 147 | def sc_aota(features, tile_size, sigma=100.0, k=400): 148 | def one_to_many_dist(values, dist_f): 149 | distances = [] 150 | for i in range(values.shape[-2]): 151 | for j in range(values.shape[-1]): 152 | distances.append(dist_f(values[:, i, j], values)) 153 | return values.new_tensor(distances).reshape(values.shape[-2:]) 154 | 155 | local_moments = blur(features, kernel_size=tile_size[0], sigma=sigma)[0] # C x H x W 156 | 157 | def dist(one, many): 158 | vec_arr = one[:, None, None].expand_as(many) 159 | loss = ((vec_arr - many) ** 2).sum(dim=0).view(-1) 160 | del vec_arr 161 | loss = torch.topk(loss, k=k, largest=False).values.mean() 162 | return loss 163 | 164 | result = one_to_many_dist(local_moments, dist_f=dist) 165 | return blur(F.interpolate(result[None, None], (320, 320)), kernel_size=25, sigma=4.0)[0, 0] 166 | -------------------------------------------------------------------------------- /static/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/static/teaser.png --------------------------------------------------------------------------------