├── results
├── res01.png
├── res02.png
├── res03.png
├── res04.png
└── res05.png
├── losses
└── losses_12.png
├── config.py
├── test.py
├── README.md
├── datasets.py
├── unet.py
├── train.py
└── generate_synthetic_dataset.py
/results/res01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/n0obcoder/UNet-based-Denoising-Autoencoder-In-PyTorch/HEAD/results/res01.png
--------------------------------------------------------------------------------
/results/res02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/n0obcoder/UNet-based-Denoising-Autoencoder-In-PyTorch/HEAD/results/res02.png
--------------------------------------------------------------------------------
/results/res03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/n0obcoder/UNet-based-Denoising-Autoencoder-In-PyTorch/HEAD/results/res03.png
--------------------------------------------------------------------------------
/results/res04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/n0obcoder/UNet-based-Denoising-Autoencoder-In-PyTorch/HEAD/results/res04.png
--------------------------------------------------------------------------------
/results/res05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/n0obcoder/UNet-based-Denoising-Autoencoder-In-PyTorch/HEAD/results/res05.png
--------------------------------------------------------------------------------
/losses/losses_12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/n0obcoder/UNet-based-Denoising-Autoencoder-In-PyTorch/HEAD/losses/losses_12.png
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # path to saving models
4 | models_dir = 'models'
5 |
6 | # path to saving loss plots
7 | losses_dir = 'losses'
8 |
9 | # path to the data directories
10 | data_dir = 'data'
11 | train_dir = 'train'
12 | val_dir = 'val'
13 | imgs_dir = 'imgs'
14 | noisy_dir = 'noisy'
15 | debug_dir = 'debug'
16 |
17 | # depth of UNet
18 | depth = 4 # try decreasing the depth value if there is a memory error
19 |
20 | # text file to get text from
21 | txt_file_dir = 'shitty_text.txt'
22 |
23 | # maximun number of synthetic words to generate
24 | num_synthetic_imgs = 18000
25 | train_percentage = 0.8
26 |
27 | resume = not True # False for trainig from scratch, True for loading a previously saved weight
28 | ckpt='model01.pth' # model file path to load the weights from, only useful when resume is True
29 | lr = 1e-5 # learning rate
30 | epochs = 12 # epochs to train for
31 |
32 | # batch size for train and val loaders
33 | batch_size = 32 # try decreasing the batch_size if there is a memory error
34 |
35 | # log interval for training and validation
36 | log_interval = 25
37 |
38 | test_dir = os.path.join(data_dir, val_dir, noisy_dir)
39 | res_dir = 'results'
40 | test_bs = 64
41 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os, shutil, cv2
2 | import numpy as np
3 | import torch
4 | from torchvision import transforms
5 | from unet import UNet
6 | from datasets import custom_test_dataset
7 | import config as cfg
8 |
9 | res_dir = cfg.res_dir
10 |
11 | if os.path.exists(res_dir):
12 | shutil.rmtree(res_dir)
13 |
14 | if not os.path.exists(res_dir):
15 | os.mkdir(res_dir)
16 |
17 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
18 | print('device: ', device)
19 |
20 | transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])
21 |
22 | test_dir = cfg.test_dir
23 | test_dataset = custom_test_dataset(test_dir, transform = transform)
24 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = cfg.test_bs, shuffle = not True)
25 |
26 | print('\nlen(test_dataset) : {}'.format(len(test_dataset)))
27 | print('len(test_loader) : {} @bs={}'.format(len(test_loader), cfg.test_bs))
28 |
29 | # defining the model
30 | model = UNet(n_classes = 1, depth = cfg.depth, padding = True).to(device)
31 |
32 | ckpt_path = os.path.join(cfg.models_dir, cfg.ckpt)
33 | ckpt = torch.load(ckpt_path)
34 | print(f'\nckpt loaded: {ckpt_path}')
35 | model_state_dict = ckpt['model_state_dict']
36 | model.load_state_dict(model_state_dict)
37 | model.to(device)
38 |
39 | def get_img_strip(tensr):
40 | # shape: [bs,1,h,w]
41 | bs, _ , h, w = tensr.shape
42 | tensr2np = (tensr.cpu().numpy().clip(0,1)*255).astype(np.uint8)
43 | canvas = np.ones((h, w*bs), dtype = np.uint8)
44 | for i in range(tensr.shape[0]):
45 | patch_to_paste = tensr2np[i, 0, :, :]
46 | canvas[:, i*w: (i+1)*w] = patch_to_paste
47 | return canvas
48 |
49 | def denoise(noisy_imgs, out):
50 | noisy_imgs = get_img_strip(noisy_imgs)
51 | out = get_img_strip(out)
52 | denoised = np.concatenate((noisy_imgs, out), axis = 0)
53 | return denoised
54 |
55 | print('\nDenoising noisy images...')
56 | model.eval()
57 | with torch.no_grad():
58 | for batch_idx, noisy_imgs in enumerate(test_loader):
59 | print('batch: {}/{}'.format(str(batch_idx + 1).zfill(len(str(len(test_loader)))), len(test_loader)), end='\r')
60 | noisy_imgs = noisy_imgs.to(device)
61 | out = model(noisy_imgs)
62 | denoised = denoise(noisy_imgs, out)
63 | cv2.imwrite(os.path.join(res_dir, f'denoised{str(batch_idx).zfill(3)}.jpg'), denoised)
64 |
65 | print('\n\nresults saved in \'{}\' directory'.format(res_dir))
66 |
67 | print('\nFin.')
68 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UNet-based-Denoising-Autoencoder-In-PyTorch
2 | Cleaning printed text using Denoising Autoencoder based on UNet architecture in PyTorch
3 |
4 | ## Acknowledgement
5 | The UNet architecture used here is borrowed from https://github.com/jvanvugt/pytorch-unet.
6 | The only modification made in the UNet architecture mentioned in the above link is the addition of dropout layers.
7 |
8 | ## Requirements
9 | * torch >= 0.4
10 | * torchvision >= 0.2.2
11 | * opencv-python
12 | * numpy >= 1.7.3
13 | * matplotlib
14 | * tqdm
15 |
16 | ## Generating Synthetic Data
17 | Set the number of total synthetic images to be generated **num_synthetic_imgs** and set the percentage of training data **train_percentage** in *config.py*
18 | Then run
19 | ```
20 | python generate_synthetic_dataset.py
21 | ```
22 | It will generate the synthetic data in a directory named *data* (can be changed in the config.py) in the root dirctory.
23 |
24 | ## Training
25 | Set the desired values of **lr**, **epochs** and **batch_size** in *config.py*
26 | ### Start Training
27 | In *config.py*,
28 | * set **resume** to False
29 |
30 | ```
31 | python train.py
32 | ```
33 | ### Resume Training
34 | In *config.py*,
35 | * set **resume** to True and
36 | * set **ckpt** to the path of the model to be loaded, i.e. ckpt = 'model02.pth'
37 |
38 | ```
39 | python train.py
40 | ```
41 |
42 | ## Losses
43 | The model was trained for 12 epochs for the configuration mentioned in `config.py`
44 |
45 |
46 | ## Testing
47 | In *config.py*,
48 | * set **ckpt** to the path of the model to be loaded, i.e. ckpt = 'model02.pth'
49 | * set **test_dir** to the path that contains the noisy images that you need to denoise ('data/val/noisy' by default)
50 | * set **test_bs** to the desired batch size for the test set (1 by default)
51 | ```
52 | python test.py
53 | ```
54 | Once the testing is done, the results will be saved in a directory named *results*
55 |
56 | ## Results {Noisy (Top) and Denoised (Bottom) Image Pairs)}
57 |
58 | *
59 |

