├── Data ├── real │ └── .gitkeep └── syn │ └── .gitkeep ├── images └── .gitkeep ├── saved_models └── .gitkeep ├── test_images └── .gitkeep ├── main.py ├── README.md ├── cfg.py ├── losses.py ├── evaluate.py ├── inference.py ├── creat_mask_for_scut_syn.py ├── train.py ├── ssim.py ├── utils.py ├── dataset.py └── model.py /Data/real/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Data/syn/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /saved_models/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test_images/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from cfg import Config as C 3 | from model import TFPNet 4 | from losses import * 5 | import torch 6 | from dataset import * 7 | 8 | 9 | def main() 10 | epochs=C.epochs 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | net = TFPNet() 13 | net.to(device) 14 | optimizer = torch.optim.Adam(net.parameters(), lr=C.lr) 15 | loss = Loss().to(device) 16 | train_loader, test_loader = get_loaders() 17 | train(net, epochs, train_loader, test_loader, optimizer, loss) 18 | 19 | if __name__ == "__main__": 20 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TFPNet 2 | 3 | If you want to train with SCUT-8k dataset then firtsly run `creat_mask_for_scut_syn.py` 4 | 5 | ## Sample output of TPFNet 6 | 7 | The following two sample videos show the results of TPFNet network. In these videos, the left-most frame shows the original video. The middle frame shows the video after text-erasure by TPFNet. The right-most frame shows the attention map of text. 8 | 9 | Poem video: https://drive.google.com/file/d/11l0STpYpxlxsaXXwXiJS1gyYrTiH37ho/view?usp=share_link 10 | 11 | Sports news video: https://drive.google.com/file/d/1-6ra8poRShUXkNe5p0QUfT4k8KRX_EVa/view 12 | -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | # paths for synthetic train dataset 3 | train_x_syn = 'Data/syn/syn_train/img/*.png' 4 | train_y_syn = 'Data/syn/syn_train/label/*.png' 5 | train_mask_syn = 'Data/syn/train/all_gts/*.txt' 6 | # paths for synthetic test Dataset 7 | test_x_syn = 'Data/syn/syn_test/img/*.png' 8 | test_y_syn = 'Data/syn/syn_test/label/*.png' 9 | test_mask_syn = 'Data/syn/test/all_gts/*.txt' 10 | 11 | # paths for SCUT-real train Dataset 12 | train_x = 'Data/real/train/all_images/*.jpg' 13 | train_y = 'Data/real/train/all_labels/*.jpg' 14 | train_mask = 'Data/real/train/all_gts/*.txt' 15 | 16 | # paths for SCUT-real test Dataset 17 | test_x = 'Data/real/test/all_images/*.jpg' 18 | test_y = 'Data/real/test/all_labels/*.jpg' 19 | test_mask = 'Data/real/test/all_gts/*.txt' 20 | 21 | 22 | # train conig 23 | batch_size = 32 24 | epochs=400 25 | saved_model_path = 'saved_model/' 26 | num_worker=3 27 | lr=1e-4 28 | 29 | 30 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class DiceLoss(nn.Module): 4 | def __init__(self, weight=None, size_average=True): 5 | super(DiceLoss, self).__init__() 6 | 7 | def forward(self, inputs, targets, smooth=1e-7): 8 | 9 | inputs = inputs.view(-1) 10 | targets = targets.view(-1) 11 | 12 | intersection = (inputs * targets).sum() 13 | dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) 14 | 15 | return 1 - dice 16 | 17 | class Loss(nn.Module): 18 | def __init__(self): 19 | super(Loss, self).__init__() 20 | 21 | self.l1 = nn.L1Loss() 22 | self.l2 = nn.L1Loss() 23 | self.l3 = nn.L1Loss() 24 | self.bce = nn.BCELoss() 25 | self.dice = DiceLoss() 26 | 27 | def forward(self, pred1, pred2, pred3, y1, y2, y3): 28 | 29 | l1 = self.l1(pred1, y1) 30 | l2 = self.l2(pred2, y2) 31 | l3 = self.l3(pred3, y3) 32 | 33 | bi = self.bce(pred1, y1) 34 | dice = self.dice(pred1, y1) 35 | loss = (l1 + l2 + l3)/3.0 36 | return loss + 0.5*(bi + dice) -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from ssim import ssim 3 | from model import TFPNet 4 | from dataset import * 5 | import glob 6 | from cfg import Config as C 7 | from torch.utils.data import DataLoader 8 | import torch 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | x_test = glob.glob(C.test_x) 12 | y_test = glob.glob(C.test_y) 13 | 14 | test_ds = TestDeTextDataset(x_test, y_test) 15 | test_loader = DataLoader(test_ds, 1, shuffle=False) 16 | 17 | checkpoint = torch.load(C.saved_model_path+'weight_best.pth') 18 | model = TFPNet() 19 | model.load_state_dict(checkpoint['weights']) 20 | model.to(device) 21 | 22 | 23 | psnr1 = [] 24 | ssim1 = [] 25 | 26 | with torch.no_grad(): 27 | net.eval() 28 | for i, data in tqdm(enumerate(test_loader)): 29 | 30 | x = data[0].to(device) 31 | y = data[1].to(device) 32 | 33 | _, _, pred = net(x) 34 | 35 | ps = psnr(pred, y) 36 | sm = ssim(pred, y) 37 | 38 | psnr1.append(ps) 39 | ssim1.append(sm) 40 | 41 | print('PSNR: ',sum(psnr1)/len(psnr1)) 42 | print('SSIM: ',sum(ssim1)/len(ssim1)) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | from model import TFPNet 4 | from cfg import Config as C 5 | 6 | model = TFPNet() 7 | checkpoint = torch.load(C.saved_model_path+'weight_best.pth') 8 | model.load_state_dict(checkpoint['weights']) 9 | 10 | 11 | 12 | def get_image_predections(img_path, size): 13 | 14 | with torch.no_grad(): 15 | f, axrr = plt.subplots(1,3, figsize=(25, 25)) 16 | 17 | img = cv2.imread(img_path) 18 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 19 | img = cv2.resize(img, (size, size)) 20 | 21 | axrr[0].imshow(img) 22 | image = img 23 | 24 | img = img/255. 25 | img = torch.from_numpy(img.astype('float32')).permute(2, 0, 1).unsqueeze(0).to(device) 26 | 27 | pred1, pred2, pred3 = net(img) 28 | 29 | pred3 = pred3.detach().cpu().clamp(0,1).squeeze(0).permute(1, 2, 0).numpy() 30 | pred1 = pred1.detach().cpu().clamp(0,1).squeeze(0).permute(1, 2, 0).squeeze(-1).numpy() 31 | 32 | 33 | axrr[1].imshow(pred3) 34 | axrr[2].imshow(pred1>0.4) 35 | plt.show() 36 | 37 | if __name__ == '__main__': 38 | get_image_predections 39 | 40 | -------------------------------------------------------------------------------- /creat_mask_for_scut_syn.py: -------------------------------------------------------------------------------- 1 | import glob2 2 | import os 3 | import cv2 4 | import numpy as np 5 | # from skimage.measure import compare_ssim 6 | import imutils 7 | 8 | input_dir = 'Data/syn/syn_train' 9 | output_dir = os.path.join(input_dir, 'bbox') 10 | 11 | if not os.path.exists(output_dir): 12 | os.mkdir(output_dir) 13 | 14 | images = glob2.glob(os.path.join(input_dir, 'img', '*.png')) 15 | print("#item: ", len(images)) 16 | DEGBUG = False 17 | 18 | for i, item in enumerate(images): 19 | label_img_path = os.path.join(input_dir, 'label', os.path.basename(item)) 20 | img = cv2.imread(item) 21 | label_img = cv2.imread(label_img_path) 22 | gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 23 | gray_label_img = cv2.cvtColor(label_img, cv2.COLOR_BGR2GRAY) 24 | 25 | diff = np.abs(gray_img.astype(np.float32) - gray_label_img.astype(np.float32)) 26 | #diff = np.mean(diff, axis=2) 27 | diff_threshold = 50 28 | diff[diff < diff_threshold] = 1 29 | diff[diff >= diff_threshold] = 0 30 | diff = (diff * 255).astype("uint8") 31 | 32 | thresh = cv2.threshold(diff, 0, 255, 33 | cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)[1] 34 | cnts = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL, 35 | cv2.CHAIN_APPROX_SIMPLE) 36 | cnts = imutils.grab_contours(cnts) 37 | 38 | # loop over the contours 39 | cnt_list = [] 40 | bbox_size_threshold = 50 41 | for c in cnts: 42 | 43 | (x, y, w, h) = cv2.boundingRect(c) 44 | if w*h < bbox_size_threshold: 45 | continue 46 | cnt_list.append((x, y, x + w, y + h)) 47 | if DEGBUG: 48 | cv2.rectangle(img, (x, y), (x + w, y + h), (0, 0, 255), 2) 49 | cv2.rectangle(label_img, (x, y), (x + w, y + h), (0, 0, 255), 2) 50 | 51 | # show the output images 52 | # cv2.imshow("Original", gray_img) 53 | # cv2.imshow("Modified", gray_label_img) 54 | if DEGBUG: 55 | cv2.imshow("Original", img) 56 | cv2.imshow("Modified", label_img) 57 | cv2.imshow("Diff", diff) 58 | cv2.imshow("Thresh", thresh) 59 | cv2.waitKey(0) 60 | 61 | with open(os.path.join(output_dir, os.path.basename(item).split('.')[0] + ".txt"), 'w') as f: 62 | for item in cnt_list: 63 | f.write(','.join([str(i) for i in item]) + '\n') -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from cfg import Config as C 3 | from utils import * 4 | from ssim import ssim 5 | 6 | def train(model, epochs, train_loader, test_loader, optimizer, critean): 7 | for epoch in range(epochs, 500): 8 | print('Epoch: {}'.format(epoch)) 9 | l_train = [] 10 | ps_train = [] 11 | l_test = [] 12 | ps_test = [] 13 | mss = [] 14 | model.train() 15 | current_psnr = 15 16 | 17 | for i, data in tqdm(enumerate(train_loader)): 18 | x, y1, y2, y3 = data[0].to(device), data[1].to(device), data[2].to(device), data[3].to(device) 19 | optimizer.zero_grad() 20 | pred1, pred2, pred3 = model(x) 21 | loss = critean(pred1, pred2, pred3, y1, y2, y3) 22 | # ssim1 = ssim(pred3, y3) 23 | ssim2 = ssim(pred2, y2) 24 | # ssim3 = ssim(pred1, y1) 25 | loss_ssim = (1 - ssim2)*2 26 | loss += loss_ssim 27 | 28 | psnr1=psnr(pred3.detach(),y3) 29 | loss.backward() 30 | optimizer.step() 31 | l_train.append(loss.item()) 32 | # mss.append(ssim1) 33 | ps_train.append(psnr1) 34 | print("Epoch loss: ", sum(l_train)/len(l_train)) 35 | print('Epoch {} PSNR: '.format(epoch), sum(ps_train)/len(ps_train)) 36 | 37 | 38 | with torch.no_grad(): 39 | model.eval() 40 | mss_val = [] 41 | for i, data in tqdm(enumerate(test_loader)): 42 | x, y1, y1, y3 = data[0].to(device), data[1].to(device), data[2].to(device), data[3].to(device) 43 | pred1, pred2, pred3 = model(x) 44 | 45 | psnr1=psnr(pred3,y3) 46 | val_mss = ssim(pred3, y3) 47 | 48 | mss_val.append(val_mss) 49 | ps_test.append(psnr1) 50 | 51 | 52 | print('Val Epoch {} PSNR: '.format(epoch), sum(ps_test)/len(ps_test)) 53 | print('VAL SSIM: ', sum(mss_val)/len(mss_val)) 54 | 55 | if current_psnr < sum(ps_test)/len(ps_test): 56 | checkpoint = { 57 | 'weights': model.state_dict(), 58 | 'optimizer':optimizer.state_dict() 59 | } 60 | print("saving best one.....") 61 | torch.save(checkpoint, C.saved_model_path + 'weight_best.pth') 62 | current_psnr = sum(ps_test)/len(ps_test) 63 | 64 | 65 | if (epoch+1) % 10 == 0: 66 | checkpoint = { 67 | 'weights': model.state_dict(), 68 | 'optimizer':optimizer.state_dict() 69 | } 70 | torch.save(checkpoint, C.saved_model_path+"checkpoint_SCUT{}.pth".format(epoch+1)) -------------------------------------------------------------------------------- /ssim.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F3 4 | 5 | def gaussian(window_size, sigma): 6 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 7 | return gauss/gauss.sum() 8 | 9 | def create_window(window_size, channel): 10 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 11 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 12 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 13 | return window 14 | 15 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 16 | mu1 = F3.conv2d(img1, window, padding = window_size//2, groups = channel) 17 | mu2 = F3.conv2d(img2, window, padding = window_size//2, groups = channel) 18 | 19 | mu1_sq = mu1.pow(2) 20 | mu2_sq = mu2.pow(2) 21 | mu1_mu2 = mu1*mu2 22 | 23 | sigma1_sq = F3.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 24 | sigma2_sq = F3.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 25 | sigma12 = F3.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 26 | 27 | C1 = 0.01**2 28 | C2 = 0.03**2 29 | 30 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 31 | 32 | if size_average: 33 | return ssim_map.mean() 34 | else: 35 | return ssim_map.mean(1).mean(1).mean(1) 36 | 37 | class SSIM(torch.nn.Module): 38 | def __init__(self, window_size = 11, size_average = True): 39 | super(SSIM, self).__init__() 40 | self.window_size = window_size 41 | self.size_average = size_average 42 | self.channel = 1 43 | self.window = create_window(window_size, self.channel) 44 | 45 | def forward(self, img1, img2): 46 | (_, channel, _, _) = img1.size() 47 | 48 | if channel == self.channel and self.window.data.type() == img1.data.type(): 49 | window = self.window 50 | else: 51 | window = create_window(self.window_size, channel) 52 | 53 | if img1.is_cuda: 54 | window = window.cuda(img1.get_device()) 55 | window = window.type_as(img1) 56 | 57 | self.window = window 58 | self.channel = channel 59 | 60 | 61 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 62 | 63 | def ssim(img1, img2, window_size = 11, size_average = True): 64 | (_, channel, _, _) = img1.size() 65 | window = create_window(window_size, channel) 66 | 67 | if img1.is_cuda: 68 | window = window.cuda(img1.get_device()) 69 | window = window.type_as(img1) 70 | 71 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from PIL import Image, ImageDraw 4 | import albumentations as A 5 | import math 6 | 7 | 8 | def poly_to_mask(poly): 9 | filee = open(poly, 'r') 10 | mask = np.zeros((512, 512)) 11 | lines = filee.readlines() 12 | for line in lines: 13 | line = line.replace('\n', '') 14 | line = line.split(',') 15 | line = [int(i) for i in line] 16 | 17 | polygon = line 18 | width = 512 19 | height = 512 20 | 21 | img = Image.fromarray(np.zeros((512, 512), dtype='uint8')) 22 | ImageDraw.Draw(img).polygon(polygon, outline=1, fill=1) 23 | mask += np.array(img) 24 | mask = np.expand_dims((mask > 0).astype('uint8'), axis=2) 25 | 26 | return np.concatenate((mask, mask, mask), axis=2)*255 27 | 28 | 29 | def bbox_to_mask(img, txt): 30 | img = cv2.imread(img) 31 | mm = open(txt, 'r') 32 | mask = np.zeros((img.shape[0],img.shape[1]),dtype=np.uint8) 33 | 34 | for i in mm.readlines(): 35 | i = i.split(',') 36 | i = [int(k) for k in i] 37 | mask[i[1]:i[3], i[0]:i[2]] = 255 38 | return mask 39 | 40 | 41 | def transforms(x1, x2, x3): 42 | 43 | if random.uniform(0,1) > 0.4: 44 | t2 = A.HorizontalFlip(p=1) 45 | x1 = t2(image=x1) 46 | x1 = x1['image'] 47 | x2 = t2(image=x2) 48 | x2 = x2['image'] 49 | x3 = t2(image=x3, mask=x3) 50 | x3 = x3['mask'] 51 | 52 | 53 | elif random.uniform(0,1) > 0.4: 54 | t2 = A.RandomBrightnessContrast(p=1) 55 | x1 = t2(image=x1) 56 | x1 = x1['image'] 57 | x2 = t2(image=x2) 58 | x2 = x2['image'] 59 | x3 = t2(image=x3, mask=x3) 60 | x3 = x3['mask'] 61 | 62 | 63 | elif random.uniform(0,1) > 0.5: 64 | t2 = A.ElasticTransform(p=1, alpha=120, sigma=120 * 0.5, alpha_affine=120 * 0.3) 65 | x1 = t2(image=x1) 66 | x1 = x1['image'] 67 | x2 = t2(image=x2) 68 | x2 = x2['image'] 69 | x3 = t2(image=x3, mask=x3) 70 | x3 = x3['mask'] 71 | 72 | 73 | 74 | elif random.uniform(0,1) > 0.5: 75 | t2 = A.GridDistortion(p=1) 76 | x1 = t2(image=x1) 77 | x1 = x1['image'] 78 | x2 = t2(image=x2) 79 | x2 = x2['image'] 80 | x3 = t2(image=x3, mask=x3) 81 | x3 = x3['mask'] 82 | 83 | 84 | elif random.uniform(0,1) > 0.: 85 | t2 = A.OpticalDistortion(distort_limit=8, shift_limit=0.7, p=1) 86 | x1 = t2(image=x1) 87 | x1 = x1['image'] 88 | x2 = t2(image=x2) 89 | x2 = x2['image'] 90 | x3 = t2(image=x3, mask=x3) 91 | x3 = x3['mask'] 92 | 93 | return x1, x2, x3 94 | 95 | 96 | def psnr(pred, gt): 97 | pred=pred.clamp(0,1).detach().cpu().numpy() 98 | gt=gt.clamp(0,1).detach().cpu().numpy() 99 | imdff = pred - gt 100 | rmse = math.sqrt(np.mean(imdff ** 2)) 101 | if rmse == 0: 102 | return 100 103 | return 20 * math.log10( 1.0 / rmse) 104 | 105 | 106 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from cfg import Config as C 6 | from utils import * 7 | 8 | class DeTextDataset(Dataset): 9 | def __init__(self, img_paths=None, mask_paths=None, poly_paths=None, a=True, test=False): 10 | self.a = a 11 | self.img_paths = img_paths 12 | self.mask_paths = mask_paths 13 | self.poly_paths = poly_paths 14 | assert len(self.img_paths) == len(self.mask_paths) 15 | self.images = len(self.img_paths) #list all the files present in that folder... 16 | self.test = test 17 | 18 | def __len__(self): 19 | return len(self.img_paths) #length of dataset 20 | 21 | def Lowpass(self, img): 22 | temp = img.copy() 23 | dst = cv2.GaussianBlur(temp,(5,5),cv2.BORDER_DEFAULT) 24 | return dst 25 | 26 | def Highpass(self, img): 27 | temp = img.copy() 28 | dst = cv2.GaussianBlur(temp, (3,3) ,0) 29 | source_gray = cv2.cvtColor(dst, cv2.COLOR_BGR2GRAY) 30 | dest = cv2.Laplacian(source_gray, cv2.CV_16S, ksize=3) 31 | abs_dest = cv2.convertScaleAbs(dest) 32 | return abs_dest 33 | 34 | def __getitem__(self, index): 35 | img_path = self.img_paths[index] 36 | mask_path = self.mask_paths[index] 37 | poly_path = self.poly_paths[index] 38 | 39 | image = cv2.imread(img_path) 40 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 41 | image = cv2.resize(image, (512, 512)) 42 | 43 | mask = cv2.imread(mask_path) 44 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB) 45 | mask = cv2.resize(mask, (512, 512)) 46 | 47 | poly = poly_to_mask(poly_path) 48 | if self.a: 49 | image, mask, poly = transforms(image, mask, poly) 50 | 51 | highpass = self.Highpass(mask) 52 | 53 | image = image.astype(np.float32) 54 | image = image/255.0 55 | image = torch.from_numpy(image) 56 | image = image.permute(2,0,1) 57 | 58 | poly = poly.astype('float32') 59 | poly = torch.from_numpy(poly) 60 | poly = (poly.permute(2, 0, 1) / 255.0) 61 | 62 | 63 | highpass = highpass.astype(np.float32) 64 | highpass = highpass[:,:,np.newaxis] 65 | highpass = highpass/255.0 66 | highpass = torch.from_numpy(highpass) 67 | highpass = highpass.permute(2,0,1) 68 | 69 | mask = mask.astype(np.float32) 70 | mask = mask/255.0 71 | mask = torch.from_numpy(mask) 72 | mask = mask.permute(2,0,1) 73 | 74 | return image, poly[0:1, :, :], highpass, mask 75 | 76 | 77 | class TestDeTextDataset(Dataset): 78 | def __init__(self, img_paths=None, mask_paths=None, size=1024): 79 | 80 | self.img_paths = img_paths 81 | self.mask_paths = mask_paths 82 | self.size = size 83 | assert len(self.img_paths) == len(self.mask_paths) 84 | 85 | 86 | def __len__(self): 87 | return len(self.img_paths) 88 | 89 | def __getitem__(self, index): 90 | img_path = self.img_paths[index] 91 | mask_path = self.mask_paths[index] 92 | 93 | image = cv2.imread(img_path) 94 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 95 | image = cv2.resize(image, (self.size, self.size)) 96 | 97 | image = image.astype(np.float32) 98 | image = image/255.0 99 | image = torch.from_numpy(image) 100 | image = image.permute(2,0,1) 101 | 102 | mask = cv2.imread(mask_path) 103 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB) 104 | mask = cv2.resize(mask, (self.size, self.size)) 105 | 106 | mask = mask.astype(np.float32) 107 | mask = mask/255.0 108 | mask = torch.from_numpy(mask) 109 | mask = mask.permute(2,0,1) 110 | 111 | return image, mask 112 | 113 | 114 | def get_loaders(): 115 | 116 | x_train=glob.glob(C.train_x) 117 | y_train=glob.glob(C.train_y) 118 | mask_trian=glob.glob(C.train_mask) 119 | 120 | x_test=glob.glob(C.test_x) 121 | y_test=glob.glob(C.test_y) 122 | mask_test=glob.glob(C.mask_test) 123 | 124 | train_ds = DeTextDataset(x_train, y_train, 125 | mask_train, 126 | a=True, test=False) 127 | 128 | test_ds = DeTextDataset(x_test, y_test, 129 | mask_test, 130 | a=False, test=True) 131 | 132 | 133 | train_loader = DataLoader(train_ds, batch_size=C.batch_size, num_workers=C.num_worker, shuffle=True) 134 | test_loader = DataLoader(test_ds, batch_size=C.batch_size, num_workers=C.num_worker, shuffle=False) 135 | 136 | return train_loader, test_loader 137 | 138 | 139 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import * 4 | 5 | 6 | class backbone(nn.Module): 7 | def __init__(self): 8 | super(backbone, self).__init__() 9 | 10 | model = models.efficientnet_b6(pretrained=True) 11 | m1 = list(model.features.children()) 12 | self.l1 = nn.Sequential(*list(model.features.children())[:2]) 13 | self.l2 = nn.Sequential(*list(model.features[2][:3])) 14 | self.l3 = nn.Sequential(*list(model.features[3][:4])) 15 | self.l4 = nn.Sequential(*list(model.features[4][:4])) 16 | self.l5 = nn.Sequential(*list(model.features[5][:2])) 17 | self.l6 = nn.Sequential(*list(model.features[6][:2])) 18 | 19 | def forward(self, x): 20 | x1 = self.l1(x) 21 | x2 = self.l2(x1) 22 | x3 = self.l3(x2) 23 | x4 = self.l4(x3) 24 | x5 = self.l5(x4) 25 | x6 = self.l6(x5) 26 | return [x1, x2, x3, x5, x6] 27 | 28 | 29 | class ConvbnGelu(nn.Module): 30 | def __init__(self, inchannels, outchannels, kernel_size=3, stride=1, padding=1): 31 | super(ConvbnGelu, self).__init__() 32 | self.conv = nn.Sequential( 33 | nn.Conv2d(inchannels, outchannels, kernel_size, stride, padding), 34 | nn.GELU(), 35 | nn.BatchNorm2d(outchannels) 36 | ) 37 | def forward(self, x): 38 | return self.conv(x) 39 | 40 | class SEAttention(nn.Module): #it gives channel attention 41 | def __init__(self, in_channels, reduced_dim=16): #input_shape ---> output_shape 42 | super(SEAttention, self).__init__() 43 | self.se = nn.Sequential( 44 | nn.AdaptiveAvgPool2d(1), # C x H x W -> C x 1 x 1 45 | nn.Conv2d(in_channels, reduced_dim, 1), 46 | nn.SiLU(), 47 | nn.Conv2d(reduced_dim, in_channels, 1), 48 | nn.Sigmoid(), 49 | ) 50 | 51 | def forward(self, x): 52 | return x * self.se(x) 53 | 54 | class ResidualBlock(nn.Module): 55 | def __init__(self, in_c, out_c): 56 | super(ResidualBlock, self).__init__() 57 | 58 | self.c1 = ConvbnGelu(inchannels=in_c, outchannels=out_c) 59 | self.c2 = ConvbnGelu(inchannels=out_c, outchannels=out_c) 60 | self.c3 = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0) 61 | self.bn3 = nn.BatchNorm2d(out_c) 62 | self.se = SEAttention(in_channels=out_c) 63 | self.relu = nn.ReLU(inplace=True) 64 | 65 | def forward(self, x): 66 | # print(x.shape) 67 | x1 = self.c1(x) 68 | x2 = self.c2(x1) 69 | x3 = self.c3(x) 70 | x3 = self.bn3(x3) 71 | x3 = self.se(x3) 72 | x4 = x2 + x3 73 | x4 = self.relu(x4) 74 | return x4 75 | 76 | class EncoderBlock(nn.Module): 77 | def __init__(self, in_c, out_c): 78 | super(EncoderBlock, self).__init__() 79 | 80 | self.r1 = ResidualBlock(in_c, out_c) 81 | self.pool = nn.AvgPool2d(2, stride=2) 82 | 83 | def forward(self, x): 84 | x = self.r1(x) 85 | p = self.pool(x) 86 | return x, p 87 | 88 | class DecoderBlock(nn.Module): 89 | def __init__(self, in_c, skip_c, out_c): 90 | super(DecoderBlock, self).__init__() 91 | 92 | self.upsample = nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1) 93 | self.r1 = ResidualBlock(skip_c+out_c, out_c) 94 | # self.r2 = ResidualBlock(out_c, out_c) 95 | 96 | def forward(self, x, s): 97 | x = self.upsample(x) 98 | x = torch.cat([x, s], axis=1) 99 | x = self.r1(x) 100 | # x = self.r2(x) 101 | return x 102 | 103 | class GBlock(nn.Module): 104 | def __init__(self, in_c, out_c): 105 | super(GBlock, self).__init__() 106 | 107 | self.c = nn.Conv2d(in_c+in_c, out_c, 1) 108 | self.sig = nn.Sigmoid() 109 | 110 | def forward(self, x1, x2, x3): 111 | x = torch.cat([x1, x2], axis=1) 112 | x = self.c(x) 113 | x = self.sig(x) 114 | x = x * x3 115 | return x 116 | 117 | class TFPNet(nn.Module): 118 | def __init__(self): 119 | super(TFPNet, self).__init__() 120 | 121 | self.backbone = backbone() 122 | 123 | self.ld1 = DecoderBlock(in_c=344, skip_c=200, out_c=256) 124 | self.hd1 = DecoderBlock(in_c=344, skip_c=200, out_c=256) 125 | self.id1 = DecoderBlock(in_c=344, skip_c=200, out_c=256) 126 | self.g1 = GBlock(in_c=256, out_c=256) 127 | 128 | self.ld2 = DecoderBlock(in_c=256, skip_c=72, out_c=128) 129 | self.hd2 = DecoderBlock(in_c=256, skip_c=72, out_c=128) 130 | self.id2 = DecoderBlock(in_c=256, skip_c=72, out_c=128) 131 | self.g2 = GBlock(in_c=128, out_c=128) 132 | 133 | self.ld3 = DecoderBlock(in_c=128, skip_c=40, out_c=64) 134 | self.hd3 = DecoderBlock(in_c=128, skip_c=40, out_c=64) 135 | self.id3 = DecoderBlock(in_c=128, skip_c=40, out_c=64) 136 | self.g3 = GBlock(in_c=64, out_c=64) 137 | 138 | self.ld4 = DecoderBlock(in_c=64, skip_c=32, out_c=32) 139 | self.hd4 = DecoderBlock(in_c=64, skip_c=32, out_c=32) 140 | self.id4 = DecoderBlock(in_c=64, skip_c=32, out_c=32) 141 | self.g4 = GBlock(in_c=32, out_c=32) 142 | 143 | self.ld5 = nn.Sequential( nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1), 144 | nn.Conv2d(32, 1, 3, 1, 1), 145 | nn.Sigmoid() 146 | )#ResidualBlock(in_c=32, out_c=3, last=True) 147 | self.hd5 = nn.Sequential( nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1), 148 | nn.Conv2d(32, 1, 3, 1, 1), 149 | nn.Tanh() 150 | )#ResidualBlock(in_c=32, out_c=1, last=True) 151 | self.id5 = nn.Sequential( nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1), 152 | nn.Conv2d(32, 3, 3, 1, 1), 153 | nn.Tanh() 154 | )#ResidualBlock(in_c=32, out_c=3, last=True) 155 | 156 | def forward(self, x): 157 | 158 | x1, x2, x3, x4, x5 = self.backbone(x) 159 | 160 | """block 1""" 161 | l1 = self.ld1(x5,x4) 162 | h1 = self.hd1(x5,x4) 163 | i1 = self.id1(x5,x4) 164 | i1 = self.g1(l1,h1,i1) 165 | 166 | """block 2""" 167 | l2 = self.ld2(l1,x3) 168 | h2 = self.hd2(h1,x3) 169 | i2 = self.id2(i1,x3) 170 | i2 = self.g2(l2,h2,i2) 171 | 172 | """block 3""" 173 | l3 = self.ld3(l2,x2) 174 | h3 = self.hd3(h2,x2) 175 | i3 = self.id3(i2,x2) 176 | i3 = self.g3(l3,h3,i3) 177 | 178 | """block 4""" 179 | l4 = self.ld4(l3,x1) 180 | h4 = self.hd4(h3,x1) 181 | i4 = self.id4(i3,x1) 182 | i4 = self.g4(l4,h4,i4) 183 | 184 | """block 5 [last block]""" 185 | l5 = self.ld5(l4) 186 | h5 = self.hd5(h4) 187 | i5 = self.id5(i4) 188 | 189 | return l5,h5,i5 --------------------------------------------------------------------------------