├── .gitignore ├── LICENSE ├── config.py ├── dataset.py ├── engine.py ├── input └── .keep ├── model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 abhishek thakur 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | DATA_DIR = "input/captcha_images_v2/" 2 | BATCH_SIZE = 8 3 | IMAGE_WIDTH = 300 4 | IMAGE_HEIGHT = 75 5 | NUM_WORKERS = 8 6 | EPOCHS = 200 7 | DEVICE = "cuda" 8 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import albumentations 2 | import torch 3 | 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from PIL import ImageFile 8 | 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | 11 | 12 | class ClassificationDataset: 13 | def __init__(self, image_paths, targets, resize=None): 14 | # resize = (height, width) 15 | self.image_paths = image_paths 16 | self.targets = targets 17 | self.resize = resize 18 | 19 | mean = (0.485, 0.456, 0.406) 20 | std = (0.229, 0.224, 0.225) 21 | self.aug = albumentations.Compose( 22 | [ 23 | albumentations.Normalize( 24 | mean, std, max_pixel_value=255.0, always_apply=True 25 | ) 26 | ] 27 | ) 28 | 29 | def __len__(self): 30 | return len(self.image_paths) 31 | 32 | def __getitem__(self, item): 33 | image = Image.open(self.image_paths[item]).convert("RGB") 34 | targets = self.targets[item] 35 | 36 | if self.resize is not None: 37 | image = image.resize( 38 | (self.resize[1], self.resize[0]), resample=Image.BILINEAR 39 | ) 40 | 41 | image = np.array(image) 42 | augmented = self.aug(image=image) 43 | image = augmented["image"] 44 | image = np.transpose(image, (2, 0, 1)).astype(np.float32) 45 | 46 | return { 47 | "images": torch.tensor(image, dtype=torch.float), 48 | "targets": torch.tensor(targets, dtype=torch.long), 49 | } 50 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | import config 4 | 5 | 6 | def train_fn(model, data_loader, optimizer): 7 | model.train() 8 | fin_loss = 0 9 | tk0 = tqdm(data_loader, total=len(data_loader)) 10 | for data in tk0: 11 | for key, value in data.items(): 12 | data[key] = value.to(config.DEVICE) 13 | optimizer.zero_grad() 14 | _, loss = model(**data) 15 | loss.backward() 16 | optimizer.step() 17 | fin_loss += loss.item() 18 | return fin_loss / len(data_loader) 19 | 20 | 21 | def eval_fn(model, data_loader): 22 | model.eval() 23 | fin_loss = 0 24 | fin_preds = [] 25 | tk0 = tqdm(data_loader, total=len(data_loader)) 26 | for data in tk0: 27 | for key, value in data.items(): 28 | data[key] = value.to(config.DEVICE) 29 | batch_preds, loss = model(**data) 30 | fin_loss += loss.item() 31 | fin_preds.append(batch_preds) 32 | return fin_preds, fin_loss / len(data_loader) 33 | -------------------------------------------------------------------------------- /input/.keep: -------------------------------------------------------------------------------- 1 | all the captcha images in this folder 2 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class CaptchaModel(nn.Module): 7 | def __init__(self, num_chars): 8 | super(CaptchaModel, self).__init__() 9 | self.conv_1 = nn.Conv2d(3, 128, kernel_size=(3, 6), padding=(1, 1)) 10 | self.pool_1 = nn.MaxPool2d(kernel_size=(2, 2)) 11 | self.conv_2 = nn.Conv2d(128, 64, kernel_size=(3, 6), padding=(1, 1)) 12 | self.pool_2 = nn.MaxPool2d(kernel_size=(2, 2)) 13 | self.linear_1 = nn.Linear(1152, 64) 14 | self.drop_1 = nn.Dropout(0.2) 15 | self.lstm = nn.GRU(64, 32, bidirectional=True, num_layers=2, dropout=0.25, batch_first=True) 16 | self.output = nn.Linear(64, num_chars + 1) 17 | 18 | def forward(self, images, targets=None): 19 | bs, _, _, _ = images.size() 20 | x = F.relu(self.conv_1(images)) 21 | x = self.pool_1(x) 22 | x = F.relu(self.conv_2(x)) 23 | x = self.pool_2(x) 24 | x = x.permute(0, 3, 1, 2) 25 | x = x.view(bs, x.size(1), -1) 26 | x = F.relu(self.linear_1(x)) 27 | x = self.drop_1(x) 28 | x, _ = self.lstm(x) 29 | x = self.output(x) 30 | x = x.permute(1, 0, 2) 31 | 32 | if targets is not None: 33 | log_probs = F.log_softmax(x, 2) 34 | input_lengths = torch.full( 35 | size=(bs,), fill_value=log_probs.size(0), dtype=torch.int32 36 | ) 37 | target_lengths = torch.full( 38 | size=(bs,), fill_value=targets.size(1), dtype=torch.int32 39 | ) 40 | loss = nn.CTCLoss(blank=0)( 41 | log_probs, targets, input_lengths, target_lengths 42 | ) 43 | return x, loss 44 | 45 | return x, None 46 | 47 | 48 | if __name__ == "__main__": 49 | cm = CaptchaModel(19) 50 | img = torch.rand((1, 3, 50, 200)) 51 | x, _ = cm(img, torch.rand((1, 5))) 52 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import numpy as np 5 | 6 | import albumentations 7 | from sklearn import preprocessing 8 | from sklearn import model_selection 9 | from sklearn import metrics 10 | 11 | import config 12 | import dataset 13 | import engine 14 | from model import CaptchaModel 15 | 16 | 17 | from torch import nn 18 | 19 | 20 | def remove_duplicates(x): 21 | if len(x) < 2: 22 | return x 23 | fin = "" 24 | for j in x: 25 | if fin == "": 26 | fin = j 27 | else: 28 | if j == fin[-1]: 29 | continue 30 | else: 31 | fin = fin + j 32 | return fin 33 | 34 | 35 | def decode_predictions(preds, encoder): 36 | preds = preds.permute(1, 0, 2) 37 | preds = torch.softmax(preds, 2) 38 | preds = torch.argmax(preds, 2) 39 | preds = preds.detach().cpu().numpy() 40 | cap_preds = [] 41 | for j in range(preds.shape[0]): 42 | temp = [] 43 | for k in preds[j, :]: 44 | k = k - 1 45 | if k == -1: 46 | temp.append("§") 47 | else: 48 | p = encoder.inverse_transform([k])[0] 49 | temp.append(p) 50 | tp = "".join(temp).replace("§", "") 51 | cap_preds.append(remove_duplicates(tp)) 52 | return cap_preds 53 | 54 | 55 | def run_training(): 56 | image_files = glob.glob(os.path.join(config.DATA_DIR, "*.png")) 57 | targets_orig = [x.split("/")[-1][:-4] for x in image_files] 58 | targets = [[c for c in x] for x in targets_orig] 59 | targets_flat = [c for clist in targets for c in clist] 60 | 61 | lbl_enc = preprocessing.LabelEncoder() 62 | lbl_enc.fit(targets_flat) 63 | targets_enc = [lbl_enc.transform(x) for x in targets] 64 | targets_enc = np.array(targets_enc) 65 | targets_enc = targets_enc + 1 66 | 67 | ( 68 | train_imgs, 69 | test_imgs, 70 | train_targets, 71 | test_targets, 72 | _, 73 | test_targets_orig, 74 | ) = model_selection.train_test_split( 75 | image_files, targets_enc, targets_orig, test_size=0.1, random_state=42 76 | ) 77 | 78 | train_dataset = dataset.ClassificationDataset( 79 | image_paths=train_imgs, 80 | targets=train_targets, 81 | resize=(config.IMAGE_HEIGHT, config.IMAGE_WIDTH), 82 | ) 83 | train_loader = torch.utils.data.DataLoader( 84 | train_dataset, 85 | batch_size=config.BATCH_SIZE, 86 | num_workers=config.NUM_WORKERS, 87 | shuffle=True, 88 | ) 89 | test_dataset = dataset.ClassificationDataset( 90 | image_paths=test_imgs, 91 | targets=test_targets, 92 | resize=(config.IMAGE_HEIGHT, config.IMAGE_WIDTH), 93 | ) 94 | test_loader = torch.utils.data.DataLoader( 95 | test_dataset, 96 | batch_size=config.BATCH_SIZE, 97 | num_workers=config.NUM_WORKERS, 98 | shuffle=False, 99 | ) 100 | 101 | model = CaptchaModel(num_chars=len(lbl_enc.classes_)) 102 | model.to(config.DEVICE) 103 | 104 | optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) 105 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 106 | optimizer, factor=0.8, patience=5, verbose=True 107 | ) 108 | for epoch in range(config.EPOCHS): 109 | train_loss = engine.train_fn(model, train_loader, optimizer) 110 | valid_preds, test_loss = engine.eval_fn(model, test_loader) 111 | valid_captcha_preds = [] 112 | for vp in valid_preds: 113 | current_preds = decode_predictions(vp, lbl_enc) 114 | valid_captcha_preds.extend(current_preds) 115 | combined = list(zip(test_targets_orig, valid_captcha_preds)) 116 | print(combined[:10]) 117 | test_dup_rem = [remove_duplicates(c) for c in test_targets_orig] 118 | accuracy = metrics.accuracy_score(test_dup_rem, valid_captcha_preds) 119 | print( 120 | f"Epoch={epoch}, Train Loss={train_loss}, Test Loss={test_loss} Accuracy={accuracy}" 121 | ) 122 | scheduler.step(test_loss) 123 | 124 | 125 | if __name__ == "__main__": 126 | run_training() 127 | --------------------------------------------------------------------------------