├── requirements.txt ├── README.md ├── augment.py └── CNN_generator.py /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==0.25.0 2 | opencv_python==4.1.2.30 3 | numpy==1.17.2 4 | matplotlib==3.1.1 5 | Keras==2.2.4 6 | tensorflow==1.12.0 7 | Pillow==7.1.2 8 | scikit_learn==0.22.2.post1 9 | skimage==0.0 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image_Generator 2 | Train Image Classifier with Augmentation and Keras Fit Generator 3 | 4 | 1. ```augment.py```
5 | Contains different image augmentation methods 6 | 7 | 2. ```CNN_generator.py```
8 | Train Convolutional Neural Network model for image classifier with augmentation using Keras ```fit_generator``` method 9 | -------------------------------------------------------------------------------- /augment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from PIL import Image 3 | import numpy as np 4 | import skimage 5 | from matplotlib import pyplot as plt 6 | from skimage.util import random_noise 7 | import os 8 | 9 | class ImageAugment: 10 | def add_gaussian_noise(self,image): 11 | gauss = random_noise(image, mode='gaussian', seed=None, clip=True) 12 | return gauss 13 | 14 | def add_salt_pepper_noise(self,image): 15 | sp = random_noise(image, mode='s&p', seed=None, clip=True) 16 | return sp 17 | 18 | def add_poisson_noise(self,image): 19 | poisson = random_noise(image, mode='poisson', seed=None, clip=True) 20 | return poisson 21 | 22 | def add_speckle_noise(self,image): 23 | speckle = random_noise(image, mode='speckle', seed=None, clip=True) 24 | return speckle 25 | 26 | def flip_vertical(self,image): 27 | flipVertical = cv2.flip(image, 0) 28 | return flipVertical 29 | 30 | def flip_horizontal(self,image): 31 | flipHorizontal = cv2.flip(image, 1) 32 | return flipHorizontal 33 | 34 | def do_augmentation(self,image): 35 | image = self.add_gaussian_noise(image) 36 | image = self.add_salt_pepper_noise(image) 37 | image = self.add_poisson_noise(image) 38 | image = self.add_speckle_noise(image) 39 | image = self.flip_vertical(image) 40 | image = self.flip_horizontal(image) 41 | return image 42 | 43 | class Generator(ImageAugment): 44 | def __init__(self,path): 45 | self.path = path 46 | 47 | def image_generator(self): 48 | path = self.path 49 | file_list = [] 50 | for root,dirs,files in os.walk(path): 51 | for file in files: 52 | file_list.append(os.path.join(root,file)) 53 | i=0 54 | while(True): 55 | yield file_list[i] 56 | i+=1 57 | 58 | def on_next(self): 59 | gen_obj = self.image_generator() 60 | return next(gen_obj) 61 | 62 | def augment_and_show(self): 63 | print(self.on_next()) 64 | while True: 65 | image = cv2.imread(self.on_next()) 66 | image = self.do_augmentation(image) 67 | cv2.imshow('image', image) 68 | cv2.waitKey(0) 69 | 70 | def label_generator(self,feat,labels): 71 | i=0 72 | while (True): 73 | yield feat[i],labels[i] 74 | i+=1 75 | 76 | 77 | # image_path = './data/image1.jpg' 78 | # I = cv2.imread(image_path, 1) 79 | # 80 | # aug = ImageAugment() 81 | # gauss = aug.add_gaussian_noise(I) 82 | # sp = aug.add_salt_pepper_noise(I) 83 | # poisson = aug.add_poisson_noise(I) 84 | # speckle = aug.add_speckle_noise(I) 85 | # flipv = aug.flip_vertical(I) 86 | # fliph = aug.flip_horizontal(I) 87 | # 88 | image_gen = Generator('./data') 89 | # augmented = image_gen.do_augmentation(I) 90 | 91 | gen = image_gen.augment_and_show() 92 | 93 | # plt.subplot(421), plt.imshow(I), plt.title('Origin') 94 | # plt.subplot(422), plt.imshow(gauss), plt.title('Gaussian') 95 | # plt.subplot(423), plt.imshow(sp), plt.title('Salt and Pepper') 96 | # plt.subplot(424), plt.imshow(poisson), plt.title('Poisson') 97 | # plt.subplot(425), plt.imshow(speckle), plt.title('Speckle') 98 | # plt.subplot(426), plt.imshow(flipv), plt.title('Flip Vertical') 99 | # plt.subplot(427), plt.imshow(fliph), plt.title('Flip Horizontal') 100 | # plt.subplot(428), plt.imshow(augmented), plt.title('Augmented') 101 | # 102 | # plt.show() 103 | -------------------------------------------------------------------------------- /CNN_generator.py: -------------------------------------------------------------------------------- 1 | from keras.utils import to_categorical 2 | from keras.models import Sequential 3 | from keras.layers import Dense, Conv2D, Flatten 4 | import pandas as pd 5 | import os 6 | import cv2 7 | import numpy as np 8 | from sklearn.model_selection import train_test_split 9 | import tensorflow as tf 10 | from augment import ImageAugment 11 | 12 | class Generator(): 13 | def __init__(self, feat, labels, width, height): 14 | self.feat = feat 15 | self.labels = labels 16 | self.width = width 17 | self.height = height 18 | 19 | def gen(self): 20 | ''' 21 | Yields generator object for training or evaluation without batching 22 | Yields: 23 | im: np.array of (1,width,height,1) of images 24 | label: np.array of one-hot vector of label (1,num_labels) 25 | ''' 26 | feat = self.feat 27 | labels = self.labels 28 | width = self.width 29 | height = self.height 30 | i=0 31 | while (True): 32 | im = cv2.imread(feat[i],0) 33 | im = im.reshape(width,height,1) 34 | im = np.expand_dims(im,axis=0) 35 | label = np.expand_dims(labels[i],axis=0) 36 | yield im,label 37 | i+=1 38 | 39 | if i>=len(feat): 40 | i=0 41 | 42 | 43 | def gen_test(self): 44 | ''' 45 | Yields generator object to do prediction 46 | Yields: 47 | im: np.array of (1,width,height,1) of images 48 | ''' 49 | feat = self.feat 50 | width = self.width 51 | height = self.height 52 | i=0 53 | while (True): 54 | im = cv2.imread(feat[i],0) 55 | im = im.reshape(width,height,1) 56 | im = np.expand_dims(im,axis=0) 57 | yield im 58 | i+=1 59 | 60 | 61 | def gen_batching(self, batch_size): 62 | ''' 63 | Yields generator object with batching of batch_size 64 | Args: 65 | batch_size (int): batch_size 66 | Yields: 67 | feat_batch: np.array of (batch_size,width,height,1) of images 68 | label_batch: np.array of (batch_size,num_labels) 69 | ''' 70 | feat = self.feat 71 | labels = self.labels 72 | width = self.width 73 | height = self.height 74 | num_examples = len(feat) 75 | num_batch = num_examples/batch_size 76 | X = [] 77 | for n in range(num_examples): 78 | im = cv2.imread(feat[n],0) 79 | try: 80 | im = im.reshape(width,height,1) 81 | except: 82 | print('Error on this image: ', feat[n]) 83 | X.append(im) 84 | X = np.array(X) 85 | 86 | feat_batch = np.zeros((batch_size,width,height,1)) 87 | label_batch = np.zeros((batch_size,labels.shape[1])) 88 | while(True): 89 | for i in range(batch_size): 90 | index = np.random.randint(X.shape[0],size=1)[0] #shuffle the data 91 | feat_batch[i] = X[index] 92 | label_batch[i] = labels[index] 93 | yield feat_batch,label_batch 94 | 95 | # def on_next(self): 96 | # ''' 97 | # Advance to the next generator object 98 | # ''' 99 | # gen_obj = self.gen_test() 100 | # return next(gen_obj) 101 | # 102 | # def gen_show(self, pred): 103 | # ''' 104 | # Show the image generator object 105 | # ''' 106 | # i=0 107 | # while(True): 108 | # image = self.on_next() 109 | # image = np.squeeze(image,axis=0) 110 | # cv2.imshow('image', image) 111 | # cv2.waitKey(0) 112 | # i+=1 113 | 114 | def gen_augment(self,batch_size,augment): 115 | ''' 116 | Yields generator object with batching of batch_size and augmentation. 117 | The number of examples for 1 batch will be multiplied based on the number of augmentation 118 | 119 | augment represents [speckle, gaussian, poisson]. It means, the augmentation will be done on the augment list element that is 1 120 | for example, augment = [1,1,0] corresponds to adding speckle noise and gaussian noise 121 | if batch_size = 100, the number of examples in each batch will become 300 122 | 123 | Args: 124 | batch_size (int): batch_size 125 | augment (list): list that defines what kind of augmentation we want to do 126 | Yields: 127 | feat_batch: np.array of (batch_size*n_augment,width,height,1) of images 128 | label_batch: np.array of (batch_size*n_augment,num_labels) 129 | ''' 130 | feat = self.feat 131 | labels = self.labels 132 | width = self.width 133 | height = self.height 134 | 135 | num_examples = len(feat) 136 | num_batch = num_examples/batch_size 137 | X = [] 138 | for n in range(num_examples): 139 | im = cv2.imread(feat[n],0) 140 | try: 141 | im = im.reshape(width,height,1) 142 | except: 143 | print('Error on this image: ', feat[n]) 144 | X.append(im) 145 | X = np.array(X) 146 | 147 | n_augment = augment.count(1) 148 | print('Number of augmentations: ', n_augment) 149 | feat_batch = np.zeros(((n_augment+1)*batch_size,width,height,1)) 150 | label_batch = np.zeros(((n_augment+1)*batch_size,labels.shape[1])) 151 | 152 | while(True): 153 | i=0 154 | while (i<=batch_size): 155 | index = np.random.randint(X.shape[0],size=1)[0] #shuffle the data 156 | aug = ImageAugment(X[index]) 157 | feat_batch[i] = X[index] 158 | label_batch[i] = labels[index] 159 | 160 | j=0 161 | if augment[0] == 1: 162 | feat_batch[(j*n_augment)+i+batch_size] = aug.add_speckle_noise() 163 | label_batch[(j*n_augment)+i+batch_size] = labels[index] 164 | j+=1 165 | 166 | if augment[1] == 1: 167 | feat_batch[(j*n_augment)+i+batch_size] = aug.add_gaussian_noise() 168 | label_batch[(j*n_augment)+i+batch_size] = labels[index] 169 | j+=1 170 | 171 | if augment[2] == 1: 172 | feat_batch[(j*n_augment)+i+batch_size] = aug.add_poisson_noise() 173 | label_batch[(j*n_augment)+i+batch_size] = labels[index] 174 | j+=1 175 | 176 | i+=1 177 | 178 | 179 | yield feat_batch,label_batch 180 | 181 | def CNN_model(width,height): 182 | # #create model 183 | model = Sequential() 184 | model.add(Conv2D(64, kernel_size=3, activation="relu", input_shape=(width,height,1))) 185 | model.add(Conv2D(32, kernel_size=3, activation="relu")) 186 | model.add(Flatten()) 187 | model.add(Dense(labels.shape[1], activation="softmax")) 188 | 189 | model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) 190 | return model 191 | 192 | 193 | if __name__ == "__main__": 194 | input_dir = './mnist' 195 | output_file = 'dataset.csv' 196 | 197 | filename = [] 198 | label = [] 199 | for root,dirs,files in os.walk(input_dir): 200 | for file in files: 201 | full_path = os.path.join(root,file) 202 | filename.append(full_path) 203 | label.append(os.path.basename(os.path.dirname(full_path))) 204 | 205 | data = pd.DataFrame(data={'filename': filename, 'label':label}) 206 | data.to_csv(output_file,index=False) 207 | 208 | labels = pd.get_dummies(data.iloc[:,1]).values 209 | 210 | X, X_val, y, y_val = train_test_split( 211 | filename, labels, 212 | test_size=0.2, 213 | random_state=1234, 214 | shuffle=True, 215 | stratify=labels 216 | ) 217 | 218 | X_train, X_test, y_train, y_test = train_test_split( 219 | X, y, 220 | test_size=0.025, 221 | random_state=1234, 222 | shuffle=True, 223 | stratify=y 224 | ) 225 | 226 | width = 28 227 | height = 28 228 | 229 | test_data = pd.DataFrame(data={'filename': X_test}) 230 | 231 | 232 | image_gen_train = Generator(X_train,y_train,width,height) 233 | image_gen_val = Generator(X_val,y_val,width,height) 234 | image_gen_test = Generator(X_test,None,width,height) 235 | 236 | 237 | batch_size = 900 238 | print('len data: ', len(X_train)) 239 | print('len test data: ', len(X_test)) 240 | 241 | #augment represents [speckle, gaussian, poisson]. It means, the augmentation will be done on the augment list element that is 1 242 | #for example, augment = [1,1,0] corresponds to adding speckle noise and gaussian noise 243 | augment = [1,1,1] 244 | model = CNN_model(width,height) 245 | 246 | model.fit_generator( 247 | generator=image_gen_train.gen_augment(batch_size=batch_size,augment=augment), 248 | steps_per_epoch=np.ceil(len(X_train)/batch_size), 249 | epochs=20, 250 | verbose=1, 251 | validation_data=image_gen_val.gen(), 252 | validation_steps=len(X_val) 253 | ) 254 | model.save('model_aug_3.h5') 255 | model = tf.keras.models.load_model('model_aug_3.h5') 256 | 257 | #Try evaluate_generator 258 | image_gen_test = Generator(X_test,y_test,width,height) 259 | print(model.evaluate_generator( 260 | generator=image_gen_test.gen(), 261 | steps=len(X_test) 262 | )) 263 | 264 | #Try predict_generator 265 | image_gen_test = Generator(X_test,None,width,height) 266 | pred = model.predict_generator( 267 | generator=image_gen_test.gen_test(), 268 | steps=len(X_test) 269 | ) 270 | pred = np.argmax(pred,axis=1) 271 | # image_gen_test = Generator(X_test,pred,width*3,height*3) 272 | # image_gen_test.gen_show(pred) 273 | wrong_pred = [] 274 | for i,ex in enumerate(zip(pred,y_test)): 275 | if ex[0] != np.argmax(ex[1]): 276 | wrong_pred.append(i) 277 | print(wrong_pred) 278 | 279 | # for i in range(len(X_test)): 280 | # im = cv2.imread(X_test[i],0) 281 | # im = cv2.putText(im, str(pred[i]), (10,15), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2) 282 | # print(i) 283 | # cv2.imshow('image',im) 284 | # cv2.waitKey(0) 285 | --------------------------------------------------------------------------------