├── .gitattributes ├── .gitignore ├── README.md ├── assets ├── dice.png ├── iou.png ├── iou_formular1.png ├── pipeline.png └── unet.jpg ├── dataset ├── .gitkeep └── data │ └── .gitkeep ├── models.py ├── prepare_dataset.py ├── pretrained └── README.md ├── requirements.txt ├── test.py ├── test ├── sample1.png ├── sample2.png ├── sample3.png └── sample4.png ├── train.py └── utils ├── __init__.py ├── haarcascade_frontalface_default.xml ├── image.py └── metrics.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pt filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # Others 107 | logs/ 108 | backup/ 109 | 110 | # IntelliJ 111 | .idea 112 | 113 | weights/ 114 | pretrained/model_checkpoint.pt 115 | 116 | dataset/data 117 | dataset/temp 118 | dataset/train 119 | 120 | test/output*.png 121 | test/sample_real*.* 122 | 123 | *.h5 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ML_IDCard_Segmentation (Pytorch) - WIP 2 | Machine Learning Project to identify an ID Card on an image. 3 | 4 | ### Objectives 5 | The goal of this project is to recognize a ID Card on a photo, cut it out using semantic segmentation and to 6 | transform the perspective so that you get a frontal view of the ID Card. 7 | 8 | ## Additional Information 9 | Dataset: [MIDV-500](https://arxiv.org/abs/1807.05786) 10 | Pytorch Version: 1.7.1 CUDA 11.2 11 | 12 | Trained on a NVIDIA GeForce RTX 3090 13 | 14 | ## Installation 15 | 1. Create and activate a new environment. 16 | ``` 17 | conda create -n idcard python=3.9.1 18 | source activate idcard 19 | ``` 20 | 2. Install Dependencies. 21 | ``` 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ## Download and Prepare Dataset 26 | Download the image files (image and ground_truth). 27 | Splits the data into training, test and validation data. 28 | ``` 29 | python prepare_dataset.py 30 | ``` 31 | 32 | ### Training of the neural network 33 | ``` 34 | python train.py --resumeTraining=True 35 | ``` 36 | `resumeTraining` is optional an resumes training on an existing `./pretrained/model_checkpoint.pt` 37 | 38 | ### Test the trained model 39 | ``` 40 | python test.py test/sample1.png --output_mask=test/output_mask.png --output_prediction=test/output_pred.png --model=./pretrained/model_final.pt 41 | ``` 42 | 43 | Call `python test.py --help` for possible arguments. 44 | 45 | ### Additional commands 46 | Starts Tensorboard Visualisation. 47 | ``` 48 | tensorboard --logdir=logs/ 49 | ``` 50 | 51 | ## Background Information 52 | 53 | ### Model 54 | A [U-NET](https://arxiv.org/abs/1505.04597) was used as the model. 55 | U-Net is a convolutional neural network that was developed for biomedical image segmentation at the 56 | Computer Science Department of the University of Freiburg, Germany. 57 | The network is based on the fully convolutional networkand its architecture was modified and extended to work with 58 | fewer training images and to yield more precise segmentations. 59 | Segmentation of a 512*512 image takes less than a second on a modern GPU. 60 | 61 | ![IoU](assets/unet.jpg "U-Net") 62 | 63 | ### Metrics 64 | The Metric [IoU](https://arxiv.org/abs/1902.09630) (Intersection over Unit / Jaccard-Coefficient) was used 65 | to measure the quality of the model. 66 | The closer the Jaccard coefficient is to 1, the greater the similarity of the quantities. The minimum value of the Jaccard coefficient is 0. 67 | ![IoU](assets/iou_formular1.png "IoU") 68 | 69 | Example: 70 | ![IoU](assets/iou.png "IoU") 71 | 72 | ## Results for validation set (trained on the complete dataset) 73 | Intersection over Unit: 74 | 0.9939 75 | 76 | Pipeline Example: 77 | ![Pipeline](assets/pipeline.png "Workflow Pipeline") 78 | 79 | 80 | -------------------------------------------------------------------------------- /assets/dice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobiassteidle/ML_IDCard_Segmentation_Pytorch/2815811b2add9c5983ca4559c2398a44f1ac2533/assets/dice.png -------------------------------------------------------------------------------- /assets/iou.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobiassteidle/ML_IDCard_Segmentation_Pytorch/2815811b2add9c5983ca4559c2398a44f1ac2533/assets/iou.png -------------------------------------------------------------------------------- /assets/iou_formular1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobiassteidle/ML_IDCard_Segmentation_Pytorch/2815811b2add9c5983ca4559c2398a44f1ac2533/assets/iou_formular1.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobiassteidle/ML_IDCard_Segmentation_Pytorch/2815811b2add9c5983ca4559c2398a44f1ac2533/assets/pipeline.png -------------------------------------------------------------------------------- /assets/unet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobiassteidle/ML_IDCard_Segmentation_Pytorch/2815811b2add9c5983ca4559c2398a44f1ac2533/assets/unet.jpg -------------------------------------------------------------------------------- /dataset/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobiassteidle/ML_IDCard_Segmentation_Pytorch/2815811b2add9c5983ca4559c2398a44f1ac2533/dataset/.gitkeep -------------------------------------------------------------------------------- /dataset/data/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | """ 2 | UNet 3 | The main UNet model implementation 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | # Utility Functions 10 | ''' when filter kernel= 3x3, padding=1 makes in&out matrix same size''' 11 | 12 | 13 | def conv_bn_leru(in_channels, out_channels, kernel_size=3, stride=1, padding=1): 14 | return nn.Sequential( 15 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 19 | nn.BatchNorm2d(out_channels), 20 | nn.ReLU(inplace=True), 21 | ) 22 | 23 | 24 | def down_pooling(): 25 | return nn.MaxPool2d(2) 26 | 27 | 28 | def up_pooling(in_channels, out_channels, kernel_size=2, stride=2): 29 | return nn.Sequential( 30 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride), 31 | nn.BatchNorm2d(out_channels), 32 | nn.ReLU(inplace=True) 33 | ) 34 | 35 | 36 | # UNet class 37 | class UNet(nn.Module): 38 | def __init__(self, n_channels, n_classes): 39 | super().__init__() 40 | # go down 41 | self.conv1 = conv_bn_leru(n_channels, 64) 42 | self.conv2 = conv_bn_leru(64, 128) 43 | self.conv3 = conv_bn_leru(128, 256) 44 | self.conv4 = conv_bn_leru(256, 512) 45 | self.conv5 = conv_bn_leru(512, 1024) 46 | self.down_pooling = nn.MaxPool2d(2) 47 | 48 | # dropout 49 | self.dropout = nn.Dropout(0.5) 50 | 51 | # go up 52 | self.up_pool6 = up_pooling(1024, 512) 53 | self.conv6 = conv_bn_leru(1024, 512) 54 | self.up_pool7 = up_pooling(512, 256) 55 | self.conv7 = conv_bn_leru(512, 256) 56 | self.up_pool8 = up_pooling(256, 128) 57 | self.conv8 = conv_bn_leru(256, 128) 58 | self.up_pool9 = up_pooling(128, 64) 59 | self.conv9 = conv_bn_leru(128, 64) 60 | 61 | # output 62 | self.conv10 = nn.Conv2d(64, n_classes, 1) 63 | 64 | # test weight init 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 68 | if m.bias is not None: 69 | m.bias.data.zero_() 70 | 71 | def forward(self, x): 72 | # go down 73 | x1 = self.conv1(x) 74 | p1 = self.down_pooling(x1) 75 | x2 = self.conv2(p1) 76 | p2 = self.down_pooling(x2) 77 | x3 = self.conv3(p2) 78 | p3 = self.down_pooling(x3) 79 | x4 = self.conv4(p3) 80 | p4 = self.down_pooling(x4) 81 | x5 = self.conv5(p4) 82 | 83 | x5 = self.dropout(x5) 84 | 85 | # go up 86 | p6 = self.up_pool6(x5) 87 | x6 = torch.cat([p6, x4], dim=1) 88 | x6 = self.conv6(x6) 89 | 90 | x6 = self.dropout(x6) 91 | 92 | p7 = self.up_pool7(x6) 93 | x7 = torch.cat([p7, x3], dim=1) 94 | x7 = self.conv7(x7) 95 | 96 | p8 = self.up_pool8(x7) 97 | x8 = torch.cat([p8, x2], dim=1) 98 | x8 = self.conv8(x8) 99 | 100 | p9 = self.up_pool9(x8) 101 | x9 = torch.cat([p9, x1], dim=1) 102 | x9 = self.conv9(x9) 103 | 104 | output = self.conv10(x9) 105 | 106 | return output 107 | -------------------------------------------------------------------------------- /prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import json 3 | import numpy as np 4 | import os 5 | import random 6 | import re 7 | import shutil 8 | import wget 9 | import zipfile 10 | from PIL import Image 11 | from glob import glob 12 | ''' 13 | download_links = [ 14 | 'ftp://smartengines.com/midv-500/dataset/12_deu_drvlic_new.zip', 15 | 'ftp://smartengines.com/midv-500/dataset/13_deu_drvlic_old.zip', 16 | 'ftp://smartengines.com/midv-500/dataset/14_deu_id_new.zip', 17 | 'ftp://smartengines.com/midv-500/dataset/15_deu_id_old.zip', 18 | 'ftp://smartengines.com/midv-500/dataset/16_deu_passport_new.zip', 19 | 'ftp://smartengines.com/midv-500/dataset/17_deu_passport_old.zip'] 20 | ''' 21 | 22 | download_links = ['ftp://smartengines.com/midv-500/dataset/01_alb_id.zip', 23 | 'ftp://smartengines.com/midv-500/dataset/02_aut_drvlic_new.zip', 24 | 'ftp://smartengines.com/midv-500/dataset/03_aut_id_old.zip', 25 | 'ftp://smartengines.com/midv-500/dataset/04_aut_id.zip', 26 | 'ftp://smartengines.com/midv-500/dataset/05_aze_passport.zip', 27 | 'ftp://smartengines.com/midv-500/dataset/06_bra_passport.zip', 28 | 'ftp://smartengines.com/midv-500/dataset/07_chl_id.zip', 29 | 'ftp://smartengines.com/midv-500/dataset/08_chn_homereturn.zip', 30 | 'ftp://smartengines.com/midv-500/dataset/09_chn_id.zip', 31 | 'ftp://smartengines.com/midv-500/dataset/10_cze_id.zip', 32 | 'ftp://smartengines.com/midv-500/dataset/11_cze_passport.zip', 33 | 'ftp://smartengines.com/midv-500/dataset/12_deu_drvlic_new.zip', 34 | 'ftp://smartengines.com/midv-500/dataset/13_deu_drvlic_old.zip', 35 | 'ftp://smartengines.com/midv-500/dataset/14_deu_id_new.zip', 36 | 'ftp://smartengines.com/midv-500/dataset/15_deu_id_old.zip', 37 | 'ftp://smartengines.com/midv-500/dataset/16_deu_passport_new.zip', 38 | 'ftp://smartengines.com/midv-500/dataset/17_deu_passport_old.zip', 39 | 'ftp://smartengines.com/midv-500/dataset/18_dza_passport.zip', 40 | 'ftp://smartengines.com/midv-500/dataset/19_esp_drvlic.zip', 41 | 'ftp://smartengines.com/midv-500/dataset/20_esp_id_new.zip', 42 | 'ftp://smartengines.com/midv-500/dataset/21_esp_id_old.zip', 43 | 'ftp://smartengines.com/midv-500/dataset/22_est_id.zip', 44 | 'ftp://smartengines.com/midv-500/dataset/23_fin_drvlic.zip', 45 | 'ftp://smartengines.com/midv-500/dataset/24_fin_id.zip', 46 | 'ftp://smartengines.com/midv-500/dataset/25_grc_passport.zip', 47 | 'ftp://smartengines.com/midv-500/dataset/26_hrv_drvlic.zip', 48 | 'ftp://smartengines.com/midv-500/dataset/27_hrv_passport.zip', 49 | 'ftp://smartengines.com/midv-500/dataset/28_hun_passport.zip', 50 | 'ftp://smartengines.com/midv-500/dataset/29_irn_drvlic.zip', 51 | 'ftp://smartengines.com/midv-500/dataset/30_ita_drvlic.zip', 52 | 'ftp://smartengines.com/midv-500/dataset/31_jpn_drvlic.zip', 53 | 'ftp://smartengines.com/midv-500/dataset/32_lva_passport.zip', 54 | 'ftp://smartengines.com/midv-500/dataset/33_mac_id.zip', 55 | 'ftp://smartengines.com/midv-500/dataset/34_mda_passport.zip', 56 | 'ftp://smartengines.com/midv-500/dataset/35_nor_drvlic.zip', 57 | 'ftp://smartengines.com/midv-500/dataset/36_pol_drvlic.zip', 58 | 'ftp://smartengines.com/midv-500/dataset/37_prt_id.zip', 59 | 'ftp://smartengines.com/midv-500/dataset/38_rou_drvlic.zip', 60 | 'ftp://smartengines.com/midv-500/dataset/39_rus_internalpassport.zip', 61 | 'ftp://smartengines.com/midv-500/dataset/40_srb_id.zip', 62 | 'ftp://smartengines.com/midv-500/dataset/41_srb_passport.zip', 63 | 'ftp://smartengines.com/midv-500/dataset/42_svk_id.zip', 64 | 'ftp://smartengines.com/midv-500/dataset/43_tur_id.zip', 65 | 'ftp://smartengines.com/midv-500/dataset/44_ukr_id.zip', 66 | 'ftp://smartengines.com/midv-500/dataset/45_ukr_passport.zip', 67 | 'ftp://smartengines.com/midv-500/dataset/46_ury_passport.zip', 68 | 'ftp://smartengines.com/midv-500/dataset/47_usa_bordercrossing.zip', 69 | 'ftp://smartengines.com/midv-500/dataset/48_usa_passportcard.zip', 70 | 'ftp://smartengines.com/midv-500/dataset/49_usa_ssn82.zip', 71 | 'ftp://smartengines.com/midv-500/dataset/50_xpo_id.zip'] 72 | 73 | PATH_OFFSET = 40 74 | TARGET_PATH = 'dataset/data/' 75 | 76 | TEMP_PATH = 'dataset/temp/' 77 | TEMP_IMAGE_PATH = TEMP_PATH + 'image/' 78 | TEMP_MASK_PATH = TEMP_PATH + 'mask/' 79 | 80 | DATA_PATH = 'dataset/train/' 81 | 82 | SEED = 230 83 | 84 | 85 | def read_image(img, label): 86 | image = cv2.imread(img) 87 | mask = np.zeros(image.shape, dtype=np.uint8) 88 | quad = json.load(open(label, 'r')) 89 | coords = np.array(quad['quad'], dtype=np.int32) 90 | cv2.fillPoly(mask, coords.reshape(-1, 4, 2), color=(255, 255, 255)) 91 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) 92 | mask = cv2.resize(mask, (mask.shape[1] // 2, mask.shape[0] // 2)) 93 | image = cv2.resize(image, (image.shape[1] // 2, image.shape[0] // 2)) 94 | mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY)[1] 95 | return image, mask 96 | 97 | 98 | def download_and_unzip(): 99 | # Remove Temp Directory and create a new one 100 | if os.path.exists(TEMP_PATH): 101 | shutil.rmtree(TEMP_PATH, ignore_errors=True) 102 | 103 | os.mkdir(TEMP_PATH) 104 | os.mkdir(TEMP_IMAGE_PATH) 105 | os.mkdir(TEMP_MASK_PATH) 106 | 107 | # Counter for filename 108 | file_idx = 1 109 | 110 | for link in download_links: 111 | filename = link[PATH_OFFSET:] 112 | full_filename = TARGET_PATH + filename 113 | directory_name = TARGET_PATH + link[PATH_OFFSET:-4] 114 | 115 | print('Collect and prepare datasets...') 116 | 117 | print('Dataset available... ', directory_name) 118 | if not os.path.exists(directory_name): 119 | if not os.path.isfile(full_filename): 120 | # file not found, execute wget download 121 | print ('Downloading:', link) 122 | wget.download(link, TARGET_PATH) 123 | 124 | # Unzip archives 125 | with zipfile.ZipFile(full_filename, 'r') as zip_ref: 126 | zip_ref.extractall(TARGET_PATH) 127 | 128 | print('Prepare dataset... ', directory_name) 129 | img_dir_path = './' + directory_name + '/images/' 130 | gt_dir_path = './' + directory_name + '/ground_truth/' 131 | 132 | # Remove unessesary files 133 | if os.path.isfile(img_dir_path + filename + '.tif'): 134 | os.remove(img_dir_path + filename.replace('.zip', '.tif')) 135 | if os.path.isfile(gt_dir_path + filename + '.json'): 136 | os.remove(gt_dir_path + filename.replace('.zip', '.json')) 137 | 138 | # Load Images and Groundtruth and store as numpy array 139 | for images, ground_truth in zip(sorted(os.listdir(img_dir_path)), sorted(os.listdir(gt_dir_path))): 140 | img_list = sorted(glob(img_dir_path + images + '/*.tif')) 141 | label_list = sorted(glob(gt_dir_path + ground_truth + '/*.json')) 142 | for img, label in zip(img_list, label_list): 143 | image, mask = read_image(img, label) 144 | cv2.imwrite(TEMP_IMAGE_PATH + 'image' + str(file_idx) + '.png', image) 145 | cv2.imwrite(TEMP_MASK_PATH + 'image' + str(file_idx) + '.png', mask) 146 | 147 | file_idx += 1 148 | 149 | print('----------------------------------------------------------------------') 150 | 151 | 152 | def train_validation_split(): 153 | # Remove Temp Directory and create a new one 154 | if os.path.exists(DATA_PATH): 155 | shutil.rmtree(DATA_PATH, ignore_errors=True) 156 | 157 | # Create folders to hold images and masks 158 | folders = ['train_frames/image', 'train_masks/image', 'val_frames/image', 'val_masks/image', 'test_frames/image', 159 | 'test_masks/image'] 160 | 161 | for folder in folders: 162 | os.makedirs(DATA_PATH + folder) 163 | 164 | # Get all frames and masks, sort them, shuffle them to generate data sets. 165 | all_frames = os.listdir(TEMP_IMAGE_PATH) 166 | all_masks = os.listdir(TEMP_MASK_PATH) 167 | 168 | all_frames.sort(key=lambda var: [int(x) if x.isdigit() else x 169 | for x in re.findall(r'[^0-9]|[0-9]+', var)]) 170 | all_masks.sort(key=lambda var: [int(x) if x.isdigit() else x 171 | for x in re.findall(r'[^0-9]|[0-9]+', var)]) 172 | 173 | random.seed(SEED) 174 | random.shuffle(all_frames) 175 | 176 | # Generate train, val, and test sets for frames 177 | train_split = int(0.7 * len(all_frames)) 178 | val_split = int(0.9 * len(all_frames)) 179 | 180 | train_frames = all_frames[:train_split] 181 | val_frames = all_frames[train_split:val_split] 182 | test_frames = all_frames[val_split:] 183 | 184 | # Generate corresponding mask lists for masks 185 | train_masks = [f for f in all_masks if f in train_frames] 186 | val_masks = [f for f in all_masks if f in val_frames] 187 | test_masks = [f for f in all_masks if f in test_frames] 188 | 189 | # Add train, val, test frames and masks to relevant folders 190 | def add_frames(dir_name, image): 191 | img = Image.open(TEMP_IMAGE_PATH + image) 192 | img.save(DATA_PATH + '/{}'.format(dir_name) + '/' + image) 193 | 194 | def add_masks(dir_name, image): 195 | img = Image.open(TEMP_MASK_PATH + image) 196 | img.save(DATA_PATH + '/{}'.format(dir_name) + '/' + image) 197 | 198 | frame_folders = [(train_frames, 'train_frames/image'), (val_frames, 'val_frames/image'), 199 | (test_frames, 'test_frames/image')] 200 | mask_folders = [(train_masks, 'train_masks/image'), (val_masks, 'val_masks/image'), 201 | (test_masks, 'test_masks/image')] 202 | 203 | print('Split images into train, test and validation...') 204 | 205 | # Add frames 206 | for folder in frame_folders: 207 | array = folder[0] 208 | name = [folder[1]] * len(array) 209 | list(map(add_frames, name, array)) 210 | 211 | # Add masks 212 | for folder in mask_folders: 213 | array = folder[0] 214 | name = [folder[1]] * len(array) 215 | list(map(add_masks, name, array)) 216 | 217 | 218 | def main(): 219 | download_and_unzip() 220 | 221 | train_validation_split() 222 | 223 | 224 | if __name__ == '__main__': 225 | main() 226 | -------------------------------------------------------------------------------- /pretrained/README.md: -------------------------------------------------------------------------------- 1 | ### Download model 2 | Download the [model_final.pt](https://drive.google.com/file/d/1u88se4-G0F-r_ntqm3IblhzMkgxMbzVN/view?usp=sharing) from my Google Drive. 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy==1.16.4 3 | pandas 4 | wget==3.0 5 | opencv-python 6 | opencv-contrib-python 7 | pillow 8 | pytesseract 9 | tesseract 10 | wget==3.0 11 | imutils 12 | scikit-image 13 | librosa 14 | seaborn 15 | tqdm 16 | xgboost 17 | PyWavelets 18 | torch==1.7.1+cu110 19 | torchaudio==0.7.2 20 | torchsummary==1.5.1 21 | torchvision==0.8.2+cu110 22 | scikit-image==0.18.1 23 | scikit-learn==0.24.1 24 | scikit-plot==0.3.7 25 | h5py 26 | ipykernel 27 | jupyter 28 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import os 5 | import pathlib 6 | 7 | import torch 8 | 9 | import models 10 | from utils import image 11 | 12 | parser = argparse.ArgumentParser(description='Semantic segmentation of IDCard in Image.') 13 | parser.add_argument('input', type=str, help='Image (with IDCard) Input file') 14 | parser.add_argument('--output_mask', type=str, default='output_mask.png', help='Output file for mask') 15 | parser.add_argument('--output_prediction', type=str, default='output_pred.png', help='Output file for image') 16 | parser.add_argument('--model', type=str, default='./pretrained/model_checkpoint.pt', help='Path to checkpoint file') 17 | 18 | args = parser.parse_args() 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | INPUT_FILE = args.input 22 | OUTPUT_MASK = args.output_mask 23 | OUTPUT_FILE = args.output_prediction 24 | MODEL_FILE = args.model 25 | 26 | 27 | def predict_image(model, image): 28 | with torch.no_grad(): 29 | output = model(image.to(device)) 30 | 31 | output = output.detach().cpu().numpy()[0] 32 | output = output.transpose((1, 2, 0)) 33 | output = np.uint8(output) 34 | _, output = cv2.threshold(output, 127, 255, cv2.THRESH_BINARY_INV) 35 | 36 | return output 37 | 38 | 39 | def main(): 40 | if not os.path.isfile(INPUT_FILE): 41 | print('Input image not found ', INPUT_FILE) 42 | else: 43 | if not os.path.isfile(MODEL_FILE): 44 | print('Model not found ', MODEL_FILE) 45 | 46 | else: 47 | print('Load model... ', MODEL_FILE) 48 | model = models.UNet(n_channels=1, n_classes=1) 49 | 50 | checkpoint = torch.load(pathlib.Path(MODEL_FILE)) 51 | model.load_state_dict(checkpoint) 52 | model.to(device) 53 | model.eval() 54 | 55 | print('Load image... ', INPUT_FILE) 56 | img, h, w = image.load_image(INPUT_FILE) 57 | 58 | print('Prediction...') 59 | output_image = predict_image(model, img) 60 | 61 | print('Resize mask to original size...') 62 | mask_image = cv2.resize(output_image, (w, h)) 63 | cv2.imwrite(OUTPUT_MASK, mask_image) 64 | 65 | print('Cut it out...') 66 | warped = image.extract_idcard(cv2.imread(INPUT_FILE), mask_image) 67 | cv2.imwrite(OUTPUT_FILE, warped) 68 | 69 | print('Done.') 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /test/sample1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobiassteidle/ML_IDCard_Segmentation_Pytorch/2815811b2add9c5983ca4559c2398a44f1ac2533/test/sample1.png -------------------------------------------------------------------------------- /test/sample2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobiassteidle/ML_IDCard_Segmentation_Pytorch/2815811b2add9c5983ca4559c2398a44f1ac2533/test/sample2.png -------------------------------------------------------------------------------- /test/sample3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobiassteidle/ML_IDCard_Segmentation_Pytorch/2815811b2add9c5983ca4559c2398a44f1ac2533/test/sample3.png -------------------------------------------------------------------------------- /test/sample4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobiassteidle/ML_IDCard_Segmentation_Pytorch/2815811b2add9c5983ca4559c2398a44f1ac2533/test/sample4.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | import time 5 | import argparse 6 | import random 7 | from os import walk 8 | import pathlib 9 | from PIL import Image 10 | 11 | import torch 12 | from torch import nn 13 | from torch.optim import Adam 14 | from torch.utils.data import Dataset, DataLoader 15 | from torch.utils.tensorboard import SummaryWriter 16 | from torch.optim.lr_scheduler import ReduceLROnPlateau 17 | 18 | import models 19 | from utils.metrics import multi_acc, iou_score 20 | 21 | NO_OF_EPOCHS = 500 22 | BATCH_SIZE = 32 23 | IMAGE_SIZE = (256, 256) 24 | 25 | SEED = 230 26 | 27 | CHECKPOINT_PATH = pathlib.Path("./pretrained/model_checkpoint.pt") 28 | FINAL_PATH = pathlib.Path("./pretrained/model_final.pt") 29 | 30 | parser = argparse.ArgumentParser(description='Training Semantic segmentation of IDCard in Image.') 31 | parser.add_argument('--resumeTraining', type=bool, default=False, help='Resume Training') 32 | 33 | args = parser.parse_args() 34 | RESUME_TRAINING = args.resumeTraining 35 | 36 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | 38 | 39 | def seed_torch(seed=SEED): 40 | random.seed(seed) 41 | os.environ['PYTHONHASHSEED'] = str(seed) 42 | np.random.seed(seed) 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed(seed) 45 | # torch.backends.cudnn.deterministic = True 46 | 47 | 48 | class SegmentationImageDataset(Dataset): 49 | def __init__(self, image_dir, mask_dir): 50 | self.image_dir = image_dir 51 | self.mask_dir = mask_dir 52 | _, _, self.filenames = next(walk(image_dir)) 53 | 54 | @classmethod 55 | def preprocess(cls, pil_img, normalize=True): 56 | pil_img = pil_img.convert('L') 57 | 58 | pil_img = pil_img.resize(IMAGE_SIZE) 59 | img_nd = np.array(pil_img) 60 | 61 | if len(img_nd.shape) == 2: 62 | img_nd = np.expand_dims(img_nd, axis=2) 63 | 64 | # HWC to CHW 65 | img_trans = img_nd.transpose((2, 0, 1)) 66 | 67 | if normalize: 68 | img_trans = img_trans / 255 69 | 70 | return img_trans 71 | 72 | def __getitem__(self, idx): 73 | image = Image.open(self.image_dir + '/' + self.filenames[idx]) 74 | mask = Image.open(self.mask_dir + '/' + self.filenames[idx]) 75 | 76 | image = self.preprocess(image) 77 | mask = self.preprocess(mask) 78 | 79 | image = torch.from_numpy(image).type(torch.FloatTensor) 80 | mask = torch.from_numpy(mask).type(torch.FloatTensor) 81 | 82 | return image, mask 83 | 84 | def __len__(self): 85 | return len(self.filenames) 86 | 87 | 88 | def saveCheckpoint(filename, epoch, model, optimizer, batchsize): 89 | checkpoint = { 90 | 'epoch': epoch, 91 | 'model_state_dict': model.state_dict(), 92 | 'optimizer_state_dict': optimizer.state_dict(), 93 | "batch_size": batchsize, 94 | } 95 | 96 | # save all important stuff 97 | torch.save(checkpoint, filename) 98 | 99 | 100 | def train(model, data_loader, criterion, optimizer, scheduler, num_epochs=5, epochs_earlystopping=10): 101 | logdir = './logs/' + time.strftime("%Y%m%d_%H%M%S") 102 | logdir = os.path.join(logdir) 103 | pathlib.Path(logdir).mkdir(parents=True, exist_ok=True) 104 | tb_writer = SummaryWriter(log_dir=logdir) 105 | 106 | best_acc = 0.0 107 | best_loss = sys.float_info.max 108 | best_iou = 0.0 109 | 110 | early_stopping = epochs_earlystopping 111 | 112 | for epoch in range(num_epochs): 113 | result = [] 114 | early_stopping += 1 115 | 116 | for phase in ['train', 'val']: 117 | if phase == 'train': # put the model in training mode 118 | model.train() 119 | else: 120 | # put the model in validation mode 121 | model.eval() 122 | 123 | # keep track of training and validation loss 124 | batch_nums = 0 125 | running_loss = 0.0 126 | running_iou = 0.0 127 | running_corrects = 0.0 128 | 129 | for (data, labels) in data_loader[phase]: 130 | # load the data and target to respective device 131 | (data, labels) = (data.to(device), labels.to(device)) 132 | 133 | with torch.set_grad_enabled(phase == 'train'): 134 | # feed the input 135 | output = model(data) 136 | 137 | # calculate the loss 138 | loss = criterion(output, labels) 139 | 140 | if phase == 'train': 141 | # backward pass: compute gradient of the loss with respect to model parameters 142 | loss.backward() 143 | 144 | optimizer.step() 145 | 146 | # zero the grad to stop it from accumulating 147 | optimizer.zero_grad() 148 | 149 | # statistics 150 | batch_nums += 1 151 | running_loss += loss.item() 152 | running_iou += iou_score(output, labels) 153 | running_corrects += multi_acc(output, labels) 154 | 155 | if phase == 'train': 156 | scheduler.step(running_iou) 157 | 158 | # epoch statistics 159 | epoch_loss = running_loss / batch_nums 160 | epoch_iou = running_iou / batch_nums 161 | epoch_acc = running_corrects / batch_nums 162 | 163 | result.append('{} Loss: {:.4f} Acc: {:.4f} IoU: {:.4f}'.format(phase, epoch_loss, epoch_acc, epoch_iou)) 164 | 165 | tb_writer.add_scalar('Loss/' + phase, epoch_loss, epoch) 166 | tb_writer.add_scalar('IoU/' + phase, epoch_iou, epoch) 167 | tb_writer.add_scalar('Accuracy/' + phase, epoch_acc, epoch) 168 | 169 | if phase == 'val' and epoch_iou > best_iou: 170 | early_stopping = 0 171 | 172 | best_acc = epoch_acc 173 | best_loss = epoch_loss 174 | best_iou = epoch_iou 175 | saveCheckpoint(CHECKPOINT_PATH, epoch, model, optimizer, BATCH_SIZE) 176 | print( 177 | 'Checkpoint saved - Loss: {:.4f} Acc: {:.4f} IoU: {:.4f}'.format(epoch_loss, epoch_acc, epoch_iou)) 178 | 179 | print(result) 180 | 181 | if early_stopping == 10: 182 | break 183 | 184 | print('-----------------------------------------') 185 | print('Final Result: Loss: {:.4f} Acc: {:.4f}'.format(best_loss, best_acc)) 186 | print('-----------------------------------------') 187 | 188 | 189 | def main(): 190 | seed_torch() 191 | 192 | print('Create datasets...') 193 | train_dataset = SegmentationImageDataset('./dataset/train/train_frames/image', './dataset/train/train_masks/image') 194 | validation_dataset = SegmentationImageDataset('./dataset/train/val_frames/image', './dataset/train/val_masks/image') 195 | 196 | print('Create dataloader...') 197 | train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) 198 | validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) 199 | 200 | dataloader = {"train": train_dataloader, 201 | "val": validation_dataloader} 202 | 203 | print('Initialize model...') 204 | model = models.UNet(n_channels=1, n_classes=1) 205 | model = model.to(device) 206 | 207 | criterion = nn.BCEWithLogitsLoss() 208 | optimizer = Adam(model.parameters(), lr=1e-4) 209 | scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.2, patience=3, verbose=True) 210 | 211 | print(RESUME_TRAINING) 212 | if RESUME_TRAINING: 213 | print('Load Model to resume training...') 214 | checkpoint = torch.load(CHECKPOINT_PATH) 215 | model.load_state_dict(checkpoint['model_state_dict']) 216 | model.to(device) 217 | model.eval() 218 | 219 | print('Start training...') 220 | train(model, dataloader, criterion, optimizer, scheduler, num_epochs=NO_OF_EPOCHS) 221 | 222 | print('Save final model...') 223 | torch.save(model.state_dict(), FINAL_PATH) 224 | 225 | 226 | if __name__ == '__main__': 227 | main() 228 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobiassteidle/ML_IDCard_Segmentation_Pytorch/2815811b2add9c5983ca4559c2398a44f1ac2533/utils/__init__.py -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | # https://github.com/KMKnation/Four-Point-Invoice-Transform-with-OpenCV/blob/master/four_point_object_extractor.py 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | import torch 7 | 8 | 9 | def load_image(input_file): 10 | source_img = Image.open(input_file).convert('L') 11 | 12 | # start with an quadratic image 13 | offset = 20 14 | size = np.max([source_img.size, source_img.size]) 15 | image = Image.new('L', (size + offset, size + offset)) 16 | image.paste(source_img, (offset // 2, offset // 2)) 17 | 18 | width, height = image.size 19 | 20 | # resize for inference 21 | image = image.resize((256, 256)) 22 | img_nd = np.array(image) 23 | 24 | # expand grayscale image to 3 dimensions 25 | if len(img_nd.shape) == 2: 26 | img_nd = np.expand_dims(img_nd, axis=2) 27 | 28 | # HWC to CHW 29 | img_trans = img_nd.transpose((2, 0, 1)) 30 | img_trans = img_trans / 255 31 | 32 | # reshape to 1-batched tensor 33 | img_trans = img_trans.reshape(1, 1, 256, 256) 34 | 35 | return torch.from_numpy(img_trans).type(torch.FloatTensor), height, width 36 | 37 | 38 | def order_points(pts): 39 | rect = np.zeros((4, 2), dtype="float32") 40 | 41 | s = pts.sum(axis=1) 42 | rect[0] = pts[np.argmin(s)] 43 | rect[2] = pts[np.argmax(s)] 44 | 45 | diff = np.diff(pts, axis=1) 46 | rect[1] = pts[np.argmin(diff)] 47 | rect[3] = pts[np.argmax(diff)] 48 | 49 | return rect 50 | 51 | 52 | def four_point_transform(image, pts): 53 | rect = order_points(pts) 54 | (tl, tr, br, bl) = rect 55 | widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) 56 | widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) 57 | maxWidth = max(int(widthA), int(widthB)) 58 | 59 | heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) 60 | heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) 61 | maxHeight = max(int(heightA), int(heightB)) 62 | 63 | dst = np.array([ 64 | [0, 0], 65 | [maxWidth - 1, 0], 66 | [maxWidth - 1, maxHeight - 1], 67 | [0, maxHeight - 1]], dtype="float32") 68 | 69 | M = cv2.getPerspectiveTransform(rect, dst) 70 | warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight)) 71 | return warped 72 | 73 | 74 | def find_contours(image, thickness=3): 75 | contours, hierarchy = cv2.findContours(image.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 76 | contour_image = np.zeros_like(image) 77 | cv2.drawContours(contour_image, contours, -1, 255, thickness) 78 | return contour_image, contours, hierarchy 79 | 80 | 81 | def extract_idcard(raw_image, mask_image): 82 | contour_image, contours, hierarchy = find_contours(mask_image) 83 | 84 | cnts = sorted(contours, key=cv2.contourArea, reverse=True) 85 | screenCntList = [] 86 | for cnt in cnts: 87 | peri = cv2.arcLength(cnt, True) 88 | approx = cv2.approxPolyDP(cnt, 0.02 * peri, True) 89 | screenCnt = approx 90 | 91 | if (len(screenCnt) == 4): 92 | screenCntList.append(screenCnt) 93 | 94 | assert len(screenCntList) == 1 95 | new_points = np.array([[points[0][0], points[0][1]] for points in screenCntList[0]]) 96 | 97 | warped = four_point_transform(raw_image, new_points) 98 | return cv2.cvtColor(warped, cv2.COLOR_BGR2RGB) 99 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def multi_acc(pred, label): 5 | _, tags = torch.max(pred, dim=1) 6 | corrects = (tags == label).float() 7 | acc = corrects.sum() / corrects.numel() 8 | acc = acc * 100 9 | return acc 10 | 11 | 12 | def iou_score(output, target): 13 | smooth = 1e-5 14 | 15 | if torch.is_tensor(output): 16 | output = torch.sigmoid(output).data.cpu().numpy() 17 | if torch.is_tensor(target): 18 | target = target.data.cpu().numpy() 19 | output_ = output > 0.5 20 | target_ = target > 0.5 21 | intersection = (output_ & target_).sum() 22 | union = (output_ | target_).sum() 23 | 24 | return (intersection + smooth) / (union + smooth) 25 | --------------------------------------------------------------------------------