├── results ├── res01.png ├── res02.png ├── res03.png ├── res04.png └── res05.png ├── losses └── losses_12.png ├── config.py ├── test.py ├── README.md ├── datasets.py ├── unet.py ├── train.py └── generate_synthetic_dataset.py /results/res01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n0obcoder/UNet-based-Denoising-Autoencoder-In-PyTorch/HEAD/results/res01.png -------------------------------------------------------------------------------- /results/res02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n0obcoder/UNet-based-Denoising-Autoencoder-In-PyTorch/HEAD/results/res02.png -------------------------------------------------------------------------------- /results/res03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n0obcoder/UNet-based-Denoising-Autoencoder-In-PyTorch/HEAD/results/res03.png -------------------------------------------------------------------------------- /results/res04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n0obcoder/UNet-based-Denoising-Autoencoder-In-PyTorch/HEAD/results/res04.png -------------------------------------------------------------------------------- /results/res05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n0obcoder/UNet-based-Denoising-Autoencoder-In-PyTorch/HEAD/results/res05.png -------------------------------------------------------------------------------- /losses/losses_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n0obcoder/UNet-based-Denoising-Autoencoder-In-PyTorch/HEAD/losses/losses_12.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # path to saving models 4 | models_dir = 'models' 5 | 6 | # path to saving loss plots 7 | losses_dir = 'losses' 8 | 9 | # path to the data directories 10 | data_dir = 'data' 11 | train_dir = 'train' 12 | val_dir = 'val' 13 | imgs_dir = 'imgs' 14 | noisy_dir = 'noisy' 15 | debug_dir = 'debug' 16 | 17 | # depth of UNet 18 | depth = 4 # try decreasing the depth value if there is a memory error 19 | 20 | # text file to get text from 21 | txt_file_dir = 'shitty_text.txt' 22 | 23 | # maximun number of synthetic words to generate 24 | num_synthetic_imgs = 18000 25 | train_percentage = 0.8 26 | 27 | resume = not True # False for trainig from scratch, True for loading a previously saved weight 28 | ckpt='model01.pth' # model file path to load the weights from, only useful when resume is True 29 | lr = 1e-5 # learning rate 30 | epochs = 12 # epochs to train for 31 | 32 | # batch size for train and val loaders 33 | batch_size = 32 # try decreasing the batch_size if there is a memory error 34 | 35 | # log interval for training and validation 36 | log_interval = 25 37 | 38 | test_dir = os.path.join(data_dir, val_dir, noisy_dir) 39 | res_dir = 'results' 40 | test_bs = 64 41 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os, shutil, cv2 2 | import numpy as np 3 | import torch 4 | from torchvision import transforms 5 | from unet import UNet 6 | from datasets import custom_test_dataset 7 | import config as cfg 8 | 9 | res_dir = cfg.res_dir 10 | 11 | if os.path.exists(res_dir): 12 | shutil.rmtree(res_dir) 13 | 14 | if not os.path.exists(res_dir): 15 | os.mkdir(res_dir) 16 | 17 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 18 | print('device: ', device) 19 | 20 | transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()]) 21 | 22 | test_dir = cfg.test_dir 23 | test_dataset = custom_test_dataset(test_dir, transform = transform) 24 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = cfg.test_bs, shuffle = not True) 25 | 26 | print('\nlen(test_dataset) : {}'.format(len(test_dataset))) 27 | print('len(test_loader) : {} @bs={}'.format(len(test_loader), cfg.test_bs)) 28 | 29 | # defining the model 30 | model = UNet(n_classes = 1, depth = cfg.depth, padding = True).to(device) 31 | 32 | ckpt_path = os.path.join(cfg.models_dir, cfg.ckpt) 33 | ckpt = torch.load(ckpt_path) 34 | print(f'\nckpt loaded: {ckpt_path}') 35 | model_state_dict = ckpt['model_state_dict'] 36 | model.load_state_dict(model_state_dict) 37 | model.to(device) 38 | 39 | def get_img_strip(tensr): 40 | # shape: [bs,1,h,w] 41 | bs, _ , h, w = tensr.shape 42 | tensr2np = (tensr.cpu().numpy().clip(0,1)*255).astype(np.uint8) 43 | canvas = np.ones((h, w*bs), dtype = np.uint8) 44 | for i in range(tensr.shape[0]): 45 | patch_to_paste = tensr2np[i, 0, :, :] 46 | canvas[:, i*w: (i+1)*w] = patch_to_paste 47 | return canvas 48 | 49 | def denoise(noisy_imgs, out): 50 | noisy_imgs = get_img_strip(noisy_imgs) 51 | out = get_img_strip(out) 52 | denoised = np.concatenate((noisy_imgs, out), axis = 0) 53 | return denoised 54 | 55 | print('\nDenoising noisy images...') 56 | model.eval() 57 | with torch.no_grad(): 58 | for batch_idx, noisy_imgs in enumerate(test_loader): 59 | print('batch: {}/{}'.format(str(batch_idx + 1).zfill(len(str(len(test_loader)))), len(test_loader)), end='\r') 60 | noisy_imgs = noisy_imgs.to(device) 61 | out = model(noisy_imgs) 62 | denoised = denoise(noisy_imgs, out) 63 | cv2.imwrite(os.path.join(res_dir, f'denoised{str(batch_idx).zfill(3)}.jpg'), denoised) 64 | 65 | print('\n\nresults saved in \'{}\' directory'.format(res_dir)) 66 | 67 | print('\nFin.') 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UNet-based-Denoising-Autoencoder-In-PyTorch 2 | Cleaning printed text using Denoising Autoencoder based on UNet architecture in PyTorch 3 | 4 | ## Acknowledgement 5 | The UNet architecture used here is borrowed from https://github.com/jvanvugt/pytorch-unet. 6 | The only modification made in the UNet architecture mentioned in the above link is the addition of dropout layers. 7 | 8 | ## Requirements 9 | * torch >= 0.4 10 | * torchvision >= 0.2.2 11 | * opencv-python 12 | * numpy >= 1.7.3 13 | * matplotlib 14 | * tqdm 15 | 16 | ## Generating Synthetic Data 17 | Set the number of total synthetic images to be generated **num_synthetic_imgs** and set the percentage of training data **train_percentage** in *config.py* 18 | Then run 19 | ``` 20 | python generate_synthetic_dataset.py 21 | ``` 22 | It will generate the synthetic data in a directory named *data* (can be changed in the config.py) in the root dirctory. 23 | 24 | ## Training 25 | Set the desired values of **lr**, **epochs** and **batch_size** in *config.py* 26 | ### Start Training 27 | In *config.py*, 28 | * set **resume** to False 29 | 30 | ``` 31 | python train.py 32 | ``` 33 | ### Resume Training 34 | In *config.py*, 35 | * set **resume** to True and 36 | * set **ckpt** to the path of the model to be loaded, i.e. ckpt = 'model02.pth' 37 | 38 | ``` 39 | python train.py 40 | ``` 41 | 42 | ## Losses 43 | The model was trained for 12 epochs for the configuration mentioned in `config.py` 44 | loss after 12 epochs 45 | 46 | ## Testing 47 | In *config.py*, 48 | * set **ckpt** to the path of the model to be loaded, i.e. ckpt = 'model02.pth' 49 | * set **test_dir** to the path that contains the noisy images that you need to denoise ('data/val/noisy' by default) 50 | * set **test_bs** to the desired batch size for the test set (1 by default) 51 | ``` 52 | python test.py 53 | ``` 54 | Once the testing is done, the results will be saved in a directory named *results* 55 | 56 | ## Results {Noisy (Top) and Denoised (Bottom) Image Pairs)} 57 |
58 | *
59 | res01.png 60 |
61 | *
62 | res02.png 63 |
64 | *
65 | res03.png 66 |
67 | *
68 | res04.png 69 |
70 | *
71 | res05.png 72 |
73 |
74 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, glob, cv2, sys 3 | from torch.utils.data import Dataset 4 | 5 | class DAE_dataset(Dataset): 6 | def __init__(self, data_dir, transform = None): 7 | self.data_dir = data_dir 8 | self.transform = transform 9 | self.imgs_data = self.get_data(os.path.join(self.data_dir, 'imgs')) 10 | self.noisy_imgs_data = self.get_data(os.path.join(self.data_dir, 'noisy')) 11 | 12 | def get_data(self, data_path): 13 | data = [] 14 | for img_path in glob.glob(data_path + os.sep + '*'): 15 | data.append(img_path) 16 | return data 17 | 18 | def __getitem__(self, index): 19 | # read images in grayscale, then invert them 20 | img = cv2.imread(self.imgs_data[index] ,0) 21 | noisy_img = cv2.imread(self.noisy_imgs_data[index] ,0) 22 | 23 | if self.transform is not None: 24 | img = self.transform(img) 25 | noisy_img = self.transform(noisy_img) 26 | 27 | return img, noisy_img 28 | 29 | def __len__(self): 30 | return len(self.imgs_data) 31 | 32 | class custom_test_dataset(Dataset): 33 | def __init__(self, data_dir, transform = None, out_size = (64, 256)): 34 | assert out_size[0] <= out_size[1], 'height/width of the output image shouldn\'t not be greater than 1' 35 | self.data_dir = data_dir 36 | self.transform = transform 37 | self.out_size = out_size 38 | self.imgs_data = self.get_data(self.data_dir) 39 | 40 | def get_data(self, data_path): 41 | data = [] 42 | for img_path in glob.glob(data_path + os.sep + '*'): 43 | data.append(img_path) 44 | return data 45 | 46 | def __getitem__(self, index): 47 | # read images in grayscale, then invert them 48 | img = cv2.imread(self.imgs_data[index] ,0) 49 | 50 | # check if img height exceeds out_size height 51 | if img.shape[0] > self.out_size[0]: 52 | resize_factor = self.out_size[0]/img.shape[0] 53 | img = cv2.resize(img, (0, 0), fx=resize_factor, fy=resize_factor) 54 | 55 | # check if img width exceeds out_size width 56 | if img.shape[1] > self.out_size[1]: 57 | resize_factor = self.out_size[1]/img.shape[1] 58 | img = cv2.resize(img, (0, 0), fx=resize_factor, fy=resize_factor) 59 | 60 | # add padding where required 61 | # pad height 62 | pad_height = self.out_size[0] - img.shape[0] 63 | pad_top = int(pad_height/2) 64 | pad_bottom = self.out_size[0] - img.shape[0] - pad_top 65 | # pad width 66 | pad_width = self.out_size[1] - img.shape[1] 67 | pad_left = int(pad_width/2) 68 | pad_right = self.out_size[1] - img.shape[1] - pad_left 69 | 70 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), constant_values=(0,0)) 71 | 72 | if self.transform is not None: 73 | img = self.transform(img) 74 | 75 | return img 76 | 77 | def __len__(self): 78 | return len(self.imgs_data) 79 | -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class UNet(nn.Module): 6 | def __init__( 7 | self, 8 | in_channels=1, 9 | n_classes=2, 10 | depth=5, 11 | wf=6, 12 | padding=False, 13 | batch_norm=False, 14 | up_mode='upconv', 15 | ): 16 | """ 17 | Implementation of 18 | U-Net: Convolutional Networks for Biomedical Image Segmentation 19 | (Ronneberger et al., 2015) 20 | https://arxiv.org/abs/1505.04597 21 | Using the default arguments will yield the exact version used 22 | in the original paper 23 | Args: 24 | in_channels (int): number of input channels 25 | n_classes (int): number of output channels 26 | depth (int): depth of the network 27 | wf (int): number of filters in the first layer is 2**wf 28 | padding (bool): if True, apply padding such that the input shape 29 | is the same as the output. 30 | This may introduce artifacts 31 | batch_norm (bool): Use BatchNorm after layers with an 32 | activation function 33 | up_mode (str): one of 'upconv' or 'upsample'. 34 | 'upconv' will use transposed convolutions for 35 | learned upsampling. 36 | 'upsample' will use bilinear upsampling. 37 | """ 38 | super(UNet, self).__init__() 39 | assert up_mode in ('upconv', 'upsample') 40 | self.padding = padding 41 | self.depth = depth 42 | prev_channels = in_channels 43 | self.down_path = nn.ModuleList() 44 | for i in range(depth): 45 | self.down_path.append( 46 | UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm) 47 | ) 48 | prev_channels = 2 ** (wf + i) 49 | 50 | self.up_path = nn.ModuleList() 51 | for i in reversed(range(depth - 1)): 52 | self.up_path.append( 53 | UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm) 54 | ) 55 | prev_channels = 2 ** (wf + i) 56 | 57 | self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1) 58 | 59 | def forward(self, x): 60 | blocks = [] 61 | for i, down in enumerate(self.down_path): 62 | x = down(x) 63 | if i != len(self.down_path) - 1: 64 | blocks.append(x) 65 | x = F.max_pool2d(x, 2) 66 | 67 | for i, up in enumerate(self.up_path): 68 | x = up(x, blocks[-i - 1]) 69 | 70 | output = self.last(x) 71 | 72 | return output 73 | 74 | 75 | class UNetConvBlock(nn.Module): 76 | def __init__(self, in_size, out_size, padding, batch_norm): 77 | super(UNetConvBlock, self).__init__() 78 | block = [] 79 | 80 | block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding))) 81 | block.append(nn.ReLU()) 82 | if batch_norm: 83 | block.append(nn.BatchNorm2d(out_size)) 84 | 85 | block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding))) 86 | block.append(nn.ReLU()) 87 | block.append(nn.Dropout2d(p=0.15)) # edited 88 | if batch_norm: 89 | block.append(nn.BatchNorm2d(out_size)) 90 | 91 | self.block = nn.Sequential(*block) 92 | 93 | def forward(self, x): 94 | out = self.block(x) 95 | return out 96 | 97 | 98 | class UNetUpBlock(nn.Module): 99 | def __init__(self, in_size, out_size, up_mode, padding, batch_norm): 100 | super(UNetUpBlock, self).__init__() 101 | if up_mode == 'upconv': 102 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) 103 | elif up_mode == 'upsample': 104 | self.up = nn.Sequential( 105 | nn.Upsample(mode='bilinear', scale_factor=2), 106 | nn.Conv2d(in_size, out_size, kernel_size=1), 107 | ) 108 | 109 | self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm) 110 | 111 | def center_crop(self, layer, target_size): 112 | _, _, layer_height, layer_width = layer.size() 113 | diff_y = (layer_height - target_size[0]) // 2 114 | diff_x = (layer_width - target_size[1]) // 2 115 | return layer[ 116 | :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1]) 117 | ] 118 | 119 | def forward(self, x, bridge): 120 | up = self.up(x) 121 | crop1 = self.center_crop(bridge, up.shape[2:]) 122 | out = torch.cat([up, crop1], 1) 123 | out = self.conv_block(out) 124 | 125 | return out 126 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys, os, time, glob, time, pdb, cv2 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | import matplotlib.pyplot as plt 7 | plt.switch_backend('agg') # for servers not supporting display 8 | 9 | # import neccesary libraries for defining the optimizers 10 | import torch.optim as optim 11 | from torch.optim import lr_scheduler 12 | from torchvision import transforms 13 | 14 | from unet import UNet 15 | from datasets import DAE_dataset 16 | import config as cfg 17 | 18 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 19 | print('device: ', device) 20 | 21 | script_time = time.time() 22 | 23 | def q(text = ''): 24 | print('> {}'.format(text)) 25 | sys.exit() 26 | 27 | data_dir = cfg.data_dir 28 | train_dir = cfg.train_dir 29 | val_dir = cfg.val_dir 30 | 31 | models_dir = cfg.models_dir 32 | if not os.path.exists(models_dir): 33 | os.mkdir(models_dir) 34 | 35 | losses_dir = cfg.losses_dir 36 | if not os.path.exists(losses_dir): 37 | os.mkdir(losses_dir) 38 | 39 | def count_parameters(model): 40 | num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 41 | return num_parameters/1e6 # in terms of millions 42 | 43 | def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoch_loss, epoch): 44 | fig = plt.figure(figsize=(16,16)) 45 | fig.suptitle('loss trends', fontsize=20) 46 | ax1 = fig.add_subplot(221) 47 | ax2 = fig.add_subplot(222) 48 | ax3 = fig.add_subplot(223) 49 | ax4 = fig.add_subplot(224) 50 | 51 | ax1.title.set_text('epoch train loss VS #epochs') 52 | ax1.set_xlabel('#epochs') 53 | ax1.set_ylabel('epoch train loss') 54 | ax1.plot(train_epoch_loss) 55 | 56 | ax2.title.set_text('epoch val loss VS #epochs') 57 | ax2.set_xlabel('#epochs') 58 | ax2.set_ylabel('epoch val loss') 59 | ax2.plot(val_epoch_loss) 60 | 61 | ax3.title.set_text('batch train loss VS #batches') 62 | ax3.set_xlabel('#batches') 63 | ax3.set_ylabel('batch train loss') 64 | ax3.plot(running_train_loss) 65 | 66 | ax4.title.set_text('batch val loss VS #batches') 67 | ax4.set_xlabel('#batches') 68 | ax4.set_ylabel('batch val loss') 69 | ax4.plot(running_val_loss) 70 | 71 | plt.savefig(os.path.join(losses_dir,'losses_{}.png'.format(str(epoch + 1).zfill(2)))) 72 | 73 | transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()]) 74 | 75 | train_dataset = DAE_dataset(os.path.join(data_dir, train_dir), transform = transform) 76 | val_dataset = DAE_dataset(os.path.join(data_dir, val_dir), transform = transform) 77 | 78 | print('\nlen(train_dataset) : ', len(train_dataset)) 79 | print('len(val_dataset) : ', len(val_dataset)) 80 | 81 | batch_size = cfg.batch_size 82 | 83 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True) 84 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, shuffle = not True) 85 | 86 | print('\nlen(train_loader): {} @bs={}'.format(len(train_loader), batch_size)) 87 | print('len(val_loader) : {} @bs={}'.format(len(val_loader), batch_size)) 88 | 89 | # defining the model 90 | model = UNet(n_classes = 1, depth = cfg.depth, padding = True).to(device) # try decreasing the depth value if there is a memory error 91 | 92 | resume = cfg.resume 93 | 94 | if not resume: 95 | print('\nfrom scratch') 96 | train_epoch_loss = [] 97 | val_epoch_loss = [] 98 | running_train_loss = [] 99 | running_val_loss = [] 100 | epochs_till_now = 0 101 | else: 102 | ckpt_path = os.path.join(models_dir, cfg.ckpt) 103 | ckpt = torch.load(ckpt_path) 104 | print(f'\nckpt loaded: {ckpt_path}') 105 | model_state_dict = ckpt['model_state_dict'] 106 | model.load_state_dict(model_state_dict) 107 | model.to(device) 108 | losses = ckpt['losses'] 109 | running_train_loss = losses['running_train_loss'] 110 | running_val_loss = losses['running_val_loss'] 111 | train_epoch_loss = losses['train_epoch_loss'] 112 | val_epoch_loss = losses['val_epoch_loss'] 113 | epochs_till_now = ckpt['epochs_till_now'] 114 | 115 | lr = cfg.lr 116 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr) 117 | loss_fn = nn.MSELoss() 118 | 119 | log_interval = cfg.log_interval 120 | epochs = cfg.epochs 121 | 122 | ### 123 | print('\nmodel has {} M parameters'.format(count_parameters(model))) 124 | print(f'\nloss_fn : {loss_fn}') 125 | print(f'lr : {lr}') 126 | print(f'epochs_till_now: {epochs_till_now}') 127 | print(f'epochs from now: {epochs}') 128 | ### 129 | 130 | for epoch in range(epochs_till_now, epochs_till_now+epochs): 131 | print('\n===== EPOCH {}/{} ====='.format(epoch + 1, epochs_till_now + epochs)) 132 | print('\nTRAINING...') 133 | epoch_train_start_time = time.time() 134 | model.train() 135 | for batch_idx, (imgs, noisy_imgs) in enumerate(train_loader): 136 | batch_start_time = time.time() 137 | imgs = imgs.to(device) 138 | noisy_imgs = noisy_imgs.to(device) 139 | 140 | optimizer.zero_grad() 141 | out = model(noisy_imgs) 142 | 143 | loss = loss_fn(out, imgs) 144 | running_train_loss.append(loss.item()) 145 | loss.backward() 146 | optimizer.step() 147 | 148 | if (batch_idx + 1)%log_interval == 0: 149 | batch_time = time.time() - batch_start_time 150 | m,s = divmod(batch_time, 60) 151 | print('train loss @batch_idx {}/{}: {} in {} mins {} secs (per batch)'.format(str(batch_idx+1).zfill(len(str(len(train_loader)))), len(train_loader), loss.item(), int(m), round(s, 2))) 152 | 153 | train_epoch_loss.append(np.array(running_train_loss).mean()) 154 | 155 | epoch_train_time = time.time() - epoch_train_start_time 156 | m,s = divmod(epoch_train_time, 60) 157 | h,m = divmod(m, 60) 158 | print('\nepoch train time: {} hrs {} mins {} secs'.format(int(h), int(m), int(s))) 159 | 160 | print('\nVALIDATION...') 161 | epoch_val_start_time = time.time() 162 | model.eval() 163 | with torch.no_grad(): 164 | for batch_idx, (imgs, noisy_imgs) in enumerate(val_loader): 165 | 166 | imgs = imgs.to(device) 167 | noisy_imgs = noisy_imgs.to(device) 168 | 169 | out = model(noisy_imgs) 170 | loss = loss_fn(out, imgs) 171 | 172 | running_val_loss.append(loss.item()) 173 | 174 | if (batch_idx + 1)%log_interval == 0: 175 | print('val loss @batch_idx {}/{}: {}'.format(str(batch_idx+1).zfill(len(str(len(val_loader)))), len(val_loader), loss.item())) 176 | 177 | val_epoch_loss.append(np.array(running_val_loss).mean()) 178 | 179 | epoch_val_time = time.time() - epoch_val_start_time 180 | m,s = divmod(epoch_val_time, 60) 181 | h,m = divmod(m, 60) 182 | print('\nepoch val time: {} hrs {} mins {} secs'.format(int(h), int(m), int(s))) 183 | 184 | plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoch_loss, epoch) 185 | 186 | torch.save({'model_state_dict': model.state_dict(), 187 | 'losses': {'running_train_loss': running_train_loss, 188 | 'running_val_loss': running_val_loss, 189 | 'train_epoch_loss': train_epoch_loss, 190 | 'val_epoch_loss': val_epoch_loss}, 191 | 'epochs_till_now': epoch+1}, 192 | os.path.join(models_dir, 'model{}.pth'.format(str(epoch + 1).zfill(2)))) 193 | 194 | total_script_time = time.time() - script_time 195 | m, s = divmod(total_script_time, 60) 196 | h, m = divmod(m, 60) 197 | print(f'\ntotal time taken for running this script: {int(h)} hrs {int(m)} mins {int(s)} secs') 198 | 199 | print('\nFin.') 200 | -------------------------------------------------------------------------------- /generate_synthetic_dataset.py: -------------------------------------------------------------------------------- 1 | import sys, os, glob, time, pdb, cv2 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | import shutil 9 | 10 | import config as cfg 11 | 12 | def q(text = ''): 13 | print(f'>{text}<') 14 | sys.exit() 15 | 16 | data_dir = cfg.data_dir 17 | train_dir = cfg.train_dir 18 | val_dir = cfg.val_dir 19 | 20 | imgs_dir = cfg.imgs_dir 21 | noisy_dir = cfg.noisy_dir 22 | debug_dir = cfg.debug_dir 23 | 24 | train_data_dir = os.path.join(data_dir, train_dir) 25 | val_data_dir = os.path.join(data_dir, val_dir) 26 | 27 | if os.path.exists(data_dir): 28 | shutil.rmtree(data_dir) 29 | 30 | if not os.path.exists(data_dir): 31 | os.mkdir(data_dir) 32 | 33 | if not os.path.exists(train_data_dir): 34 | os.mkdir(train_data_dir) 35 | 36 | if not os.path.exists(val_data_dir): 37 | os.mkdir(val_data_dir) 38 | 39 | img_train_dir = os.path.join(data_dir, train_dir, imgs_dir) 40 | noisy_train_dir = os.path.join(data_dir, train_dir, noisy_dir) 41 | debug_train_dir = os.path.join(data_dir, train_dir, debug_dir) 42 | 43 | img_val_dir = os.path.join(data_dir, val_dir, imgs_dir) 44 | noisy_val_dir = os.path.join(data_dir, val_dir, noisy_dir) 45 | debug_val_dir = os.path.join(data_dir, val_dir, debug_dir) 46 | 47 | dir_list = [img_train_dir, noisy_train_dir, debug_train_dir, img_val_dir, noisy_val_dir, debug_val_dir] 48 | for dir_path in dir_list: 49 | if not os.path.exists(dir_path): 50 | os.mkdir(dir_path) 51 | 52 | def get_word_list(): 53 | f = open(cfg.txt_file_dir, encoding='utf-8', mode="r") 54 | text = f.read() 55 | f.close() 56 | lines_list = str.split(text, '\n') 57 | while '' in lines_list: 58 | lines_list.remove('') 59 | 60 | lines_word_list = [str.split(line) for line in lines_list] 61 | words_list = [words for sublist in lines_word_list for words in sublist] 62 | 63 | return words_list 64 | 65 | 66 | words_list = get_word_list() 67 | print('\nnumber of words in the txt file: ', len(words_list)) 68 | 69 | # list of all the font styles 70 | font_list = [cv2.FONT_HERSHEY_COMPLEX, 71 | cv2.FONT_HERSHEY_COMPLEX_SMALL, 72 | cv2.FONT_HERSHEY_DUPLEX, 73 | cv2.FONT_HERSHEY_PLAIN, 74 | cv2.FONT_HERSHEY_SIMPLEX, 75 | cv2.FONT_HERSHEY_TRIPLEX, 76 | cv2.FONT_ITALIC] # cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, cv2.FONT_HERSHEY_SCRIPT_COMPLEX, cursive 77 | 78 | # size of the synthetic images to be generated 79 | syn_h, syn_w = 64, 256 80 | 81 | # scale factor 82 | scale_h, scale_w = 4, 4 83 | 84 | # initial size of the image, scaled up by the factor of scale_h and scale_w 85 | h, w = syn_h*scale_h, syn_w*scale_w 86 | 87 | img_count = 1 88 | word_count = 0 89 | num_imgs = int(cfg.num_synthetic_imgs) # max number of synthetic images to be generated 90 | train_num = int(num_imgs*cfg.train_percentage) # training percent 91 | print('\nnum_imgs : ', num_imgs) 92 | print('train_num: ', train_num) 93 | 94 | 95 | word_start_x = 5 # min space left on the left side of the printed text 96 | word_start_y = 5 97 | word_end_y = 5 # min space left on the bottom side of the printed text 98 | 99 | def get_text(): 100 | global word_count, words_list 101 | # text to be printed on the blank image 102 | num_words = np.random.randint(1,8) 103 | 104 | # renew the word list in case we run out of words 105 | if (word_count + num_words) >= len(words_list): 106 | 107 | print('===\nrecycling the words_list') 108 | 109 | words_list = get_word_list() 110 | word_count = 0 111 | 112 | print_text = '' 113 | for _ in range(num_words): 114 | print_text += str.split(words_list[word_count])[0] + ' ' 115 | word_count += 1 116 | print_text = print_text.strip() # to get rif of the last space 117 | return print_text 118 | 119 | def get_text_height(img, fontColor): 120 | black_coords = np.where(img == fontColor) 121 | # finding the extremes of the printed text 122 | ymin = np.min(black_coords[0]) 123 | ymax = np.max(black_coords[0]) 124 | # xmin = np.min(black_coords[1]) 125 | # xmax = np.max(black_coords[1]) 126 | ''' # for vizualising 127 | cv2.line(img, (0,ymin),(1000,ymin), 0,2) 128 | cv2.line(img, (0,ymax),(1000,ymax), 0,2) 129 | cv2.imshow('ymax', img) 130 | ''' 131 | return ymax - ymin 132 | 133 | def print_lines(img, font, bottomLeftCornerOfText, fontColor, fontScale, lineType, thickness): 134 | 135 | line_num = 0 136 | y_line_list = [] 137 | 138 | # print('img.shape: ', img.shape) 139 | # print('initial bottomLeftCornerOfText: ', bottomLeftCornerOfText) 140 | 141 | while bottomLeftCornerOfText[1] <= img.shape[0]: 142 | # get a line of text 143 | print_text = get_text() 144 | 145 | # put it on a blank image and get its height 146 | if line_num == 0: 147 | # get the correct text height 148 | big_img = np.ones((500,300), dtype = np.uint8)*255 149 | big_img_text = print_text.upper() 150 | cv2.putText(img = big_img, text = big_img_text, org = (0,200), fontFace = font, fontScale = fontScale, color = fontColor, thickness = thickness, lineType = lineType) 151 | text_height = get_text_height(big_img, fontColor) 152 | # print('text_height: ', text_height) 153 | 154 | if text_height > bottomLeftCornerOfText[1]: 155 | bottomLeftCornerOfText = (bottomLeftCornerOfText[0], np.random.randint(word_start_y, int(img.shape[0]*0.5)) + text_height) 156 | 157 | cv2.putText(img = img, text = print_text.upper(), org = bottomLeftCornerOfText, fontFace = font, fontScale = fontScale, color = fontColor, thickness = thickness, lineType = lineType) 158 | y_line_list.append(bottomLeftCornerOfText[1]) 159 | else: 160 | # sampling the chances of adding one more line of text 161 | one_more_line = np.random.choice([0, 1], p = [0.4, 0.6]) 162 | if not one_more_line: 163 | break 164 | cv2.putText(img = img, text = print_text.upper(), org = bottomLeftCornerOfText, fontFace = font, fontScale = fontScale, color = fontColor, thickness = thickness, lineType = lineType) 165 | y_line_list.append(bottomLeftCornerOfText[1]) 166 | # calculate the (text_height+line break space) left on the bottom 167 | bottom_space_left = int(text_height*(1 + np.random.randint(20, 40)/100)) 168 | # bottom_space_left = int(text_height*(1)) 169 | # print('bottom_space_left: ', bottom_space_left) 170 | 171 | # update the bottomLeftCornerOfText 172 | bottomLeftCornerOfText = (bottomLeftCornerOfText[0], bottomLeftCornerOfText[1] + bottom_space_left) 173 | 174 | # print('bottomLeftCornerOfText: ', bottomLeftCornerOfText) 175 | line_num += 1 176 | ''' 177 | for l in y_line_list: 178 | cv2.line(img, (0, l), (1000, l), 0, 1) 179 | ''' 180 | return img, y_line_list, text_height 181 | 182 | def get_noisy_img(img, y_line_list, text_height): 183 | 184 | # adding noise (horizontal and vertical lines) on the image containing text 185 | noisy_img = img.copy() 186 | 187 | # adding horizontal line (noise) 188 | for y_line in y_line_list: 189 | 190 | # samples the possibility of adding a horizontal line 191 | add_horizontal_line = np.random.choice([0, 1], p = [0.5, 0.5]) 192 | if not add_horizontal_line: 193 | continue 194 | 195 | # shift y_line randomly in the y-axis within a defined limit 196 | limit = int(text_height*0.3) 197 | if limit == 0: # this happens when the text used for getting the text height is '-', ',', '=' and other little symbols like these 198 | limit = 10 199 | y_line += np.random.randint(-limit, limit) 200 | 201 | h_start_x = 0 #np.random.randint(0,xmin) # min x of the horizontal line 202 | h_end_x = np.random.randint(int(noisy_img.shape[1]*0.8), noisy_img.shape[1]) # max x of the horizontal line 203 | h_length = h_end_x - h_start_x + 1 204 | num_h_lines = np.random.randint(10,30) # partitions to be made in the horizontal line (necessary to make it look like naturally broken lines) 205 | h_lines = [] 206 | h_start_temp = h_start_x 207 | next_line = True 208 | 209 | num_line = 0 210 | while (next_line) and (num_line < num_h_lines): 211 | if h_start_temp < h_end_x: 212 | h_end_temp = np.random.randint(h_start_temp + 1, h_end_x + 1) 213 | if h_end_temp < h_end_x: 214 | h_lines.append([h_start_temp, h_end_temp]) 215 | h_start_temp = h_end_temp + 1 216 | num_line += 1 217 | else: 218 | h_lines.append([h_start_temp, h_end_x]) 219 | num_line += 1 220 | next_line = False 221 | else: 222 | next_line = False 223 | 224 | for h_line in h_lines: 225 | col = np.random.choice(['black', 'white'], p = [0.65, 0.35]) # probabilities of line segment being a solid one or a broken one 226 | if col == 'black': 227 | x_points = list(range(h_line[0], h_line[1] + 1)) 228 | x_points_black_prob = np.random.choice([0,1], size = len(x_points), p = [0.2, 0.8]) 229 | 230 | for idx, x in enumerate(x_points): 231 | if x_points_black_prob[idx]: 232 | noisy_img[ y_line - np.random.randint(4): y_line + np.random.randint(4), x] = np.random.randint(0,30) 233 | 234 | # adding vertical line (noise) 235 | vertical_bool = {'left': np.random.choice([0,1], p =[0.3, 0.7]), 'right': np.random.choice([0,1])} # [1 or 0, 1 or 0] whether to make vertical left line on left and right side of the image 236 | for left_right, bool_ in vertical_bool.items(): 237 | if bool_: 238 | # print('left_right: ', left_right) 239 | if left_right == 'left': 240 | v_start_x = np.random.randint(5, int(noisy_img.shape[1]*0.06)) 241 | else: 242 | v_start_x = np.random.randint(int(noisy_img.shape[1]*0.95), noisy_img.shape[1] - 5) 243 | 244 | v_start_y = np.random.randint(0, int(noisy_img.shape[0]*0.06)) 245 | v_end_y = np.random.randint(int(noisy_img.shape[0]*0.95), noisy_img.shape[0]) 246 | 247 | y_points = list(range(v_start_y, v_end_y + 1)) 248 | y_points_black_prob = np.random.choice([0,1], size = len(y_points), p = [0.2, 0.8]) 249 | 250 | for idx, y in enumerate(y_points): 251 | if y_points_black_prob[idx]: 252 | noisy_img[y, v_start_x - np.random.randint(4): v_start_x + np.random.randint(4)] = np.random.randint(0,30) 253 | 254 | return noisy_img 255 | 256 | def degrade_qualities(img, noisy_img): 257 | ''' 258 | This function takes in a couple of images (color or grayscale), downsizes it to a 259 | randomly chosen size and then resizes it to the original size, 260 | degrading the quality of the images in the process. 261 | ''' 262 | h, w = img.shape[0], img.shape[1] 263 | fx=np.random.randint(50,100)/100 264 | fy=np.random.randint(50,100)/100 265 | # print('fx, fy: ', fx, fy) 266 | img_small = cv2.resize(img, (0,0), fx = fx, fy = fy) 267 | img = cv2.resize(img_small,(w,h)) 268 | 269 | noisy_img_small = cv2.resize(noisy_img, (0,0), fx = fx, fy = fy) 270 | noisy_img = cv2.resize(noisy_img_small,(w,h)) 271 | 272 | return img, noisy_img 273 | 274 | def get_debug_image(img, noisy_img): 275 | debug_img = np.ones((2*h, w), dtype = np.uint8)*255 # to visualize the generated images (clean and noisy) 276 | debug_img[0:h, :] = img 277 | debug_img[h:2*h, :] = noisy_img 278 | cv2.line(debug_img, (0, h), (debug_img.shape[1], h), 150, 5) 279 | return debug_img 280 | 281 | def erode_dilate(img, noisy_img): 282 | # erode the image 283 | kernel = np.ones((3,3),np.uint8) 284 | erosion_iteration = np.random.randint(1,3) 285 | dilate_iteration = np.random.randint(0,2) 286 | img = cv2.erode(img,kernel,iterations = erosion_iteration) 287 | noisy_img = cv2.erode(noisy_img,kernel,iterations = erosion_iteration) 288 | img = cv2.dilate(img,kernel,iterations = dilate_iteration) 289 | noisy_img = cv2.dilate(noisy_img,kernel,iterations = dilate_iteration) 290 | return img, noisy_img 291 | 292 | def write_images(img, noisy_img, debug_img): 293 | global img_count 294 | 295 | img = 255 - cv2.resize(img, (0,0), fx = 1/scale_w, fy = 1/scale_h) 296 | noisy_img = 255 - cv2.resize(noisy_img, (0,0), fx = 1/scale_w, fy = 1/scale_h) 297 | debug_img = 255 - cv2.resize(debug_img, (0,0), fx = 1/scale_w, fy = 1/scale_h) 298 | 299 | if img_count <= train_num: 300 | cv2.imwrite(os.path.join(data_dir, train_dir, imgs_dir, '{}.jpg'.format(str(img_count).zfill(6))), img) 301 | cv2.imwrite(os.path.join(data_dir, train_dir, noisy_dir, '{}.jpg'.format(str(img_count).zfill(6))), noisy_img) 302 | cv2.imwrite(os.path.join(data_dir, train_dir, debug_dir, '{}.jpg'.format(str(img_count).zfill(6))), debug_img) 303 | else: 304 | cv2.imwrite(os.path.join(data_dir, val_dir, imgs_dir, '{}.jpg'.format(str(img_count).zfill(6))), img) 305 | cv2.imwrite(os.path.join(data_dir, val_dir, noisy_dir, '{}.jpg'.format(str(img_count).zfill(6))), noisy_img) 306 | cv2.imwrite(os.path.join(data_dir, val_dir, debug_dir, '{}.jpg'.format(str(img_count).zfill(6))), debug_img) 307 | 308 | img_count += 1 309 | 310 | print('\nsynthesizing image data...') 311 | for i in tqdm(range(num_imgs)): 312 | # make a blank image 313 | img = np.ones((h, w), dtype = np.uint8)*255 314 | 315 | # set random parameters 316 | font = font_list[np.random.randint(len(font_list))] 317 | bottomLeftCornerOfText = (np.random.randint(word_start_x, int(img.shape[1]/3)), np.random.randint(0, int(img.shape[0]*0.8))) # (x, y) 318 | fontColor = np.random.randint(0,30) 319 | fontScale = np.random.randint(1800, 2400)/1000 320 | lineType = np.random.randint(1,3) 321 | thickness = np.random.randint(1,7) 322 | 323 | # put text 324 | img, y_line_list, text_height = print_lines(img, font, bottomLeftCornerOfText, fontColor, fontScale, lineType, thickness) 325 | 326 | # add noise 327 | noisy_img = get_noisy_img(img, y_line_list, text_height) 328 | 329 | # degrade_quality 330 | img, noisy_img = degrade_qualities(img, noisy_img) 331 | 332 | # morphological operations 333 | img, noisy_img = erode_dilate(img, noisy_img) 334 | 335 | # make debug image 336 | debug_img = get_debug_image(img, noisy_img) 337 | 338 | # write images 339 | write_images(img, noisy_img, debug_img) 340 | 341 | ''' 342 | cv2.imshow('textonimage', img) 343 | cv2.imshow('noisy_img', noisy_img) 344 | cv2.waitKey() 345 | ''' 346 | --------------------------------------------------------------------------------