├── requirements.txt
├── README.md
├── augment.py
└── CNN_generator.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas==0.25.0
2 | opencv_python==4.1.2.30
3 | numpy==1.17.2
4 | matplotlib==3.1.1
5 | Keras==2.2.4
6 | tensorflow==1.12.0
7 | Pillow==7.1.2
8 | scikit_learn==0.22.2.post1
9 | skimage==0.0
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Image_Generator
2 | Train Image Classifier with Augmentation and Keras Fit Generator
3 |
4 | 1. ```augment.py```
5 | Contains different image augmentation methods
6 |
7 | 2. ```CNN_generator.py```
8 | Train Convolutional Neural Network model for image classifier with augmentation using Keras ```fit_generator``` method
9 |
--------------------------------------------------------------------------------
/augment.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from PIL import Image
3 | import numpy as np
4 | import skimage
5 | from matplotlib import pyplot as plt
6 | from skimage.util import random_noise
7 | import os
8 |
9 | class ImageAugment:
10 | def add_gaussian_noise(self,image):
11 | gauss = random_noise(image, mode='gaussian', seed=None, clip=True)
12 | return gauss
13 |
14 | def add_salt_pepper_noise(self,image):
15 | sp = random_noise(image, mode='s&p', seed=None, clip=True)
16 | return sp
17 |
18 | def add_poisson_noise(self,image):
19 | poisson = random_noise(image, mode='poisson', seed=None, clip=True)
20 | return poisson
21 |
22 | def add_speckle_noise(self,image):
23 | speckle = random_noise(image, mode='speckle', seed=None, clip=True)
24 | return speckle
25 |
26 | def flip_vertical(self,image):
27 | flipVertical = cv2.flip(image, 0)
28 | return flipVertical
29 |
30 | def flip_horizontal(self,image):
31 | flipHorizontal = cv2.flip(image, 1)
32 | return flipHorizontal
33 |
34 | def do_augmentation(self,image):
35 | image = self.add_gaussian_noise(image)
36 | image = self.add_salt_pepper_noise(image)
37 | image = self.add_poisson_noise(image)
38 | image = self.add_speckle_noise(image)
39 | image = self.flip_vertical(image)
40 | image = self.flip_horizontal(image)
41 | return image
42 |
43 | class Generator(ImageAugment):
44 | def __init__(self,path):
45 | self.path = path
46 |
47 | def image_generator(self):
48 | path = self.path
49 | file_list = []
50 | for root,dirs,files in os.walk(path):
51 | for file in files:
52 | file_list.append(os.path.join(root,file))
53 | i=0
54 | while(True):
55 | yield file_list[i]
56 | i+=1
57 |
58 | def on_next(self):
59 | gen_obj = self.image_generator()
60 | return next(gen_obj)
61 |
62 | def augment_and_show(self):
63 | print(self.on_next())
64 | while True:
65 | image = cv2.imread(self.on_next())
66 | image = self.do_augmentation(image)
67 | cv2.imshow('image', image)
68 | cv2.waitKey(0)
69 |
70 | def label_generator(self,feat,labels):
71 | i=0
72 | while (True):
73 | yield feat[i],labels[i]
74 | i+=1
75 |
76 |
77 | # image_path = './data/image1.jpg'
78 | # I = cv2.imread(image_path, 1)
79 | #
80 | # aug = ImageAugment()
81 | # gauss = aug.add_gaussian_noise(I)
82 | # sp = aug.add_salt_pepper_noise(I)
83 | # poisson = aug.add_poisson_noise(I)
84 | # speckle = aug.add_speckle_noise(I)
85 | # flipv = aug.flip_vertical(I)
86 | # fliph = aug.flip_horizontal(I)
87 | #
88 | image_gen = Generator('./data')
89 | # augmented = image_gen.do_augmentation(I)
90 |
91 | gen = image_gen.augment_and_show()
92 |
93 | # plt.subplot(421), plt.imshow(I), plt.title('Origin')
94 | # plt.subplot(422), plt.imshow(gauss), plt.title('Gaussian')
95 | # plt.subplot(423), plt.imshow(sp), plt.title('Salt and Pepper')
96 | # plt.subplot(424), plt.imshow(poisson), plt.title('Poisson')
97 | # plt.subplot(425), plt.imshow(speckle), plt.title('Speckle')
98 | # plt.subplot(426), plt.imshow(flipv), plt.title('Flip Vertical')
99 | # plt.subplot(427), plt.imshow(fliph), plt.title('Flip Horizontal')
100 | # plt.subplot(428), plt.imshow(augmented), plt.title('Augmented')
101 | #
102 | # plt.show()
103 |
--------------------------------------------------------------------------------
/CNN_generator.py:
--------------------------------------------------------------------------------
1 | from keras.utils import to_categorical
2 | from keras.models import Sequential
3 | from keras.layers import Dense, Conv2D, Flatten
4 | import pandas as pd
5 | import os
6 | import cv2
7 | import numpy as np
8 | from sklearn.model_selection import train_test_split
9 | import tensorflow as tf
10 | from augment import ImageAugment
11 |
12 | class Generator():
13 | def __init__(self, feat, labels, width, height):
14 | self.feat = feat
15 | self.labels = labels
16 | self.width = width
17 | self.height = height
18 |
19 | def gen(self):
20 | '''
21 | Yields generator object for training or evaluation without batching
22 | Yields:
23 | im: np.array of (1,width,height,1) of images
24 | label: np.array of one-hot vector of label (1,num_labels)
25 | '''
26 | feat = self.feat
27 | labels = self.labels
28 | width = self.width
29 | height = self.height
30 | i=0
31 | while (True):
32 | im = cv2.imread(feat[i],0)
33 | im = im.reshape(width,height,1)
34 | im = np.expand_dims(im,axis=0)
35 | label = np.expand_dims(labels[i],axis=0)
36 | yield im,label
37 | i+=1
38 |
39 | if i>=len(feat):
40 | i=0
41 |
42 |
43 | def gen_test(self):
44 | '''
45 | Yields generator object to do prediction
46 | Yields:
47 | im: np.array of (1,width,height,1) of images
48 | '''
49 | feat = self.feat
50 | width = self.width
51 | height = self.height
52 | i=0
53 | while (True):
54 | im = cv2.imread(feat[i],0)
55 | im = im.reshape(width,height,1)
56 | im = np.expand_dims(im,axis=0)
57 | yield im
58 | i+=1
59 |
60 |
61 | def gen_batching(self, batch_size):
62 | '''
63 | Yields generator object with batching of batch_size
64 | Args:
65 | batch_size (int): batch_size
66 | Yields:
67 | feat_batch: np.array of (batch_size,width,height,1) of images
68 | label_batch: np.array of (batch_size,num_labels)
69 | '''
70 | feat = self.feat
71 | labels = self.labels
72 | width = self.width
73 | height = self.height
74 | num_examples = len(feat)
75 | num_batch = num_examples/batch_size
76 | X = []
77 | for n in range(num_examples):
78 | im = cv2.imread(feat[n],0)
79 | try:
80 | im = im.reshape(width,height,1)
81 | except:
82 | print('Error on this image: ', feat[n])
83 | X.append(im)
84 | X = np.array(X)
85 |
86 | feat_batch = np.zeros((batch_size,width,height,1))
87 | label_batch = np.zeros((batch_size,labels.shape[1]))
88 | while(True):
89 | for i in range(batch_size):
90 | index = np.random.randint(X.shape[0],size=1)[0] #shuffle the data
91 | feat_batch[i] = X[index]
92 | label_batch[i] = labels[index]
93 | yield feat_batch,label_batch
94 |
95 | # def on_next(self):
96 | # '''
97 | # Advance to the next generator object
98 | # '''
99 | # gen_obj = self.gen_test()
100 | # return next(gen_obj)
101 | #
102 | # def gen_show(self, pred):
103 | # '''
104 | # Show the image generator object
105 | # '''
106 | # i=0
107 | # while(True):
108 | # image = self.on_next()
109 | # image = np.squeeze(image,axis=0)
110 | # cv2.imshow('image', image)
111 | # cv2.waitKey(0)
112 | # i+=1
113 |
114 | def gen_augment(self,batch_size,augment):
115 | '''
116 | Yields generator object with batching of batch_size and augmentation.
117 | The number of examples for 1 batch will be multiplied based on the number of augmentation
118 |
119 | augment represents [speckle, gaussian, poisson]. It means, the augmentation will be done on the augment list element that is 1
120 | for example, augment = [1,1,0] corresponds to adding speckle noise and gaussian noise
121 | if batch_size = 100, the number of examples in each batch will become 300
122 |
123 | Args:
124 | batch_size (int): batch_size
125 | augment (list): list that defines what kind of augmentation we want to do
126 | Yields:
127 | feat_batch: np.array of (batch_size*n_augment,width,height,1) of images
128 | label_batch: np.array of (batch_size*n_augment,num_labels)
129 | '''
130 | feat = self.feat
131 | labels = self.labels
132 | width = self.width
133 | height = self.height
134 |
135 | num_examples = len(feat)
136 | num_batch = num_examples/batch_size
137 | X = []
138 | for n in range(num_examples):
139 | im = cv2.imread(feat[n],0)
140 | try:
141 | im = im.reshape(width,height,1)
142 | except:
143 | print('Error on this image: ', feat[n])
144 | X.append(im)
145 | X = np.array(X)
146 |
147 | n_augment = augment.count(1)
148 | print('Number of augmentations: ', n_augment)
149 | feat_batch = np.zeros(((n_augment+1)*batch_size,width,height,1))
150 | label_batch = np.zeros(((n_augment+1)*batch_size,labels.shape[1]))
151 |
152 | while(True):
153 | i=0
154 | while (i<=batch_size):
155 | index = np.random.randint(X.shape[0],size=1)[0] #shuffle the data
156 | aug = ImageAugment(X[index])
157 | feat_batch[i] = X[index]
158 | label_batch[i] = labels[index]
159 |
160 | j=0
161 | if augment[0] == 1:
162 | feat_batch[(j*n_augment)+i+batch_size] = aug.add_speckle_noise()
163 | label_batch[(j*n_augment)+i+batch_size] = labels[index]
164 | j+=1
165 |
166 | if augment[1] == 1:
167 | feat_batch[(j*n_augment)+i+batch_size] = aug.add_gaussian_noise()
168 | label_batch[(j*n_augment)+i+batch_size] = labels[index]
169 | j+=1
170 |
171 | if augment[2] == 1:
172 | feat_batch[(j*n_augment)+i+batch_size] = aug.add_poisson_noise()
173 | label_batch[(j*n_augment)+i+batch_size] = labels[index]
174 | j+=1
175 |
176 | i+=1
177 |
178 |
179 | yield feat_batch,label_batch
180 |
181 | def CNN_model(width,height):
182 | # #create model
183 | model = Sequential()
184 | model.add(Conv2D(64, kernel_size=3, activation="relu", input_shape=(width,height,1)))
185 | model.add(Conv2D(32, kernel_size=3, activation="relu"))
186 | model.add(Flatten())
187 | model.add(Dense(labels.shape[1], activation="softmax"))
188 |
189 | model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
190 | return model
191 |
192 |
193 | if __name__ == "__main__":
194 | input_dir = './mnist'
195 | output_file = 'dataset.csv'
196 |
197 | filename = []
198 | label = []
199 | for root,dirs,files in os.walk(input_dir):
200 | for file in files:
201 | full_path = os.path.join(root,file)
202 | filename.append(full_path)
203 | label.append(os.path.basename(os.path.dirname(full_path)))
204 |
205 | data = pd.DataFrame(data={'filename': filename, 'label':label})
206 | data.to_csv(output_file,index=False)
207 |
208 | labels = pd.get_dummies(data.iloc[:,1]).values
209 |
210 | X, X_val, y, y_val = train_test_split(
211 | filename, labels,
212 | test_size=0.2,
213 | random_state=1234,
214 | shuffle=True,
215 | stratify=labels
216 | )
217 |
218 | X_train, X_test, y_train, y_test = train_test_split(
219 | X, y,
220 | test_size=0.025,
221 | random_state=1234,
222 | shuffle=True,
223 | stratify=y
224 | )
225 |
226 | width = 28
227 | height = 28
228 |
229 | test_data = pd.DataFrame(data={'filename': X_test})
230 |
231 |
232 | image_gen_train = Generator(X_train,y_train,width,height)
233 | image_gen_val = Generator(X_val,y_val,width,height)
234 | image_gen_test = Generator(X_test,None,width,height)
235 |
236 |
237 | batch_size = 900
238 | print('len data: ', len(X_train))
239 | print('len test data: ', len(X_test))
240 |
241 | #augment represents [speckle, gaussian, poisson]. It means, the augmentation will be done on the augment list element that is 1
242 | #for example, augment = [1,1,0] corresponds to adding speckle noise and gaussian noise
243 | augment = [1,1,1]
244 | model = CNN_model(width,height)
245 |
246 | model.fit_generator(
247 | generator=image_gen_train.gen_augment(batch_size=batch_size,augment=augment),
248 | steps_per_epoch=np.ceil(len(X_train)/batch_size),
249 | epochs=20,
250 | verbose=1,
251 | validation_data=image_gen_val.gen(),
252 | validation_steps=len(X_val)
253 | )
254 | model.save('model_aug_3.h5')
255 | model = tf.keras.models.load_model('model_aug_3.h5')
256 |
257 | #Try evaluate_generator
258 | image_gen_test = Generator(X_test,y_test,width,height)
259 | print(model.evaluate_generator(
260 | generator=image_gen_test.gen(),
261 | steps=len(X_test)
262 | ))
263 |
264 | #Try predict_generator
265 | image_gen_test = Generator(X_test,None,width,height)
266 | pred = model.predict_generator(
267 | generator=image_gen_test.gen_test(),
268 | steps=len(X_test)
269 | )
270 | pred = np.argmax(pred,axis=1)
271 | # image_gen_test = Generator(X_test,pred,width*3,height*3)
272 | # image_gen_test.gen_show(pred)
273 | wrong_pred = []
274 | for i,ex in enumerate(zip(pred,y_test)):
275 | if ex[0] != np.argmax(ex[1]):
276 | wrong_pred.append(i)
277 | print(wrong_pred)
278 |
279 | # for i in range(len(X_test)):
280 | # im = cv2.imread(X_test[i],0)
281 | # im = cv2.putText(im, str(pred[i]), (10,15), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
282 | # print(i)
283 | # cv2.imshow('image',im)
284 | # cv2.waitKey(0)
285 |
--------------------------------------------------------------------------------