├── README.md ├── image ├── bent_000_origin.png ├── bent_000_rec.png ├── bent_000_residual.png └── bent_000_visual.png ├── network.py ├── options.py ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # AutoEncoder with SSIM loss 2 | 3 | This is a third party implementation of the paper **Improving Unsupervised Defect Segmentation by Applying Structural Similarity to Autoencoders**.
4 | 5 | 6 | ![avatar](./image/bent_000_origin.png) ![avatar](./image/bent_000_rec.png) ![avatar](./image/bent_000_visual.png) 7 | 8 | 9 | ## Requirement 10 | `tensorflow==2.2.0`
11 | `skimage`
12 | 13 | ## Datasets 14 | MVTec AD datasets https://www.mvtec.com/company/research/datasets/mvtec-ad/ 15 | 16 | ## Code examples 17 | 18 | ### Step 1. Set the *DATASET_PATH* variable. 19 | 20 | Set the [DATASET_PATH](options.py#L046) to the root path of the downloaded MVTec AD dataset. 21 | 22 | ### Step 2. Train **SSIM-AE** and Test. 23 | 24 | - **bottle** object 25 | ```bash 26 | python train.py --name bottle --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate 0. 27 | python test.py --name bottle --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --bg_mask W 28 | ``` 29 | - **cable** object 30 | ```bash 31 | python train.py --name cable --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate 0. --p_horizonal_flip 0. --p_vertical_flip 0. 32 | python test.py --name cable --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 33 | ``` 34 | - **capsule** object 35 | ```bash 36 | python train.py --name capsule --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate 0. --p_horizonal_flip 0. --p_vertical_flip 0. 37 | python test.py --name capsule --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --bg_mask W 38 | ``` 39 | - **carpet** texture 40 | ```bash 41 | python train.py --name carpet --loss ssim_loss --im_resize 512 --patch_size 128 --z_dim 100 --do_aug --rotate_angle_vari 10 42 | python test.py --name carpet --loss ssim_loss --im_resize 512 --patch_size 128 --z_dim 100 43 | ``` 44 | - **grid** texture 45 | ```bash 46 | python train.py --name grid --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100 --grayscale --do_aug 47 | python test.py --name grid --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100 --grayscale 48 | ``` 49 | - **hazelnut** object 50 | ```bash 51 | python train.py --name hazelnut --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate_crop 0. 52 | python test.py --name hazelnut --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --bg_mask B 53 | ``` 54 | - **leather** texture 55 | ```bash 56 | python train.py --name leather --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100 --do_aug 57 | python test.py --name leather --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100 58 | ``` 59 | - **metal_nut** object 60 | ```bash 61 | python train.py --name metal_nut --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate_crop 0. --p_horizonal_flip 0. --p_vertical_flip 0. 62 | python test.py --name metal_nut --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --bg_mask B 63 | ``` 64 | - **pill** object 65 | ```bash 66 | python train.py --name pill --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate 0. --p_horizonal_flip 0. --p_vertical_flip 0. 67 | python test.py --name pill --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --bg_mask B 68 | ``` 69 | - **screw** object 70 | ```bash 71 | python train.py --name screw --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --grayscale --do_aug --p_rotate 0. 72 | python test.py --name screw --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --grayscale --bg_mask W 73 | ``` 74 | - **tile** texture 75 | ```bash 76 | python train.py --name tile --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100 --do_aug 77 | python test.py --name tile --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100 78 | ``` 79 | - **toothbrush** object 80 | ```bash 81 | python train.py --name toothbrush --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate 0. --p_vertical_flip 0. 82 | python test.py --name toothbrush --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 83 | ``` 84 | - **transistor** object 85 | ```bash 86 | python train.py --name transistor --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --do_aug --p_rotate 0. --p_vertical_flip 0. 87 | python test.py --name transistor --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 88 | ``` 89 | - **wood** texture 90 | ```bash 91 | python train.py --name wood --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100 --do_aug --rotate_angle_vari 15 92 | python test.py --name wood --loss ssim_loss --im_resize 256 --patch_size 128 --z_dim 100 93 | ``` 94 | - **zipper** object 95 | ```bash 96 | python train.py --name zipper --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --grayscale --do_aug --p_rotate 0. 97 | python test.py --name zipper --loss ssim_loss --im_resize 266 --patch_size 256 --z_dim 500 --grayscale 98 | ``` 99 | 100 | ## Overview of Results 101 | 102 | **Classification** 103 | During test, I simply classify a test image as defect if there is any anomalous response on the residual map. It is strict for anomaly-free images, resulting in relatively lower accuracy in the `ok` column shown as below.
104 | Please note that the **threshold** makes a big difference to the outcome, which should be carefully selected. 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 |
oknokaverage
bottle 90.098.496.4
cable 0.045.728.0
capsule 34.889.678.0
carpet 42.998.988.9
grid 10094.796.2
hazelnut 55.098.682.7
leather 71.992.487.1
metal nut 22.767.759.1
pill 11.575.965.9
screw 0.590.068.1
tile 100.03.630.8
toothbrush 83.310095.2
transistor 23.397.553.0
wood 89.576.779.7
zipper 68.881.578.8
205 | *SSIM loss, 200 epochs, different threshold 206 | 207 | ## Discussion 208 | - **SSIM + L1 metrics**
209 | Since SSIM is a measure of similarity only between grayscale images, it cannot handle color defect in some cases. So here I use SSIM + L1 distance for anomaly segmentation. 210 | - **VAE**
211 | I have tried VAE, observing no performances improvements. 212 | - **InstanceNorm**
213 | I have also tried adding the IN layer for accelerating convergence, but the droplet artifact appears in some cases. It is also mentioned and discussed in **StyleGAN-2** paper. 214 | 215 | ## Supplementary materials 216 | My notes https://www.yuque.com/books/share/8c7613f7-7571-4bfa-865a-689de3763c59?# 217 | password `ixgg` 218 | 219 | ## References 220 | @inproceedings{inproceedings, 221 | author = {Bergmann, Paul and Löwe, Sindy and Fauser, Michael and Sattlegger, David and Steger, Carsten}, 222 | year = {2019}, 223 | month = {01}, 224 | pages = {372-380}, 225 | title = {Improving Unsupervised Defect Segmentation by Applying Structural Similarity to Autoencoders}, 226 | doi = {10.5220/0007364503720380} 227 | } 228 | 229 | Paul Bergmann, Michael Fauser, David Sattlegger, Carsten Steger. MVTec AD - A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection; in: IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2019 230 | 231 | 232 | -------------------------------------------------------------------------------- /image/bent_000_origin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/plutoyuxie/AutoEncoder-SSIM-for-unsupervised-anomaly-detection-/1bd6651a77d796445156d2a3e08d11d805079367/image/bent_000_origin.png -------------------------------------------------------------------------------- /image/bent_000_rec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/plutoyuxie/AutoEncoder-SSIM-for-unsupervised-anomaly-detection-/1bd6651a77d796445156d2a3e08d11d805079367/image/bent_000_rec.png -------------------------------------------------------------------------------- /image/bent_000_residual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/plutoyuxie/AutoEncoder-SSIM-for-unsupervised-anomaly-detection-/1bd6651a77d796445156d2a3e08d11d805079367/image/bent_000_residual.png -------------------------------------------------------------------------------- /image/bent_000_visual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/plutoyuxie/AutoEncoder-SSIM-for-unsupervised-anomaly-detection-/1bd6651a77d796445156d2a3e08d11d805079367/image/bent_000_visual.png -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, LeakyReLU 2 | from tensorflow.keras.models import Model 3 | 4 | 5 | def AutoEncoder(cfg): 6 | input_img = Input(shape=(cfg.patch_size, cfg.patch_size, cfg.input_channel)) 7 | 8 | h = Conv2D(cfg.flc, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(input_img) 9 | h = Conv2D(cfg.flc, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h) 10 | if cfg.patch_size==256: 11 | h = Conv2D(cfg.flc, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h) 12 | h = Conv2D(cfg.flc, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h) 13 | h = Conv2D(cfg.flc*2, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h) 14 | h = Conv2D(cfg.flc*2, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h) 15 | h = Conv2D(cfg.flc*4, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h) 16 | h = Conv2D(cfg.flc*2, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h) 17 | h = Conv2D(cfg.flc, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h) 18 | encoded = Conv2D(cfg.z_dim, (8, 8), strides=1, activation='linear', padding='valid')(h) 19 | 20 | h = Conv2DTranspose(cfg.flc, (8, 8), strides=1, activation=LeakyReLU(alpha=0.2), padding='valid')(encoded) 21 | h = Conv2D(cfg.flc*2, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h) 22 | h = Conv2D(cfg.flc*4, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h) 23 | h = Conv2DTranspose(cfg.flc*2, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h) 24 | h = Conv2D(cfg.flc*2, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h) 25 | h = Conv2DTranspose(cfg.flc, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h) 26 | h = Conv2D(cfg.flc, (3, 3), strides=1, activation=LeakyReLU(alpha=0.2), padding='same')(h) 27 | h = Conv2DTranspose(cfg.flc, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h) 28 | if cfg.patch_size==256: 29 | h = Conv2DTranspose(cfg.flc, (4, 4), strides=2, activation=LeakyReLU(alpha=0.2), padding='same')(h) 30 | 31 | decoded = Conv2DTranspose(cfg.input_channel, (4, 4), strides=2, activation='sigmoid', padding='same')(h) 32 | 33 | return Model(input_img, decoded) -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | class Options(): 5 | def __init__(self): 6 | self.parser = argparse.ArgumentParser() 7 | 8 | self.parser.add_argument('--name', type=str, default='leather') 9 | self.parser.add_argument('--train_data_dir', type=str, default=None) 10 | self.parser.add_argument('--test_dir', type=str, default=None) 11 | self.parser.add_argument('--sub_folder', type=list, nargs='*', default=None) 12 | self.parser.add_argument('--do_aug', action='store_true', help='whether to do data augmentation before training') 13 | self.parser.add_argument('--aug_dir', type=str, default=None) 14 | self.parser.add_argument('--chechpoint_dir', type=str, default=None) 15 | self.parser.add_argument('--save_dir', type=str, default=None) 16 | 17 | self.parser.add_argument('--augment_num', type=int, default=10000) 18 | self.parser.add_argument('--im_resize', type=int, default=256, help='scale images to this size') 19 | self.parser.add_argument('--patch_size', type=int, default=128, help='then crop to this size') 20 | self.parser.add_argument("--grayscale", action='store_true', help='color or grayscale input image') 21 | self.parser.add_argument('--p_rotate', type=float, default=0.3, help='probability to do image rotation') 22 | self.parser.add_argument('--rotate_angle_vari', type=float, default=45.0, help='rotate image between [-angle, +angle]') 23 | self.parser.add_argument('--p_rotate_crop', type=float, default=1.0, help='probability to crop inner rotated image') 24 | self.parser.add_argument('--p_horizonal_flip', type=float, default=0.3, help='probability to do horizonal flip') 25 | self.parser.add_argument('--p_vertical_flip', type=float, default=0.3, help='probability to do vertical flip') 26 | 27 | self.parser.add_argument('--z_dim', type=int, default=100, help='dimension of the latent space vector') 28 | self.parser.add_argument('--flc', type=int, default=32, help='number of the first hidden layer channels') 29 | 30 | self.parser.add_argument('--epochs', type=int, default=200, help='maximum training epochs') 31 | self.parser.add_argument('--batch_size', type=int, default=128) 32 | self.parser.add_argument('--loss', type=str, default='ssim_loss', help='loss type in [ssim_loss, ssim_l1_loss, l2_loss]') 33 | self.parser.add_argument('--weight', type=int, default=1, help='weight of the l1_loss item if using ssim_l1_loss') 34 | self.parser.add_argument('--lr', type=float, default=2e-4, help='learning rate of Adam') 35 | self.parser.add_argument('--decay', type=float, default=1e-5, help='decay of Adam') 36 | 37 | 38 | self.parser.add_argument('--weight_file', type=str, default=None, help='if set None, the latest weight file will be automatically selected') 39 | self.parser.add_argument('--stride', type=int, default=32, help='step length of the sliding window') 40 | self.parser.add_argument('--ssim_threshold', type=float, default=None, help='ssim threshold for testing') 41 | self.parser.add_argument('--l1_threshold', type=float, default=None, help='l1 threshold for testing') 42 | self.parser.add_argument('--percent', type=float, default=98.0, help='for estimating threshold based on valid positive samples') 43 | self.parser.add_argument('--bg_mask', type=str, default=None, help='background mask, B means black, W means white') 44 | 45 | def parse(self): 46 | DATASET_PATH = 'D:/user/dataset/mvtec_anomaly_detection' 47 | self.opt = self.parser.parse_args() 48 | 49 | if not self.opt.train_data_dir: 50 | self.opt.train_data_dir = DATASET_PATH+'/'+self.opt.name+'/train/good' 51 | if not self.opt.test_dir: 52 | self.opt.test_dir = DATASET_PATH+'/'+self.opt.name+'/test' 53 | if not self.opt.sub_folder: 54 | self.opt.sub_folder = os.listdir(self.opt.test_dir) 55 | if not self.opt.aug_dir: 56 | self.opt.aug_dir = './train_patches/'+self.opt.name 57 | if not self.opt.chechpoint_dir: 58 | self.opt.chechpoint_dir = './results/'+self.opt.name+'/chechpoints/'+self.opt.loss 59 | if not self.opt.save_dir: 60 | self.opt.save_dir = './results/'+self.opt.name+'/reconst/ssim_l1_metric_'+self.opt.loss 61 | 62 | if not os.path.exists(self.opt.chechpoint_dir): 63 | os.makedirs(self.opt.chechpoint_dir) 64 | if not os.path.exists(self.opt.aug_dir): 65 | os.makedirs(self.opt.aug_dir) 66 | if not os.path.exists(self.opt.save_dir): 67 | os.makedirs(self.opt.save_dir) 68 | 69 | self.opt.input_channel = 1 if self.opt.grayscale else 3 70 | self.opt.p_crop = 1 if self.opt.patch_size != self.opt.im_resize else 0 71 | self.opt.mask_size = self.opt.patch_size if self.opt.im_resize - self.opt.patch_size < self.opt.stride else self.opt.im_resize 72 | 73 | return self.opt 74 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.metrics import structural_similarity as ssim 3 | from skimage import morphology 4 | from glob import glob 5 | import cv2 6 | import os 7 | 8 | from utils import read_img, get_patch, patch2img, set_img_color, bg_mask 9 | from network import AutoEncoder 10 | from options import Options 11 | 12 | 13 | cfg = Options().parse() 14 | 15 | # network 16 | autoencoder = AutoEncoder(cfg) 17 | 18 | if cfg.weight_file: 19 | autoencoder.load_weights(cfg.chechpoint_dir + '/' + cfg.weight_file) 20 | else: 21 | file_list = os.listdir(cfg.chechpoint_dir) 22 | latest_epoch = max([int(i.split('-')[0]) for i in file_list if 'hdf5' in i]) 23 | print('load latest weight file: ', latest_epoch) 24 | autoencoder.load_weights(glob(cfg.chechpoint_dir + '/' + str(latest_epoch) + '*.hdf5')[0]) 25 | #autoencoder.summary() 26 | 27 | def get_residual_map(img_path, cfg): 28 | test_img = read_img(img_path, cfg.grayscale) 29 | 30 | if test_img.shape[:2] != (cfg.im_resize, cfg.im_resize): 31 | test_img = cv2.resize(test_img, (cfg.im_resize, cfg.im_resize)) 32 | if cfg.im_resize != cfg.mask_size: 33 | tmp = (cfg.im_resize - cfg.mask_size)//2 34 | test_img = test_img[tmp:tmp+cfg.mask_size, tmp:tmp+cfg.mask_size] 35 | 36 | test_img_ = test_img / 255. 37 | 38 | if test_img.shape[:2] == (cfg.patch_size, cfg.patch_size): 39 | test_img_ = np.expand_dims(test_img_, 0) 40 | decoded_img = autoencoder.predict(test_img_) 41 | else: 42 | patches = get_patch(test_img_, cfg.patch_size, cfg.stride) 43 | patches = autoencoder.predict(patches) 44 | decoded_img = patch2img(patches, cfg.im_resize, cfg.patch_size, cfg.stride) 45 | 46 | rec_img = np.reshape((decoded_img * 255.).astype('uint8'), test_img.shape) 47 | 48 | if cfg.grayscale: 49 | ssim_residual_map = 1 - ssim(test_img, rec_img, win_size=11, full=True)[1] 50 | l1_residual_map = np.abs(test_img / 255. - rec_img / 255.) 51 | else: 52 | ssim_residual_map = ssim(test_img, rec_img, win_size=11, full=True, multichannel=True)[1] 53 | ssim_residual_map = 1 - np.mean(ssim_residual_map, axis=2) 54 | l1_residual_map = np.mean(np.abs(test_img / 255. - rec_img / 255.), axis=2) 55 | 56 | return test_img, rec_img, ssim_residual_map, l1_residual_map 57 | 58 | 59 | def get_threshold(cfg): 60 | print('estimating threshold...') 61 | valid_good_list = glob(cfg.train_data_dir + '/*png') 62 | num_valid_data = int(np.ceil(len(valid_good_list) * 0.2)) 63 | total_rec_ssim, total_rec_l1 = [], [] 64 | for img_path in valid_good_list[-num_valid_data:]: 65 | _, _, ssim_residual_map, l1_residual_map = get_residual_map(img_path, cfg) 66 | total_rec_ssim.append(ssim_residual_map) 67 | total_rec_l1.append(l1_residual_map) 68 | total_rec_ssim = np.array(total_rec_ssim) 69 | total_rec_l1 = np.array(total_rec_l1) 70 | ssim_threshold = float(np.percentile(total_rec_ssim, [cfg.percent])) 71 | l1_threshold = float(np.percentile(total_rec_l1, [cfg.percent])) 72 | print('ssim_threshold: %f, l1_threshold: %f' %(ssim_threshold, l1_threshold)) 73 | if not cfg.ssim_threshold: 74 | cfg.ssim_threshold = ssim_threshold 75 | if not cfg.l1_threshold: 76 | cfg.l1_threshold = l1_threshold 77 | 78 | 79 | def get_depressing_mask(cfg): 80 | depr_mask = np.ones((cfg.mask_size, cfg.mask_size)) * 0.2 81 | depr_mask[5:cfg.mask_size-5, 5:cfg.mask_size-5] = 1 82 | cfg.depr_mask = depr_mask 83 | 84 | 85 | def get_results(file_list, cfg): 86 | for img_path in file_list: 87 | img_name = img_path.split('\\')[-1][:-4] 88 | c = '' if not cfg.sub_folder else k 89 | test_img, rec_img, ssim_residual_map, l1_residual_map = get_residual_map(img_path, cfg) 90 | 91 | ssim_residual_map *= cfg.depr_mask 92 | if 'ssim' in cfg.loss: 93 | l1_residual_map *= cfg.depr_mask 94 | 95 | mask = np.zeros((cfg.mask_size, cfg.mask_size)) 96 | mask[ssim_residual_map > cfg.ssim_threshold] = 1 97 | mask[l1_residual_map > cfg.l1_threshold] = 1 98 | if cfg.bg_mask == 'B': 99 | bg_m = bg_mask(test_img.copy(), 50, cv2.THRESH_BINARY, cfg.grayscale) 100 | mask *= bg_m 101 | elif cfg.bg_mask == 'W': 102 | bg_m = bg_mask(test_img.copy(), 200, cv2.THRESH_BINARY_INV, cfg.grayscale) 103 | mask *= bg_m 104 | kernel = morphology.disk(4) 105 | mask = morphology.opening(mask, kernel) 106 | mask *= 255 107 | 108 | vis_img = set_img_color(test_img.copy(), mask, weight_foreground=0.3, grayscale=cfg.grayscale) 109 | 110 | cv2.imwrite(cfg.save_dir+'/'+c+'_'+img_name+'_residual.png', mask) 111 | cv2.imwrite(cfg.save_dir+'/'+c+'_'+img_name+'_origin.png', test_img) 112 | cv2.imwrite(cfg.save_dir+'/'+c+'_'+img_name+'_rec.png', rec_img) 113 | cv2.imwrite(cfg.save_dir+'/'+c+'_'+img_name+'_visual.png', vis_img) 114 | 115 | 116 | if __name__ == '__main__': 117 | if not cfg.ssim_threshold or not cfg.l1_threshold: 118 | get_threshold(cfg) 119 | 120 | get_depressing_mask(cfg) 121 | 122 | if cfg.sub_folder: 123 | for k in cfg.sub_folder: 124 | test_list = glob(cfg.test_dir+'/'+k+'/*') 125 | get_results(test_list, cfg) 126 | else: 127 | test_list = glob(cfg.test_dir+'/*') 128 | get_results(test_list, cfg) 129 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping 2 | from tensorflow.keras.utils import Sequence 3 | from tensorflow.keras.optimizers import Adam 4 | import tensorflow as tf 5 | import numpy as np 6 | import cv2 7 | from glob import glob 8 | import os 9 | 10 | from network import AutoEncoder 11 | from utils import generate_image_list, augment_images, read_img 12 | from options import Options 13 | 14 | 15 | cfg = Options().parse() 16 | 17 | class data_flow(Sequence): 18 | def __init__(self, filenames, batch_size, grayscale): 19 | self.filenames = filenames 20 | self.batch_size = batch_size 21 | self.grayscale = grayscale 22 | 23 | def __len__(self): 24 | return int(np.ceil(len(self.filenames) / float(self.batch_size))) 25 | 26 | def __getitem__(self, idx): 27 | batch_x = self.filenames[idx * self.batch_size:(idx + 1) * self.batch_size] 28 | batch_x = np.array([read_img(filename, self.grayscale) for filename in batch_x]) 29 | 30 | batch_x = batch_x / 255. 31 | return batch_x, batch_x 32 | 33 | # data 34 | if cfg.aug_dir and cfg.do_aug: 35 | img_list = generate_image_list(cfg) 36 | augment_images(img_list, cfg) 37 | 38 | dataset_dir = cfg.aug_dir if cfg.aug_dir else cfg.train_data_dir 39 | file_list = glob(dataset_dir + '/*') 40 | num_valid_data = int(np.ceil(len(file_list) * 0.2)) 41 | data_train = data_flow(file_list[:-num_valid_data], cfg.batch_size, cfg.grayscale) 42 | data_valid = data_flow(file_list[-num_valid_data:], cfg.batch_size, cfg.grayscale) 43 | 44 | # loss 45 | if cfg.loss == 'ssim_loss': 46 | 47 | @tf.function 48 | def ssim_loss(gt, y_pred, max_val=1.0): 49 | return 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val)) 50 | 51 | loss = ssim_loss 52 | elif cfg.loss == 'ssim_l1_loss': 53 | 54 | @tf.function 55 | def ssim_l1_loss(gt, y_pred, max_val=1.0): 56 | ssim_loss = 1 - tf.reduce_mean(tf.image.ssim(gt, y_pred, max_val=max_val)) 57 | L1 = tf.reduce_mean(tf.abs(gt - y_pred)) 58 | return ssim_loss + L1 * cfg.weight 59 | 60 | loss = ssim_l1_loss 61 | else: 62 | loss = 'mse' 63 | 64 | # network 65 | autoencoder = AutoEncoder(cfg) 66 | optimizer = Adam(lr=cfg.lr, decay=cfg.decay) 67 | autoencoder.compile(optimizer=optimizer, loss=loss, metrics=['mae'] if loss == 'mse' else ['mse']) 68 | autoencoder.summary() 69 | 70 | earlystopping = EarlyStopping(patience=20) 71 | 72 | checkpoint = ModelCheckpoint(os.path.join(cfg.chechpoint_dir, '{epoch:02d}-{val_loss:.5f}.hdf5'), save_best_only=True, 73 | period=1, mode='auto', verbose=1, save_weights_only=True) 74 | 75 | autoencoder.fit(data_train, epochs=cfg.epochs, validation_data=data_valid, callbacks=[checkpoint, earlystopping]) 76 | 77 | # show reconstructed images 78 | decoded_imgs = autoencoder.predict(data_valid) 79 | n = len(decoded_imgs) 80 | save_snapshot_dir = cfg.chechpoint_dir +'/snapshot/' 81 | if not os.path.exists(save_snapshot_dir): 82 | os.makedirs(save_snapshot_dir) 83 | for i in range(n): 84 | cv2.imwrite(save_snapshot_dir+str(i)+'_rec_valid.png', (decoded_imgs[i]*255).astype('uint8')) 85 | 86 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import cv2 4 | import random 5 | import os 6 | 7 | 8 | def read_img(img_path, grayscale): 9 | if grayscale: 10 | im = cv2.imread(img_path, 0) 11 | else: 12 | im = cv2.imread(img_path) 13 | return im 14 | 15 | 16 | def random_crop(image, new_size): 17 | h, w = image.shape[:2] 18 | y = np.random.randint(0, h - new_size) 19 | x = np.random.randint(0, w - new_size) 20 | image = image[y:y+new_size, x:x+new_size] 21 | return image 22 | 23 | 24 | def rotate_image(img, angle, crop): 25 | h, w = img.shape[:2] 26 | angle %= 360 27 | M_rotate = cv2.getRotationMatrix2D((w/2, h/2), angle, 1) 28 | img_rotated = cv2.warpAffine(img, M_rotate, (w, h)) 29 | if crop: 30 | angle_crop = angle % 180 31 | if angle_crop > 90: 32 | angle_crop = 180 - angle_crop 33 | theta = angle_crop * np.pi / 180.0 34 | hw_ratio = float(h) / float(w) 35 | tan_theta = np.tan(theta) 36 | numerator = np.cos(theta) + np.sin(theta) * tan_theta 37 | r = hw_ratio if h > w else 1 / hw_ratio 38 | denominator = r * tan_theta + 1 39 | crop_mult = numerator / denominator 40 | w_crop = int(round(crop_mult*w)) 41 | h_crop = int(round(crop_mult*h)) 42 | x0 = int((w-w_crop)/2) 43 | y0 = int((h-h_crop)/2) 44 | img_rotated = img_rotated[y0:y0+h_crop, x0:x0+w_crop] 45 | return img_rotated 46 | 47 | 48 | def random_rotate(img, angle_vari, p_crop): 49 | angle = np.random.uniform(-angle_vari, angle_vari) 50 | crop = False if np.random.random() > p_crop else True 51 | return rotate_image(img, angle, crop) 52 | 53 | 54 | def generate_image_list(args): 55 | filenames = os.listdir(args.train_data_dir) 56 | num_imgs = len(filenames) 57 | num_ave_aug = int(math.floor(args.augment_num/num_imgs)) 58 | rem = args.augment_num - num_ave_aug*num_imgs 59 | lucky_seq = [True]*rem + [False]*(num_imgs-rem) 60 | random.shuffle(lucky_seq) 61 | 62 | img_list = [ 63 | (os.sep.join([args.train_data_dir, filename]), num_ave_aug+1 if lucky else num_ave_aug) 64 | for filename, lucky in zip(filenames, lucky_seq) 65 | ] 66 | 67 | return img_list 68 | 69 | 70 | def augment_images(filelist, args): 71 | for filepath, n in filelist: 72 | img = read_img(filepath, args.grayscale) 73 | if img.shape[:2] != (args.im_resize, args.im_resize): 74 | img = cv2.resize(img, (args.im_resize, args.im_resize)) 75 | filename = filepath.split(os.sep)[-1] 76 | dot_pos = filename.rfind('.') 77 | imgname = filename[:dot_pos] 78 | ext = filename[dot_pos:] 79 | 80 | print('Augmenting {} ...'.format(filename)) 81 | for i in range(n): 82 | img_varied = img.copy() 83 | varied_imgname = '{}_{:0>3d}_'.format(imgname, i) 84 | 85 | if random.random() < args.p_rotate: 86 | img_varied_ = random_rotate( 87 | img_varied, 88 | args.rotate_angle_vari, 89 | args.p_rotate_crop) 90 | if img_varied_.shape[0] >= args.patch_size and img_varied_.shape[1] >= args.patch_size: 91 | img_varied = img_varied_ 92 | varied_imgname += 'r' 93 | 94 | if random.random() < args.p_crop: 95 | img_varied = random_crop( 96 | img_varied, 97 | args.patch_size) 98 | varied_imgname += 'c' 99 | 100 | if random.random() < args.p_horizonal_flip: 101 | img_varied = cv2.flip(img_varied, 1) 102 | varied_imgname += 'h' 103 | 104 | if random.random() < args.p_vertical_flip: 105 | img_varied = cv2.flip(img_varied, 0) 106 | varied_imgname += 'v' 107 | 108 | output_filepath = os.sep.join([ 109 | args.aug_dir, 110 | '{}{}'.format(varied_imgname, ext)]) 111 | cv2.imwrite(output_filepath, img_varied) 112 | 113 | 114 | def get_patch(image, new_size, stride): 115 | h, w = image.shape[:2] 116 | i, j = new_size, new_size 117 | patch = [] 118 | while i <= h: 119 | while j <= w: 120 | patch.append(image[i - new_size:i, j - new_size:j]) 121 | j += stride 122 | j = new_size 123 | i += stride 124 | return np.array(patch) 125 | 126 | 127 | def patch2img(patches, im_size, patch_size, stride): 128 | img = np.zeros((im_size, im_size, patches.shape[3]+1)) 129 | i, j = patch_size, patch_size 130 | k = 0 131 | while i <= im_size: 132 | while j <= im_size: 133 | img[i - patch_size:i, j - patch_size:j, :-1] += patches[k] 134 | img[i - patch_size:i, j - patch_size:j, -1] += np.ones((patch_size, patch_size)) 135 | k += 1 136 | j += stride 137 | j = patch_size 138 | i += stride 139 | mask=np.repeat(img[:,:,-1][...,np.newaxis], patches.shape[3], 2) 140 | img = img[:,:,:-1]/mask 141 | return img 142 | 143 | 144 | def set_img_color(img, predict_mask, weight_foreground, grayscale): 145 | if grayscale: 146 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 147 | origin = img 148 | img[np.where(predict_mask == 255)] = (0,0,255) 149 | cv2.addWeighted(img, weight_foreground, origin, (1 - weight_foreground), 0, img) 150 | return img 151 | 152 | 153 | def bg_mask(img, value, mode, grayscale): 154 | 155 | if not grayscale: 156 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 157 | _,thresh=cv2.threshold(img,value,255,mode) 158 | 159 | def FillHole(mask): 160 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 161 | len_contour = len(contours) 162 | contour_list = [] 163 | for i in range(len_contour): 164 | drawing = np.zeros_like(mask, np.uint8) # create a black image 165 | img_contour = cv2.drawContours(drawing, contours, i, (255, 255, 255), -1) 166 | contour_list.append(img_contour) 167 | 168 | out = sum(contour_list) 169 | return out 170 | 171 | thresh = FillHole(thresh) 172 | if type(thresh) is int: 173 | return np.ones(img.shape) 174 | mask_ = np.ones(thresh.shape) 175 | mask_[np.where(thresh <= 127)] = 0 176 | return mask_ --------------------------------------------------------------------------------