├── README.md
├── image
├── bent_000_origin.png
├── bent_000_rec.png
├── bent_000_residual.png
└── bent_000_visual.png
├── network.py
├── options.py
├── test.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # AutoEncoder with SSIM loss
2 |
3 | This is a third party implementation of the paper **Improving Unsupervised Defect Segmentation by Applying Structural Similarity to Autoencoders**.
4 |
5 |
6 |   
7 |
8 |
9 | ## Requirement
10 | `tensorflow==2.2.0`
11 | `skimage`
12 |
13 | ## Datasets
14 | MVTec AD datasets https://www.mvtec.com/company/research/datasets/mvtec-ad/
15 |
16 | ## Code examples
17 |
18 | ### Step 1. Set the *DATASET_PATH* variable.
19 |
20 | Set the [DATASET_PATH](options.py#L046) to the root path of the downloaded MVTec AD dataset.
21 |
22 | ### Step 2. Train **SSIM-AE** and Test.
23 |
24 | - **bottle** object
25 | ```bash
26 | python train.py --name bottle --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate 0.
27 | python test.py --name bottle --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --bg_mask W
28 | ```
29 | - **cable** object
30 | ```bash
31 | python train.py --name cable --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate 0. --p_horizonal_flip 0. --p_vertical_flip 0.
32 | python test.py --name cable --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500
33 | ```
34 | - **capsule** object
35 | ```bash
36 | python train.py --name capsule --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate 0. --p_horizonal_flip 0. --p_vertical_flip 0.
37 | python test.py --name capsule --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --bg_mask W
38 | ```
39 | - **carpet** texture
40 | ```bash
41 | python train.py --name carpet --loss ssim_loss --im_resize 512 --patch_size 128 --z_dim 100 --do_aug --rotate_angle_vari 10
42 | python test.py --name carpet --loss ssim_loss --im_resize 512 --patch_size 128 --z_dim 100
43 | ```
44 | - **grid** texture
45 | ```bash
46 | python train.py --name grid --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100 --grayscale --do_aug
47 | python test.py --name grid --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100 --grayscale
48 | ```
49 | - **hazelnut** object
50 | ```bash
51 | python train.py --name hazelnut --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate_crop 0.
52 | python test.py --name hazelnut --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --bg_mask B
53 | ```
54 | - **leather** texture
55 | ```bash
56 | python train.py --name leather --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100 --do_aug
57 | python test.py --name leather --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100
58 | ```
59 | - **metal_nut** object
60 | ```bash
61 | python train.py --name metal_nut --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate_crop 0. --p_horizonal_flip 0. --p_vertical_flip 0.
62 | python test.py --name metal_nut --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --bg_mask B
63 | ```
64 | - **pill** object
65 | ```bash
66 | python train.py --name pill --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate 0. --p_horizonal_flip 0. --p_vertical_flip 0.
67 | python test.py --name pill --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --bg_mask B
68 | ```
69 | - **screw** object
70 | ```bash
71 | python train.py --name screw --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --grayscale --do_aug --p_rotate 0.
72 | python test.py --name screw --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --grayscale --bg_mask W
73 | ```
74 | - **tile** texture
75 | ```bash
76 | python train.py --name tile --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100 --do_aug
77 | python test.py --name tile --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100
78 | ```
79 | - **toothbrush** object
80 | ```bash
81 | python train.py --name toothbrush --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate 0. --p_vertical_flip 0.
82 | python test.py --name toothbrush --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500
83 | ```
84 | - **transistor** object
85 | ```bash
86 | python train.py --name transistor --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate 0. --p_vertical_flip 0.
87 | python test.py --name transistor --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500
88 | ```
89 | - **wood** texture
90 | ```bash
91 | python train.py --name wood --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100 --do_aug --rotate_angle_vari 15
92 | python test.py --name wood --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100
93 | ```
94 | - **zipper** object
95 | ```bash
96 | python train.py --name zipper --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --grayscale --do_aug --p_rotate 0.
97 | python test.py --name zipper --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --grayscale
98 | ```
99 |
100 | ## Overview of Results
101 |
102 | **Classification**
103 | During test, I simply classify a test image as defect if there is any anomalous response on the residual map. It is strict for anomaly-free images, resulting in relatively lower accuracy in the `ok` column shown as below.
104 | Please note that the **threshold** makes a big difference to the outcome, which should be carefully selected.
105 |
106 |
107 |
108 | |
109 | ok |
110 | nok |
111 | average |
112 |
113 |
114 | bottle |
115 | 90.0 |
116 | 98.4 |
117 | 96.4 |
118 |
119 |
120 | cable |
121 | 0.0 |
122 | 45.7 |
123 | 28.0 |
124 |
125 |
126 | capsule |
127 | 34.8 |
128 | 89.6 |
129 | 78.0 |
130 |
131 |
132 | carpet |
133 | 42.9 |
134 | 98.9 |
135 | 88.9 |
136 |
137 |
138 | grid |
139 | 100 |
140 | 94.7 |
141 | 96.2 |
142 |
143 |
144 | hazelnut |
145 | 55.0 |
146 | 98.6 |
147 | 82.7 |
148 |
149 |
150 | leather |
151 | 71.9 |
152 | 92.4 |
153 | 87.1 |
154 |
155 |
156 | metal nut |
157 | 22.7 |
158 | 67.7 |
159 | 59.1 |
160 |
161 |
162 | pill |
163 | 11.5 |
164 | 75.9 |
165 | 65.9 |
166 |
167 |
168 | screw |
169 | 0.5 |
170 | 90.0 |
171 | 68.1 |
172 |
173 |
174 | tile |
175 | 100.0 |
176 | 3.6 |
177 | 30.8 |
178 |
179 |
180 | toothbrush |
181 | 83.3 |
182 | 100 |
183 | 95.2 |
184 |
185 |
186 | transistor |
187 | 23.3 |
188 | 97.5 |
189 | 53.0 |
190 |
191 |
192 | wood |
193 | 89.5 |
194 | 76.7 |
195 | 79.7 |
196 |
197 |
198 | zipper |
199 | 68.8 |
200 | 81.5 |
201 | 78.8 |
202 |
203 |
204 |
205 | *SSIM loss, 200 epochs, different threshold
206 |
207 | ## Discussion
208 | - **SSIM + L1 metrics**
209 | Since SSIM is a measure of similarity only between grayscale images, it cannot handle color defect in some cases. So here I use SSIM + L1 distance for anomaly segmentation.
210 | - **VAE**
211 | I have tried VAE, observing no performances improvements.
212 | - **InstanceNorm**
213 | I have also tried adding the IN layer for accelerating convergence, but the droplet artifact appears in some cases. It is also mentioned and discussed in **StyleGAN-2** paper.
214 |
215 | ## Supplementary materials
216 | My notes https://www.yuque.com/books/share/8c7613f7-7571-4bfa-865a-689de3763c59?#
217 | password `ixgg`
218 |
219 | ## References
220 | @inproceedings{inproceedings,
221 | author = {Bergmann, Paul and Löwe, Sindy and Fauser, Michael and Sattlegger, David and Steger, Carsten},
222 | year = {2019},
223 | month = {01},
224 | pages = {372-380},
225 | title = {Improving Unsupervised Defect Segmentation by Applying Structural Similarity to Autoencoders},
226 | doi = {10.5220/0007364503720380}
227 | }
228 |
229 | Paul Bergmann, Michael Fauser, David Sattlegger, Carsten Steger. MVTec AD - A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection; in: IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2019
230 |
231 |
232 |
--------------------------------------------------------------------------------
/image/bent_000_origin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/plutoyuxie/AutoEncoder-SSIM-for-unsupervised-anomaly-detection-/1bd6651a77d796445156d2a3e08d11d805079367/image/bent_000_origin.png
--------------------------------------------------------------------------------
/image/bent_000_rec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/plutoyuxie/AutoEncoder-SSIM-for-unsupervised-anomaly-detection-/1bd6651a77d796445156d2a3e08d11d805079367/image/bent_000_rec.png
--------------------------------------------------------------------------------
/image/bent_000_residual.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/plutoyuxie/AutoEncoder-SSIM-for-unsupervised-anomaly-detection-/1bd6651a77d796445156d2a3e08d11d805079367/image/bent_000_residual.png
--------------------------------------------------------------------------------
/image/bent_000_visual.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/plutoyuxie/AutoEncoder-SSIM-for-unsupervised-anomaly-detection-/1bd6651a77d796445156d2a3e08d11d805079367/image/bent_000_visual.png
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, LeakyReLU
2 | from tensorflow.keras.models import Model
3 |
4 |
5 | def AutoEncoder(cfg):
6 | input_img = Input(shape=(cfg.patch_size, cfg.patch_size, cfg.input_channel))
7 |
8 | h = Conv2D(cfg.flc, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(input_img)
9 | h = Conv2D(cfg.flc, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h)
10 | if cfg.patch_size==256:
11 | h = Conv2D(cfg.flc, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h)
12 | h = Conv2D(cfg.flc, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h)
13 | h = Conv2D(cfg.flc*2, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h)
14 | h = Conv2D(cfg.flc*2, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h)
15 | h = Conv2D(cfg.flc*4, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h)
16 | h = Conv2D(cfg.flc*2, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h)
17 | h = Conv2D(cfg.flc, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h)
18 | encoded = Conv2D(cfg.z_dim, (8, 8), strides=1, activation='linear', padding='valid')(h)
19 |
20 | h = Conv2DTranspose(cfg.flc, (8, 8), strides=1, activation=LeakyReLU(alpha=0.2), padding='valid')(encoded)
21 | h = Conv2D(cfg.flc*2, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h)
22 | h = Conv2D(cfg.flc*4, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h)
23 | h = Conv2DTranspose(cfg.flc*2, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h)
24 | h = Conv2D(cfg.flc*2, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h)
25 | h = Conv2DTranspose(cfg.flc, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h)
26 | h = Conv2D(cfg.flc, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h)
27 | h = Conv2DTranspose(cfg.flc, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h)
28 | if cfg.patch_size==256:
29 | h = Conv2DTranspose(cfg.flc, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h)
30 |
31 | decoded = Conv2DTranspose(cfg.input_channel, (4, 4), strides=2, activation='sigmoid', padding='same')(h)
32 |
33 | return Model(input_img, decoded)
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | class Options():
5 | def __init__(self):
6 | self.parser = argparse.ArgumentParser()
7 |
8 | self.parser.add_argument('--name', type=str, default='leather')
9 | self.parser.add_argument('--train_data_dir', type=str, default=None)
10 | self.parser.add_argument('--test_dir', type=str, default=None)
11 | self.parser.add_argument('--sub_folder', type=list, nargs='*', default=None)
12 | self.parser.add_argument('--do_aug', action='store_true', help='whether to do data augmentation before training')
13 | self.parser.add_argument('--aug_dir', type=str, default=None)
14 | self.parser.add_argument('--chechpoint_dir', type=str, default=None)
15 | self.parser.add_argument('--save_dir', type=str, default=None)
16 |
17 | self.parser.add_argument('--augment_num', type=int, default=10000)
18 | self.parser.add_argument('--im_resize', type=int, default=256, help='scale images to this size')
19 | self.parser.add_argument('--patch_size', type=int, default=128, help='then crop to this size')
20 | self.parser.add_argument("--grayscale", action='store_true', help='color or grayscale input image')
21 | self.parser.add_argument('--p_rotate', type=float, default=0.3, help='probability to do image rotation')
22 | self.parser.add_argument('--rotate_angle_vari', type=float, default=45.0, help='rotate image between [-angle, +angle]')
23 | self.parser.add_argument('--p_rotate_crop', type=float, default=1.0, help='probability to crop inner rotated image')
24 | self.parser.add_argument('--p_horizonal_flip', type=float, default=0.3, help='probability to do horizonal flip')
25 | self.parser.add_argument('--p_vertical_flip', type=float, default=0.3, help='probability to do vertical flip')
26 |
27 | self.parser.add_argument('--z_dim', type=int, default=100, help='dimension of the latent space vector')
28 | self.parser.add_argument('--flc', type=int, default=32, help='number of the first hidden layer channels')
29 |
30 | self.parser.add_argument('--epochs', type=int, default=200, help='maximum training epochs')
31 | self.parser.add_argument('--batch_size', type=int, default=128)
32 | self.parser.add_argument('--loss', type=str, default='ssim_loss', help='loss type in [ssim_loss, ssim_l1_loss, l2_loss]')
33 | self.parser.add_argument('--weight', type=int, default=1, help='weight of the l1_loss item if using ssim_l1_loss')
34 | self.parser.add_argument('--lr', type=float, default=2e-4, help='learning rate of Adam')
35 | self.parser.add_argument('--decay', type=float, default=1e-5, help='decay of Adam')
36 |
37 |
38 | self.parser.add_argument('--weight_file', type=str, default=None, help='if set None, the latest weight file will be automatically selected')
39 | self.parser.add_argument('--stride', type=int, default=32, help='step length of the sliding window')
40 | self.parser.add_argument('--ssim_threshold', type=float, default=None, help='ssim threshold for testing')
41 | self.parser.add_argument('--l1_threshold', type=float, default=None, help='l1 threshold for testing')
42 | self.parser.add_argument('--percent', type=float, default=98.0, help='for estimating threshold based on valid positive samples')
43 | self.parser.add_argument('--bg_mask', type=str, default=None, help='background mask, B means black, W means white')
44 |
45 | def parse(self):
46 | DATASET_PATH = 'D:/user/dataset/mvtec_anomaly_detection'
47 | self.opt = self.parser.parse_args()
48 |
49 | if not self.opt.train_data_dir:
50 | self.opt.train_data_dir = DATASET_PATH+'/'+self.opt.name+'/train/good'
51 | if not self.opt.test_dir:
52 | self.opt.test_dir = DATASET_PATH+'/'+self.opt.name+'/test'
53 | if not self.opt.sub_folder:
54 | self.opt.sub_folder = os.listdir(self.opt.test_dir)
55 | if not self.opt.aug_dir:
56 | self.opt.aug_dir = './train_patches/'+self.opt.name
57 | if not self.opt.chechpoint_dir:
58 | self.opt.chechpoint_dir = './results/'+self.opt.name+'/chechpoints/'+self.opt.loss
59 | if not self.opt.save_dir:
60 | self.opt.save_dir = './results/'+self.opt.name+'/reconst/ssim_l1_metric_'+self.opt.loss
61 |
62 | if not os.path.exists(self.opt.chechpoint_dir):
63 | os.makedirs(self.opt.chechpoint_dir)
64 | if not os.path.exists(self.opt.aug_dir):
65 | os.makedirs(self.opt.aug_dir)
66 | if not os.path.exists(self.opt.save_dir):
67 | os.makedirs(self.opt.save_dir)
68 |
69 | self.opt.input_channel = 1 if self.opt.grayscale else 3
70 | self.opt.p_crop = 1 if self.opt.patch_size != self.opt.im_resize else 0
71 | self.opt.mask_size = self.opt.patch_size if self.opt.im_resize - self.opt.patch_size < self.opt.stride else self.opt.im_resize
72 |
73 | return self.opt
74 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage.metrics import structural_similarity as ssim
3 | from skimage import morphology
4 | from glob import glob
5 | import cv2
6 | import os
7 |
8 | from utils import read_img, get_patch, patch2img, set_img_color, bg_mask
9 | from network import AutoEncoder
10 | from options import Options
11 |
12 |
13 | cfg = Options().parse()
14 |
15 | # network
16 | autoencoder = AutoEncoder(cfg)
17 |
18 | if cfg.weight_file:
19 | autoencoder.load_weights(cfg.chechpoint_dir + '/' + cfg.weight_file)
20 | else:
21 | file_list = os.listdir(cfg.chechpoint_dir)
22 | latest_epoch = max([int(i.split('-')[0]) for i in file_list if 'hdf5' in i])
23 | print('load latest weight file: ', latest_epoch)
24 | autoencoder.load_weights(glob(cfg.chechpoint_dir + '/' + str(latest_epoch) + '*.hdf5')[0])
25 | #autoencoder.summary()
26 |
27 | def get_residual_map(img_path, cfg):
28 | test_img = read_img(img_path, cfg.grayscale)
29 |
30 | if test_img.shape[:2] != (cfg.im_resize, cfg.im_resize):
31 | test_img = cv2.resize(test_img, (cfg.im_resize, cfg.im_resize))
32 | if cfg.im_resize != cfg.mask_size:
33 | tmp = (cfg.im_resize - cfg.mask_size)//2
34 | test_img = test_img[tmp:tmp+cfg.mask_size, tmp:tmp+cfg.mask_size]
35 |
36 | test_img_ = test_img / 255.
37 |
38 | if test_img.shape[:2] == (cfg.patch_size, cfg.patch_size):
39 | test_img_ = np.expand_dims(test_img_, 0)
40 | decoded_img = autoencoder.predict(test_img_)
41 | else:
42 | patches = get_patch(test_img_, cfg.patch_size, cfg.stride)
43 | patches = autoencoder.predict(patches)
44 | decoded_img = patch2img(patches, cfg.im_resize, cfg.patch_size, cfg.stride)
45 |
46 | rec_img = np.reshape((decoded_img * 255.).astype('uint8'), test_img.shape)
47 |
48 | if cfg.grayscale:
49 | ssim_residual_map = 1 - ssim(test_img, rec_img, win_size=11, full=True)[1]
50 | l1_residual_map = np.abs(test_img / 255. - rec_img / 255.)
51 | else:
52 | ssim_residual_map = ssim(test_img, rec_img, win_size=11, full=True, multichannel=True)[1]
53 | ssim_residual_map = 1 - np.mean(ssim_residual_map, axis=2)
54 | l1_residual_map = np.mean(np.abs(test_img / 255. - rec_img / 255.), axis=2)
55 |
56 | return test_img, rec_img, ssim_residual_map, l1_residual_map
57 |
58 |
59 | def get_threshold(cfg):
60 | print('estimating threshold...')
61 | valid_good_list = glob(cfg.train_data_dir + '/*png')
62 | num_valid_data = int(np.ceil(len(valid_good_list) * 0.2))
63 | total_rec_ssim, total_rec_l1 = [], []
64 | for img_path in valid_good_list[-num_valid_data:]:
65 | _, _, ssim_residual_map, l1_residual_map = get_residual_map(img_path, cfg)
66 | total_rec_ssim.append(ssim_residual_map)
67 | total_rec_l1.append(l1_residual_map)
68 | total_rec_ssim = np.array(total_rec_ssim)
69 | total_rec_l1 = np.array(total_rec_l1)
70 | ssim_threshold = float(np.percentile(total_rec_ssim, [cfg.percent]))
71 | l1_threshold = float(np.percentile(total_rec_l1, [cfg.percent]))
72 | print('ssim_threshold: %f, l1_threshold: %f' %(ssim_threshold, l1_threshold))
73 | if not cfg.ssim_threshold:
74 | cfg.ssim_threshold = ssim_threshold
75 | if not cfg.l1_threshold:
76 | cfg.l1_threshold = l1_threshold
77 |
78 |
79 | def get_depressing_mask(cfg):
80 | depr_mask = np.ones((cfg.mask_size, cfg.mask_size)) * 0.2
81 | depr_mask[5:cfg.mask_size-5, 5:cfg.mask_size-5] = 1
82 | cfg.depr_mask = depr_mask
83 |
84 |
85 | def get_results(file_list, cfg):
86 | for img_path in file_list:
87 | img_name = img_path.split('\\')[-1][:-4]
88 | c = '' if not cfg.sub_folder else k
89 | test_img, rec_img, ssim_residual_map, l1_residual_map = get_residual_map(img_path, cfg)
90 |
91 | ssim_residual_map *= cfg.depr_mask
92 | if 'ssim' in cfg.loss:
93 | l1_residual_map *= cfg.depr_mask
94 |
95 | mask = np.zeros((cfg.mask_size, cfg.mask_size))
96 | mask[ssim_residual_map > cfg.ssim_threshold] = 1
97 | mask[l1_residual_map > cfg.l1_threshold] = 1
98 | if cfg.bg_mask == 'B':
99 | bg_m = bg_mask(test_img.copy(), 50, cv2.THRESH_BINARY, cfg.grayscale)
100 | mask *= bg_m
101 | elif cfg.bg_mask == 'W':
102 | bg_m = bg_mask(test_img.copy(), 200, cv2.THRESH_BINARY_INV, cfg.grayscale)
103 | mask *= bg_m
104 | kernel = morphology.disk(4)
105 | mask = morphology.opening(mask, kernel)
106 | mask *= 255
107 |
108 | vis_img = set_img_color(test_img.copy(), mask, weight_foreground=0.3, grayscale=cfg.grayscale)
109 |
110 | cv2.imwrite(cfg.save_dir+'/'+c+'_'+img_name+'_residual.png', mask)
111 | cv2.imwrite(cfg.save_dir+'/'+c+'_'+img_name+'_origin.png', test_img)
112 | cv2.imwrite(cfg.save_dir+'/'+c+'_'+img_name+'_rec.png', rec_img)
113 | cv2.imwrite(cfg.save_dir+'/'+c+'_'+img_name+'_visual.png', vis_img)
114 |
115 |
116 | if __name__ == '__main__':
117 | if not cfg.ssim_threshold or not cfg.l1_threshold:
118 | get_threshold(cfg)
119 |
120 | get_depressing_mask(cfg)
121 |
122 | if cfg.sub_folder:
123 | for k in cfg.sub_folder:
124 | test_list = glob(cfg.test_dir+'/'+k+'/*')
125 | get_results(test_list, cfg)
126 | else:
127 | test_list = glob(cfg.test_dir+'/*')
128 | get_results(test_list, cfg)
129 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
2 | from tensorflow.keras.utils import Sequence
3 | from tensorflow.keras.optimizers import Adam
4 | import tensorflow as tf
5 | import numpy as np
6 | import cv2
7 | from glob import glob
8 | import os
9 |
10 | from network import AutoEncoder
11 | from utils import generate_image_list, augment_images, read_img
12 | from options import Options
13 |
14 |
15 | cfg = Options().parse()
16 |
17 | class data_flow(Sequence):
18 | def __init__(self, filenames, batch_size, grayscale):
19 | self.filenames = filenames
20 | self.batch_size = batch_size
21 | self.grayscale = grayscale
22 |
23 | def __len__(self):
24 | return int(np.ceil(len(self.filenames) / float(self.batch_size)))
25 |
26 | def __getitem__(self, idx):
27 | batch_x = self.filenames[idx * self.batch_size:(idx + 1) * self.batch_size]
28 | batch_x = np.array([read_img(filename, self.grayscale) for filename in batch_x])
29 |
30 | batch_x = batch_x / 255.
31 | return batch_x, batch_x
32 |
33 | # data
34 | if cfg.aug_dir and cfg.do_aug:
35 | img_list = generate_image_list(cfg)
36 | augment_images(img_list, cfg)
37 |
38 | dataset_dir = cfg.aug_dir if cfg.aug_dir else cfg.train_data_dir
39 | file_list = glob(dataset_dir + '/*')
40 | num_valid_data = int(np.ceil(len(file_list) * 0.2))
41 | data_train = data_flow(file_list[:-num_valid_data], cfg.batch_size, cfg.grayscale)
42 | data_valid = data_flow(file_list[-num_valid_data:], cfg.batch_size, cfg.grayscale)
43 |
44 | # loss
45 | if cfg.loss == 'ssim_loss':
46 |
47 | @tf.function
48 | def ssim_loss(gt, y_pred, max_val=1.0):
49 | return 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val))
50 |
51 | loss = ssim_loss
52 | elif cfg.loss == 'ssim_l1_loss':
53 |
54 | @tf.function
55 | def ssim_l1_loss(gt, y_pred, max_val=1.0):
56 | ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val))
57 | L1 = tf.reduce_mean(tf.abs(gt - y_pred))
58 | return ssim_loss + L1 * cfg.weight
59 |
60 | loss = ssim_l1_loss
61 | else:
62 | loss = 'mse'
63 |
64 | # network
65 | autoencoder = AutoEncoder(cfg)
66 | optimizer = Adam(lr=cfg.lr, decay=cfg.decay)
67 | autoencoder.compile(optimizer=optimizer, loss=loss, metrics=['mae'] if loss == 'mse' else ['mse'])
68 | autoencoder.summary()
69 |
70 | earlystopping = EarlyStopping(patience=20)
71 |
72 | checkpoint = ModelCheckpoint(os.path.join(cfg.chechpoint_dir, '{epoch:02d}-{val_loss:.5f}.hdf5'), save_best_only=True,
73 | period=1, mode='auto', verbose=1, save_weights_only=True)
74 |
75 | autoencoder.fit(data_train, epochs=cfg.epochs, validation_data=data_valid, callbacks=[checkpoint, earlystopping])
76 |
77 | # show reconstructed images
78 | decoded_imgs = autoencoder.predict(data_valid)
79 | n = len(decoded_imgs)
80 | save_snapshot_dir = cfg.chechpoint_dir +'/snapshot/'
81 | if not os.path.exists(save_snapshot_dir):
82 | os.makedirs(save_snapshot_dir)
83 | for i in range(n):
84 | cv2.imwrite(save_snapshot_dir+str(i)+'_rec_valid.png', (decoded_imgs[i]*255).astype('uint8'))
85 |
86 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import cv2
4 | import random
5 | import os
6 |
7 |
8 | def read_img(img_path, grayscale):
9 | if grayscale:
10 | im = cv2.imread(img_path, 0)
11 | else:
12 | im = cv2.imread(img_path)
13 | return im
14 |
15 |
16 | def random_crop(image, new_size):
17 | h, w = image.shape[:2]
18 | y = np.random.randint(0, h - new_size)
19 | x = np.random.randint(0, w - new_size)
20 | image = image[y:y+new_size, x:x+new_size]
21 | return image
22 |
23 |
24 | def rotate_image(img, angle, crop):
25 | h, w = img.shape[:2]
26 | angle %= 360
27 | M_rotate = cv2.getRotationMatrix2D((w/2, h/2), angle, 1)
28 | img_rotated = cv2.warpAffine(img, M_rotate, (w, h))
29 | if crop:
30 | angle_crop = angle % 180
31 | if angle_crop > 90:
32 | angle_crop = 180 - angle_crop
33 | theta = angle_crop * np.pi / 180.0
34 | hw_ratio = float(h) / float(w)
35 | tan_theta = np.tan(theta)
36 | numerator = np.cos(theta) + np.sin(theta) * tan_theta
37 | r = hw_ratio if h > w else 1 / hw_ratio
38 | denominator = r * tan_theta + 1
39 | crop_mult = numerator / denominator
40 | w_crop = int(round(crop_mult*w))
41 | h_crop = int(round(crop_mult*h))
42 | x0 = int((w-w_crop)/2)
43 | y0 = int((h-h_crop)/2)
44 | img_rotated = img_rotated[y0:y0+h_crop, x0:x0+w_crop]
45 | return img_rotated
46 |
47 |
48 | def random_rotate(img, angle_vari, p_crop):
49 | angle = np.random.uniform(-angle_vari, angle_vari)
50 | crop = False if np.random.random() > p_crop else True
51 | return rotate_image(img, angle, crop)
52 |
53 |
54 | def generate_image_list(args):
55 | filenames = os.listdir(args.train_data_dir)
56 | num_imgs = len(filenames)
57 | num_ave_aug = int(math.floor(args.augment_num/num_imgs))
58 | rem = args.augment_num - num_ave_aug*num_imgs
59 | lucky_seq = [True]*rem + [False]*(num_imgs-rem)
60 | random.shuffle(lucky_seq)
61 |
62 | img_list = [
63 | (os.sep.join([args.train_data_dir, filename]), num_ave_aug+1 if lucky else num_ave_aug)
64 | for filename, lucky in zip(filenames, lucky_seq)
65 | ]
66 |
67 | return img_list
68 |
69 |
70 | def augment_images(filelist, args):
71 | for filepath, n in filelist:
72 | img = read_img(filepath, args.grayscale)
73 | if img.shape[:2] != (args.im_resize, args.im_resize):
74 | img = cv2.resize(img, (args.im_resize, args.im_resize))
75 | filename = filepath.split(os.sep)[-1]
76 | dot_pos = filename.rfind('.')
77 | imgname = filename[:dot_pos]
78 | ext = filename[dot_pos:]
79 |
80 | print('Augmenting {} ...'.format(filename))
81 | for i in range(n):
82 | img_varied = img.copy()
83 | varied_imgname = '{}_{:0>3d}_'.format(imgname, i)
84 |
85 | if random.random() < args.p_rotate:
86 | img_varied_ = random_rotate(
87 | img_varied,
88 | args.rotate_angle_vari,
89 | args.p_rotate_crop)
90 | if img_varied_.shape[0] >= args.patch_size and img_varied_.shape[1] >= args.patch_size:
91 | img_varied = img_varied_
92 | varied_imgname += 'r'
93 |
94 | if random.random() < args.p_crop:
95 | img_varied = random_crop(
96 | img_varied,
97 | args.patch_size)
98 | varied_imgname += 'c'
99 |
100 | if random.random() < args.p_horizonal_flip:
101 | img_varied = cv2.flip(img_varied, 1)
102 | varied_imgname += 'h'
103 |
104 | if random.random() < args.p_vertical_flip:
105 | img_varied = cv2.flip(img_varied, 0)
106 | varied_imgname += 'v'
107 |
108 | output_filepath = os.sep.join([
109 | args.aug_dir,
110 | '{}{}'.format(varied_imgname, ext)])
111 | cv2.imwrite(output_filepath, img_varied)
112 |
113 |
114 | def get_patch(image, new_size, stride):
115 | h, w = image.shape[:2]
116 | i, j = new_size, new_size
117 | patch = []
118 | while i <= h:
119 | while j <= w:
120 | patch.append(image[i - new_size:i, j - new_size:j])
121 | j += stride
122 | j = new_size
123 | i += stride
124 | return np.array(patch)
125 |
126 |
127 | def patch2img(patches, im_size, patch_size, stride):
128 | img = np.zeros((im_size, im_size, patches.shape[3]+1))
129 | i, j = patch_size, patch_size
130 | k = 0
131 | while i <= im_size:
132 | while j <= im_size:
133 | img[i - patch_size:i, j - patch_size:j, :-1] += patches[k]
134 | img[i - patch_size:i, j - patch_size:j, -1] += np.ones((patch_size, patch_size))
135 | k += 1
136 | j += stride
137 | j = patch_size
138 | i += stride
139 | mask=np.repeat(img[:,:,-1][...,np.newaxis], patches.shape[3], 2)
140 | img = img[:,:,:-1]/mask
141 | return img
142 |
143 |
144 | def set_img_color(img, predict_mask, weight_foreground, grayscale):
145 | if grayscale:
146 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
147 | origin = img
148 | img[np.where(predict_mask == 255)] = (0,0,255)
149 | cv2.addWeighted(img, weight_foreground, origin, (1 - weight_foreground), 0, img)
150 | return img
151 |
152 |
153 | def bg_mask(img, value, mode, grayscale):
154 |
155 | if not grayscale:
156 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
157 | _,thresh=cv2.threshold(img,value,255,mode)
158 |
159 | def FillHole(mask):
160 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
161 | len_contour = len(contours)
162 | contour_list = []
163 | for i in range(len_contour):
164 | drawing = np.zeros_like(mask, np.uint8) # create a black image
165 | img_contour = cv2.drawContours(drawing, contours, i, (255, 255, 255), -1)
166 | contour_list.append(img_contour)
167 |
168 | out = sum(contour_list)
169 | return out
170 |
171 | thresh = FillHole(thresh)
172 | if type(thresh) is int:
173 | return np.ones(img.shape)
174 | mask_ = np.ones(thresh.shape)
175 | mask_[np.where(thresh <= 127)] = 0
176 | return mask_
--------------------------------------------------------------------------------