60 |
61 | *
62 |

63 |
64 | *
65 |

66 |
67 | *
68 |

69 |
70 | *
71 |

72 |
73 |
74 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os, glob, cv2, sys
3 | from torch.utils.data import Dataset
4 |
5 | class DAE_dataset(Dataset):
6 | def __init__(self, data_dir, transform = None):
7 | self.data_dir = data_dir
8 | self.transform = transform
9 | self.imgs_data = self.get_data(os.path.join(self.data_dir, 'imgs'))
10 | self.noisy_imgs_data = self.get_data(os.path.join(self.data_dir, 'noisy'))
11 |
12 | def get_data(self, data_path):
13 | data = []
14 | for img_path in glob.glob(data_path + os.sep + '*'):
15 | data.append(img_path)
16 | return data
17 |
18 | def __getitem__(self, index):
19 | # read images in grayscale, then invert them
20 | img = cv2.imread(self.imgs_data[index] ,0)
21 | noisy_img = cv2.imread(self.noisy_imgs_data[index] ,0)
22 |
23 | if self.transform is not None:
24 | img = self.transform(img)
25 | noisy_img = self.transform(noisy_img)
26 |
27 | return img, noisy_img
28 |
29 | def __len__(self):
30 | return len(self.imgs_data)
31 |
32 | class custom_test_dataset(Dataset):
33 | def __init__(self, data_dir, transform = None, out_size = (64, 256)):
34 | assert out_size[0] <= out_size[1], 'height/width of the output image shouldn\'t not be greater than 1'
35 | self.data_dir = data_dir
36 | self.transform = transform
37 | self.out_size = out_size
38 | self.imgs_data = self.get_data(self.data_dir)
39 |
40 | def get_data(self, data_path):
41 | data = []
42 | for img_path in glob.glob(data_path + os.sep + '*'):
43 | data.append(img_path)
44 | return data
45 |
46 | def __getitem__(self, index):
47 | # read images in grayscale, then invert them
48 | img = cv2.imread(self.imgs_data[index] ,0)
49 |
50 | # check if img height exceeds out_size height
51 | if img.shape[0] > self.out_size[0]:
52 | resize_factor = self.out_size[0]/img.shape[0]
53 | img = cv2.resize(img, (0, 0), fx=resize_factor, fy=resize_factor)
54 |
55 | # check if img width exceeds out_size width
56 | if img.shape[1] > self.out_size[1]:
57 | resize_factor = self.out_size[1]/img.shape[1]
58 | img = cv2.resize(img, (0, 0), fx=resize_factor, fy=resize_factor)
59 |
60 | # add padding where required
61 | # pad height
62 | pad_height = self.out_size[0] - img.shape[0]
63 | pad_top = int(pad_height/2)
64 | pad_bottom = self.out_size[0] - img.shape[0] - pad_top
65 | # pad width
66 | pad_width = self.out_size[1] - img.shape[1]
67 | pad_left = int(pad_width/2)
68 | pad_right = self.out_size[1] - img.shape[1] - pad_left
69 |
70 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), constant_values=(0,0))
71 |
72 | if self.transform is not None:
73 | img = self.transform(img)
74 |
75 | return img
76 |
77 | def __len__(self):
78 | return len(self.imgs_data)
79 |
--------------------------------------------------------------------------------
/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | class UNet(nn.Module):
6 | def __init__(
7 | self,
8 | in_channels=1,
9 | n_classes=2,
10 | depth=5,
11 | wf=6,
12 | padding=False,
13 | batch_norm=False,
14 | up_mode='upconv',
15 | ):
16 | """
17 | Implementation of
18 | U-Net: Convolutional Networks for Biomedical Image Segmentation
19 | (Ronneberger et al., 2015)
20 | https://arxiv.org/abs/1505.04597
21 | Using the default arguments will yield the exact version used
22 | in the original paper
23 | Args:
24 | in_channels (int): number of input channels
25 | n_classes (int): number of output channels
26 | depth (int): depth of the network
27 | wf (int): number of filters in the first layer is 2**wf
28 | padding (bool): if True, apply padding such that the input shape
29 | is the same as the output.
30 | This may introduce artifacts
31 | batch_norm (bool): Use BatchNorm after layers with an
32 | activation function
33 | up_mode (str): one of 'upconv' or 'upsample'.
34 | 'upconv' will use transposed convolutions for
35 | learned upsampling.
36 | 'upsample' will use bilinear upsampling.
37 | """
38 | super(UNet, self).__init__()
39 | assert up_mode in ('upconv', 'upsample')
40 | self.padding = padding
41 | self.depth = depth
42 | prev_channels = in_channels
43 | self.down_path = nn.ModuleList()
44 | for i in range(depth):
45 | self.down_path.append(
46 | UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm)
47 | )
48 | prev_channels = 2 ** (wf + i)
49 |
50 | self.up_path = nn.ModuleList()
51 | for i in reversed(range(depth - 1)):
52 | self.up_path.append(
53 | UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
54 | )
55 | prev_channels = 2 ** (wf + i)
56 |
57 | self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)
58 |
59 | def forward(self, x):
60 | blocks = []
61 | for i, down in enumerate(self.down_path):
62 | x = down(x)
63 | if i != len(self.down_path) - 1:
64 | blocks.append(x)
65 | x = F.max_pool2d(x, 2)
66 |
67 | for i, up in enumerate(self.up_path):
68 | x = up(x, blocks[-i - 1])
69 |
70 | output = self.last(x)
71 |
72 | return output
73 |
74 |
75 | class UNetConvBlock(nn.Module):
76 | def __init__(self, in_size, out_size, padding, batch_norm):
77 | super(UNetConvBlock, self).__init__()
78 | block = []
79 |
80 | block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
81 | block.append(nn.ReLU())
82 | if batch_norm:
83 | block.append(nn.BatchNorm2d(out_size))
84 |
85 | block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
86 | block.append(nn.ReLU())
87 | block.append(nn.Dropout2d(p=0.15)) # edited
88 | if batch_norm:
89 | block.append(nn.BatchNorm2d(out_size))
90 |
91 | self.block = nn.Sequential(*block)
92 |
93 | def forward(self, x):
94 | out = self.block(x)
95 | return out
96 |
97 |
98 | class UNetUpBlock(nn.Module):
99 | def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
100 | super(UNetUpBlock, self).__init__()
101 | if up_mode == 'upconv':
102 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
103 | elif up_mode == 'upsample':
104 | self.up = nn.Sequential(
105 | nn.Upsample(mode='bilinear', scale_factor=2),
106 | nn.Conv2d(in_size, out_size, kernel_size=1),
107 | )
108 |
109 | self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)
110 |
111 | def center_crop(self, layer, target_size):
112 | _, _, layer_height, layer_width = layer.size()
113 | diff_y = (layer_height - target_size[0]) // 2
114 | diff_x = (layer_width - target_size[1]) // 2
115 | return layer[
116 | :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
117 | ]
118 |
119 | def forward(self, x, bridge):
120 | up = self.up(x)
121 | crop1 = self.center_crop(bridge, up.shape[2:])
122 | out = torch.cat([up, crop1], 1)
123 | out = self.conv_block(out)
124 |
125 | return out
126 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import sys, os, time, glob, time, pdb, cv2
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from tqdm import tqdm
6 | import matplotlib.pyplot as plt
7 | plt.switch_backend('agg') # for servers not supporting display
8 |
9 | # import neccesary libraries for defining the optimizers
10 | import torch.optim as optim
11 | from torch.optim import lr_scheduler
12 | from torchvision import transforms
13 |
14 | from unet import UNet
15 | from datasets import DAE_dataset
16 | import config as cfg
17 |
18 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
19 | print('device: ', device)
20 |
21 | script_time = time.time()
22 |
23 | def q(text = ''):
24 | print('> {}'.format(text))
25 | sys.exit()
26 |
27 | data_dir = cfg.data_dir
28 | train_dir = cfg.train_dir
29 | val_dir = cfg.val_dir
30 |
31 | models_dir = cfg.models_dir
32 | if not os.path.exists(models_dir):
33 | os.mkdir(models_dir)
34 |
35 | losses_dir = cfg.losses_dir
36 | if not os.path.exists(losses_dir):
37 | os.mkdir(losses_dir)
38 |
39 | def count_parameters(model):
40 | num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
41 | return num_parameters/1e6 # in terms of millions
42 |
43 | def plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoch_loss, epoch):
44 | fig = plt.figure(figsize=(16,16))
45 | fig.suptitle('loss trends', fontsize=20)
46 | ax1 = fig.add_subplot(221)
47 | ax2 = fig.add_subplot(222)
48 | ax3 = fig.add_subplot(223)
49 | ax4 = fig.add_subplot(224)
50 |
51 | ax1.title.set_text('epoch train loss VS #epochs')
52 | ax1.set_xlabel('#epochs')
53 | ax1.set_ylabel('epoch train loss')
54 | ax1.plot(train_epoch_loss)
55 |
56 | ax2.title.set_text('epoch val loss VS #epochs')
57 | ax2.set_xlabel('#epochs')
58 | ax2.set_ylabel('epoch val loss')
59 | ax2.plot(val_epoch_loss)
60 |
61 | ax3.title.set_text('batch train loss VS #batches')
62 | ax3.set_xlabel('#batches')
63 | ax3.set_ylabel('batch train loss')
64 | ax3.plot(running_train_loss)
65 |
66 | ax4.title.set_text('batch val loss VS #batches')
67 | ax4.set_xlabel('#batches')
68 | ax4.set_ylabel('batch val loss')
69 | ax4.plot(running_val_loss)
70 |
71 | plt.savefig(os.path.join(losses_dir,'losses_{}.png'.format(str(epoch + 1).zfill(2))))
72 |
73 | transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])
74 |
75 | train_dataset = DAE_dataset(os.path.join(data_dir, train_dir), transform = transform)
76 | val_dataset = DAE_dataset(os.path.join(data_dir, val_dir), transform = transform)
77 |
78 | print('\nlen(train_dataset) : ', len(train_dataset))
79 | print('len(val_dataset) : ', len(val_dataset))
80 |
81 | batch_size = cfg.batch_size
82 |
83 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
84 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, shuffle = not True)
85 |
86 | print('\nlen(train_loader): {} @bs={}'.format(len(train_loader), batch_size))
87 | print('len(val_loader) : {} @bs={}'.format(len(val_loader), batch_size))
88 |
89 | # defining the model
90 | model = UNet(n_classes = 1, depth = cfg.depth, padding = True).to(device) # try decreasing the depth value if there is a memory error
91 |
92 | resume = cfg.resume
93 |
94 | if not resume:
95 | print('\nfrom scratch')
96 | train_epoch_loss = []
97 | val_epoch_loss = []
98 | running_train_loss = []
99 | running_val_loss = []
100 | epochs_till_now = 0
101 | else:
102 | ckpt_path = os.path.join(models_dir, cfg.ckpt)
103 | ckpt = torch.load(ckpt_path)
104 | print(f'\nckpt loaded: {ckpt_path}')
105 | model_state_dict = ckpt['model_state_dict']
106 | model.load_state_dict(model_state_dict)
107 | model.to(device)
108 | losses = ckpt['losses']
109 | running_train_loss = losses['running_train_loss']
110 | running_val_loss = losses['running_val_loss']
111 | train_epoch_loss = losses['train_epoch_loss']
112 | val_epoch_loss = losses['val_epoch_loss']
113 | epochs_till_now = ckpt['epochs_till_now']
114 |
115 | lr = cfg.lr
116 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr)
117 | loss_fn = nn.MSELoss()
118 |
119 | log_interval = cfg.log_interval
120 | epochs = cfg.epochs
121 |
122 | ###
123 | print('\nmodel has {} M parameters'.format(count_parameters(model)))
124 | print(f'\nloss_fn : {loss_fn}')
125 | print(f'lr : {lr}')
126 | print(f'epochs_till_now: {epochs_till_now}')
127 | print(f'epochs from now: {epochs}')
128 | ###
129 |
130 | for epoch in range(epochs_till_now, epochs_till_now+epochs):
131 | print('\n===== EPOCH {}/{} ====='.format(epoch + 1, epochs_till_now + epochs))
132 | print('\nTRAINING...')
133 | epoch_train_start_time = time.time()
134 | model.train()
135 | for batch_idx, (imgs, noisy_imgs) in enumerate(train_loader):
136 | batch_start_time = time.time()
137 | imgs = imgs.to(device)
138 | noisy_imgs = noisy_imgs.to(device)
139 |
140 | optimizer.zero_grad()
141 | out = model(noisy_imgs)
142 |
143 | loss = loss_fn(out, imgs)
144 | running_train_loss.append(loss.item())
145 | loss.backward()
146 | optimizer.step()
147 |
148 | if (batch_idx + 1)%log_interval == 0:
149 | batch_time = time.time() - batch_start_time
150 | m,s = divmod(batch_time, 60)
151 | print('train loss @batch_idx {}/{}: {} in {} mins {} secs (per batch)'.format(str(batch_idx+1).zfill(len(str(len(train_loader)))), len(train_loader), loss.item(), int(m), round(s, 2)))
152 |
153 | train_epoch_loss.append(np.array(running_train_loss).mean())
154 |
155 | epoch_train_time = time.time() - epoch_train_start_time
156 | m,s = divmod(epoch_train_time, 60)
157 | h,m = divmod(m, 60)
158 | print('\nepoch train time: {} hrs {} mins {} secs'.format(int(h), int(m), int(s)))
159 |
160 | print('\nVALIDATION...')
161 | epoch_val_start_time = time.time()
162 | model.eval()
163 | with torch.no_grad():
164 | for batch_idx, (imgs, noisy_imgs) in enumerate(val_loader):
165 |
166 | imgs = imgs.to(device)
167 | noisy_imgs = noisy_imgs.to(device)
168 |
169 | out = model(noisy_imgs)
170 | loss = loss_fn(out, imgs)
171 |
172 | running_val_loss.append(loss.item())
173 |
174 | if (batch_idx + 1)%log_interval == 0:
175 | print('val loss @batch_idx {}/{}: {}'.format(str(batch_idx+1).zfill(len(str(len(val_loader)))), len(val_loader), loss.item()))
176 |
177 | val_epoch_loss.append(np.array(running_val_loss).mean())
178 |
179 | epoch_val_time = time.time() - epoch_val_start_time
180 | m,s = divmod(epoch_val_time, 60)
181 | h,m = divmod(m, 60)
182 | print('\nepoch val time: {} hrs {} mins {} secs'.format(int(h), int(m), int(s)))
183 |
184 | plot_losses(running_train_loss, running_val_loss, train_epoch_loss, val_epoch_loss, epoch)
185 |
186 | torch.save({'model_state_dict': model.state_dict(),
187 | 'losses': {'running_train_loss': running_train_loss,
188 | 'running_val_loss': running_val_loss,
189 | 'train_epoch_loss': train_epoch_loss,
190 | 'val_epoch_loss': val_epoch_loss},
191 | 'epochs_till_now': epoch+1},
192 | os.path.join(models_dir, 'model{}.pth'.format(str(epoch + 1).zfill(2))))
193 |
194 | total_script_time = time.time() - script_time
195 | m, s = divmod(total_script_time, 60)
196 | h, m = divmod(m, 60)
197 | print(f'\ntotal time taken for running this script: {int(h)} hrs {int(m)} mins {int(s)} secs')
198 |
199 | print('\nFin.')
200 |
--------------------------------------------------------------------------------
/generate_synthetic_dataset.py:
--------------------------------------------------------------------------------
1 | import sys, os, glob, time, pdb, cv2
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from tqdm import tqdm
6 | import cv2
7 | import matplotlib.pyplot as plt
8 | import shutil
9 |
10 | import config as cfg
11 |
12 | def q(text = ''):
13 | print(f'>{text}<')
14 | sys.exit()
15 |
16 | data_dir = cfg.data_dir
17 | train_dir = cfg.train_dir
18 | val_dir = cfg.val_dir
19 |
20 | imgs_dir = cfg.imgs_dir
21 | noisy_dir = cfg.noisy_dir
22 | debug_dir = cfg.debug_dir
23 |
24 | train_data_dir = os.path.join(data_dir, train_dir)
25 | val_data_dir = os.path.join(data_dir, val_dir)
26 |
27 | if os.path.exists(data_dir):
28 | shutil.rmtree(data_dir)
29 |
30 | if not os.path.exists(data_dir):
31 | os.mkdir(data_dir)
32 |
33 | if not os.path.exists(train_data_dir):
34 | os.mkdir(train_data_dir)
35 |
36 | if not os.path.exists(val_data_dir):
37 | os.mkdir(val_data_dir)
38 |
39 | img_train_dir = os.path.join(data_dir, train_dir, imgs_dir)
40 | noisy_train_dir = os.path.join(data_dir, train_dir, noisy_dir)
41 | debug_train_dir = os.path.join(data_dir, train_dir, debug_dir)
42 |
43 | img_val_dir = os.path.join(data_dir, val_dir, imgs_dir)
44 | noisy_val_dir = os.path.join(data_dir, val_dir, noisy_dir)
45 | debug_val_dir = os.path.join(data_dir, val_dir, debug_dir)
46 |
47 | dir_list = [img_train_dir, noisy_train_dir, debug_train_dir, img_val_dir, noisy_val_dir, debug_val_dir]
48 | for dir_path in dir_list:
49 | if not os.path.exists(dir_path):
50 | os.mkdir(dir_path)
51 |
52 | def get_word_list():
53 | f = open(cfg.txt_file_dir, encoding='utf-8', mode="r")
54 | text = f.read()
55 | f.close()
56 | lines_list = str.split(text, '\n')
57 | while '' in lines_list:
58 | lines_list.remove('')
59 |
60 | lines_word_list = [str.split(line) for line in lines_list]
61 | words_list = [words for sublist in lines_word_list for words in sublist]
62 |
63 | return words_list
64 |
65 |
66 | words_list = get_word_list()
67 | print('\nnumber of words in the txt file: ', len(words_list))
68 |
69 | # list of all the font styles
70 | font_list = [cv2.FONT_HERSHEY_COMPLEX,
71 | cv2.FONT_HERSHEY_COMPLEX_SMALL,
72 | cv2.FONT_HERSHEY_DUPLEX,
73 | cv2.FONT_HERSHEY_PLAIN,
74 | cv2.FONT_HERSHEY_SIMPLEX,
75 | cv2.FONT_HERSHEY_TRIPLEX,
76 | cv2.FONT_ITALIC] # cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, cv2.FONT_HERSHEY_SCRIPT_COMPLEX, cursive
77 |
78 | # size of the synthetic images to be generated
79 | syn_h, syn_w = 64, 256
80 |
81 | # scale factor
82 | scale_h, scale_w = 4, 4
83 |
84 | # initial size of the image, scaled up by the factor of scale_h and scale_w
85 | h, w = syn_h*scale_h, syn_w*scale_w
86 |
87 | img_count = 1
88 | word_count = 0
89 | num_imgs = int(cfg.num_synthetic_imgs) # max number of synthetic images to be generated
90 | train_num = int(num_imgs*cfg.train_percentage) # training percent
91 | print('\nnum_imgs : ', num_imgs)
92 | print('train_num: ', train_num)
93 |
94 |
95 | word_start_x = 5 # min space left on the left side of the printed text
96 | word_start_y = 5
97 | word_end_y = 5 # min space left on the bottom side of the printed text
98 |
99 | def get_text():
100 | global word_count, words_list
101 | # text to be printed on the blank image
102 | num_words = np.random.randint(1,8)
103 |
104 | # renew the word list in case we run out of words
105 | if (word_count + num_words) >= len(words_list):
106 |
107 | print('===\nrecycling the words_list')
108 |
109 | words_list = get_word_list()
110 | word_count = 0
111 |
112 | print_text = ''
113 | for _ in range(num_words):
114 | print_text += str.split(words_list[word_count])[0] + ' '
115 | word_count += 1
116 | print_text = print_text.strip() # to get rif of the last space
117 | return print_text
118 |
119 | def get_text_height(img, fontColor):
120 | black_coords = np.where(img == fontColor)
121 | # finding the extremes of the printed text
122 | ymin = np.min(black_coords[0])
123 | ymax = np.max(black_coords[0])
124 | # xmin = np.min(black_coords[1])
125 | # xmax = np.max(black_coords[1])
126 | ''' # for vizualising
127 | cv2.line(img, (0,ymin),(1000,ymin), 0,2)
128 | cv2.line(img, (0,ymax),(1000,ymax), 0,2)
129 | cv2.imshow('ymax', img)
130 | '''
131 | return ymax - ymin
132 |
133 | def print_lines(img, font, bottomLeftCornerOfText, fontColor, fontScale, lineType, thickness):
134 |
135 | line_num = 0
136 | y_line_list = []
137 |
138 | # print('img.shape: ', img.shape)
139 | # print('initial bottomLeftCornerOfText: ', bottomLeftCornerOfText)
140 |
141 | while bottomLeftCornerOfText[1] <= img.shape[0]:
142 | # get a line of text
143 | print_text = get_text()
144 |
145 | # put it on a blank image and get its height
146 | if line_num == 0:
147 | # get the correct text height
148 | big_img = np.ones((500,300), dtype = np.uint8)*255
149 | big_img_text = print_text.upper()
150 | cv2.putText(img = big_img, text = big_img_text, org = (0,200), fontFace = font, fontScale = fontScale, color = fontColor, thickness = thickness, lineType = lineType)
151 | text_height = get_text_height(big_img, fontColor)
152 | # print('text_height: ', text_height)
153 |
154 | if text_height > bottomLeftCornerOfText[1]:
155 | bottomLeftCornerOfText = (bottomLeftCornerOfText[0], np.random.randint(word_start_y, int(img.shape[0]*0.5)) + text_height)
156 |
157 | cv2.putText(img = img, text = print_text.upper(), org = bottomLeftCornerOfText, fontFace = font, fontScale = fontScale, color = fontColor, thickness = thickness, lineType = lineType)
158 | y_line_list.append(bottomLeftCornerOfText[1])
159 | else:
160 | # sampling the chances of adding one more line of text
161 | one_more_line = np.random.choice([0, 1], p = [0.4, 0.6])
162 | if not one_more_line:
163 | break
164 | cv2.putText(img = img, text = print_text.upper(), org = bottomLeftCornerOfText, fontFace = font, fontScale = fontScale, color = fontColor, thickness = thickness, lineType = lineType)
165 | y_line_list.append(bottomLeftCornerOfText[1])
166 | # calculate the (text_height+line break space) left on the bottom
167 | bottom_space_left = int(text_height*(1 + np.random.randint(20, 40)/100))
168 | # bottom_space_left = int(text_height*(1))
169 | # print('bottom_space_left: ', bottom_space_left)
170 |
171 | # update the bottomLeftCornerOfText
172 | bottomLeftCornerOfText = (bottomLeftCornerOfText[0], bottomLeftCornerOfText[1] + bottom_space_left)
173 |
174 | # print('bottomLeftCornerOfText: ', bottomLeftCornerOfText)
175 | line_num += 1
176 | '''
177 | for l in y_line_list:
178 | cv2.line(img, (0, l), (1000, l), 0, 1)
179 | '''
180 | return img, y_line_list, text_height
181 |
182 | def get_noisy_img(img, y_line_list, text_height):
183 |
184 | # adding noise (horizontal and vertical lines) on the image containing text
185 | noisy_img = img.copy()
186 |
187 | # adding horizontal line (noise)
188 | for y_line in y_line_list:
189 |
190 | # samples the possibility of adding a horizontal line
191 | add_horizontal_line = np.random.choice([0, 1], p = [0.5, 0.5])
192 | if not add_horizontal_line:
193 | continue
194 |
195 | # shift y_line randomly in the y-axis within a defined limit
196 | limit = int(text_height*0.3)
197 | if limit == 0: # this happens when the text used for getting the text height is '-', ',', '=' and other little symbols like these
198 | limit = 10
199 | y_line += np.random.randint(-limit, limit)
200 |
201 | h_start_x = 0 #np.random.randint(0,xmin) # min x of the horizontal line
202 | h_end_x = np.random.randint(int(noisy_img.shape[1]*0.8), noisy_img.shape[1]) # max x of the horizontal line
203 | h_length = h_end_x - h_start_x + 1
204 | num_h_lines = np.random.randint(10,30) # partitions to be made in the horizontal line (necessary to make it look like naturally broken lines)
205 | h_lines = []
206 | h_start_temp = h_start_x
207 | next_line = True
208 |
209 | num_line = 0
210 | while (next_line) and (num_line < num_h_lines):
211 | if h_start_temp < h_end_x:
212 | h_end_temp = np.random.randint(h_start_temp + 1, h_end_x + 1)
213 | if h_end_temp < h_end_x:
214 | h_lines.append([h_start_temp, h_end_temp])
215 | h_start_temp = h_end_temp + 1
216 | num_line += 1
217 | else:
218 | h_lines.append([h_start_temp, h_end_x])
219 | num_line += 1
220 | next_line = False
221 | else:
222 | next_line = False
223 |
224 | for h_line in h_lines:
225 | col = np.random.choice(['black', 'white'], p = [0.65, 0.35]) # probabilities of line segment being a solid one or a broken one
226 | if col == 'black':
227 | x_points = list(range(h_line[0], h_line[1] + 1))
228 | x_points_black_prob = np.random.choice([0,1], size = len(x_points), p = [0.2, 0.8])
229 |
230 | for idx, x in enumerate(x_points):
231 | if x_points_black_prob[idx]:
232 | noisy_img[ y_line - np.random.randint(4): y_line + np.random.randint(4), x] = np.random.randint(0,30)
233 |
234 | # adding vertical line (noise)
235 | vertical_bool = {'left': np.random.choice([0,1], p =[0.3, 0.7]), 'right': np.random.choice([0,1])} # [1 or 0, 1 or 0] whether to make vertical left line on left and right side of the image
236 | for left_right, bool_ in vertical_bool.items():
237 | if bool_:
238 | # print('left_right: ', left_right)
239 | if left_right == 'left':
240 | v_start_x = np.random.randint(5, int(noisy_img.shape[1]*0.06))
241 | else:
242 | v_start_x = np.random.randint(int(noisy_img.shape[1]*0.95), noisy_img.shape[1] - 5)
243 |
244 | v_start_y = np.random.randint(0, int(noisy_img.shape[0]*0.06))
245 | v_end_y = np.random.randint(int(noisy_img.shape[0]*0.95), noisy_img.shape[0])
246 |
247 | y_points = list(range(v_start_y, v_end_y + 1))
248 | y_points_black_prob = np.random.choice([0,1], size = len(y_points), p = [0.2, 0.8])
249 |
250 | for idx, y in enumerate(y_points):
251 | if y_points_black_prob[idx]:
252 | noisy_img[y, v_start_x - np.random.randint(4): v_start_x + np.random.randint(4)] = np.random.randint(0,30)
253 |
254 | return noisy_img
255 |
256 | def degrade_qualities(img, noisy_img):
257 | '''
258 | This function takes in a couple of images (color or grayscale), downsizes it to a
259 | randomly chosen size and then resizes it to the original size,
260 | degrading the quality of the images in the process.
261 | '''
262 | h, w = img.shape[0], img.shape[1]
263 | fx=np.random.randint(50,100)/100
264 | fy=np.random.randint(50,100)/100
265 | # print('fx, fy: ', fx, fy)
266 | img_small = cv2.resize(img, (0,0), fx = fx, fy = fy)
267 | img = cv2.resize(img_small,(w,h))
268 |
269 | noisy_img_small = cv2.resize(noisy_img, (0,0), fx = fx, fy = fy)
270 | noisy_img = cv2.resize(noisy_img_small,(w,h))
271 |
272 | return img, noisy_img
273 |
274 | def get_debug_image(img, noisy_img):
275 | debug_img = np.ones((2*h, w), dtype = np.uint8)*255 # to visualize the generated images (clean and noisy)
276 | debug_img[0:h, :] = img
277 | debug_img[h:2*h, :] = noisy_img
278 | cv2.line(debug_img, (0, h), (debug_img.shape[1], h), 150, 5)
279 | return debug_img
280 |
281 | def erode_dilate(img, noisy_img):
282 | # erode the image
283 | kernel = np.ones((3,3),np.uint8)
284 | erosion_iteration = np.random.randint(1,3)
285 | dilate_iteration = np.random.randint(0,2)
286 | img = cv2.erode(img,kernel,iterations = erosion_iteration)
287 | noisy_img = cv2.erode(noisy_img,kernel,iterations = erosion_iteration)
288 | img = cv2.dilate(img,kernel,iterations = dilate_iteration)
289 | noisy_img = cv2.dilate(noisy_img,kernel,iterations = dilate_iteration)
290 | return img, noisy_img
291 |
292 | def write_images(img, noisy_img, debug_img):
293 | global img_count
294 |
295 | img = 255 - cv2.resize(img, (0,0), fx = 1/scale_w, fy = 1/scale_h)
296 | noisy_img = 255 - cv2.resize(noisy_img, (0,0), fx = 1/scale_w, fy = 1/scale_h)
297 | debug_img = 255 - cv2.resize(debug_img, (0,0), fx = 1/scale_w, fy = 1/scale_h)
298 |
299 | if img_count <= train_num:
300 | cv2.imwrite(os.path.join(data_dir, train_dir, imgs_dir, '{}.jpg'.format(str(img_count).zfill(6))), img)
301 | cv2.imwrite(os.path.join(data_dir, train_dir, noisy_dir, '{}.jpg'.format(str(img_count).zfill(6))), noisy_img)
302 | cv2.imwrite(os.path.join(data_dir, train_dir, debug_dir, '{}.jpg'.format(str(img_count).zfill(6))), debug_img)
303 | else:
304 | cv2.imwrite(os.path.join(data_dir, val_dir, imgs_dir, '{}.jpg'.format(str(img_count).zfill(6))), img)
305 | cv2.imwrite(os.path.join(data_dir, val_dir, noisy_dir, '{}.jpg'.format(str(img_count).zfill(6))), noisy_img)
306 | cv2.imwrite(os.path.join(data_dir, val_dir, debug_dir, '{}.jpg'.format(str(img_count).zfill(6))), debug_img)
307 |
308 | img_count += 1
309 |
310 | print('\nsynthesizing image data...')
311 | for i in tqdm(range(num_imgs)):
312 | # make a blank image
313 | img = np.ones((h, w), dtype = np.uint8)*255
314 |
315 | # set random parameters
316 | font = font_list[np.random.randint(len(font_list))]
317 | bottomLeftCornerOfText = (np.random.randint(word_start_x, int(img.shape[1]/3)), np.random.randint(0, int(img.shape[0]*0.8))) # (x, y)
318 | fontColor = np.random.randint(0,30)
319 | fontScale = np.random.randint(1800, 2400)/1000
320 | lineType = np.random.randint(1,3)
321 | thickness = np.random.randint(1,7)
322 |
323 | # put text
324 | img, y_line_list, text_height = print_lines(img, font, bottomLeftCornerOfText, fontColor, fontScale, lineType, thickness)
325 |
326 | # add noise
327 | noisy_img = get_noisy_img(img, y_line_list, text_height)
328 |
329 | # degrade_quality
330 | img, noisy_img = degrade_qualities(img, noisy_img)
331 |
332 | # morphological operations
333 | img, noisy_img = erode_dilate(img, noisy_img)
334 |
335 | # make debug image
336 | debug_img = get_debug_image(img, noisy_img)
337 |
338 | # write images
339 | write_images(img, noisy_img, debug_img)
340 |
341 | '''
342 | cv2.imshow('textonimage', img)
343 | cv2.imshow('noisy_img', noisy_img)
344 | cv2.waitKey()
345 | '''
346 |
--------------------------------------------------------------------------------