├── .gitignore
├── LICENSE
├── PyTorchSteerablePyramid
├── README.md
├── __init__.py
├── assets
│ ├── coeff.png
│ ├── lena.jpg
│ ├── patagonia.jpg
│ ├── runtime_benchmark.pdf
│ └── runtime_benchmark.png
├── examples
│ ├── benchmark_runtime.py
│ ├── compare_both.py
│ ├── example_lena.py
│ ├── example_numpy.py
│ └── example_pytorch.py
├── license.txt
├── requirements.txt
├── setup.py
├── steerable
│ ├── SCFpyr_NumPy.py
│ ├── SCFpyr_PyTorch.py
│ ├── __init__.py
│ ├── math_utils.py
│ └── utils.py
└── tests
│ ├── test_ifft.py
│ ├── test_torch_fft.py
│ ├── test_torch_fftshift.py
│ └── test_torch_fftshift2d.py
├── README.md
├── conf
├── base.yaml
├── dataset
│ ├── mvtec.yaml
│ └── wft.yaml
├── fe
│ ├── color.yaml
│ ├── ltem.yaml
│ ├── random.yaml
│ ├── steerable.yaml
│ ├── vgg.yaml
│ └── wide.yaml
└── sc
│ ├── aota.yaml
│ ├── fca.yaml
│ ├── hist.yaml
│ ├── moments.yaml
│ └── sww.yaml
├── data_loaders.py
├── datasets
└── .gitkeep
├── demo.py
├── environment.yml
├── evaluation
├── additional_util.py
├── evaluate_experiment.py
├── evaluate_multiple_experiments.py
├── generic_util.py
├── print_metrics.py
├── pro_curve_util.py
└── roc_curve_util.py
├── example
├── 000.png
└── 000_out.png
├── feature_extractors.py
├── main.py
├── op_utils.py
├── outputs
└── .gitkeep
├── sc_methods.py
└── static
└── teaser.png
/.gitignore:
--------------------------------------------------------------------------------
1 | outputs/**
2 | !outputs/.gitkeep
3 | datasets/**
4 | !datasets/.gitkeep
5 | .idea
6 | __pycache__
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2023 Andrei-Timotei Ardelean, Tim Weyrich
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/README.md:
--------------------------------------------------------------------------------
1 | # Complex Steerable Pyramid in PyTorch
2 |
3 | This is a PyTorch implementation of the Complex Steerable Pyramid described in [Portilla and Simoncelli (IJCV, 2000)](http://www.cns.nyu.edu/~lcv/pubs/makeAbs.php?loc=Portilla99).
4 |
5 | It uses PyTorch's efficient spectral decomposition layers `torch.fft` and `torch.ifft`. Just like a normal convolution layer, the complex steerable pyramid expects a batch of images of shape `[N,C,H,W]` with current support only for grayscale images (`C=1`). It returns a `list` structure containing the low-pass, high-pass and intermediate levels of the pyramid for each image in the batch (as `torch.Tensor`). Computing the steerable pyramid is significantly faster on the GPU as can be observed from the runtime benchmark below.
6 |
7 |
8 |
9 | ## Usage
10 |
11 | In addition to the PyTorch implementation defined in `SCFpyr_PyTorch` the original SciPy version is also included in `SCFpyr` for completeness and comparison. As the GPU implementation highly benefits from parallelization, the `cwt` and `power` methods expect signal batches of shape `[N,H,W]` containing a batch of `N` images of shape `HxW`.
12 |
13 | ```python
14 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
15 | import steerable.utils as utils
16 |
17 | # Load batch of images [N,1,H,W]
18 | im_batch_numpy = utils.load_image_batch(...)
19 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device)
20 |
21 | # Requires PyTorch with MKL when setting to 'cpu'
22 | device = torch.device('cuda:0')
23 |
24 | # Initialize Complex Steerbale Pyramid
25 | pyr = SCFpyr_PyTorch(height=5, nbands=4, scale_factor=2, device=device)
26 |
27 | # Decompose entire batch of images
28 | coeff = pyr.build(im_batch_torch)
29 |
30 | # Reconstruct batch of images again
31 | im_batch_reconstructed = pyr.reconstruct(coeff)
32 |
33 | # Visualization
34 | coeff_single = utils.extract_from_batch(coeff, 0)
35 | coeff_grid = utils.make_grid_coeff(coeff, normalize=True)
36 | cv2.imshow('Complex Steerable Pyramid', coeff_grid)
37 | cv2.waitKey(0)
38 | ```
39 |
40 | ## Benchmark
41 |
42 | Performing parallel the CSP decomposition on the GPU using PyTorch results in a significant speed-up. Increasing the batch size will give faster runtimes. The plot below shows a comprison between the `scipy` versus `torch` implementation as function of the batch size `N` and input signal length. These results were obtained on a powerful Linux desktop with NVIDIA Titan X GPU.
43 |
44 |
45 |
46 | ## Installation
47 |
48 | Clone and install:
49 |
50 | ```sh
51 | git clone https://github.com/tomrunia/PyTorchSteerablePyramid.git
52 | cd PyTorchSteerablePyramid
53 | pip install -r requirements.txt
54 | python setup.py install
55 | ```
56 |
57 | ## Requirements
58 |
59 | - Python 2.7 or 3.6 (other versions might also work)
60 | - Numpy (developed with 1.15.4)
61 | - Scipy (developed with 1.1.0)
62 | - PyTorch >= 0.4.0 (developed with 1.0.0; see note below)
63 |
64 | The steerable pyramid depends utilizes `torch.fft` and `torch.ifft` to perform operations in the spectral domain. At the moment, PyTorch only implements these operations for the GPU or with the MKL back-end on the CPU. Therefore, if you want to run the code on the CPU you might need to compile PyTorch from source with MKL enabled. Use `torch.backends.mkl.is_available()` to check if MKL is installed.
65 |
66 | ## References
67 |
68 | - [J. Portilla and E.P. Simoncelli, Complex Steerable Pyramid (IJCV, 2000)](http://www.cns.nyu.edu/pub/eero/portilla99-reprint.pdf)
69 | - [The Steerable Pyramid](http://www.cns.nyu.edu/~eero/steerpyr/)
70 | - [Official implementation: matPyrTools](http://www.cns.nyu.edu/~lcv/software.php)
71 | - [perceptual repository by Dzung Nguyen](https://github.com/andreydung/Steerable-filter)
72 |
73 | ## License
74 |
75 | MIT License
76 |
77 | Copyright (c) 2018 Tom Runia (tomrunia@gmail.com)
78 |
79 | Permission is hereby granted, free of charge, to any person obtaining a copy
80 | of this software and associated documentation files (the "Software"), to deal
81 | in the Software without restriction, including without limitation the rights
82 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
83 | copies of the Software, and to permit persons to whom the Software is
84 | furnished to do so, subject to the following conditions:
85 |
86 | The above copyright notice and this permission notice shall be included in all
87 | copies or substantial portions of the Software.
88 |
89 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
90 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
91 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
92 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
93 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
94 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
95 | SOFTWARE.
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 |
5 | sys.path.append(os.path.dirname(__file__))
6 |
7 | from PyTorchSteerablePyramid.steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
8 |
9 |
10 | def extract_steerable_features(image, o=4, m=1, scale_factor=2):
11 | # image of shape 1 x 3 x H x W
12 | assert image.shape[0] == 1
13 | im_batch_torch = image.permute(1, 0, 2, 3)
14 | pyr = SCFpyr_PyTorch(height=m + 2, nbands=o, scale_factor=scale_factor, device=image.device)
15 | coeff = pyr.build(im_batch_torch)
16 |
17 | h, w = im_batch_torch.shape[-2:]
18 | real_features = [torch.stack(c)[..., 0] for c in coeff[1:-1]]
19 | real_features = [torch.nn.functional.interpolate(f, size=(h, w), mode='bilinear') for f in real_features]
20 | real_features = torch.cat(real_features, dim=0).reshape(image.shape[0], -1, h, w)
21 |
22 | return real_features
23 |
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/assets/coeff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/PyTorchSteerablePyramid/assets/coeff.png
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/assets/lena.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/PyTorchSteerablePyramid/assets/lena.jpg
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/assets/patagonia.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/PyTorchSteerablePyramid/assets/patagonia.jpg
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/assets/runtime_benchmark.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/PyTorchSteerablePyramid/assets/runtime_benchmark.pdf
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/assets/runtime_benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/PyTorchSteerablePyramid/assets/runtime_benchmark.png
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/examples/benchmark_runtime.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-10
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import time
20 | import argparse
21 |
22 | import numpy as np
23 | import torch
24 | import matplotlib.pyplot as plt
25 |
26 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
27 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy
28 | import steerable.utils as utils
29 |
30 | import cortex.plot
31 | cortex.plot.init_plotting()
32 | colors = cortex.plot.nature_colors()
33 |
34 |
35 | if __name__ == "__main__":
36 |
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg')
39 | parser.add_argument('--batch_sizes', type=str, default='1,8,16,32,64,128,256')
40 | parser.add_argument('--image_sizes', type=str, default='128,256,512')
41 | parser.add_argument('--num_runs', type=int, default='5')
42 | parser.add_argument('--pyr_nlevels', type=int, default='5')
43 | parser.add_argument('--pyr_nbands', type=int, default='4')
44 | parser.add_argument('--pyr_scale_factor', type=int, default='2')
45 | parser.add_argument('--device', type=str, default='cuda:0')
46 | config = parser.parse_args()
47 |
48 | config.batch_sizes = list(map(int, config.batch_sizes.split(',')))
49 | config.image_sizes = list(map(int, config.image_sizes.split(',')))
50 |
51 | device = utils.get_device(config.device)
52 |
53 | ################################################################################
54 |
55 | pyr_numpy = SCFpyr_NumPy(
56 | height=config.pyr_nlevels,
57 | nbands=config.pyr_nbands,
58 | scale_factor=config.pyr_scale_factor,
59 | )
60 |
61 | pyr_torch = SCFpyr_PyTorch(
62 | height=config.pyr_nlevels,
63 | nbands=config.pyr_nbands,
64 | scale_factor=config.pyr_scale_factor,
65 | device=device
66 | )
67 |
68 | ############################################################################
69 | # Run Benchmark
70 |
71 | durations_numpy = np.zeros((len(config.batch_sizes), len(config.image_sizes), config.num_runs))
72 | durations_torch = np.zeros((len(config.batch_sizes), len(config.image_sizes), config.num_runs))
73 |
74 | for batch_idx, batch_size in enumerate(config.batch_sizes):
75 |
76 | for size_idx, im_size in enumerate(config.image_sizes):
77 |
78 | for run_idx in range(config.num_runs):
79 |
80 | im_batch_numpy = utils.load_image_batch(config.image_file, batch_size, im_size)
81 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device)
82 |
83 | # NumPy implementation
84 | start_time = time.time()
85 |
86 | for image in im_batch_numpy:
87 | coeff = pyr_numpy.build(image[0,])
88 |
89 | duration = time.time()-start_time
90 | durations_numpy[batch_idx,size_idx,run_idx] = duration
91 | print('BatchSize: {batch_size} | ImSize: {im_size} | NumPy Run {curr_run}/{num_run} | Duration: {duration:.3f} seconds.'.format(
92 | batch_size=batch_size,
93 | im_size=im_size,
94 | curr_run=run_idx+1,
95 | num_run=config.num_runs,
96 | duration=duration
97 | ))
98 |
99 | # PyTorch Implementation
100 | start_time = time.time()
101 |
102 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device)
103 | coeff = pyr_torch.build(im_batch_torch)
104 |
105 | duration = time.time()-start_time
106 | durations_torch[batch_idx,size_idx,run_idx] = duration
107 | print('BatchSize: {batch_size} | ImSize: {im_size} | Torch Run {curr_run}/{num_run} | Duration: {duration:.3f} seconds.'.format(
108 | batch_size=batch_size,
109 | im_size=im_size,
110 | curr_run=run_idx+1,
111 | num_run=config.num_runs,
112 | duration=duration
113 | ))
114 |
115 | np.save('./assets/durations_numpy.npy', durations_numpy)
116 | np.save('./assets/durations_torch.npy', durations_torch)
117 |
118 | ################################################################################
119 | # Plotting
120 |
121 | durations_numpy = np.load('./assets/durations_numpy.npy')
122 | durations_torch = np.load('./assets/durations_torch.npy')
123 |
124 | for i, num_examples in enumerate(config.batch_sizes):
125 | if num_examples == 8: continue
126 | avg_durations_numpy = np.mean(durations_numpy[i,:], -1) / num_examples
127 | avg_durations_torch = np.mean(durations_torch[i,:], -1) / num_examples
128 | plt.plot(config.image_sizes, avg_durations_numpy, marker='o', linestyle='-', lw=1.4, color=colors[i], label='numpy (N = {})'.format(num_examples))
129 | plt.plot(config.image_sizes, avg_durations_torch, marker='d', linestyle='--', lw=1.4, color=colors[i], label='pytorch (N = {})'.format(num_examples))
130 |
131 | plt.title('Runtime Benchmark ({} levels, {} bands, average {numruns} runs)'.format(config.pyr_nlevels, config.pyr_nbands, numruns=config.num_runs))
132 | plt.xlabel('Image Size (px)')
133 | plt.ylabel('Time per Example (s)')
134 | plt.xlim((100, 550))
135 | plt.ylim((-0.01, 0.2))
136 | plt.xticks(config.image_sizes)
137 | plt.legend(ncol=2, loc='top left')
138 | plt.tight_layout()
139 |
140 | plt.savefig('./assets/runtime_benchmark.png', dpi=600)
141 | plt.savefig('./assets/runtime_benchmark.pdf')
142 | plt.show()
143 |
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/examples/compare_both.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import torch
21 | import cv2
22 |
23 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy
24 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
25 | import steerable.utils as utils
26 |
27 | ################################################################################
28 | ################################################################################
29 | # Common
30 |
31 | image_file = './assets/lena.jpg'
32 | image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
33 | image = cv2.resize(image, (200,200))
34 |
35 | # Number of pyramid levels
36 | pyr_height = 5
37 |
38 | # Number of orientation bands
39 | pyr_nbands = 4
40 |
41 | # Tolerance for error checking
42 | tolerance = 1e-3
43 |
44 | ################################################################################
45 | # NumPy
46 |
47 | pyr_numpy = SCFpyr_NumPy(pyr_height, pyr_nbands, scale_factor=2)
48 | coeff_numpy = pyr_numpy.build(image)
49 | reconstruction_numpy = pyr_numpy.reconstruct(coeff_numpy)
50 | reconstruction_numpy = reconstruction_numpy.astype(np.uint8)
51 |
52 | print('#'*60)
53 |
54 | ################################################################################
55 | # PyTorch
56 |
57 | device = torch.device('cuda:0')
58 |
59 | im_batch = torch.from_numpy(image[None,None,:,:])
60 | im_batch = im_batch.to(device).float()
61 |
62 | pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device)
63 | coeff_torch = pyr_torch.build(im_batch)
64 | reconstruction_torch = pyr_torch.reconstruct(coeff_torch)
65 | reconstruction_torch = reconstruction_torch.cpu().numpy()[0,]
66 |
67 | # Extract first example from the batch and move to CPU
68 | coeff_torch = utils.extract_from_batch(coeff_torch, 0)
69 |
70 | ################################################################################
71 | # Check correctness
72 |
73 | print('#'*60)
74 | assert len(coeff_numpy) == len(coeff_torch)
75 |
76 | for level, _ in enumerate(coeff_numpy):
77 |
78 | print('Pyramid Level {level}'.format(level=level))
79 | coeff_level_numpy = coeff_numpy[level]
80 | coeff_level_torch = coeff_torch[level]
81 |
82 | assert isinstance(coeff_level_torch, type(coeff_level_numpy))
83 |
84 | if isinstance(coeff_level_numpy, np.ndarray):
85 |
86 | # Low- or High-Pass
87 | print(' NumPy. min = {min:.3f}, max = {max:.3f},'
88 | ' mean = {mean:.3f}, std = {std:.3f}'.format(
89 | min=np.min(coeff_level_numpy), max=np.max(coeff_level_numpy),
90 | mean=np.mean(coeff_level_numpy), std=np.std(coeff_level_numpy)))
91 |
92 | print(' PyTorch. min = {min:.3f}, max = {max:.3f},'
93 | ' mean = {mean:.3f}, std = {std:.3f}'.format(
94 | min=np.min(coeff_level_torch), max=np.max(coeff_level_torch),
95 | mean=np.mean(coeff_level_torch), std=np.std(coeff_level_torch)))
96 |
97 | # Check numerical correctness
98 | assert np.allclose(coeff_level_numpy, coeff_level_torch, atol=tolerance)
99 |
100 | elif isinstance(coeff_level_numpy, list):
101 |
102 | # Intermediate bands
103 | for band, _ in enumerate(coeff_level_numpy):
104 |
105 | band_numpy = coeff_level_numpy[band]
106 | band_torch = coeff_level_torch[band]
107 |
108 | print(' Orientation Band {}'.format(band))
109 | print(' NumPy. min = {min:.3f}, max = {max:.3f},'
110 | ' mean = {mean:.3f}, std = {std:.3f}'.format(
111 | min=np.min(band_numpy), max=np.max(band_numpy),
112 | mean=np.mean(band_numpy), std=np.std(band_numpy)))
113 |
114 | print(' PyTorch. min = {min:.3f}, max = {max:.3f},'
115 | ' mean = {mean:.3f}, std = {std:.3f}'.format(
116 | min=np.min(band_torch), max=np.max(band_torch),
117 | mean=np.mean(band_torch), std=np.std(band_torch)))
118 |
119 | # Check numerical correctness
120 | assert np.allclose(band_numpy, band_torch, atol=tolerance)
121 |
122 | ################################################################################
123 | # Visualize
124 |
125 | coeff_grid_numpy = utils.make_grid_coeff(coeff_numpy, normalize=False)
126 | coeff_grid_torch = utils.make_grid_coeff(coeff_torch, normalize=False)
127 |
128 | cv2.imshow('image', image)
129 | cv2.imshow('coeff numpy', np.ascontiguousarray(coeff_grid_numpy))
130 | cv2.imshow('coeff torch', np.ascontiguousarray(coeff_grid_torch))
131 | cv2.imshow('reconstruction numpy', reconstruction_numpy.astype(np.uint8))
132 | cv2.imshow('reconstruction torch', reconstruction_torch.astype(np.uint8))
133 |
134 | cv2.waitKey(0)
135 |
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/examples/example_lena.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import argparse
20 | import cv2
21 |
22 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy
23 | import steerable.utils as utils
24 |
25 | ################################################################################
26 | ################################################################################
27 |
28 | if __name__ == "__main__":
29 |
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg')
32 | parser.add_argument('--batch_size', type=int, default='1')
33 | parser.add_argument('--image_size', type=int, default='200')
34 | parser.add_argument('--pyr_nlevels', type=int, default='5')
35 | parser.add_argument('--pyr_nbands', type=int, default='4')
36 | parser.add_argument('--pyr_scale_factor', type=int, default='2')
37 | parser.add_argument('--visualize', type=bool, default=True)
38 | config = parser.parse_args()
39 |
40 | ############################################################################
41 | # Build the complex steerable pyramid
42 |
43 | pyr = SCFpyr_NumPy(
44 | height=config.pyr_nlevels,
45 | nbands=config.pyr_nbands,
46 | scale_factor=config.pyr_scale_factor,
47 | )
48 |
49 | ############################################################################
50 | # Create a batch and feed-forward
51 |
52 | image = cv2.imread('./assets/lena.jpg', cv2.IMREAD_GRAYSCALE)
53 | image = cv2.resize(image, dsize=(200,200))
54 |
55 | coeff = pyr.build(image)
56 |
57 | grid = utils.make_grid_coeff(coeff, normalize=True)
58 |
59 | cv2.imwrite('./assets/coeff.png', grid)
60 | cv2.imshow('image', image)
61 | cv2.imshow('coeff', grid)
62 | cv2.waitKey(0)
63 |
64 |
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/examples/example_numpy.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import time
20 | import argparse
21 | import numpy as np
22 |
23 | from steerable.SCFpyr_NumPy import SCFpyr_NumPy
24 | import steerable.utils as utils
25 |
26 | ################################################################################
27 | ################################################################################
28 |
29 | if __name__ == "__main__":
30 |
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg')
33 | parser.add_argument('--batch_size', type=int, default='1')
34 | parser.add_argument('--image_size', type=int, default='200')
35 | parser.add_argument('--pyr_nlevels', type=int, default='5')
36 | parser.add_argument('--pyr_nbands', type=int, default='4')
37 | parser.add_argument('--pyr_scale_factor', type=int, default='2')
38 | parser.add_argument('--visualize', type=bool, default=True)
39 | config = parser.parse_args()
40 |
41 | ############################################################################
42 | # Build the complex steerable pyramid
43 |
44 | pyr = SCFpyr_NumPy(
45 | height=config.pyr_nlevels,
46 | nbands=config.pyr_nbands,
47 | scale_factor=config.pyr_scale_factor,
48 | )
49 |
50 | ############################################################################
51 | # Create a batch and feed-forward
52 |
53 | start_time = time.time()
54 |
55 | im_batch_numpy = utils.load_image_batch(config.image_file, config.batch_size, config.image_size)
56 | im_batch_numpy = im_batch_numpy.squeeze(1) # no channel dim for NumPy
57 |
58 | # Compute Steerable Pyramid
59 | start_time = time.time()
60 | for image in im_batch_numpy:
61 | coeff = pyr.build(image)
62 |
63 | duration = time.time()-start_time
64 | print('Finishing decomposing {batch_size} images in {duration:.1f} seconds.'.format(
65 | batch_size=config.batch_size,
66 | duration=duration
67 | ))
68 |
69 | ############################################################################
70 | # Visualization
71 |
72 | if config.visualize:
73 | import cv2
74 | coeff_grid = utils.make_grid_coeff(coeff, normalize=True)
75 | cv2.imshow('image', (im_batch_numpy[0,]*255.).astype(np.uint8))
76 | cv2.imshow('coeff', coeff_grid)
77 | cv2.waitKey(0)
78 |
79 |
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/examples/example_pytorch.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import argparse
20 | import time
21 | import numpy as np
22 | import torch
23 |
24 | from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
25 | import steerable.utils as utils
26 |
27 | ################################################################################
28 | ################################################################################
29 |
30 | if __name__ == "__main__":
31 |
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg')
34 | parser.add_argument('--batch_size', type=int, default='32')
35 | parser.add_argument('--image_size', type=int, default='200')
36 | parser.add_argument('--pyr_nlevels', type=int, default='5')
37 | parser.add_argument('--pyr_nbands', type=int, default='4')
38 | parser.add_argument('--pyr_scale_factor', type=int, default='2')
39 | parser.add_argument('--device', type=str, default='cuda:0')
40 | parser.add_argument('--visualize', type=bool, default=True)
41 | config = parser.parse_args()
42 |
43 | device = utils.get_device(config.device)
44 |
45 | ############################################################################
46 | # Build the complex steerable pyramid
47 |
48 | pyr = SCFpyr_PyTorch(
49 | height=config.pyr_nlevels,
50 | nbands=config.pyr_nbands,
51 | scale_factor=config.pyr_scale_factor,
52 | device=device
53 | )
54 |
55 | ############################################################################
56 | # Create a batch and feed-forward
57 |
58 | start_time = time.time()
59 |
60 | # Load Batch
61 | im_batch_numpy = utils.load_image_batch(config.image_file, config.batch_size, config.image_size)
62 | im_batch_torch = torch.from_numpy(im_batch_numpy).to(device)
63 |
64 | # Compute Steerable Pyramid
65 | coeff = pyr.build(im_batch_torch)
66 |
67 | duration = time.time()-start_time
68 | print('Finishing decomposing {batch_size} images in {duration:.1f} seconds.'.format(
69 | batch_size=config.batch_size,
70 | duration=duration
71 | ))
72 |
73 | ############################################################################
74 | # Visualization
75 |
76 | # Just extract a single example from the batch
77 | # Also moves the example to CPU and NumPy
78 | coeff = utils.extract_from_batch(coeff, 0)
79 |
80 | if config.visualize:
81 | import cv2
82 | coeff_grid = utils.make_grid_coeff(coeff, normalize=True)
83 | cv2.imshow('image', (im_batch_numpy[0,0,]*255.).astype(np.uint8))
84 | cv2.imshow('coeff', coeff_grid)
85 | cv2.waitKey(0)
86 |
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/license.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Tom Runia (tomrunia@gmail.com)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/requirements.txt:
--------------------------------------------------------------------------------
1 | six
2 | numpy
3 | scipy
4 | torch >= 0.4.0
5 | matplotlib
6 | pillow
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/setup.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from setuptools import setup
3 |
4 | setup(
5 | name='steerable_pytorch',
6 | version='0.1',
7 | author='Tom Runia',
8 | author_email='tomrunia@gmail.com',
9 | url='https://github.com/tomrunia/PyTorchSteerablePyramid',
10 | description='Complex Steerable Pyramids in PyTorch',
11 | long_description='Fast CPU/CUDA implementation of the Complex Steerable Pyramid in PyTorch.',
12 | license='MIT',
13 | packages=['steerable_pytorch'],
14 | scripts=[]
15 | )
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/steerable/SCFpyr_NumPy.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | from scipy.misc import factorial
21 |
22 | import steerable.math_utils as math_utils
23 | pointOp = math_utils.pointOp
24 |
25 | ################################################################################
26 |
27 | class SCFpyr_NumPy():
28 | '''
29 | This is a modified version of buildSFpyr, that constructs a
30 | complex-valued steerable pyramid using Hilbert-transform pairs
31 | of filters. Note that the imaginary parts will *not* be steerable.
32 |
33 | Description of this transform appears in: Portilla & Simoncelli,
34 | International Journal of Computer Vision, 40(1):49-71, Oct 2000.
35 | Further information: http://www.cns.nyu.edu/~eero/STEERPYR/
36 |
37 | Modified code from the perceptual repository:
38 | https://github.com/andreydung/Steerable-filter
39 |
40 | This code looks very similar to the original Matlab code:
41 | https://github.com/LabForComputationalVision/matlabPyrTools/blob/master/buildSCFpyr.m
42 |
43 | Also looks very similar to the original Python code presented here:
44 | https://github.com/LabForComputationalVision/pyPyrTools/blob/master/pyPyrTools/SCFpyr.py
45 |
46 | '''
47 |
48 | def __init__(self, height=5, nbands=4, scale_factor=2):
49 | self.nbands = nbands # number of orientation bands
50 | self.height = height # including low-pass and high-pass
51 | self.scale_factor = scale_factor
52 |
53 | # Cache constants
54 | self.lutsize = 1024
55 | self.Xcosn = np.pi * np.array(range(-(2*self.lutsize+1), (self.lutsize+2)))/self.lutsize
56 | self.alpha = (self.Xcosn + np.pi) % (2*np.pi) - np.pi
57 |
58 |
59 | ################################################################################
60 | # Construction of Steerable Pyramid
61 |
62 | def build(self, im):
63 | ''' Decomposes an image into it's complex steerable pyramid.
64 | The pyramid typically has ~4 levels and 4-8 orientations.
65 |
66 | Args:
67 | im_batch (np.ndarray): single image [H,W]
68 |
69 | Returns:
70 | pyramid: list containing np.ndarray objects storing the pyramid
71 | '''
72 |
73 | assert len(im.shape) == 2, 'Input im must be grayscale'
74 | height, width = im.shape
75 |
76 | # Check whether image size is sufficient for number of levels
77 | if self.height > int(np.floor(np.log2(min(width, height))) - 2):
78 | raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height))
79 |
80 | # Prepare a grid
81 | log_rad, angle = math_utils.prepare_grid(height, width)
82 |
83 | # Radial transition function (a raised cosine in log-frequency):
84 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
85 | Yrcos = np.sqrt(Yrcos)
86 |
87 | YIrcos = np.sqrt(1 - Yrcos**2)
88 | lo0mask = pointOp(log_rad, YIrcos, Xrcos)
89 | hi0mask = pointOp(log_rad, Yrcos, Xrcos)
90 |
91 | # Shift the zero-frequency component to the center of the spectrum.
92 | imdft = np.fft.fftshift(np.fft.fft2(im))
93 |
94 | # Low-pass
95 | lo0dft = imdft * lo0mask
96 |
97 | # Recursive build the steerable pyramid
98 | coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height-1)
99 |
100 | # High-pass
101 | hi0dft = imdft * hi0mask
102 | hi0 = np.fft.ifft2(np.fft.ifftshift(hi0dft))
103 | coeff.insert(0, hi0.real)
104 | return coeff
105 |
106 |
107 | def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
108 |
109 | if height <= 1:
110 |
111 | # Low-pass
112 | lo0 = np.fft.ifftshift(lodft)
113 | lo0 = np.fft.ifft2(lo0)
114 | coeff = [lo0.real]
115 |
116 | else:
117 |
118 | Xrcos = Xrcos - np.log2(self.scale_factor)
119 |
120 | ####################################################################
121 | ####################### Orientation bandpass #######################
122 | ####################################################################
123 |
124 | himask = pointOp(log_rad, Yrcos, Xrcos)
125 |
126 | order = self.nbands - 1
127 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order))
128 | Ycosn = 2*np.sqrt(const) * np.power(np.cos(self.Xcosn), order) * (np.abs(self.alpha) < np.pi/2)
129 |
130 | # Loop through all orientation bands
131 | orientations = []
132 | for b in range(self.nbands):
133 | anglemask = pointOp(angle, Ycosn, self.Xcosn + np.pi*b/self.nbands)
134 | banddft = np.power(np.complex(0, -1), self.nbands - 1) * lodft * anglemask * himask
135 | band = np.fft.ifft2(np.fft.ifftshift(banddft))
136 | orientations.append(band)
137 |
138 | ####################################################################
139 | ######################## Subsample lowpass #########################
140 | ####################################################################
141 |
142 | dims = np.array(lodft.shape)
143 |
144 | # Both are tuples of size 2
145 | low_ind_start = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(int)
146 | low_ind_end = (low_ind_start + np.ceil((dims-0.5)/2)).astype(int)
147 |
148 | # Selection
149 | log_rad = log_rad[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]]
150 | angle = angle[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]]
151 | lodft = lodft[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]]
152 |
153 | # Subsampling in frequency domain
154 | YIrcos = np.abs(np.sqrt(1 - Yrcos**2))
155 | lomask = pointOp(log_rad, YIrcos, Xrcos)
156 | lodft = lomask * lodft
157 |
158 | ####################################################################
159 | ####################### Recursion next level #######################
160 | ####################################################################
161 |
162 | coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos, height-1)
163 | coeff.insert(0, orientations)
164 |
165 | return coeff
166 |
167 | ############################################################################
168 | ########################### RECONSTRUCTION #################################
169 | ############################################################################
170 |
171 | def reconstruct(self, coeff):
172 |
173 | if self.nbands != len(coeff[1]):
174 | raise Exception("Unmatched number of orientations")
175 |
176 | height, width = coeff[0].shape
177 | log_rad, angle = math_utils.prepare_grid(height, width)
178 |
179 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
180 | Yrcos = np.sqrt(Yrcos)
181 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2))
182 |
183 | lo0mask = pointOp(log_rad, YIrcos, Xrcos)
184 | hi0mask = pointOp(log_rad, Yrcos, Xrcos)
185 |
186 | tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle)
187 |
188 | hidft = np.fft.fftshift(np.fft.fft2(coeff[0]))
189 | outdft = tempdft * lo0mask + hidft * hi0mask
190 |
191 | reconstruction = np.fft.ifftshift(outdft)
192 | reconstruction = np.fft.ifft2(reconstruction)
193 | reconstruction = reconstruction.real
194 |
195 | return reconstruction
196 |
197 | def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
198 |
199 | if len(coeff) == 1:
200 | dft = np.fft.fft2(coeff[0])
201 | dft = np.fft.fftshift(dft)
202 | return dft
203 |
204 | Xrcos = Xrcos - np.log2(self.scale_factor)
205 |
206 | ####################################################################
207 | ####################### Orientation Residue ########################
208 | ####################################################################
209 |
210 | himask = pointOp(log_rad, Yrcos, Xrcos)
211 |
212 | lutsize = 1024
213 | Xcosn = np.pi * np.array(range(-(2*lutsize+1), (lutsize+2)))/lutsize
214 | order = self.nbands - 1
215 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order))
216 | Ycosn = np.sqrt(const) * np.power(np.cos(Xcosn), order)
217 |
218 | orientdft = np.zeros(coeff[0][0].shape)
219 |
220 | for b in range(self.nbands):
221 | anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands)
222 | banddft = np.fft.fft2(coeff[0][b])
223 | banddft = np.fft.fftshift(banddft)
224 | orientdft = orientdft + np.power(np.complex(0, 1), order) * banddft * anglemask * himask
225 |
226 | ####################################################################
227 | ########## Lowpass component are upsampled and convoluted ##########
228 | ####################################################################
229 |
230 | dims = np.array(coeff[0][0].shape)
231 |
232 | lostart = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(np.int32)
233 | loend = lostart + np.ceil((dims-0.5)/2).astype(np.int32)
234 |
235 | nlog_rad = log_rad[lostart[0]:loend[0], lostart[1]:loend[1]]
236 | nangle = angle[lostart[0]:loend[0], lostart[1]:loend[1]]
237 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2))
238 | lomask = pointOp(nlog_rad, YIrcos, Xrcos)
239 |
240 | ################################################################################
241 |
242 | # Recursive call for image reconstruction
243 | nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle)
244 |
245 | resdft = np.zeros(dims, 'complex')
246 | resdft[lostart[0]:loend[0], lostart[1]:loend[1]] = nresdft * lomask
247 |
248 | return resdft + orientdft
249 |
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/steerable/SCFpyr_PyTorch.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import torch
21 | from scipy.special import factorial
22 |
23 | import steerable.math_utils as math_utils
24 | pointOp = math_utils.pointOp
25 |
26 | ################################################################################
27 | ################################################################################
28 |
29 |
30 | class SCFpyr_PyTorch(object):
31 | '''
32 | This is a modified version of buildSFpyr, that constructs a
33 | complex-valued steerable pyramid using Hilbert-transform pairs
34 | of filters. Note that the imaginary parts will *not* be steerable.
35 |
36 | Description of this transform appears in: Portilla & Simoncelli,
37 | International Journal of Computer Vision, 40(1):49-71, Oct 2000.
38 | Further information: http://www.cns.nyu.edu/~eero/STEERPYR/
39 |
40 | Modified code from the perceptual repository:
41 | https://github.com/andreydung/Steerable-filter
42 |
43 | This code looks very similar to the original Matlab code:
44 | https://github.com/LabForComputationalVision/matlabPyrTools/blob/master/buildSCFpyr.m
45 |
46 | Also looks very similar to the original Python code presented here:
47 | https://github.com/LabForComputationalVision/pyPyrTools/blob/master/pyPyrTools/SCFpyr.py
48 |
49 | '''
50 |
51 | def __init__(self, height=5, nbands=4, scale_factor=2, device=None):
52 | self.height = height # including low-pass and high-pass
53 | self.nbands = nbands # number of orientation bands
54 | self.scale_factor = scale_factor
55 | self.device = torch.device('cpu') if device is None else device
56 |
57 | # Cache constants
58 | self.lutsize = 1024
59 | self.Xcosn = np.pi * np.array(range(-(2*self.lutsize+1), (self.lutsize+2)))/self.lutsize
60 | self.alpha = (self.Xcosn + np.pi) % (2*np.pi) - np.pi
61 | self.complex_fact_construct = np.power(np.complex(0, -1), self.nbands-1)
62 | self.complex_fact_reconstruct = np.power(np.complex(0, 1), self.nbands-1)
63 |
64 | ################################################################################
65 | # Construction of Steerable Pyramid
66 |
67 | def build(self, im_batch):
68 | ''' Decomposes a batch of images into a complex steerable pyramid.
69 | The pyramid typically has ~4 levels and 4-8 orientations.
70 |
71 | Args:
72 | im_batch (torch.Tensor): Batch of images of shape [N,C,H,W]
73 |
74 | Returns:
75 | pyramid: list containing torch.Tensor objects storing the pyramid
76 | '''
77 |
78 | assert im_batch.device == self.device, 'Devices invalid (pyr = {}, batch = {})'.format(self.device, im_batch.device)
79 | assert im_batch.dtype == torch.float32, 'Image batch must be torch.float32'
80 | assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]'
81 | assert im_batch.shape[1] == 1, 'Second dimension must be 1 encoding grayscale image'
82 |
83 | im_batch = im_batch.squeeze(1) # flatten channels dim
84 | height, width = im_batch.shape[1], im_batch.shape[2]
85 |
86 | # Check whether image size is sufficient for number of levels
87 | if self.height > int(np.floor(np.log2(min(width, height))) - 2):
88 | raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height))
89 |
90 | # Prepare a grid
91 | log_rad, angle = math_utils.prepare_grid(height, width)
92 |
93 | # Radial transition function (a raised cosine in log-frequency):
94 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
95 | Yrcos = np.sqrt(Yrcos)
96 |
97 | YIrcos = np.sqrt(1 - Yrcos**2)
98 |
99 | lo0mask = pointOp(log_rad, YIrcos, Xrcos)
100 | hi0mask = pointOp(log_rad, Yrcos, Xrcos)
101 |
102 | # Note that we expand dims to support broadcasting later
103 | lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device)
104 | hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device)
105 |
106 | # Fourier transform (2D) and shifting
107 | batch_dft = torch.view_as_real(torch.fft.fft2(im_batch))
108 | batch_dft = math_utils.batch_fftshift2d(batch_dft)
109 |
110 | # Low-pass
111 | lo0dft = batch_dft * lo0mask
112 |
113 | # Start recursively building the pyramids
114 | coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height-1)
115 |
116 | # High-pass
117 | hi0dft = batch_dft * hi0mask
118 | hi0 = math_utils.batch_ifftshift2d(hi0dft)
119 | # hi0 = torch.ifft(hi0, signal_ndim=2)
120 | hi0 = torch.view_as_real(torch.fft.ifft2(torch.view_as_complex(hi0)))
121 | hi0_real = torch.unbind(hi0, -1)[0]
122 | coeff.insert(0, hi0_real)
123 | return coeff
124 |
125 | def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
126 |
127 | if height <= 1:
128 |
129 | # Low-pass
130 | lo0 = math_utils.batch_ifftshift2d(lodft)
131 | # lo0 = torch.ifft(lo0, signal_ndim=2)
132 | lo0 = torch.view_as_real(torch.fft.ifft2(torch.view_as_complex(lo0)))
133 | lo0_real = torch.unbind(lo0, -1)[0]
134 | coeff = [lo0_real]
135 |
136 | else:
137 |
138 | Xrcos = Xrcos - np.log2(self.scale_factor)
139 |
140 | ####################################################################
141 | ####################### Orientation bandpass #######################
142 | ####################################################################
143 |
144 | himask = pointOp(log_rad, Yrcos, Xrcos)
145 | himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device)
146 |
147 | order = self.nbands - 1
148 | const = np.power(2.0, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order))
149 | Ycosn = 2*np.sqrt(const) * np.power(np.cos(self.Xcosn), order) * (np.abs(self.alpha) < np.pi/2) # [n,]
150 |
151 | # Loop through all orientation bands
152 | orientations = []
153 | for b in range(self.nbands):
154 |
155 | anglemask = pointOp(angle, Ycosn, self.Xcosn + np.pi*b/self.nbands)
156 | anglemask = anglemask[None,:,:,None] # for broadcasting
157 | anglemask = torch.from_numpy(anglemask).float().to(self.device)
158 |
159 | # Bandpass filtering
160 | banddft = lodft * anglemask * himask
161 |
162 | # Now multiply with complex number
163 | # (x+yi)(u+vi) = (xu-yv) + (xv+yu)i
164 | banddft = torch.unbind(banddft, -1)
165 | banddft_real = self.complex_fact_construct.real*banddft[0] - self.complex_fact_construct.imag*banddft[1]
166 | banddft_imag = self.complex_fact_construct.real*banddft[1] + self.complex_fact_construct.imag*banddft[0]
167 | banddft = torch.stack((banddft_real, banddft_imag), -1)
168 |
169 | band = math_utils.batch_ifftshift2d(banddft)
170 | # band = torch.ifft(band, signal_ndim=2)
171 | band = torch.view_as_real(torch.fft.ifft2(torch.view_as_complex(band)))
172 | orientations.append(band)
173 |
174 | ####################################################################
175 | ######################## Subsample lowpass #########################
176 | ####################################################################
177 |
178 | # Don't consider batch_size and imag/real dim
179 | dims = np.array(lodft.shape[1:3])
180 |
181 | # Both are tuples of size 2
182 | low_ind_start = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(int)
183 | low_ind_end = (low_ind_start + np.ceil((dims-0.5)/2)).astype(int)
184 |
185 | # Subsampling indices
186 | log_rad = log_rad[low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]]
187 | angle = angle[low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]]
188 |
189 | # Actual subsampling
190 | lodft = lodft[:,low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1],:]
191 |
192 | # Filtering
193 | YIrcos = np.abs(np.sqrt(1 - Yrcos**2))
194 | lomask = pointOp(log_rad, YIrcos, Xrcos)
195 | lomask = torch.from_numpy(lomask[None,:,:,None]).float()
196 | lomask = lomask.to(self.device)
197 |
198 | # Convolution in spatial domain
199 | lodft = lomask * lodft
200 |
201 | ####################################################################
202 | ####################### Recursion next level #######################
203 | ####################################################################
204 |
205 | coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos, height-1)
206 | coeff.insert(0, orientations)
207 |
208 | return coeff
209 |
210 | ############################################################################
211 | ########################### RECONSTRUCTION #################################
212 | ############################################################################
213 |
214 | def reconstruct(self, coeff):
215 |
216 | if self.nbands != len(coeff[1]):
217 | raise Exception("Unmatched number of orientations")
218 |
219 | height, width = coeff[0].shape[2], coeff[0].shape[1]
220 | log_rad, angle = math_utils.prepare_grid(height, width)
221 |
222 | Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
223 | Yrcos = np.sqrt(Yrcos)
224 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2))
225 |
226 | lo0mask = pointOp(log_rad, YIrcos, Xrcos)
227 | hi0mask = pointOp(log_rad, Yrcos, Xrcos)
228 |
229 | # Note that we expand dims to support broadcasting later
230 | lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device)
231 | hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device)
232 |
233 | # Start recursive reconstruction
234 | tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle)
235 |
236 | hidft = torch.view_as_real(torch.fft.fft2(coeff[0].contiguous()))
237 | hidft = math_utils.batch_fftshift2d(hidft)
238 |
239 | outdft = tempdft * lo0mask + hidft * hi0mask
240 |
241 | reconstruction = math_utils.batch_ifftshift2d(outdft)
242 | # reconstruction = torch.ifft(reconstruction, signal_ndim=2)
243 | reconstruction = torch.view_as_real(torch.fft.ifft2(torch.view_as_complex(reconstruction)))
244 | reconstruction = torch.unbind(reconstruction, -1)[0] # real
245 |
246 | return reconstruction
247 |
248 | def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
249 |
250 | if len(coeff) == 1:
251 | # dft = torch.rfft(coeff[0], signal_ndim=2, onesided=False)
252 | dft = torch.view_as_real(torch.fft.fft2(coeff[0].contiguous()))
253 | dft = math_utils.batch_fftshift2d(dft)
254 | return dft
255 |
256 | Xrcos = Xrcos - np.log2(self.scale_factor)
257 |
258 | ####################################################################
259 | ####################### Orientation Residue ########################
260 | ####################################################################
261 |
262 | himask = pointOp(log_rad, Yrcos, Xrcos)
263 | himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device)
264 |
265 | lutsize = 1024
266 | Xcosn = np.pi * np.array(range(-(2*lutsize+1), (lutsize+2)))/lutsize
267 | order = self.nbands - 1
268 | const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order))
269 | Ycosn = np.sqrt(const) * np.power(np.cos(Xcosn), order)
270 |
271 | orientdft = torch.zeros_like(coeff[0][0])
272 | for b in range(self.nbands):
273 |
274 | anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands)
275 | anglemask = anglemask[None,:,:,None] # for broadcasting
276 | anglemask = torch.from_numpy(anglemask).float().to(self.device)
277 |
278 | # banddft = torch.fft(coeff[0][b], signal_ndim=2)
279 | banddft = torch.view_as_real(torch.fft.fft2(torch.view_as_complex(coeff[0][b]), norm="backward"))
280 | banddft = math_utils.batch_fftshift2d(banddft)
281 | banddft = banddft * anglemask * himask
282 | banddft = torch.unbind(banddft, -1)
283 | banddft_real = self.complex_fact_reconstruct.real*banddft[0] - self.complex_fact_reconstruct.imag*banddft[1]
284 | banddft_imag = self.complex_fact_reconstruct.real*banddft[1] + self.complex_fact_reconstruct.imag*banddft[0]
285 | banddft = torch.stack((banddft_real, banddft_imag), -1)
286 |
287 | orientdft = orientdft + banddft
288 |
289 | ####################################################################
290 | ########## Lowpass component are upsampled and convoluted ##########
291 | ####################################################################
292 |
293 | dims = np.array(coeff[0][0].shape[1:3])
294 |
295 | lostart = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(np.int32)
296 | loend = lostart + np.ceil((dims-0.5)/2).astype(np.int32)
297 |
298 | nlog_rad = log_rad[lostart[0]:loend[0], lostart[1]:loend[1]]
299 | nangle = angle[lostart[0]:loend[0], lostart[1]:loend[1]]
300 | YIrcos = np.sqrt(np.abs(1 - Yrcos**2))
301 | lomask = pointOp(nlog_rad, YIrcos, Xrcos)
302 |
303 | # Filtering
304 | lomask = pointOp(nlog_rad, YIrcos, Xrcos)
305 | lomask = torch.from_numpy(lomask[None,:,:,None])
306 | lomask = lomask.float().to(self.device)
307 |
308 | ################################################################################
309 |
310 | # Recursive call for image reconstruction
311 | nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle)
312 |
313 | resdft = torch.zeros_like(coeff[0][0]).to(self.device)
314 | resdft[:,lostart[0]:loend[0], lostart[1]:loend[1],:] = nresdft * lomask
315 |
316 | return resdft + orientdft
317 |
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/steerable/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/PyTorchSteerablePyramid/steerable/__init__.py
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/steerable/math_utils.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import torch
21 |
22 | ################################################################################
23 | ################################################################################
24 |
25 | def roll_n(X, axis, n):
26 | f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim()))
27 | b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim()))
28 | front = X[f_idx]
29 | back = X[b_idx]
30 | return torch.cat([back, front], axis)
31 |
32 | def batch_fftshift2d(x):
33 | real, imag = torch.unbind(x, -1)
34 | for dim in range(1, len(real.size())):
35 | n_shift = real.size(dim)//2
36 | if real.size(dim) % 2 != 0:
37 | n_shift += 1 # for odd-sized images
38 | real = roll_n(real, axis=dim, n=n_shift)
39 | imag = roll_n(imag, axis=dim, n=n_shift)
40 | return torch.stack((real, imag), -1) # last dim=2 (real&imag)
41 |
42 | def batch_ifftshift2d(x):
43 | real, imag = torch.unbind(x, -1)
44 | for dim in range(len(real.size()) - 1, 0, -1):
45 | real = roll_n(real, axis=dim, n=real.size(dim)//2)
46 | imag = roll_n(imag, axis=dim, n=imag.size(dim)//2)
47 | return torch.stack((real, imag), -1) # last dim=2 (real&imag)
48 |
49 | ################################################################################
50 | ################################################################################
51 |
52 | def prepare_grid(m, n):
53 | x = np.linspace(-(m // 2)/(m / 2), (m // 2)/(m / 2) - (1 - m % 2)*2/m, num=m)
54 | y = np.linspace(-(n // 2)/(n / 2), (n // 2)/(n / 2) - (1 - n % 2)*2/n, num=n)
55 | xv, yv = np.meshgrid(y, x)
56 | angle = np.arctan2(yv, xv)
57 | rad = np.sqrt(xv**2 + yv**2)
58 | rad[m//2][n//2] = rad[m//2][n//2 - 1]
59 | log_rad = np.log2(rad)
60 | return log_rad, angle
61 |
62 | def rcosFn(width, position):
63 | N = 256 # abritrary
64 | X = np.pi * np.array(range(-N-1, 2))/2/N
65 | Y = np.cos(X)**2
66 | Y[0] = Y[1]
67 | Y[N+2] = Y[N+1]
68 | X = position + 2*width/np.pi*(X + np.pi/4)
69 | return X, Y
70 |
71 | def pointOp(im, Y, X):
72 | out = np.interp(im.flatten(), X, Y)
73 | return np.reshape(out, im.shape)
74 |
75 | def getlist(coeff):
76 | straight = [bands for scale in coeff[1:-1] for bands in scale]
77 | straight = [coeff[0]] + straight + [coeff[-1]]
78 | return straight
79 |
80 | ################################################################################
81 | # NumPy reference implementation (fftshift and ifftshift)
82 |
83 | # def fftshift(x, axes=None):
84 | # """
85 | # Shift the zero-frequency component to the center of the spectrum.
86 | # This function swaps half-spaces for all axes listed (defaults to all).
87 | # Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even.
88 | # Parameters
89 | # """
90 | # x = np.asarray(x)
91 | # if axes is None:
92 | # axes = tuple(range(x.ndim))
93 | # shift = [dim // 2 for dim in x.shape]
94 | # shift = [x.shape[ax] // 2 for ax in axes]
95 | # return np.roll(x, shift, axes)
96 | #
97 | # def ifftshift(x, axes=None):
98 | # """
99 | # The inverse of `fftshift`. Although identical for even-length `x`, the
100 | # functions differ by one sample for odd-length `x`.
101 | # """
102 | # x = np.asarray(x)
103 | # if axes is None:
104 | # axes = tuple(range(x.ndim))
105 | # shift = [-(dim // 2) for dim in x.shape]
106 | # shift = [-(x.shape[ax] // 2) for ax in axes]
107 | # return np.roll(x, shift, axes)
108 |
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/steerable/utils.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-10
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 | import numpy as np
21 | import skimage
22 | import matplotlib.pyplot as plt
23 |
24 | import torch
25 | import torchvision
26 |
27 | ################################################################################
28 |
29 | ToPIL = torchvision.transforms.ToPILImage()
30 | from PIL import Image
31 | Grayscale = torchvision.transforms.Grayscale()
32 | RandomCrop = torchvision.transforms.RandomCrop
33 |
34 | def get_device(device='cuda:0'):
35 | assert isinstance(device, str)
36 | num_cuda = torch.cuda.device_count()
37 |
38 | if 'cuda' in device:
39 | if num_cuda > 0:
40 | # Found CUDA device, use the GPU
41 | return torch.device(device)
42 | # Fallback to CPU
43 | print('No CUDA devices found, falling back to CPU')
44 | device = 'cpu'
45 |
46 | if not torch.backends.mkl.is_available():
47 | raise NotImplementedError(
48 | 'torch.fft on the CPU requires MKL back-end. ' +
49 | 'Please recompile your PyTorch distribution.')
50 | return torch.device('cpu')
51 |
52 | def load_image_batch(image_file, batch_size, image_size=200):
53 | if not os.path.isfile(image_file):
54 | raise FileNotFoundError('Image file not found on disk: {}'.format(image_file))
55 | # im = ToPIL(skimage.io.imread(image_file))
56 | im = Image.open(image_file)
57 | im = Grayscale(im)
58 | im_batch = np.zeros((batch_size, image_size, image_size), np.float32)
59 | for i in range(batch_size):
60 | im_batch[i] = RandomCrop(image_size)(im)
61 | # insert channels dim and rescale
62 | return im_batch[:,None,:,:]/225.
63 |
64 | def show_image_batch(im_batch):
65 | assert isinstance(im_batch, torch.Tensor)
66 | im_batch = torchvision.utils.make_grid(im_batch).numpy()
67 | im_batch = np.transpose(im_batch.squeeze(1), (1,2,0))
68 | plt.imshow(im_batch)
69 | plt.axis('off')
70 | plt.tight_layout()
71 | plt.show()
72 | return im_batch
73 |
74 | def extract_from_batch(coeff_batch, example_idx=0):
75 | '''
76 | Given the batched Complex Steerable Pyramid, extract the coefficients
77 | for a single example from the batch. Additionally, it converts all
78 | torch.Tensor's to np.ndarrays' and changes creates proper np.complex
79 | objects for all the orientation bands.
80 |
81 | Args:
82 | coeff_batch (list): list containing low-pass, high-pass and pyr levels
83 | example_idx (int, optional): Defaults to 0. index in batch to extract
84 |
85 | Returns:
86 | list: list containing low-pass, high-pass and pyr levels as np.ndarray
87 | '''
88 | if not isinstance(coeff_batch, list):
89 | raise ValueError('Batch of coefficients must be a list')
90 | coeff = [] # coefficient for single example
91 | for coeff_level in coeff_batch:
92 | if isinstance(coeff_level, torch.Tensor):
93 | # Low- or High-Pass
94 | coeff_level_numpy = coeff_level[example_idx].cpu().numpy()
95 | coeff.append(coeff_level_numpy)
96 | elif isinstance(coeff_level, list):
97 | coeff_orientations_numpy = []
98 | for coeff_orientation in coeff_level:
99 | coeff_orientation_numpy = coeff_orientation[example_idx].cpu().numpy()
100 | coeff_orientation_numpy = coeff_orientation_numpy[:,:,0] + 1j*coeff_orientation_numpy[:,:,1]
101 | coeff_orientations_numpy.append(coeff_orientation_numpy)
102 | coeff.append(coeff_orientations_numpy)
103 | else:
104 | raise ValueError('coeff leve must be of type (list, torch.Tensor)')
105 | return coeff
106 |
107 | ################################################################################
108 |
109 | def make_grid_coeff(coeff, normalize=True):
110 | '''
111 | Visualization function for building a large image that contains the
112 | low-pass, high-pass and all intermediate levels in the steerable pyramid.
113 | For the complex intermediate bands, the real part is visualized.
114 |
115 | Args:
116 | coeff (list): complex pyramid stored as list containing all levels
117 | normalize (bool, optional): Defaults to True. Whether to normalize each band
118 |
119 | Returns:
120 | np.ndarray: large image that contains grid of all bands and orientations
121 | '''
122 | M, N = coeff[1][0].shape
123 | Norients = len(coeff[1])
124 | out = np.zeros((M * 2 - coeff[-1].shape[0], Norients * N))
125 | currentx, currenty = 0, 0
126 |
127 | for i in range(1, len(coeff[:-1])):
128 | for j in range(len(coeff[1])):
129 | tmp = coeff[i][j].real
130 | m, n = tmp.shape
131 | if normalize:
132 | tmp = 255 * tmp/tmp.max()
133 | tmp[m-1,:] = 255
134 | tmp[:,n-1] = 255
135 | out[currentx:currentx+m,currenty:currenty+n] = tmp
136 | currenty += n
137 | currentx += coeff[i][0].shape[0]
138 | currenty = 0
139 |
140 | m, n = coeff[-1].shape
141 | out[currentx: currentx+m, currenty: currenty+n] = 255 * coeff[-1]/coeff[-1].max()
142 | out[0,:] = 255
143 | out[:,0] = 255
144 | return out.astype(np.uint8)
145 |
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/tests/test_ifft.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-07
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import torch
21 | import cv2
22 |
23 | import steerable.fft as fft_utils
24 | import matplotlib.pyplot as plt
25 |
26 | ################################################################################
27 |
28 | tolerance = 1e-6
29 |
30 | image_file = './assets/lena.jpg'
31 | im = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
32 | im = cv2.resize(im, dsize=(200, 200))
33 | im = im.astype(np.float64)/255. # note the np.float64
34 |
35 | ################################################################################
36 | # NumPy
37 |
38 | fft_numpy = np.fft.fft2(im)
39 | fft_numpy = np.fft.fftshift(fft_numpy)
40 |
41 | fft_numpy_mag_viz = np.log10(np.abs(fft_numpy))
42 | fft_numpy_ang_viz = np.angle(fft_numpy)
43 |
44 | ifft_numpy1 = np.fft.ifftshift(fft_numpy)
45 | ifft_numpy = np.fft.ifft2(ifft_numpy1)
46 |
47 | ################################################################################
48 | # Torch
49 |
50 | device = torch.device('cpu')
51 |
52 | im_torch = torch.from_numpy(im[None,:,:]) # add batch dim
53 | im_torch = im_torch.to(device)
54 |
55 | # fft = complex-to-complex, rfft = real-to-complex
56 | fft_torch = torch.rfft(im_torch, signal_ndim=2, onesided=False)
57 | fft_torch = fft_utils.batch_fftshift2d(fft_torch)
58 |
59 | ifft_torch = fft_utils.batch_ifftshift2d(fft_torch)
60 | ifft_torch = torch.ifft(ifft_torch, signal_ndim=2, normalized=False)
61 |
62 | ifft_torch_to_numpy = ifft_torch.numpy()
63 | ifft_torch_to_numpy = np.split(ifft_torch_to_numpy, 2, -1) # complex => real/imag
64 | ifft_torch_to_numpy = np.squeeze(ifft_torch_to_numpy, -1)
65 | ifft_torch_to_numpy = ifft_torch_to_numpy[0] + 1j*ifft_torch_to_numpy[1]
66 | all_close_ifft = np.allclose(ifft_numpy1, ifft_torch_to_numpy, atol=tolerance)
67 | print('ifft all close: ', all_close_ifft)
68 |
69 | fft_torch = fft_torch.cpu().numpy().squeeze()
70 | fft_torch = np.split(fft_torch, 2, -1) # complex => real/imag
71 | fft_torch = np.squeeze(fft_torch, -1)
72 | fft_torch = fft_torch[0] + 1j*fft_torch[1]
73 |
74 | ifft_torch = ifft_torch.cpu().numpy().squeeze()
75 | ifft_torch = np.split(ifft_torch, 2, -1) # complex => real/imag
76 | ifft_torch = np.squeeze(ifft_torch, -1)
77 | ifft_torch = ifft_torch[0] + 1j*ifft_torch[1]
78 |
79 | fft_torch_mag_viz = np.log10(np.abs(fft_torch))
80 | fft_torch_ang_viz = np.angle(fft_torch)
81 |
82 | ################################################################################
83 | # Tolerance checking
84 |
85 | all_close_real = np.allclose(np.real(fft_numpy), np.real(fft_torch), atol=tolerance)
86 | all_close_imag = np.allclose(np.imag(fft_numpy), np.imag(fft_torch), atol=tolerance)
87 | print('fft allclose real: {}'.format(all_close_real))
88 | print('fft allclose imag: {}'.format(all_close_imag))
89 |
90 | all_close_real = np.allclose(np.real(ifft_numpy), np.real(ifft_torch), atol=tolerance)
91 | all_close_imag = np.allclose(np.imag(ifft_numpy), np.imag(ifft_torch), atol=tolerance)
92 | print('ifft allclose real: {}'.format(all_close_real))
93 | print('ifft allclose imag: {}'.format(all_close_imag))
94 |
95 | ################################################################################
96 |
97 | fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(12,6))
98 |
99 | # Plotting NumPy results
100 | ax[0][0].imshow(im, cmap='gray')
101 |
102 | ax[0][1].imshow(fft_numpy_mag_viz, cmap='gray')
103 | ax[0][1].set_title('NumPy fft magnitude')
104 | ax[0][2].imshow(fft_numpy_ang_viz, cmap='gray')
105 | ax[0][2].set_title('NumPy fft spectrum')
106 | ax[0][3].imshow(ifft_numpy.real, cmap='gray')
107 | ax[0][3].set_title('NumPy ifft real')
108 | ax[0][4].imshow(ifft_numpy.imag, cmap='gray')
109 | ax[0][4].set_title('NumPy ifft imag')
110 |
111 | # Plotting PyTorch results
112 | ax[1][0].imshow(im, cmap='gray')
113 | ax[1][1].imshow(fft_torch_mag_viz, cmap='gray')
114 | ax[1][1].set_title('PyTorch fft magnitude')
115 | ax[1][2].imshow(fft_torch_ang_viz, cmap='gray')
116 | ax[1][2].set_title('PyTorch fft phase')
117 | ax[1][3].imshow(ifft_torch.real, cmap='gray')
118 | ax[1][3].set_title('PyTorch ifft real')
119 | ax[1][4].imshow(ifft_torch.imag, cmap='gray')
120 | ax[1][4].set_title('PyTorch ifft imag')
121 |
122 | for cur_ax in ax.flatten():
123 | cur_ax.axis('off')
124 | plt.tight_layout()
125 | plt.show()
126 |
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/tests/test_torch_fft.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import torch
21 |
22 | ################################################################################
23 |
24 | x = np.linspace(-10, 10, 100)
25 | y = np.cos(np.pi*x)
26 |
27 | ################################################################################
28 | # NumPy
29 |
30 | y_fft_numpy = np.fft.fft(y)
31 |
32 | ################################################################################
33 | # Torch
34 |
35 | device = torch.device('cuda:0')
36 |
37 | y_torch = torch.from_numpy(y[None, :])
38 | y_torch = y_torch.to(device)
39 |
40 | # fft = complex-to-complex, rfft = real-to-complex
41 | y_fft_torch = torch.rfft(y_torch, signal_ndim=1, onesided=False)
42 | y_fft_torch = y_fft_torch.cpu().numpy().squeeze()
43 | y_fft_torch = y_fft_torch[:, 0] + 1j*y_fft_torch[:, 1]
44 |
45 | tolerance = 1e-6
46 | all_close = np.allclose(y_fft_numpy, y_fft_torch, atol=tolerance)
47 | print('numpy', y_fft_numpy.shape, y_fft_numpy.dtype)
48 | print('torch', y_fft_torch.shape, y_fft_torch.dtype)
49 | print('Succesful: {}'.format(all_close))
50 |
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/tests/test_torch_fftshift.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import matplotlib.pyplot as plt
21 | import torch
22 | import torch.nn as nn
23 |
24 | from steerable.fft import fftshift
25 |
26 | ################################################################################
27 |
28 | x = np.linspace(-10, 10, 100)
29 | y = np.cos(np.pi*x)
30 |
31 | ################################################################################
32 | # NumPy
33 |
34 | y_fft_numpy = np.fft.fft(y)
35 | y_fft_numpy_shift = np.fft.fftshift(y_fft_numpy)
36 | fft_numpy_real = np.real(y_fft_numpy_shift)
37 | fft_numpy_imag = np.imag(y_fft_numpy_shift)
38 |
39 | ################################################################################
40 | # Torch
41 |
42 | device = torch.device('cuda:0')
43 |
44 | y_torch = torch.from_numpy(y[None, :])
45 | y_torch = y_torch.to(device)
46 |
47 | # fft = complex-to-complex, rfft = real-to-complex
48 | y_fft_torch = torch.rfft(y_torch, signal_ndim=1, onesided=False)
49 | y_fft_torch = fftshift(y_fft_torch[:,:,0], y_fft_torch[:,:,1])
50 | y_fft_torch = y_fft_torch.cpu().numpy().squeeze()
51 | fft_torch_real = y_fft_torch[:,0]
52 | fft_torch_imag = y_fft_torch[:,1]
53 |
54 | tolerance = 1e-6
55 |
56 | all_close_real = np.allclose(fft_numpy_real, fft_torch_real, atol=tolerance)
57 | all_close_imag = np.allclose(fft_numpy_imag, fft_torch_imag, atol=tolerance)
58 |
59 | print('fftshift allclose real: {}'.format(all_close_real))
60 | print('fftshift allclose imag: {}'.format(all_close_imag))
61 |
62 | ################################################################################
63 |
64 | import cortex.plot
65 | import cortex.plot.colors
66 | colors = cortex.plot.nature_colors()
67 |
68 | fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(10,6))
69 |
70 | # Plotting NumPy results
71 | ax[0][0].plot(x, y, color=colors[3])
72 | ax[0][1].plot(np.real(y_fft_numpy), color=colors[0])
73 | ax[0][1].plot(fft_numpy_real, color=colors[1])
74 | ax[0][1].set_title('NumPy Real')
75 | ax[0][2].plot(np.imag(y_fft_numpy), color=colors[0])
76 | ax[0][2].plot(fft_numpy_real, color=colors[1])
77 | ax[0][2].set_title('NumPy Imag')
78 |
79 | # Plotting NumPy results
80 | ax[1][0].plot(x, y, color=colors[3])
81 | ax[1][1].plot(fft_torch_real, color=colors[1])
82 | ax[1][1].set_title('Torch Real')
83 | ax[1][2].plot(fft_torch_imag, color=colors[1])
84 | ax[1][2].set_title('Torch Imag')
85 |
86 | plt.show()
87 |
--------------------------------------------------------------------------------
/PyTorchSteerablePyramid/tests/test_torch_fftshift2d.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2018 Tom Runia
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to conditions.
11 | #
12 | # Author: Tom Runia
13 | # Date Created: 2018-12-04
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import numpy as np
20 | import matplotlib.pyplot as plt
21 | import torch
22 | import torch.nn as nn
23 | import cv2
24 |
25 | from steerable.fft import fftshift
26 |
27 | ################################################################################
28 |
29 | image_file = './assets/lena.jpg'
30 | im = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
31 | im = cv2.resize(im, dsize=(200, 200))
32 | im = im.astype(np.float64)/255. # note the np.float64
33 |
34 | ################################################################################
35 | # NumPy
36 |
37 | fft_numpy = np.fft.fft2(im)
38 | fft_numpy = np.fft.fftshift(fft_numpy)
39 |
40 | fft_numpy_mag_viz = np.log10(np.abs(fft_numpy))
41 | fft_numpy_ang_viz = np.angle(fft_numpy)
42 |
43 | ################################################################################
44 | # Torch
45 |
46 | device = torch.device('cuda:0')
47 |
48 | im_torch = torch.from_numpy(im[None,:,:]) # add batch dim
49 | im_torch = im_torch.to(device)
50 |
51 | # fft = complex-to-complex, rfft = real-to-complex
52 | fft_torch = torch.rfft(im_torch, signal_ndim=2, onesided=False)
53 | fft_torch = fftshift(fft_torch[:,:,:,0], fft_torch[:,:,:,1])
54 | fft_torch = fft_torch.cpu().numpy().squeeze()
55 | print(fft_torch.shape)
56 | fft_torch = np.split(fft_torch, 2, -1) # complex => real/imag
57 | fft_torch = np.squeeze(fft_torch, -1)
58 | fft_torch = fft_torch[0] + 1j*fft_torch[1]
59 |
60 | print('fft_torch', fft_torch.shape, fft_torch.dtype)
61 | fft_torch_mag_viz = np.log10(np.abs(fft_torch))
62 | fft_torch_ang_viz = np.angle(fft_torch)
63 |
64 | tolerance = 1e-6
65 |
66 | print(fft_numpy.dtype, fft_torch.dtype)
67 | all_close_real = np.allclose(np.real(fft_numpy), np.real(fft_torch), atol=tolerance)
68 | all_close_imag = np.allclose(np.imag(fft_numpy), np.imag(fft_torch), atol=tolerance)
69 |
70 | print('fftshift allclose real: {}'.format(all_close_real))
71 | print('fftshift allclose imag: {}'.format(all_close_imag))
72 |
73 | ################################################################################
74 |
75 | fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(10,6))
76 |
77 | # Plotting NumPy results
78 | ax[0][0].imshow(im, cmap='gray')
79 |
80 | ax[0][1].imshow(fft_numpy_mag_viz, cmap='gray')
81 | ax[0][1].set_title('NumPy Magnitude Spectrum')
82 | ax[0][2].imshow(fft_numpy_ang_viz, cmap='gray')
83 | ax[0][2].set_title('NumPy Phase Spectrum')
84 |
85 | # Plotting PyTorch results
86 | ax[1][0].imshow(im, cmap='gray')
87 | ax[1][1].imshow(fft_torch_mag_viz, cmap='gray')
88 | ax[1][1].set_title('Torch Magnitude Spectrum')
89 | ax[1][2].imshow(fft_torch_ang_viz, cmap='gray')
90 | ax[1][2].set_title('Torch Phase Spectrum')
91 |
92 | for cur_ax in ax.flatten():
93 | cur_ax.axis('off')
94 | plt.tight_layout()
95 | plt.show()
96 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # High-Fidelity Zero-Shot Texture Anomaly Localization Using Feature Correspondence Analysis (WACV 2024)
2 |
3 | ### [Project Page](https://reality.tf.fau.de/pub/ardelean2024highfidelity.html) | [Paper](https://arxiv.org/abs/2304.06433)
4 |
5 | The official implementation of *High-Fidelity Zero-Shot Texture Anomaly Localization Using Feature Correspondence Analysis*.
6 |
7 | 
8 |
9 | ## Installation
10 | We implemented our method using PyTorch.
11 | For an easy installation, we provide an environment file that contains all dependencies:
12 |
13 | ```
14 | conda env create -f environment.yml
15 | conda activate fca
16 | ```
17 |
18 | ## Demo
19 | We include a minimal script that allows running our code on sample images to easily try our method.
20 | By default, running the `demo.py` file will compute the anomaly score for the provided image example `example/000.png`.
21 | ```
22 | python demo.py
23 | ```
24 |
25 | ## Data
26 | To run the evaluation code, you have to first prepare the desired dataset.
27 | Our data loader assumes the data follows the file structure of the MVTec anomaly detection dataset.
28 | You can download the MVTec AD from [here](https://www.mvtec.com/company/research/datasets/mvtec-ad).
29 | The woven fabric textures dataset (WFT) can be downloaded [here](https://www.mydrive.ch/shares/46066/8338a11f32bb1b7b215c5381abe54ebf/download/420939225-1629955758/textures.zip).
30 | Please extract the data into the `datasets` directory.
31 | To use any other dataset that follows the same file structure you can simply place the data in the same folder and create a corresponding dataset config file under `conf/dataset`. To understand the required format, please see `conf/dataset/mvtec.yaml`.
32 |
33 | ## Run and evaluate on a dataset
34 | To evaluate the method, use the `main.py` script. We use Hydra to manage the command line interface, which makes it easy to specify the method and dataset configurations.
35 | For example, running our FCA statistics comparison (`sc`) with WideResnet features (`fe`) on the MVTec dataset is done using
36 | ```
37 | python main.py dataset=mvtec fe=wide sc=fca image_size=original tile_size=[9,9]
38 | ```
39 | To run the same on the WFT dataset
40 | ```
41 | python main.py dataset=wft fe=wide sc=fca image_size=original tile_size=[9,9]
42 | ```
43 |
44 | We automatically run the evaluation code after computing the anomaly scores for all images in the dataset.
45 | You can inspect the metrics in the json file `outputs/{experiment_name}/metrics_{padding}.json` and visualize the predicted anomaly maps under `outputs/{experiment_name}/{object_name}/visualize`
46 |
47 |
48 | ## Content
49 | This repository contains several options for feature extraction, as described in our paper: plain colors (`color`), random kernels (`random`), VGG-network (`vgg`), WideResnet-network (`wide`), steerable filters (`steerable`), and Laws' texture energy measure (`ltem`).
50 |
51 | For patch statistics comparison, you can opt for one of the following: moments-based (`moments`), histogram-based (`hist`), sample weighted wasserstein (`sww`), feature correspondence analysis (`fca`), and our reimplementation of the method of Aota et al. (`aota`).
52 |
53 | To reproduce one of our experiments for design space exploration (for example, histogram-based statistics comparison with VGG features) you can run:
54 | ```
55 | python main.py dataset=mvtec fe=vgg sc=hist image_size=[256,256] tile_size=[25,25]
56 | ```
57 |
58 | ## Citation
59 | Should you find our work useful in your research, please cite:
60 | ```BibTeX
61 | @inproceedings{ardelean2023highfidelity,
62 | title = {High-Fidelity Zero-Shot Texture Anomaly Localization Using Feature Correspondence Analysis},
63 | author = {Ardelean, Andrei-Timotei and Weyrich, Tim},
64 | booktitle = {IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
65 | numpages = 11,
66 | year = {2024},
67 | month = jan,
68 | day = 4,
69 | authorurl = {https://reality.tf.fau.de/pub/ardelean2024highfidelity.html},
70 | }
71 | ```
72 |
73 | ## Acknowledgements
74 | This project has received funding from the European Union’s Horizon 2020 research and innovation programme under the Marie Skłodowska-Curie grant agreement No 956585.
75 |
76 | ## License
77 | Please see the [LICENSE](LICENSE).
--------------------------------------------------------------------------------
/conf/base.yaml:
--------------------------------------------------------------------------------
1 | name: base
2 | device: cuda
3 | image_size: original
4 | tile_size: [9, 9]
5 |
6 | defaults:
7 | - _self_
8 | - fe: ???
9 | - sc: ???
10 | - dataset: mvtec
11 |
12 |
13 | hydra:
14 | job:
15 | chdir: True
16 | run:
17 | dir: ./outputs/${dataset.name}_${fe.name}_${sc.name}_${now:%Y-%m-%d}_${now:%H-%M-%S}
--------------------------------------------------------------------------------
/conf/dataset/mvtec.yaml:
--------------------------------------------------------------------------------
1 | name: mvtec
2 | data_manager:
3 | _target_: data_loaders.StandardFormatDataset
4 | data_dir: ${hydra:runtime.cwd}/datasets/mvtec_anomaly_detection
5 | out_dir: ${hydra:runtime.output_dir}
6 | paddings: [100]
7 | objects: ['carpet', 'grid', 'leather', 'tile', 'wood']
8 |
--------------------------------------------------------------------------------
/conf/dataset/wft.yaml:
--------------------------------------------------------------------------------
1 | name: wft
2 | data_manager:
3 | _target_: data_loaders.StandardFormatDataset
4 | data_dir: ${hydra:runtime.cwd}/datasets/wft
5 | out_dir: ${hydra:runtime.output_dir}
6 | paddings: [50]
7 | objects: ['texture_1', 'texture_2']
8 |
--------------------------------------------------------------------------------
/conf/fe/color.yaml:
--------------------------------------------------------------------------------
1 | name: color
2 | tile_size: ${tile_size}
3 | feature_extractor:
4 | _target_: feature_extractors.Colors
5 | post_scale: False
6 |
--------------------------------------------------------------------------------
/conf/fe/ltem.yaml:
--------------------------------------------------------------------------------
1 | name: ltem
2 | tile_size: ${tile_size}
3 | feature_extractor:
4 | _target_: feature_extractors.LawTextureEnergyMeasure
5 | mean_patch_size: 15
6 | device: ${device}
7 |
--------------------------------------------------------------------------------
/conf/fe/random.yaml:
--------------------------------------------------------------------------------
1 | name: random
2 | tile_size: ${tile_size}
3 | feature_extractor:
4 | _target_: feature_extractors.RandomKernels
5 | projections: [[128,1]]
6 | patch_size: 5
7 | device: ${device}
8 |
--------------------------------------------------------------------------------
/conf/fe/steerable.yaml:
--------------------------------------------------------------------------------
1 | name: steerable
2 | tile_size: ${tile_size}
3 | feature_extractor:
4 | _target_: feature_extractors.SteerableExtractor
5 | o: 43
6 | m: 1
7 | scale_factor: 1
8 | post_scale: True
9 |
--------------------------------------------------------------------------------
/conf/fe/vgg.yaml:
--------------------------------------------------------------------------------
1 | name: vgg
2 | tile_size: ${tile_size}
3 | feature_extractor:
4 | _target_: feature_extractors.VggExtractor
5 | post_scale: True
6 | layers: ['conv_1', 'conv_2']
7 | device: ${device}
8 |
--------------------------------------------------------------------------------
/conf/fe/wide.yaml:
--------------------------------------------------------------------------------
1 | name: wide
2 | tile_size: ${tile_size}
3 | feature_extractor:
4 | _target_: feature_extractors.WideResnetExtractor
5 | post_scale: True
6 | device: ${device}
7 |
--------------------------------------------------------------------------------
/conf/sc/aota.yaml:
--------------------------------------------------------------------------------
1 | name: aota
2 | method:
3 | _partial_: true
4 | _target_: sc_methods.sc_aota
5 | tile_size: ${tile_size}
6 |
--------------------------------------------------------------------------------
/conf/sc/fca.yaml:
--------------------------------------------------------------------------------
1 | name: FCA
2 | method:
3 | _target_: sc_methods.ScFCA
4 | tile_size: ${tile_size}
5 | chunk_size: 8
6 | sigma_p: 3.0
7 | reference_selection: median
8 | sigma_s: 1.0
9 |
--------------------------------------------------------------------------------
/conf/sc/hist.yaml:
--------------------------------------------------------------------------------
1 | name: hist
2 | method:
3 | _partial_: true
4 | _target_: sc_methods.sc_hist
5 | tile_size: ${tile_size}
6 | bins: 10
7 | sigma: 6.0
8 |
--------------------------------------------------------------------------------
/conf/sc/moments.yaml:
--------------------------------------------------------------------------------
1 | name: moments
2 | method:
3 | _partial_: true
4 | _target_: sc_methods.sc_moments
5 | tile_size: ${tile_size}
6 | sigma: 6.0
7 | powers: [1.0, 2.0, 3.0, 4.0]
8 |
--------------------------------------------------------------------------------
/conf/sc/sww.yaml:
--------------------------------------------------------------------------------
1 | name: sww
2 | method:
3 | _target_: sc_methods.ScSWW
4 | tile_size: ${tile_size}
5 | chunk_size: 8
6 | sigma: 6.0
7 | reference_selection: median
8 |
--------------------------------------------------------------------------------
/data_loaders.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | import tifffile as tiff
5 | from PIL import Image
6 | import numpy as np
7 |
8 | from torch.utils.data import Dataset
9 |
10 | from evaluation import evaluate_experiment
11 |
12 |
13 | class StandardFormatDataset(Dataset):
14 | def __init__(self, data_dir, out_dir, objects, paddings=(100,), pro_integration_limit=0.3, resize=False):
15 | super(StandardFormatDataset, self).__init__()
16 | self.data_dir = Path(data_dir)
17 | self.out_dir = Path(out_dir)
18 | self.objects = objects
19 |
20 | self.paddings = paddings
21 | self.resize = resize
22 | self.pro_integration_limit = pro_integration_limit
23 |
24 | self.in_image_paths = self.get_input_paths()
25 |
26 | def __getitem__(self, index):
27 | in_image_path = self.in_image_paths[index]
28 | return in_image_path
29 |
30 | def __len__(self):
31 | return len(self.in_image_paths)
32 |
33 | def get_input_paths(self):
34 | in_image_paths = []
35 | for obj in self.objects:
36 | in_obj_path = self.data_dir / obj / "test"
37 | for defect in sorted(in_obj_path.iterdir()):
38 | in_image_paths.extend(sorted(list(defect.iterdir())))
39 | return in_image_paths
40 |
41 | def save_output(self, age: torch.Tensor, in_image_path):
42 | np_age = age.cpu().numpy()
43 | obj, _, defect, _ = in_image_path.parts[-4:]
44 | stem = in_image_path.stem
45 | out_def_path = self.out_dir / obj / "test" / defect
46 | vis_def_path = self.out_dir / obj / "visualize" / defect
47 | out_def_path.mkdir(parents=True, exist_ok=True)
48 | vis_def_path.mkdir(parents=True, exist_ok=True)
49 |
50 | tiff.imsave(out_def_path / f"{stem}.tiff", np_age)
51 | # Save for visualization
52 | vis = ((np_age - np_age.min()) / (np_age.max() - np_age.min()) * 255)
53 | Image.fromarray(vis.astype(np.uint8)).save(vis_def_path / f"{stem}.jpg")
54 |
55 | def run_evaluation(self):
56 | args = {
57 | 'dataset_base_dir': str(self.data_dir),
58 | 'anomaly_maps_dir': str(self.out_dir),
59 | 'output_dir': str(self.out_dir),
60 | 'evaluated_objects': self.objects,
61 | 'pro_integration_limit': self.pro_integration_limit,
62 | 'resize': self.resize if self.resize else 0,
63 | }
64 | args = type('dict_as_obj', (object,), args) # Object from dict
65 | for padding in self.paddings:
66 | args.padding = padding
67 | evaluate_experiment.main(args)
68 |
--------------------------------------------------------------------------------
/datasets/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/datasets/.gitkeep
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import matplotlib.cm as cm
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 |
6 | from feature_extractors import WideResnetExtractor
7 | from sc_methods import ScFCA
8 |
9 |
10 | def main():
11 | device = torch.device('cuda')
12 | feature_extractor = WideResnetExtractor(device=device)
13 | s = ScFCA((9, 9), sigma_p=3.0, sigma_s=1.0)
14 |
15 | image = torch.tensor(np.asarray(Image.open("example/000.png")), device=device).permute(2, 0, 1)[None] / 255.0
16 | features = feature_extractor(image)
17 | a = s(features)
18 |
19 | pad = 10
20 | np_im = a.cpu().numpy()[pad:-pad, pad:-pad]
21 | np_im = (np_im - np_im.min()) / (np_im.max() - np_im.min())
22 | cma = cm.Reds(np_im, bytes=True)
23 | result = Image.fromarray(cma, 'RGBA').convert('RGB')
24 | result.save('example/000_out.png')
25 | result.show()
26 |
27 |
28 | if __name__ == '__main__':
29 | main()
30 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: fca
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1
9 | - _openmp_mutex=5.1
10 | - aom=3.6.0
11 | - blas=1.0
12 | - blosc=1.21.3
13 | - brotli=1.0.9
14 | - brotli-python=1.0.9
15 | - brunsli=0.1
16 | - bzip2=1.0.8
17 | - c-ares=1.19.1
18 | - ca-certificates=2023.08.22
19 | - certifi=2023.11.17
20 | - cffi=1.16.0
21 | - cfitsio=3.470
22 | - charls=2.2.0
23 | - charset-normalizer=2.0.4
24 | - click=8.1.7
25 | - cloudpickle=2.2.1
26 | - colorama=0.4.6
27 | - cryptography=41.0.7
28 | - cuda-cudart=11.8.89
29 | - cuda-cupti=11.8.87
30 | - cuda-libraries=11.8.0
31 | - cuda-nvrtc=11.8.89
32 | - cuda-nvtx=11.8.86
33 | - cuda-runtime=11.8.0
34 | - cytoolz=0.12.2
35 | - dask-core=2023.11.0
36 | - dav1d=1.2.1
37 | - ffmpeg=4.3
38 | - filelock=3.13.1
39 | - freetype=2.12.1
40 | - fsspec=2023.10.0
41 | - giflib=5.2.1
42 | - gmp=6.2.1
43 | - gmpy2=2.1.2
44 | - gnutls=3.6.15
45 | - idna=3.4
46 | - imagecodecs=2023.1.23
47 | - imageio=2.31.4
48 | - importlib-metadata=6.0.0
49 | - intel-openmp=2023.1.0
50 | - jinja2=3.1.2
51 | - joblib=1.2.0
52 | - jpeg=9e
53 | - jxrlib=1.1
54 | - krb5=1.20.1
55 | - lame=3.100
56 | - lcms2=2.12
57 | - ld_impl_linux-64=2.38
58 | - lerc=3.0
59 | - libaec=1.0.4
60 | - libavif=0.11.1
61 | - libbrotlicommon=1.0.9
62 | - libbrotlidec=1.0.9
63 | - libbrotlienc=1.0.9
64 | - libcublas=11.11.3.6
65 | - libcufft=10.9.0.58
66 | - libcufile=1.8.1.2
67 | - libcurand=10.3.4.101
68 | - libcurl=8.4.0
69 | - libcusolver=11.4.1.48
70 | - libcusparse=11.7.5.86
71 | - libdeflate=1.17
72 | - libedit=3.1.20221030
73 | - libev=4.33
74 | - libffi=3.4.4
75 | - libgcc-ng=11.2.0
76 | - libgfortran-ng=13.2.0
77 | - libgfortran5=13.2.0
78 | - libgomp=11.2.0
79 | - libiconv=1.16
80 | - libidn2=2.3.4
81 | - libjpeg-turbo=2.0.0
82 | - libnghttp2=1.57.0
83 | - libnpp=11.8.0.86
84 | - libnvjpeg=11.9.0.86
85 | - libpng=1.6.39
86 | - libssh2=1.10.0
87 | - libstdcxx-ng=11.2.0
88 | - libtasn1=4.19.0
89 | - libtiff=4.5.1
90 | - libunistring=0.9.10
91 | - libwebp=1.3.2
92 | - libwebp-base=1.3.2
93 | - libzopfli=1.0.3
94 | - llvm-openmp=14.0.6
95 | - locket=1.0.0
96 | - lz4-c=1.9.4
97 | - markupsafe=2.1.1
98 | - mkl=2023.1.0
99 | - mkl-service=2.4.0
100 | - mkl_fft=1.3.8
101 | - mkl_random=1.2.4
102 | - mpc=1.1.0
103 | - mpfr=4.0.2
104 | - mpmath=1.3.0
105 | - ncurses=6.4
106 | - nettle=3.7.3
107 | - networkx=3.1
108 | - numpy=1.26.2
109 | - numpy-base=1.26.2
110 | - openh264=2.1.1
111 | - openjpeg=2.4.0
112 | - openssl=3.0.12
113 | - partd=1.4.1
114 | - pillow=10.0.1
115 | - pip=23.3.1
116 | - pycparser=2.21
117 | - pyopenssl=23.2.0
118 | - pysocks=1.7.1
119 | - python=3.9.18
120 | - pytorch=2.1.2
121 | - pytorch-cuda=11.8
122 | - pytorch-mutex=1.0
123 | - pywavelets=1.4.1
124 | - pyyaml=6.0.1
125 | - readline=8.2
126 | - requests=2.31.0
127 | - scikit-image=0.19.3
128 | - scikit-learn=1.3.0
129 | - scipy=1.11.4
130 | - setuptools=68.2.2
131 | - snappy=1.1.9
132 | - sqlite=3.41.2
133 | - sympy=1.12
134 | - tbb=2021.8.0
135 | - threadpoolctl=2.2.0
136 | - tifffile=2023.2.28
137 | - tk=8.6.12
138 | - toolz=0.12.0
139 | - torchaudio=2.1.2
140 | - torchtriton=2.1.0
141 | - torchvision=0.16.2
142 | - tqdm=4.66.1
143 | - typing_extensions=4.7.1
144 | - tzdata=2023c
145 | - urllib3=1.26.18
146 | - wheel=0.41.2
147 | - xz=5.4.5
148 | - yaml=0.2.5
149 | - zfp=1.0.0
150 | - zipp=3.11.0
151 | - zlib=1.2.13
152 | - zstd=1.5.5
153 | - pip:
154 | - antlr4-python3-runtime==4.9.3
155 | - hydra-core==1.3.2
156 | - omegaconf==2.3.0
157 | - opencv-python==4.8.1.78
158 | - packaging==23.2
159 |
--------------------------------------------------------------------------------
/evaluation/additional_util.py:
--------------------------------------------------------------------------------
1 | from sklearn import metrics
2 | import numpy as np
3 |
4 |
5 | def compute_seg_au_roc(anomaly_maps, ground_truth_maps):
6 | anomaly_scores_flat = np.array(anomaly_maps).ravel()
7 | gt_flat = np.array(ground_truth_maps).ravel() > 0
8 | fpr, tpr, thresholds = metrics.roc_curve(gt_flat.astype(int), anomaly_scores_flat)
9 | # au_roc = metrics.roc_auc_score(gt_flat.astype(int), anomaly_scores_flat)
10 | return fpr, tpr
11 |
12 |
13 | def compute_precision_recall(anomaly_maps, ground_truth_maps):
14 | anomaly_scores_flat = np.array(anomaly_maps).ravel()
15 | gt_flat = np.array(ground_truth_maps).ravel() > 0
16 | precision, recall, thresholds = metrics.precision_recall_curve(gt_flat.astype(int), anomaly_scores_flat)
17 | return precision, recall, thresholds
18 |
19 |
20 | def compute_optimal_f1(anomaly_maps, ground_truth_maps):
21 | precision, recall, thresholds = compute_precision_recall(anomaly_maps, ground_truth_maps)
22 | f1_score = 2 * (precision * recall) / (precision + recall + 1e-8)
23 | ind = np.argmax(f1_score)
24 | return f1_score[ind].item(), thresholds[ind].item()
25 |
--------------------------------------------------------------------------------
/evaluation/evaluate_experiment.py:
--------------------------------------------------------------------------------
1 | """Compute evaluation metrics for a single experiment."""
2 |
3 | __author__ = "Paul Bergmann, David Sattlegger"
4 | __copyright__ = "2021, MVTec Software GmbH"
5 |
6 | import argparse
7 | import json
8 | from os import makedirs, path
9 | from pathlib import Path
10 |
11 | import numpy as np
12 | import skimage
13 | import skimage.transform
14 | from PIL import Image
15 | from tqdm import tqdm
16 |
17 | from evaluation import generic_util as util
18 | from evaluation.additional_util import compute_optimal_f1, compute_seg_au_roc
19 | from evaluation.pro_curve_util import compute_pro
20 | from evaluation.roc_curve_util import compute_classification_roc
21 |
22 |
23 | def parse_user_arguments():
24 | """Parse user arguments for the evaluation of a method on the MVTec AD
25 | dataset.
26 |
27 | Returns:
28 | Parsed user arguments.
29 | """
30 | parser = argparse.ArgumentParser(description="""Parse user arguments.""")
31 |
32 | parser.add_argument('--anomaly_maps_dir',
33 | required=True,
34 | help="""Path to the directory that contains the anomaly
35 | maps of the evaluated method.""")
36 |
37 | parser.add_argument('--dataset_base_dir',
38 | required=True,
39 | help="""Path to the directory that contains the dataset
40 | images of the MVTec AD dataset.""")
41 |
42 | parser.add_argument('--output_dir',
43 | help="""Path to the directory to store evaluation
44 | results. If no output directory is specified,
45 | the results are not written to drive.""")
46 |
47 | parser.add_argument('--padding',
48 | type=int,
49 | default=100,
50 | help="""How much to cut of the borders of the image before evaluating;
51 | the metrics are computed on m[padding:-padding, padding:-padding]""")
52 |
53 | parser.add_argument('--resize',
54 | type=int,
55 | default=0,
56 | help="""Resize the predictions and ground truth masks before evaluation""")
57 |
58 | parser.add_argument('--pro_integration_limit',
59 | type=float,
60 | default=0.3,
61 | help="""Integration limit to compute the area under
62 | the PRO curve. Must lie within the interval
63 | of (0.0, 1.0].""")
64 |
65 | parser.add_argument('--evaluated_objects',
66 | nargs='+',
67 | help="""List of objects to be evaluated. By default,
68 | all dataset objects will be evaluated.""",
69 | choices=util.OBJECT_NAMES,
70 | default=util.OBJECT_NAMES)
71 |
72 | args = parser.parse_args()
73 |
74 | # Check that the PRO integration limit is within the valid range.
75 | assert 0.0 < args.pro_integration_limit <= 1.0
76 |
77 | return args
78 |
79 |
80 | def parse_dataset_files(object_name, dataset_base_dir, anomaly_maps_dir):
81 | """Parse the filenames for one object of the MVTec AD dataset.
82 |
83 | Args:
84 | object_name: Name of the dataset object.
85 | dataset_base_dir: Base directory of the MVTec AD dataset.
86 | anomaly_maps_dir: Base directory where anomaly maps are located.
87 | """
88 |
89 | # Store a list of all ground truth filenames.
90 | gt_filenames = []
91 |
92 | # Store a list of all corresponding anomaly map filenames.
93 | prediction_filenames = []
94 |
95 | # Test images are located here.
96 | test_dir = path.join(dataset_base_dir, object_name, 'test')
97 | gt_base_dir = path.join(dataset_base_dir, object_name, 'ground_truth')
98 | anomaly_maps_base_dir = path.join(anomaly_maps_dir, object_name, 'test')
99 |
100 | # List all ground truth and corresponding anomaly images.
101 | for subdir in Path(test_dir).iterdir():
102 |
103 | # Get paths to all test images in the dataset for this subdir.
104 | test_images = list(subdir.glob('*.png')) + list(subdir.glob('*.jpg'))
105 | test_images = sorted([x.stem for x in test_images])
106 |
107 | # If subdir is not 'good', derive corresponding GT names.
108 | if subdir.name != 'good':
109 | file_names = (Path(gt_base_dir) / subdir.name).glob('*.png')
110 | gt_filenames.extend(sorted(file_names))
111 | else:
112 | # No ground truth maps exist for anomaly-free images.
113 | gt_filenames.extend([None] * len(test_images))
114 |
115 | # Fetch corresponding anomaly maps.
116 | prediction_filenames.extend(
117 | [path.join(anomaly_maps_base_dir, subdir.name, file)
118 | for file in test_images])
119 |
120 | print(f"Parsed {len(gt_filenames)} ground truth image files.")
121 |
122 | return gt_filenames, prediction_filenames
123 |
124 |
125 | def calculate_au_pro_au_roc(gt_filenames,
126 | prediction_filenames,
127 | integration_limit,
128 | padding=0,
129 | resize=False):
130 | """Compute the area under the PRO curve for a set of ground truth images
131 | and corresponding anomaly images.
132 |
133 | In addition, the function computes the area under the ROC curve for image
134 | level classification.
135 |
136 | Args:
137 | gt_filenames: List of filenames that contain the ground truth images
138 | for a single dataset object.
139 | prediction_filenames: List of filenames that contain the corresponding
140 | anomaly images for each ground truth image.
141 | integration_limit: Integration limit to use when computing the area
142 | under the PRO curve.
143 | padding: Remove borders before evaluation
144 | resize: Resize images before padding
145 | Returns:
146 | au_pro: Area under the PRO curve computed up to the given integration
147 | limit.
148 | au_roc: Area under the ROC curve.
149 | pro_curve: PRO curve values for localization (fpr,pro).
150 | roc_curve: ROC curve values for image level classifiction (fpr,tpr).
151 | """
152 | # Read all ground truth and anomaly images.
153 | ground_truth = []
154 | predictions = []
155 |
156 | print("Read ground truth files and corresponding predictions...")
157 | original_size = np.asarray(Image.open(next(_ for _ in gt_filenames if _ is not None))).shape[-2:]
158 | for (gt_name, pred_name) in tqdm(zip(gt_filenames, prediction_filenames),
159 | total=len(gt_filenames)):
160 | prediction = util.read_tiff(pred_name)
161 | if resize > 0:
162 | prediction = skimage.transform.resize(prediction, (resize, resize))
163 | elif prediction.shape != original_size:
164 | prediction = skimage.transform.resize(prediction, original_size)
165 | if padding != 0:
166 | prediction = prediction[padding:-padding, padding:-padding]
167 | predictions.append(prediction)
168 |
169 | if gt_name is not None:
170 | gt = np.asarray(Image.open(gt_name))
171 | if resize > 0:
172 | gt = skimage.transform.resize(gt, (resize, resize))
173 | if padding != 0:
174 | gt = gt[padding:-padding, padding:-padding]
175 | else:
176 | gt = np.zeros(prediction.shape)
177 | ground_truth.append(gt)
178 |
179 | # Compute the PRO curve.
180 | pro_curve = compute_pro(
181 | anomaly_maps=predictions,
182 | ground_truth_maps=ground_truth)
183 |
184 | # Compute the area under the PRO curve.
185 | au_pro = util.trapezoid(
186 | pro_curve[0], pro_curve[1], x_max=integration_limit)
187 | au_pro /= integration_limit
188 | print(f"AU-PRO (FPR limit: {integration_limit}): {au_pro}")
189 |
190 | # Compute the segmentation ROC
191 | seg_roc_curve = compute_seg_au_roc(anomaly_maps=predictions, ground_truth_maps=ground_truth)
192 | au_seg_roc = util.trapezoid(seg_roc_curve[0], seg_roc_curve[1], x_max=1.0)
193 | print(f"AU-ROC segmentation: {au_seg_roc}")
194 |
195 | f1_optimal = compute_optimal_f1(anomaly_maps=predictions, ground_truth_maps=ground_truth)
196 |
197 | # Derive binary labels for each input image:
198 | # (0 = anomaly free, 1 = anomalous).
199 | binary_labels = [int(np.any(x > 0)) for x in ground_truth]
200 | del ground_truth
201 |
202 | if all([x == 1 for x in binary_labels]):
203 | roc_curve, au_roc = [], 1.0
204 | else:
205 | # Compute the classification ROC curve.
206 | roc_curve = compute_classification_roc(
207 | anomaly_maps=predictions,
208 | scoring_function=np.max,
209 | ground_truth_labels=binary_labels)
210 |
211 | # Compute the area under the classification ROC curve.
212 | au_roc = util.trapezoid(roc_curve[0], roc_curve[1])
213 | print(f"Image-level classification AU-ROC: {au_roc}")
214 |
215 | # Return the evaluation metrics.
216 | return au_pro, au_roc, pro_curve, roc_curve, seg_roc_curve, au_seg_roc, f1_optimal
217 |
218 |
219 | def main(args):
220 | """Calculate the performance metrics for a single experiment on the
221 | MVTec AD dataset.
222 | """
223 | # Parse user arguments.
224 | padding = args.padding
225 | resize = args.resize
226 |
227 | # Store evaluation results in this dictionary.
228 | evaluation_dict = dict()
229 |
230 | # Keep track of the mean performance measures.
231 | au_pros = []
232 | au_rocs = []
233 | au_segs = []
234 | f1s = []
235 |
236 | # Evaluate each dataset object separately.
237 | for obj in args.evaluated_objects:
238 | print(f"=== Evaluate {obj} ===")
239 | evaluation_dict[obj] = dict()
240 |
241 | # Parse the filenames of all ground truth and corresponding anomaly
242 | # images for this object.
243 | gt_filenames, prediction_filenames = \
244 | parse_dataset_files(
245 | object_name=obj,
246 | dataset_base_dir=args.dataset_base_dir,
247 | anomaly_maps_dir=args.anomaly_maps_dir)
248 |
249 | # Calculate the PRO and ROC curves.
250 | au_pro, au_roc, pro_curve, roc_curve, seg_curve, au_seg, f1_optimal = \
251 | calculate_au_pro_au_roc(
252 | gt_filenames,
253 | prediction_filenames,
254 | args.pro_integration_limit, padding=padding, resize=resize)
255 |
256 | evaluation_dict[obj]['au_pro'] = au_pro
257 | evaluation_dict[obj]['classification_au_roc'] = au_roc
258 | evaluation_dict[obj]['au_segroc'] = au_seg
259 | evaluation_dict[obj]['f1_optimal_value'] = f1_optimal[0]
260 | evaluation_dict[obj]['f1_optimal_threshold'] = f1_optimal[1]
261 |
262 | # evaluation_dict[obj]['classification_roc_curve_fpr'] = roc_curve[0]
263 | # evaluation_dict[obj]['classification_roc_curve_tpr'] = roc_curve[1]
264 |
265 | # Keep track of the mean performance measures.
266 | au_pros.append(au_pro)
267 | au_rocs.append(au_roc)
268 | au_segs.append(au_seg)
269 | f1s.append(f1_optimal[0])
270 |
271 | print('\n')
272 |
273 | # Compute the mean of the performance measures.
274 | evaluation_dict['mean_au_pro'] = np.mean(au_pros).item()
275 | evaluation_dict['mean_au_segroc'] = np.mean(au_segs).item()
276 | evaluation_dict['mean_classification_au_roc'] = np.mean(au_rocs).item()
277 | evaluation_dict['mean_f1'] = np.mean(f1s).item()
278 |
279 | # If required, write evaluation metrics to drive.
280 | if args.output_dir is not None:
281 | output_dir = args.output_dir
282 | makedirs(output_dir, exist_ok=True)
283 |
284 | with open(path.join(output_dir, f'metrics_{padding}.json'), 'w') as file:
285 | json.dump(evaluation_dict, file, indent=4)
286 |
287 | print(f"Wrote metrics to {path.join(output_dir, f'metrics_{padding}.json')}")
288 |
289 |
290 | if __name__ == "__main__":
291 | main(parse_user_arguments())
292 |
--------------------------------------------------------------------------------
/evaluation/evaluate_multiple_experiments.py:
--------------------------------------------------------------------------------
1 | """Run the evaluation script for multiple experiments.
2 |
3 | This is a wrapper around evaluate_experiments.py which is called once for each
4 | experiment specified in the config file passed to this script.
5 | """
6 |
7 | __author__ = "Paul Bergmann, David Sattlegger"
8 | __copyright__ = "2021, MVTec Software GmbH"
9 |
10 | import argparse
11 | import json
12 | import subprocess
13 | from os import path
14 |
15 |
16 | def parse_user_arguments():
17 | """Parse user arguments.
18 |
19 | Returns: Parsed arguments.
20 | """
21 | parser = argparse.ArgumentParser(description="""Parse user arguments.""")
22 |
23 | parser.add_argument('--experiment_configs',
24 | default='experiment_configs.json',
25 | help="""Path to the config file that contains the
26 | locations of all experiments that should be
27 | evaluated.""")
28 |
29 | parser.add_argument('--dataset_base_dir',
30 | required=True,
31 | help="""Path to the directory that contains the dataset
32 | images of the MVTec AD dataset.""")
33 |
34 | parser.add_argument('--output_dir',
35 | default='metrics/',
36 | help="""Path to write evaluation results to.""")
37 |
38 | parser.add_argument('--dry_run',
39 | choices=['True', 'False'],
40 | default='False',
41 | help="""If set to 'True', the script is run without
42 | perfoming the actual evalutions. Instead, the
43 | experiments to be evaluated are simply printed
44 | to the standard output.""")
45 |
46 | parser.add_argument('--pro_integration_limit',
47 | type=float,
48 | default=0.3,
49 | help="""Integration limit to compute the area under
50 | the PRO curve. Must lie within the interval
51 | of (0.0, 1.0].""")
52 |
53 | return parser.parse_args()
54 |
55 |
56 | def main():
57 | """Run the evaluation script for multiple experiments."""
58 | # Parse user arguments.
59 | args = parse_user_arguments()
60 |
61 | # Read the experiment configurations to be evaluated.
62 | with open(args.experiment_configs) as file:
63 | experiment_configs = json.load(file)
64 |
65 | # Call the evaluation script for each experiment separately.
66 | for experiment_id in experiment_configs['anomaly_maps_dirs']:
67 | print(f"=== Evaluate experiment: {experiment_id} ===\n")
68 |
69 | # Anomaly maps for this experiment are located in this directory.
70 | anomaly_maps_dir = path.join(
71 | experiment_configs['exp_base_dir'],
72 | experiment_configs['anomaly_maps_dirs'][experiment_id])
73 |
74 | # Set up python call for the evaluation script.
75 | call = ['python', 'evaluate_experiment.py',
76 | '--anomaly_maps_dir', anomaly_maps_dir,
77 | '--dataset_base_dir', args.dataset_base_dir,
78 | '--output_dir', path.join(args.output_dir, experiment_id),
79 | '--pro_integration_limit', str(args.pro_integration_limit)]
80 |
81 | # Run evaluation script.
82 | if args.dry_run == 'False':
83 | subprocess.run(call, check=True)
84 | else:
85 | print(f"Would call: {' '.join(call)}\n")
86 |
87 |
88 | if __name__ == "__main__":
89 | main()
90 |
--------------------------------------------------------------------------------
/evaluation/generic_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utility functions for:
3 | - parsing user arguments.
4 | - computing the area under a curve.
5 | - generating a toy dataset to test the evaluation script.
6 | """
7 | from bisect import bisect
8 | import os
9 |
10 | import numpy as np
11 | import tifffile as tiff
12 |
13 | OBJECT_NAMES = ['bottle', 'cable', 'capsule', 'carpet', 'grid',
14 | 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
15 | 'tile', 'toothbrush', 'transistor', 'wood', 'zipper']
16 |
17 |
18 | def trapezoid(x, y, x_max=None):
19 | """
20 | This function calculates the definit integral of a curve given by
21 | x- and corresponding y-values. In contrast to, e.g., 'numpy.trapz()',
22 | this function allows to define an upper bound to the integration range by
23 | setting a value x_max.
24 |
25 | Points that do not have a finite x or y value will be ignored with a
26 | warning.
27 |
28 | Args:
29 | x: Samples from the domain of the function to integrate
30 | Need to be sorted in ascending order. May contain the same value
31 | multiple times. In that case, the order of the corresponding
32 | y values will affect the integration with the trapezoidal rule.
33 | y: Values of the function corresponding to x values.
34 | x_max: Upper limit of the integration. The y value at max_x will be
35 | determined by interpolating between its neighbors. Must not lie
36 | outside of the range of x.
37 |
38 | Returns:
39 | Area under the curve.
40 | """
41 |
42 | x = np.asarray(x)
43 | y = np.asarray(y)
44 | finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y))
45 | if not finite_mask.all():
46 | print("WARNING: Not all x and y values passed to trapezoid(...)"
47 | " are finite. Will continue with only the finite values.")
48 | x = x[finite_mask]
49 | y = y[finite_mask]
50 |
51 | # Introduce a correction term if max_x is not an element of x.
52 | correction = 0.
53 | if x_max is not None:
54 | if x_max not in x:
55 | # Get the insertion index that would keep x sorted after
56 | # np.insert(x, ins, x_max).
57 | ins = bisect(x, x_max)
58 | # x_max must be between the minimum and the maximum, so the
59 | # insertion_point cannot be zero or len(x).
60 | assert 0 < ins < len(x)
61 |
62 | # Calculate the correction term which is the integral between
63 | # the last x[ins-1] and x_max. Since we do not know the exact value
64 | # of y at x_max, we interpolate between y[ins] and y[ins-1].
65 | y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) *
66 | (x_max - x[ins - 1]) /
67 | (x[ins] - x[ins - 1]))
68 | correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1])
69 |
70 | # Cut off at x_max.
71 | mask = x <= x_max
72 | x = x[mask]
73 | y = y[mask]
74 |
75 | # Return area under the curve using the trapezoidal rule.
76 | return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction
77 |
78 |
79 | def read_tiff(file_path_no_ext, exts=('.tif', '.tiff', '.TIF', '.TIFF')):
80 | """Read a TIFF file from a given path without the TIFF extension.
81 |
82 | Args:
83 | file_path_no_ext: Path to the TIFF file without a file extension.
84 | exts: TIFF extensions to consider when searching for the file.
85 |
86 | Raises:
87 | FileNotFoundError: The given file path does not exist with any of the
88 | given extensions.
89 | IOError: The given file path exists with multiple of the given
90 | extensions.
91 | """
92 | # Get all file paths that exist
93 | file_paths = []
94 | for ext in exts:
95 | # Make sure the file path does not already end with a tiff extension.
96 | assert not file_path_no_ext.endswith(ext)
97 | file_path = file_path_no_ext + ext
98 | if os.path.exists(file_path):
99 | file_paths.append(file_path)
100 |
101 | if len(file_paths) == 0:
102 | raise FileNotFoundError('Could not find a file with a TIFF extension'
103 | f' at {file_path_no_ext}')
104 | elif len(file_paths) > 1:
105 | raise IOError('Found multiple files with a TIFF extension at'
106 | f' {file_path_no_ext}'
107 | '\nPlease specify which TIFF extension to use via the'
108 | ' `exts` parameter of this function.')
109 |
110 | return tiff.imread(file_paths[0])
111 |
112 |
113 | def generate_toy_dataset(num_images, image_width, image_height, gt_size):
114 | """Generate a toy dataset to test the evaluation script.
115 |
116 | Args:
117 | num_images: Number of images that the toy dataset contains.
118 | image_width: Width of the dataset images in pixels.
119 | image_height: Height of the dataset images in pixels.
120 | gt_size: Size of rectangular ground truth regions that are
121 | artificially generated on the dataset images.
122 |
123 | Returns:
124 | anomaly_maps: List of numpy arrays that contain random anomaly maps.
125 | ground_truth_map: Corresponding list of numpy arrays that specify a
126 | rectangular ground truth region.
127 | """
128 | # Fix a random seed for reproducibility.
129 | np.random.seed(1338)
130 |
131 | # Create synthetic evaluation data with random anomaly scores and
132 | # simple ground truth maps.
133 | anomaly_maps = []
134 | ground_truth_maps = []
135 | for _ in range(num_images):
136 | # Sample a random anomaly map.
137 | anomaly_map = np.random.random((image_height, image_width))
138 |
139 | # Construct a fixed ground truth map.
140 | ground_truth_map = np.zeros((image_height, image_width))
141 | ground_truth_map[0:gt_size, 0:gt_size] = 1
142 |
143 | anomaly_maps.append(anomaly_map)
144 | ground_truth_maps.append(ground_truth_map)
145 |
146 | return anomaly_maps, ground_truth_maps
147 |
--------------------------------------------------------------------------------
/evaluation/print_metrics.py:
--------------------------------------------------------------------------------
1 | """Print the key metrics of multiple experiments to the standard output.
2 | """
3 |
4 | __author__ = "Paul Bergmann, David Sattlegger"
5 | __copyright__ = "2021, MVTec Software GmbH"
6 |
7 | import argparse
8 | import json
9 | import os
10 | from os.path import join
11 |
12 | import numpy as np
13 | from tabulate import tabulate
14 |
15 | from generic_util import OBJECT_NAMES
16 |
17 |
18 | def parse_user_arguments():
19 | """Parse user arguments.
20 |
21 | Returns: Parsed arguments.
22 | """
23 | parser = argparse.ArgumentParser(description="""Parse user arguments.""")
24 |
25 | parser.add_argument('--metrics_folder',
26 | default="./metrics/",
27 | help="""Path to the folder that contains the evaluation
28 | results.""")
29 |
30 | return parser.parse_args()
31 |
32 |
33 | def extract_table_rows(metrics_folder, metric):
34 | """Extract all rows to create a table that displays a given metric for each
35 | evaluated experiment.
36 |
37 | Args:
38 | metrics_folder: Base folder that contains evaluation results.
39 | metric: Name of the metric to be extracted. Choose between
40 | 'au_pro' for localization and
41 | 'classification_au_roc' for classification.
42 |
43 | Returns:
44 | List of table rows. Each row contains the experiment name and the
45 | extracted metrics for each evaluated object as well as the mean
46 | performance.
47 | """
48 | assert metric in ['au_pro', 'classification_au_roc']
49 |
50 | # Iterate each experiment.
51 | exp_ids = os.listdir(metrics_folder)
52 | exp_id_to_json_path = {
53 | exp_id: join(metrics_folder, exp_id, 'metrics.json')
54 | for exp_id in exp_ids
55 | if os.path.exists(join(metrics_folder, exp_id, 'metrics.json'))
56 | }
57 |
58 | # If there is a metrics.json file in the metrics_folder, also print that.
59 | # This is the case when evaluate_experiment.py has been called.
60 | root_metrics_json_path = join(metrics_folder, 'metrics.json')
61 | if os.path.exists(root_metrics_json_path):
62 | exp_id = join(os.path.split(metrics_folder)[-1], 'metrics.json')
63 | exp_id_to_json_path[exp_id] = root_metrics_json_path
64 |
65 | rows = []
66 | for exp_id, json_path in exp_id_to_json_path.items():
67 |
68 | # Each row starts with the name of the experiment.
69 | row = [exp_id]
70 |
71 | # Open the metrics file.
72 | with open(json_path) as file:
73 | metrics = json.load(file)
74 |
75 | # Parse performance metrics for each evaluated object if available.
76 | for obj in OBJECT_NAMES:
77 | if obj in metrics:
78 | row.append(np.round(metrics[obj][metric], decimals=3))
79 | else:
80 | row.append('-')
81 |
82 | # Parse mean performance.
83 | row.append(np.round(metrics['mean_' + metric], decimals=3))
84 | rows.append(row)
85 |
86 | return rows
87 |
88 |
89 | def main():
90 | """Print the key metrics of multiple experiments to the standard output.
91 | """
92 | # Parse user arguments.
93 | args = parse_user_arguments()
94 |
95 | # Create the table rows. One row for each experiment.
96 | rows_pro = extract_table_rows(args.metrics_folder, 'au_pro')
97 | rows_roc = extract_table_rows(args.metrics_folder, 'classification_au_roc')
98 |
99 | # Print localization result table.
100 | print("\nAU PRO (localization)")
101 | print(
102 | tabulate(
103 | rows_pro, headers=['Experiment'] + OBJECT_NAMES + ['Mean'],
104 | tablefmt='fancy_grid'))
105 |
106 | # Print classification result table.
107 | print("\nAU ROC (classification)")
108 | print(
109 | tabulate(
110 | rows_roc, headers=['Experiment'] + OBJECT_NAMES + ['Mean'],
111 | tablefmt='fancy_grid'))
112 |
113 |
114 | if __name__ == "__main__":
115 | main()
116 |
--------------------------------------------------------------------------------
/evaluation/pro_curve_util.py:
--------------------------------------------------------------------------------
1 | """Utility function that computes a PRO curve, given pairs of anomaly and ground
2 | truth maps.
3 |
4 | The PRO curve can also be integrated up to a constant integration limit.
5 | """
6 | import numpy as np
7 | from scipy.ndimage.measurements import label
8 |
9 |
10 | def compute_pro(anomaly_maps, ground_truth_maps):
11 | """Compute the PRO curve for a set of anomaly maps with corresponding ground
12 | truth maps.
13 |
14 | Args:
15 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a
16 | real-valued anomaly score at each pixel.
17 |
18 | ground_truth_maps: List of ground truth maps (2D numpy arrays) that
19 | contain binary-valued ground truth labels for each pixel.
20 | 0 indicates that a pixel is anomaly-free.
21 | 1 indicates that a pixel contains an anomaly.
22 |
23 | Returns:
24 | fprs: numpy array of false positive rates.
25 | pros: numpy array of corresponding PRO values.
26 | """
27 |
28 | print("Compute PRO curve...")
29 |
30 | # Structuring element for computing connected components.
31 | structure = np.ones((3, 3), dtype=int)
32 |
33 | num_ok_pixels = 0
34 | num_gt_regions = 0
35 |
36 | shape = (len(anomaly_maps),
37 | anomaly_maps[0].shape[0],
38 | anomaly_maps[0].shape[1])
39 | fp_changes = np.zeros(shape, dtype=np.uint32)
40 | assert shape[0] * shape[1] * shape[2] < np.iinfo(fp_changes.dtype).max, \
41 | 'Potential overflow when using np.cumsum(), consider using np.uint64.'
42 |
43 | pro_changes = np.zeros(shape, dtype=np.float64)
44 |
45 | for gt_ind, gt_map in enumerate(ground_truth_maps):
46 |
47 | # Compute the connected components in the ground truth map.
48 | labeled, n_components = label(gt_map, structure)
49 | num_gt_regions += n_components
50 |
51 | # Compute the mask that gives us all ok pixels.
52 | ok_mask = labeled == 0
53 | num_ok_pixels_in_map = np.sum(ok_mask)
54 | num_ok_pixels += num_ok_pixels_in_map
55 |
56 | # Compute by how much the FPR changes when each anomaly score is
57 | # added to the set of positives.
58 | # fp_change needs to be normalized later when we know the final value
59 | # of num_ok_pixels -> right now it is only the change in the number of
60 | # false positives
61 | fp_change = np.zeros_like(gt_map, dtype=fp_changes.dtype)
62 | fp_change[ok_mask] = 1
63 |
64 | # Compute by how much the PRO changes when each anomaly score is
65 | # added to the set of positives.
66 | # pro_change needs to be normalized later when we know the final value
67 | # of num_gt_regions.
68 | pro_change = np.zeros_like(gt_map, dtype=np.float64)
69 | for k in range(n_components):
70 | region_mask = labeled == (k + 1)
71 | region_size = np.sum(region_mask)
72 | pro_change[region_mask] = 1. / region_size
73 |
74 | fp_changes[gt_ind, :, :] = fp_change
75 | pro_changes[gt_ind, :, :] = pro_change
76 |
77 | # Flatten the numpy arrays before sorting.
78 | anomaly_scores_flat = np.array(anomaly_maps).ravel()
79 | fp_changes_flat = fp_changes.ravel()
80 | pro_changes_flat = pro_changes.ravel()
81 |
82 | # Sort all anomaly scores.
83 | print(f"Sort {len(anomaly_scores_flat)} anomaly scores...")
84 | sort_idxs = np.argsort(anomaly_scores_flat).astype(np.uint32)[::-1]
85 |
86 | # Info: np.take(a, ind, out=a) followed by b=a instead of
87 | # b=a[ind] showed to be more memory efficient.
88 | np.take(anomaly_scores_flat, sort_idxs, out=anomaly_scores_flat)
89 | anomaly_scores_sorted = anomaly_scores_flat
90 | np.take(fp_changes_flat, sort_idxs, out=fp_changes_flat)
91 | fp_changes_sorted = fp_changes_flat
92 | np.take(pro_changes_flat, sort_idxs, out=pro_changes_flat)
93 | pro_changes_sorted = pro_changes_flat
94 |
95 | del sort_idxs
96 |
97 | # Get the (FPR, PRO) curve values.
98 | np.cumsum(fp_changes_sorted, out=fp_changes_sorted)
99 | fp_changes_sorted = fp_changes_sorted.astype(np.float32, copy=False)
100 | np.divide(fp_changes_sorted, num_ok_pixels, out=fp_changes_sorted)
101 | fprs = fp_changes_sorted
102 |
103 | np.cumsum(pro_changes_sorted, out=pro_changes_sorted)
104 | np.divide(pro_changes_sorted, num_gt_regions, out=pro_changes_sorted)
105 | pros = pro_changes_sorted
106 |
107 | # Merge (FPR, PRO) points that occur together at the same threshold.
108 | # For those points, only the final (FPR, PRO) point should be kept.
109 | # That is because that point is the one that takes all changes
110 | # to the FPR and the PRO at the respective threshold into account.
111 | # -> keep_mask is True if the subsequent score is different from the
112 | # score at the respective position.
113 | # anomaly_scores_sorted = [7, 4, 4, 4, 3, 1, 1]
114 | # -> keep_mask = [T, F, F, T, T, F]
115 | keep_mask = np.append(np.diff(anomaly_scores_sorted) != 0, np.True_)
116 | del anomaly_scores_sorted
117 |
118 | fprs = fprs[keep_mask]
119 | pros = pros[keep_mask]
120 | del keep_mask
121 |
122 | # To mitigate the adding up of numerical errors during the np.cumsum calls,
123 | # make sure that the curve ends at (1, 1) and does not contain values > 1.
124 | np.clip(fprs, a_min=None, a_max=1., out=fprs)
125 | np.clip(pros, a_min=None, a_max=1., out=pros)
126 |
127 | # Make the fprs and pros start at 0 and end at 1.
128 | zero = np.array([0.])
129 | one = np.array([1.])
130 |
131 | return np.concatenate((zero, fprs, one)), np.concatenate((zero, pros, one))
132 |
133 |
134 | def main():
135 | """
136 | Compute the area under the PRO curve for a toy dataset and an algorithm
137 | that randomly assigns anomaly scores to each pixel. The integration
138 | limit can be specified.
139 | """
140 |
141 | from generic_util import trapezoid, generate_toy_dataset
142 |
143 | integration_limit = 0.3
144 |
145 | # Generate a toy dataset.
146 | anomaly_maps, ground_truth_maps = generate_toy_dataset(
147 | num_images=200, image_width=500, image_height=300, gt_size=10)
148 |
149 | # Compute the PRO curve for this dataset.
150 | all_fprs, all_pros = compute_pro(
151 | anomaly_maps=anomaly_maps,
152 | ground_truth_maps=ground_truth_maps)
153 |
154 | au_pro = trapezoid(all_fprs, all_pros, x_max=integration_limit)
155 | au_pro /= integration_limit
156 | print(f"AU-PRO (FPR limit: {integration_limit}): {au_pro}")
157 |
158 |
159 | if __name__ == "__main__":
160 | main()
161 |
--------------------------------------------------------------------------------
/evaluation/roc_curve_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions that compute a ROC curve and integrate its area from a set
3 | of anomaly maps and corresponding ground truth classification labels.
4 | """
5 | import numpy as np
6 |
7 |
8 | def compute_classification_roc(
9 | anomaly_maps,
10 | scoring_function,
11 | ground_truth_labels):
12 | """Compute the ROC curve for anomaly classification on the image level.
13 |
14 | Args:
15 | anomaly_maps: List of anomaly maps (2D numpy arrays) that contain
16 | a real-valued anomaly score at each pixel.
17 | scoring_function: Function that turns anomaly maps into a single
18 | real valued anomaly score.
19 |
20 | ground_truth_labels: List of integers that indicate the ground truth
21 | class for each input image. 0 corresponds to an anomaly-free sample
22 | while a value != 0 indicates an anomalous sample.
23 |
24 | Returns:
25 | fprs: List of false positive rates.
26 | tprs: List of correspoding true positive rates.
27 | """
28 | assert len(anomaly_maps) == len(ground_truth_labels)
29 |
30 | # Compute the anomaly score for each anomaly map.
31 | anomaly_scores = map(scoring_function, anomaly_maps)
32 | num_scores = len(anomaly_maps)
33 |
34 | # Sort samples by anomaly score. Keep track of ground truth label.
35 | sorted_samples = \
36 | sorted(zip(anomaly_scores, ground_truth_labels), key=lambda x: x[0])
37 |
38 | # Compute the number of OK and NOK samples from the ground truth.
39 | ground_truth_labels_np = np.array(ground_truth_labels)
40 | num_nok = ground_truth_labels_np[ground_truth_labels_np != 0].size
41 | num_ok = ground_truth_labels_np[ground_truth_labels_np == 0].size
42 |
43 | # Initially, every NOK sample is correctly classified as anomalous
44 | # (tpr = 1.0), and every OK sample is incorrectly classified as anomalous
45 | # (fpr = 1.0).
46 | fprs = [1.0]
47 | tprs = [1.0]
48 |
49 | # Keep track of the current number of false and true positive predictions.
50 | num_fp = num_ok
51 | num_tp = num_nok
52 |
53 | # Compute new true and false positive rates when successively increasing
54 | # the threshold.
55 | next_score = None
56 |
57 | for i, (current_score, label) in enumerate(sorted_samples):
58 |
59 | if label == 0:
60 | num_fp -= 1
61 | else:
62 | num_tp -= 1
63 |
64 | if i < num_scores - 1:
65 | next_score = sorted_samples[i + 1][0]
66 | else:
67 | next_score = None # end of list
68 |
69 | if (next_score != current_score) or (next_score is None):
70 | fprs.append(num_fp / num_ok)
71 | tprs.append(num_tp / num_nok)
72 |
73 | # Return (FPR, TPR) pairs in increasing order.
74 | fprs = fprs[::-1]
75 | tprs = tprs[::-1]
76 |
77 | return fprs, tprs
78 |
79 |
80 | def main():
81 | """
82 | Compute the area under the ROC curve for a toy dataset and an algorithm
83 | that randomly assigns anomaly scores to each image pixel.
84 | """
85 |
86 | from generic_util import trapezoid, generate_toy_dataset
87 |
88 | # Generate a toy dataset.
89 | anomaly_maps, _ = generate_toy_dataset(
90 | num_images=10000, image_width=30, image_height=30, gt_size=0)
91 |
92 | # Assign a random classification label to each image.
93 | np.random.seed(42)
94 | labels = np.random.randint(2, size=len(anomaly_maps))
95 |
96 | # Compute the ROC curve.
97 | all_fprs, all_tprs = compute_classification_roc(anomaly_maps=anomaly_maps,
98 | scoring_function=np.max,
99 | ground_truth_labels=labels)
100 |
101 | # Compute the area under the ROC curve.
102 | au_roc = trapezoid(all_fprs, all_tprs)
103 | print(f"AU-ROC: {au_roc}")
104 |
105 |
106 | if __name__ == "__main__":
107 | main()
108 |
--------------------------------------------------------------------------------
/example/000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/example/000.png
--------------------------------------------------------------------------------
/example/000_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/example/000_out.png
--------------------------------------------------------------------------------
/feature_extractors.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torchvision import models
4 | from torchvision.models import wide_resnet50_2, Wide_ResNet50_2_Weights
5 |
6 | import op_utils
7 | from op_utils import scale_features, reflect_pad
8 | from PyTorchSteerablePyramid import extract_steerable_features
9 |
10 |
11 | class Colors:
12 | def __init__(self, post_scale=False):
13 | self.post_scale = post_scale
14 |
15 | def __call__(self, image: torch.tensor) -> torch.tensor:
16 | if self.post_scale:
17 | return scale_features(image)
18 | else:
19 | return image
20 |
21 |
22 | class RandomKernels:
23 | def __init__(self, input_c=3, projections=((256, 1),), patch_size=7, device=None):
24 | self.ms_kernels = [self.build_kernels(input_c, num_proj, patch_size, device) for num_proj, _ in projections]
25 | self.scales = [scale for _, scale in projections]
26 |
27 | @staticmethod
28 | def build_kernels(c, num_proj, patch_size, device):
29 | kernels = torch.randn(num_proj, c * patch_size ** 2, device=device)
30 | kernels = kernels / torch.norm(kernels, dim=1, keepdim=True)
31 | kernels = kernels.reshape(num_proj, c, patch_size, patch_size)
32 | return kernels
33 |
34 | def __call__(self, image: torch.tensor) -> torch.tensor:
35 | parts = []
36 | for kernels, scale in zip(self.ms_kernels, self.scales):
37 | scaled_im = F.interpolate(image, scale_factor=scale, mode='bilinear')
38 | scaled_f = F.conv2d(reflect_pad(scaled_im, patch_size=kernels.shape[-1]), kernels)
39 | parts.append(F.interpolate(scaled_f, size=image.shape[-2:], mode='bilinear'))
40 | return torch.cat(parts, dim=1)
41 |
42 |
43 | class NeuralExtractor:
44 | def __init__(self, post_scale=True, device=None):
45 | self.normalization_mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(-1, 1, 1)
46 | self.normalization_std = torch.tensor([0.229, 0.224, 0.225], device=device).view(-1, 1, 1)
47 | self.post_scale = post_scale
48 | self.device = device
49 |
50 | def normalize(self, image: torch.tensor):
51 | return (image - self.normalization_mean) / self.normalization_std
52 |
53 | def post(self, features):
54 | if self.post_scale:
55 | return scale_features(features)
56 | return features
57 |
58 |
59 | class VggExtractor(NeuralExtractor):
60 | def __init__(self, layers=('conv_1', 'conv_2', 'conv_3'), post_scale=True, device=None):
61 | super(VggExtractor, self).__init__(post_scale, device)
62 | self.cnn = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.to(device).eval()
63 | self.layers = layers
64 |
65 | def get_vgg_features(self, image: torch.tensor) -> torch.tensor:
66 | out = self.normalize(image)
67 | vgg_features = []
68 | i = 1
69 | for layer in self.cnn.children():
70 | out = layer(out)
71 | if isinstance(layer, torch.nn.Conv2d) and f'conv_{i}' in self.layers:
72 | vgg_features.append(out.clone().detach())
73 | i += 1
74 |
75 | return vgg_features
76 |
77 | @staticmethod
78 | def union_features(vgg_features):
79 | hw = vgg_features[0].shape[-2:]
80 | same_size = [F.interpolate(f, size=hw, mode='bilinear') for f in vgg_features]
81 | return torch.cat(same_size, dim=1)
82 |
83 | @torch.no_grad()
84 | def __call__(self, image: torch.tensor) -> torch.tensor:
85 | vgg_features = self.get_vgg_features(image)
86 | union = self.union_features(vgg_features)
87 | return self.post(union)
88 |
89 |
90 | class WideResnetExtractor(NeuralExtractor):
91 | def __init__(self, post_scale=True, device=None):
92 | super(WideResnetExtractor, self).__init__(post_scale, device)
93 | cnn = wide_resnet50_2(weights=Wide_ResNet50_2_Weights.IMAGENET1K_V1).eval().to(device)
94 | self.model = torch.nn.Sequential(*list(cnn.children())[:6])
95 |
96 | @torch.no_grad()
97 | def __call__(self, image: torch.tensor) -> torch.tensor:
98 | features = self.model(self.normalize(image))
99 | return self.post(features)
100 |
101 |
102 | class SteerableExtractor:
103 | def __init__(self, post_scale=True, o=4, m=1, scale_factor=2):
104 | self.post_scale = post_scale
105 | self.o = o
106 | self.m = m
107 | self.scale_factor = scale_factor
108 |
109 | def __call__(self, image: torch.tensor) -> torch.tensor:
110 | features = extract_steerable_features(image, self.o, self.m, self.scale_factor)
111 | if self.post_scale:
112 | features = scale_features(features)
113 | return features
114 |
115 |
116 | class LawTextureEnergyMeasure:
117 | def __init__(self, mean_patch_size=15, device=None):
118 | self.mean_patch_size = mean_patch_size
119 | # noinspection SpellCheckingInspection
120 | lesr = torch.tensor([[1, 4, 6, 4, 1],
121 | [-1, -2, 0, 2, 1],
122 | [-1, 0, 2, 0, -1],
123 | [1, -4, 6, -4, 1]], dtype=torch.float32, device=device)
124 | outers = torch.einsum('ni,mj->nmij', lesr, lesr)
125 | self.kernels = outers.reshape(outers.shape[0] * outers.shape[1], 1, *outers.shape[-2:]) # 16 x 1 x 5 x 5
126 |
127 | def __call__(self, image: torch.tensor) -> torch.tensor:
128 | image = torch.mean(image, dim=1, keepdim=True) # Grayscale B x 1 x H x W
129 | if self.mean_patch_size != 0:
130 | image = image - op_utils.blur(image, self.mean_patch_size, sigma=20)
131 | energy_maps = F.conv2d(reflect_pad(image, patch_size=5), self.kernels) / 10
132 | features = torch.stack([
133 | energy_maps[:, 5], # E5E5
134 | energy_maps[:, 10], # S5S5
135 | energy_maps[:, 15], # R5R5
136 | (energy_maps[:, 1] + energy_maps[:, 4]) / 2.0, # L5E5 + E5L5
137 | (energy_maps[:, 2] + energy_maps[:, 8]) / 2.0, # L5S5 + S5L5
138 | (energy_maps[:, 3] + energy_maps[:, 12]) / 2.0, # L5R5 + R5L5
139 | (energy_maps[:, 6] + energy_maps[:, 9]) / 2.0, # E5S5 + S5E5
140 | (energy_maps[:, 7] + energy_maps[:, 13]) / 2.0, # E5R5 + R5E5
141 | (energy_maps[:, 11] + energy_maps[:, 14]) / 2.0, # S5R5 + R5S5
142 | ], dim=1)
143 | return features
144 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import hydra
4 | import torch
5 | from hydra.utils import instantiate
6 | from omegaconf import DictConfig, OmegaConf
7 | from tqdm import tqdm
8 |
9 | from op_utils import load_image_tensor
10 |
11 |
12 | def compute_anomaly_map(image: torch.tensor, feature_extractor: Callable, stat_method: Callable):
13 | features = feature_extractor(image)
14 | age = stat_method(features)
15 | return age
16 |
17 |
18 | def run_anomaly_detection(dataset, cfg: DictConfig):
19 | fe = instantiate(cfg.fe.feature_extractor)
20 | sc = instantiate(cfg.sc.method)
21 | for in_image_path in tqdm(dataset):
22 | image = load_image_tensor(in_image_path, cfg.image_size, device=cfg.device)
23 | anomaly_map = compute_anomaly_map(image, fe, sc)
24 |
25 | dataset.save_output(anomaly_map, in_image_path)
26 |
27 |
28 | @hydra.main(version_base=None, config_path="conf", config_name="base")
29 | def my_app(cfg: DictConfig) -> None:
30 | print(OmegaConf.to_yaml(cfg))
31 | data_manager = instantiate(cfg.dataset.data_manager)
32 | run_anomaly_detection(data_manager, cfg)
33 | data_manager.run_evaluation()
34 |
35 |
36 | if __name__ == "__main__":
37 | my_app()
38 |
--------------------------------------------------------------------------------
/op_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torchvision
4 | import cv2
5 | import numpy as np
6 |
7 | def scale_features(tensor: torch.tensor) -> torch.tensor:
8 | # tensor shape B x C x H x W
9 | tf = tensor.flatten(start_dim=2)
10 | mini = tf.min(dim=-1).values[..., None, None]
11 | maxi = tf.max(dim=-1).values[..., None, None]
12 | div = (maxi - mini + 1e-8)
13 | return (tensor - mini) / div
14 |
15 | # noinspection PyProtectedMember,PyUnresolvedReferences
16 | def get_gaussian_kernel(device, tile_size, s: float):
17 | return torchvision.transforms._functional_tensor._get_gaussian_kernel2d(tile_size, [s, s], torch.float32, device)
18 |
19 | def blur(image, kernel_size=7, sigma=None):
20 | if sigma is None:
21 | sigma = kernel_size / 4
22 | shape = image.shape
23 | im_b = image[(None,) * (4 - len(shape))]
24 | # noinspection PyUnresolvedReferences
25 | return torchvision.transforms.functional.gaussian_blur(im_b, kernel_size, sigma=sigma).view(shape)
26 |
27 |
28 | def load_image_tensor(path, image_size, device=None):
29 | input_image = cv2.imread(str(path))
30 | if image_size == "original":
31 | image_size = input_image.shape[:2]
32 | else:
33 | image_size = tuple(image_size)
34 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
35 | return F.interpolate(image_to_tensor(input_image, device=device), size=image_size, mode='bilinear')
36 |
37 |
38 | def image_to_tensor(image, device):
39 | image = torch.tensor(image, device=device, dtype=torch.float32)
40 | image = image.permute(2, 0, 1)[None] / 255
41 | return image
42 |
43 |
44 | def tensor_to_image(image: torch.tensor):
45 | out = image.detach().cpu()
46 | out = (out.squeeze(0).permute(1, 2, 0) * 255)
47 | return out.numpy().astype(np.uint8)
48 |
49 |
50 | def reflect_pad(image, patch_size=7):
51 | p = patch_size // 2
52 | return torch.nn.functional.pad(image, (p, p, p, p), mode='reflect')
53 |
--------------------------------------------------------------------------------
/outputs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/outputs/.gitkeep
--------------------------------------------------------------------------------
/sc_methods.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torchvision
4 |
5 | from op_utils import blur, scale_features, reflect_pad, get_gaussian_kernel
6 |
7 |
8 | def sc_moments(features, tile_size, powers=(1, 2, 3, 4), sigma=6.0):
9 | assert features.shape[0] == 1 # Only implemented for one image, but easy to extend...
10 | ps = features.new_tensor(list(powers))
11 | b, c, h, w = features.shape
12 | moments = torch.pow(features[:, :, None, :, :], ps[None, None, :, None, None]).reshape(b, c * len(ps), h, w)
13 | local_moments = blur(moments, kernel_size=tile_size[0], sigma=sigma)[0] # C*powers x H x W
14 |
15 | def dist(one, many):
16 | vec_arr = one[:, None, None].expand_as(many)
17 | loss = torch.mean((vec_arr - many) ** 2, dim=0)
18 | del vec_arr
19 | return loss
20 |
21 | representative = torch.mean(local_moments.flatten(start_dim=-2), dim=-1) # C*powers
22 | return dist(representative, local_moments)
23 |
24 |
25 | def sc_hist(features, tile_size, bins=10, sigma=6.0):
26 | assert features.shape[0] == 1 # Only implemented for one image, but easy to extend...
27 | scaled = scale_features(features)
28 | indices = torch.floor(0.999 * bins * scaled).type(torch.int64)
29 | one_hot = F.one_hot(indices)
30 | hist = one_hot.permute(0, 1, 4, 2, 3).flatten(start_dim=1, end_dim=2).type(torch.float32)
31 | aggregated_hist = blur(hist, tile_size[0], sigma=sigma)[0] # (C*bins) x H x W
32 | aggregated_hist = aggregated_hist.view(features.shape[1], bins, *features.shape[-2:]) # C x bins x H x W
33 |
34 | def dist(one, many):
35 | # Computes EMD distance between two 1D histograms
36 | vec_arr = one[:, None, None].expand_as(many)
37 | c1 = torch.cumsum(vec_arr, dim=0)
38 | c2 = torch.cumsum(many, dim=0)
39 | loss = torch.mean(torch.abs(c1 - c2), dim=0)
40 | del vec_arr
41 | return loss
42 |
43 | representative = torch.mean(aggregated_hist.flatten(start_dim=-2), dim=-1) # C x bins
44 | return torch.stack([dist(c_rep, c_hist) for (c_rep, c_hist) in zip(representative, aggregated_hist)]).mean(dim=0)
45 |
46 |
47 | def get_reference(reference_type, tile_size, num_ref=1):
48 | def reference_median(features):
49 | unf = F.unfold(features, tile_size, stride=tile_size)[0].T.reshape(-1, features.shape[1], *tile_size)
50 | val, _ = torch.sort(unf.view(-1, features.shape[1], tile_size[0] * tile_size[1]), dim=-1)
51 | return torch.median(val, dim=0).values # C x tile**2
52 |
53 | def reference_random(features):
54 | h, w = features.shape[-2:]
55 | h0 = torch.randint(0, h - tile_size[0], (num_ref,), device=features.device)
56 | w0 = torch.randint(0, w - tile_size[1], (num_ref,), device=features.device)
57 | grid = torch.meshgrid(torch.arange(tile_size[0], device=features.device),
58 | torch.arange(tile_size[1], device=features.device), indexing='ij')
59 |
60 | x_ind = (grid[0].reshape(-1)[None] + h0[:, None]).view(-1)
61 | y_ind = (grid[1].reshape(-1)[None] + w0[:, None]).view(-1)
62 | refs = features[..., x_ind, y_ind].reshape(*features.shape[:-2], num_ref, -1)
63 | refs = torch.sort(refs[0].permute(1, 0, 2), dim=-1).values # num_ref x C x tile**2
64 | assert num_ref == 1 # Current public code only allows 1 (random) reference
65 | return refs[0]
66 |
67 | if reference_type == 'median':
68 | return reference_median
69 | if reference_type == 'random':
70 | return reference_random
71 | else:
72 | raise ValueError()
73 |
74 |
75 | class ScSWW:
76 | def __init__(self, tile_size, chunk_size=8, sigma=6.0, reference_selection='median'):
77 | self.tile_size = tuple(tile_size)
78 | self.chunk_size = chunk_size
79 | self.sigma = sigma
80 | self.ref_selection = get_reference(reference_selection, self.tile_size)
81 | self.gaussian_kernel = None
82 |
83 | def generate_all_sets(self, features):
84 | b, c, h, w = features.shape
85 | padded = reflect_pad(features, self.tile_size[0])
86 | unf = torch.nn.functional.unfold(padded, (self.tile_size[0], padded.shape[-1]), stride=(1, 1))
87 | unf = unf[0].T.reshape(h, c, self.tile_size[0], padded.shape[-1])
88 | for i in range(0, len(unf), self.chunk_size):
89 | chunk = unf[i:i + self.chunk_size]
90 | unf_2 = torch.nn.functional.unfold(chunk, self.tile_size, stride=(1, 1))
91 | unf_2 = unf_2.transpose(1, 2).reshape(chunk.shape[0], w, c, -1)
92 | yield unf_2
93 |
94 | def __call__(self, features):
95 | r_set = self.ref_selection(features)
96 | generator = self.generate_all_sets(features)
97 | if self.gaussian_kernel is None:
98 | self.gaussian_kernel = get_gaussian_kernel(features.device, self.tile_size, self.sigma).reshape(-1)
99 | parts = []
100 | for f_set in generator:
101 | fvalues, ind = torch.sort(f_set, dim=-1) # h x W x C x tile**2
102 | vec_arr = r_set[None, None].expand_as(fvalues)
103 | weights = torch.gather(self.gaussian_kernel.expand_as(fvalues), dim=-1, index=ind)
104 | loss = (F.l1_loss(fvalues, vec_arr, reduction='none') * weights).sum(dim=-1).mean(dim=-1)
105 | del vec_arr
106 | parts.append(loss)
107 | p = torch.cat(parts, dim=0)
108 | return p
109 |
110 |
111 | class ScFCA(ScSWW):
112 | def __init__(self, tile_size, chunk_size=8, sigma_p=3.0, reference_selection='median', k_s=5, sigma_s=1.0):
113 | super(ScFCA, self).__init__(tile_size, chunk_size, sigma_p, reference_selection)
114 | assert tile_size[0] == tile_size[1] # Only implemented for square patches
115 | self.p_size = (tile_size[0] // 2)
116 | if sigma_s is not None:
117 | self.local_blur = torchvision.transforms.GaussianBlur(k_s, sigma=sigma_s)
118 | else:
119 | self.local_blur = None
120 |
121 | def __call__(self, features):
122 | r_set = self.ref_selection(features)
123 | wp = features.shape[-1] + 2 * self.p_size
124 | generator = self.generate_all_sets(features)
125 | if self.gaussian_kernel is None:
126 | self.gaussian_kernel = get_gaussian_kernel(features.device, self.tile_size, self.sigma).reshape(-1)
127 | parts = []
128 | for f_set in generator:
129 | fvalues, ind = torch.sort(f_set, dim=-1) # h x W x C x tile**2
130 | vec_arr = r_set[None, None].expand_as(fvalues)
131 | diff = F.l1_loss(fvalues, vec_arr, reduction='none')
132 | diff_re = torch.gather(diff, dim=-1, index=torch.argsort(ind)).mean(dim=2, keepdim=True) # h x W x 1 x t**2
133 | if self.local_blur is not None:
134 | diff_re = self.local_blur(diff_re.view(-1, 1, *self.tile_size)).reshape(diff_re.shape)
135 | diff_re = diff_re * self.gaussian_kernel # h x W x 1 x tile**2
136 | diff_re = diff_re.permute(0, 2, 3, 1).reshape(f_set.shape[0], -1, features.shape[-1]) # h x 1*tile**2 x W
137 | c_fold = F.fold(diff_re, (self.tile_size[0], wp), kernel_size=self.tile_size) # h x C x tile x WP
138 | parts.append(c_fold)
139 | combined = torch.cat(parts, dim=0) # H x 1 x tile x WP
140 | folded = F.fold(combined.permute(1, 2, 3, 0).reshape(1, -1, features.shape[-2]),
141 | output_size=(wp, wp), kernel_size=(self.tile_size[0], wp))
142 | folded = folded[0, 0, self.p_size:-self.p_size, self.p_size:-self.p_size] # Remove extra pad -> 1 x 1 x H x W
143 |
144 | return folded
145 |
146 |
147 | def sc_aota(features, tile_size, sigma=100.0, k=400):
148 | def one_to_many_dist(values, dist_f):
149 | distances = []
150 | for i in range(values.shape[-2]):
151 | for j in range(values.shape[-1]):
152 | distances.append(dist_f(values[:, i, j], values))
153 | return values.new_tensor(distances).reshape(values.shape[-2:])
154 |
155 | local_moments = blur(features, kernel_size=tile_size[0], sigma=sigma)[0] # C x H x W
156 |
157 | def dist(one, many):
158 | vec_arr = one[:, None, None].expand_as(many)
159 | loss = ((vec_arr - many) ** 2).sum(dim=0).view(-1)
160 | del vec_arr
161 | loss = torch.topk(loss, k=k, largest=False).values.mean()
162 | return loss
163 |
164 | result = one_to_many_dist(local_moments, dist_f=dist)
165 | return blur(F.interpolate(result[None, None], (320, 320)), kernel_size=25, sigma=4.0)[0, 0]
166 |
--------------------------------------------------------------------------------
/static/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TArdelean/AnomalyLocalizationFCA/afe8326790de30f33cf0e91db7c099bdc6608a84/static/teaser.png
--------------------------------------------------------------------------------