├── LICENSE ├── README.md ├── learning_curve.jpg ├── model_structure.jpg ├── ocr_captcha.py ├── predict_example.py └── result.jpg /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 nbswords 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ocr-captchas 2 | 3 | A pratice of OCR model for recognition Captchas by Tensorflow.Keras 4 | 5 | ## Data 6 | 7 | Example Dataset is from Kaggle's CAPTCHA Images.
8 | 9 | [Link](https://www.kaggle.com/fournierp/captcha-version-2-images) 10 | 11 | ## Usage 12 | 13 | - Prepare your data and name the folder `captcha_images` 14 | - notice that the label of captchas should be the filename of image just like example data 15 | 16 | - Train and save the model 17 | 18 | ```shell 19 | python ocr_captcha.py 20 | ``` 21 | 22 | - Inference with single image named `test.png` 23 | 24 | ```shell 25 | python predict_example.py 26 | ``` 27 | 28 | ## Result 29 | 30 | ![result](./result.jpg) 31 |
32 | 33 | ## Learning curve 34 | 35 | ![learning_curve](./learning_curve.jpg) 36 |
37 | 38 | ## model 39 | 40 | CRNN + CTC loss 41 | 42 | ![model_structure](./model_structure.jpg) 43 |
44 | -------------------------------------------------------------------------------- /learning_curve.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nbswords/ocr-captchas/308dce7b66ca2cf828ba2f098396d4a9e05aa32b/learning_curve.jpg -------------------------------------------------------------------------------- /model_structure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nbswords/ocr-captchas/308dce7b66ca2cf828ba2f098396d4a9e05aa32b/model_structure.jpg -------------------------------------------------------------------------------- /ocr_captcha.py: -------------------------------------------------------------------------------- 1 | """ 2 | ## Import packages and load data 3 | """ 4 | 5 | import os 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | from pathlib import Path 10 | 11 | import tensorflow as tf 12 | from tensorflow import keras 13 | from tensorflow.keras import layers 14 | 15 | # Path of data directory 16 | data_dir = Path("./captcha_images/") 17 | 18 | # Get images 19 | images = sorted(list(map(str, list(data_dir.glob("*.png"))))) 20 | labels = [img.split(os.path.sep)[-1].split(".png")[0] for img in images] 21 | characters = sorted(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 22 | 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']) 23 | 24 | 25 | print("Number of images found: ", len(images)) 26 | print("Number of labels found: ", len(labels)) 27 | print("Number of unique characters: ", len(characters)) 28 | print("Characters present: ", characters) 29 | 30 | # Parametes 31 | batch_size = 16 32 | img_width = 200 33 | img_height = 50 34 | 35 | # Factor by which the image is going to be downsampled 36 | # by the convolutional blocks. We will be using two 37 | # convolution blocks and each block will have 38 | # a pooling layer which downsample the features by a factor of 2. 39 | # Hence total downsampling factor would be 4. 40 | downsample_factor = 4 41 | 42 | # Maximum length of any captcha in the dataset 43 | max_length = max([len(label) for label in labels]) 44 | 45 | 46 | """ 47 | ## Preprocessing 48 | """ 49 | 50 | 51 | # Mapping characters to integers 52 | char_to_num = layers.StringLookup(vocabulary=list(characters), mask_token=None) 53 | 54 | # Mapping integers back to original characters 55 | num_to_char = layers.StringLookup( 56 | vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True 57 | ) 58 | 59 | 60 | def split_data(images, labels, train_size=0.9, shuffle=True): 61 | # 1. Get the total size of the dataset 62 | size = len(images) 63 | # 2. Make an indices array and shuffle it, if required 64 | indices = np.arange(size) 65 | if shuffle: 66 | np.random.shuffle(indices) 67 | # 3. Get the size of training samples 68 | train_samples = int(size * train_size) 69 | # 4. Split data into training and validation sets 70 | x_train, y_train = images[indices[:train_samples] 71 | ], labels[indices[:train_samples]] 72 | x_valid, y_valid = images[indices[train_samples:] 73 | ], labels[indices[train_samples:]] 74 | return x_train, x_valid, y_train, y_valid 75 | 76 | 77 | # Splitting data into training and validation sets 78 | x_train, x_valid, y_train, y_valid = split_data( 79 | np.array(images), np.array(labels)) 80 | 81 | 82 | def encode_single_sample(img_path, label): 83 | # 1. Read image 84 | img = tf.io.read_file(img_path) 85 | # 2. Decode and convert to grayscale 86 | img = tf.io.decode_png(img, channels=1) 87 | # 3. Convert to float32 in [0, 1] range 88 | img = tf.image.convert_image_dtype(img, tf.float32) 89 | # 4. Resize to the desired size 90 | img = tf.image.resize(img, [img_height, img_width]) 91 | # 5. Transpose the image because we want the time 92 | # dimension to correspond to the width of the image. 93 | img = tf.transpose(img, perm=[1, 0, 2]) 94 | # 6. Map the characters in label to numbers 95 | label = char_to_num(tf.strings.unicode_split( 96 | label, input_encoding="UTF-8")) 97 | # 7. Return a dict as our model is expecting two inputs 98 | return {"image": img, "label": label} 99 | 100 | 101 | """ 102 | ## Create `Dataset` objects 103 | """ 104 | 105 | 106 | train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) 107 | train_dataset = ( 108 | train_dataset.map(encode_single_sample, 109 | num_parallel_calls=tf.data.AUTOTUNE) 110 | .batch(batch_size) 111 | .prefetch(buffer_size=tf.data.AUTOTUNE) 112 | ) 113 | 114 | validation_dataset = tf.data.Dataset.from_tensor_slices((x_valid, y_valid)) 115 | validation_dataset = ( 116 | validation_dataset.map(encode_single_sample, 117 | num_parallel_calls=tf.data.AUTOTUNE) 118 | .batch(batch_size) 119 | .prefetch(buffer_size=tf.data.AUTOTUNE) 120 | ) 121 | 122 | """ 123 | ## Visualize the data 124 | """ 125 | 126 | 127 | _, ax = plt.subplots(4, 4, figsize=(10, 5)) 128 | for batch in train_dataset.take(1): 129 | images = batch["image"] 130 | labels = batch["label"] 131 | for i in range(16): 132 | img = (images[i] * 255).numpy().astype("uint8") 133 | label = tf.strings.reduce_join( 134 | num_to_char(labels[i])).numpy().decode("utf-8") 135 | ax[i // 4, i % 4].imshow(img[:, :, 0].T, cmap="gray") 136 | ax[i // 4, i % 4].set_title(label) 137 | ax[i // 4, i % 4].axis("off") 138 | plt.show() 139 | 140 | """ 141 | ## Model 142 | """ 143 | 144 | 145 | class CTCLayer(layers.Layer): 146 | def __init__(self, name=None): 147 | super().__init__(name=name) 148 | self.loss_fn = keras.backend.ctc_batch_cost 149 | 150 | def call(self, y_true, y_pred): 151 | # Compute the training-time loss value and add it 152 | # to the layer using `self.add_loss()`. 153 | batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64") 154 | input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64") 155 | label_length = tf.cast(tf.shape(y_true)[1], dtype="int64") 156 | 157 | input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64") 158 | label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64") 159 | 160 | loss = self.loss_fn(y_true, y_pred, input_length, label_length) 161 | self.add_loss(loss) 162 | 163 | # At test time, just return the computed predictions 164 | return y_pred 165 | 166 | 167 | def build_model(): 168 | # Inputs to the model 169 | input_img = layers.Input( 170 | shape=(img_width, img_height, 1), name="image", dtype="float32" 171 | ) 172 | labels = layers.Input(name="label", shape=(None,), dtype="float32") 173 | 174 | # First conv block 175 | x = layers.Conv2D( 176 | 32, 177 | (3, 3), 178 | activation="relu", 179 | kernel_initializer="he_normal", 180 | padding="same", 181 | name="Conv1", 182 | )(input_img) 183 | x = layers.MaxPooling2D((2, 2), name="pool1")(x) 184 | 185 | # Second conv block 186 | x = layers.Conv2D( 187 | 64, 188 | (3, 3), 189 | activation="relu", 190 | kernel_initializer="he_normal", 191 | padding="same", 192 | name="Conv2", 193 | )(x) 194 | x = layers.MaxPooling2D((2, 2), name="pool2")(x) 195 | 196 | # We have used two max pool with pool size and strides 2. 197 | # Hence, downsampled feature maps are 4x smaller. The number of 198 | # filters in the last layer is 64. Reshape accordingly before 199 | # passing the output to the RNN part of the model 200 | new_shape = ((img_width // 4), (img_height // 4) * 64) 201 | x = layers.Reshape(target_shape=new_shape, name="reshape")(x) 202 | x = layers.Dense(64, activation="relu", name="dense1")(x) 203 | x = layers.Dropout(0.2)(x) 204 | 205 | # RNNs 206 | x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.25))(x) 207 | x = layers.Bidirectional(layers.LSTM(64, return_sequences=True, dropout=0.25))(x) 208 | 209 | # Output layer 210 | x = layers.Dense( 211 | len(char_to_num.get_vocabulary()) + 1, activation="softmax", name="dense2" 212 | )(x) 213 | 214 | # Add CTC layer for calculating CTC loss at each step 215 | output = CTCLayer(name="ctc_loss")(labels, x) 216 | 217 | # Define the model 218 | model = keras.models.Model( 219 | inputs=[input_img, labels], outputs=output, name="ocr_model_v1" 220 | ) 221 | # Optimizer 222 | opt = keras.optimizers.Adam() 223 | # Compile the model and return 224 | model.compile(optimizer=opt) 225 | return model 226 | 227 | 228 | # Get the model 229 | model = build_model() 230 | model.summary() 231 | 232 | """ 233 | ## Training 234 | """ 235 | 236 | 237 | epochs = 100 238 | early_stopping_patience = 10 239 | # Add early stopping 240 | early_stopping = keras.callbacks.EarlyStopping( 241 | monitor="val_loss", patience=early_stopping_patience, restore_best_weights=True 242 | ) 243 | 244 | # Train the model 245 | history = model.fit( 246 | train_dataset, 247 | validation_data=validation_dataset, 248 | epochs=epochs, 249 | callbacks=[early_stopping], 250 | ) 251 | 252 | # Save the model 253 | model.save('captcha_ocr_model') 254 | 255 | """ 256 | ## Inference 257 | """ 258 | 259 | 260 | # Get the prediction model by extracting layers till the output layer 261 | prediction_model = keras.models.Model( 262 | model.get_layer(name="image").input, model.get_layer(name="dense2").output 263 | ) 264 | prediction_model.summary() 265 | 266 | # A utility function to decode the output of the network 267 | def decode_batch_predictions(pred): 268 | input_len = np.ones(pred.shape[0]) * pred.shape[1] 269 | # Use greedy search. For complex tasks, you can use beam search 270 | results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][ 271 | :, :max_length 272 | ] 273 | # Iterate over the results and get back the text 274 | output_text = [] 275 | for res in results: 276 | res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8") 277 | output_text.append(res) 278 | return output_text 279 | 280 | 281 | # Let's check results on some validation samples 282 | for batch in validation_dataset.take(1): 283 | batch_images = batch["image"] 284 | batch_labels = batch["label"] 285 | 286 | preds = prediction_model.predict(batch_images) 287 | pred_texts = decode_batch_predictions(preds) 288 | 289 | orig_texts = [] 290 | for label in batch_labels: 291 | label = tf.strings.reduce_join( 292 | num_to_char(label)).numpy().decode("utf-8") 293 | orig_texts.append(label) 294 | 295 | _, ax = plt.subplots(4, 4, figsize=(15, 5)) 296 | for i in range(len(pred_texts)): 297 | img = (batch_images[i, :, :, 0] * 255).numpy().astype(np.uint8) 298 | img = img.T 299 | title = f"Prediction: {pred_texts[i]}" 300 | ax[i // 4, i % 4].imshow(img, cmap="gray") 301 | ax[i // 4, i % 4].set_title(title) 302 | ax[i // 4, i % 4].axis("off") 303 | plt.show() 304 | -------------------------------------------------------------------------------- /predict_example.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import tensorflow as tf 4 | from tensorflow.keras import layers 5 | from tensorflow.keras.models import load_model 6 | 7 | """ 8 | ## Define constants 9 | """ 10 | img_width = 200 11 | img_height = 50 12 | max_length = 5 # Adjust this based on your dataset 13 | 14 | # Characters present in the dataset must be as same as the one used in training 15 | characters = sorted(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 16 | 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']) 17 | 18 | # Mapping characters to integers 19 | char_to_num = layers.StringLookup(vocabulary=list(characters), mask_token=None) 20 | 21 | # Mapping integers back to original characters 22 | num_to_char = layers.StringLookup( 23 | vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True) 24 | 25 | """ 26 | ## Load the model 27 | """ 28 | model = load_model('captcha_ocr_model', compile=False) 29 | 30 | # Extract the prediction model 31 | prediction_model = tf.keras.models.Model(model.get_layer( 32 | name="image").input, model.get_layer(name="dense2").output) 33 | prediction_model.summary() 34 | 35 | """ 36 | ## Preprocessing functions 37 | """ 38 | # Define a function to preprocess the image 39 | def preprocess_image(img_path): 40 | # Read image 41 | img = tf.io.read_file(img_path) 42 | # Decode and convert to grayscale 43 | img = tf.io.decode_png(img, channels=1) 44 | # Convert to float32 in [0, 1] range 45 | img = tf.image.convert_image_dtype(img, tf.float32) 46 | # Resize to the desired size 47 | img = tf.image.resize(img, [img_height, img_width]) 48 | # Transpose the image because we want the time dimension to correspond to the width of the image. 49 | img = tf.transpose(img, perm=[1, 0, 2]) 50 | # Expand dims to add batch size 51 | img = tf.expand_dims(img, axis=0) 52 | return img 53 | 54 | # Define a function to decode the prediction 55 | def decode_prediction(pred): 56 | input_len = np.ones(pred.shape[0]) * pred.shape[1] 57 | results = tf.keras.backend.ctc_decode( 58 | pred, input_length=input_len, greedy=True)[0][0][:, :max_length] 59 | output_text = [] 60 | for res in results: 61 | res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8") 62 | output_text.append(res) 63 | return output_text 64 | 65 | """ 66 | ## Inference 67 | """ 68 | 69 | # Load and preprocess the image 70 | img_path = 'test.png' 71 | img = preprocess_image(img_path) 72 | 73 | # Make the prediction 74 | pred = prediction_model.predict(img) 75 | pred_text = decode_prediction(pred) 76 | 77 | # Print the prediction 78 | print("Predicted text:", pred_text[0]) 79 | 80 | # Visualize the image and prediction 81 | plt.imshow(img[0, :, :, 0].numpy().T, cmap='gray') 82 | plt.title(f"Prediction: {pred_text[0]}") 83 | plt.axis('off') 84 | plt.show() 85 | -------------------------------------------------------------------------------- /result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nbswords/ocr-captchas/308dce7b66ca2cf828ba2f098396d4a9e05aa32b/result.jpg --------------------------------------------------------------------------------