├── .gitignore ├── LICENSE ├── README.md ├── create_dataset.py ├── dataset.py ├── doc ├── Thumbs.db ├── back.png ├── dataset_example.png ├── epoch1.PNG ├── epoch10.PNG ├── epoch120.PNG ├── epoch30.PNG ├── epoch5.PNG ├── epoch50.PNG ├── model.png └── text.png ├── losses.py ├── modules.py ├── network.py ├── requirements.txt ├── results └── show │ └── Thumbs.db └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # etc 132 | *.png 133 | *.jpg 134 | *.pth 135 | 136 | ./doc/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ZeroAct 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scene Text Remover Pytorch Implementation 2 | 3 | This is a minimal implementation of [Scene text removal via cascaded text stroke detection and erasing](https://arxiv.org/pdf/2011.09768.pdf). This github repository is for studying on image in-painting for scene text erasing. Thank you :) 4 | 5 | 6 | 7 | ## Requirements 8 | 9 | Python 3.7 or later with all [requirements.txt](./requirements.txt) dependencies installed, including `torch>=1.6`. To install run: 10 | 11 | ``` 12 | $ pip install -r requirements.txt 13 | ``` 14 | 15 | 16 | 17 | ## Model Summary 18 | 19 | ![model architecture](./doc/model.png) 20 | 21 | This model has u-net sub modules. 22 | `Gd` detects text stroke image `Ms` with `I` and `M`. `G'd` detects more precise text stroke `M's`. 23 | Similarly, `Gr` generates text erased image `Ite`, and `G'r` generates more precise output `I'te`. 24 | 25 | 26 | 27 | ## Custom Dictionary 28 | 29 | Not to be confused, I renamed the names. 30 | 31 | `I` : Input Image (with text)
32 | `Mm` : Text area mask (`M` in the model)
33 | `Ms` : Text stroke mask; output of `Gd`
34 | `Ms_` : Text stroke mask; output of `G'd`
35 | `Msgt` : Text stroke mask ; ground truth
36 | `Ite` : Text erased image; output of `Gr`
37 | `Ite_` : Text erased image; output of `G'r`
38 | `Itegt`: Text erased image; ground truth
39 | 40 | 41 | 42 | ## Prepare Dataset 43 | 44 | You need to prepare background images in `backs` directory and text binary images in `font_mask` directory. 45 | 46 | ![background image, text image example](./doc/back.png) 47 | [part of background image sample, text binary image sample] 48 | 49 | Executing `python create_dataset.py` will automatically generate `I`, `Itegt`, `Mm`, `Msgt` data. 50 | (If you already have `I`, `Itegt`, `Mm`, `Msgt`, you can skip this section) 51 | 52 | ``` 53 | ├─dataset 54 | │ ├─backs 55 | │ │ # background images 56 | │ └─font_mask 57 | │ │ # text binary images 58 | │ └─train 59 | │ │ └─I 60 | │ │ └─Itegt 61 | │ │ └─Mm 62 | │ │ └─Msgt 63 | │ └─val 64 | │ └─I 65 | │ └─Itegt 66 | │ └─Mm 67 | │ └─Msgt 68 | ``` 69 | 70 | I generated my dataset with 709 background images and 2410 font mask. 71 | I used 17040 pairs for training and 4260 pairs for validation. 72 | 73 | ![](./doc/dataset_example.png) 74 | 75 | Thanks for helping me gathering background images [sina-Kim]([sina-Kim (github.com)](https://github.com/sina-Kim)). 76 | 77 | 78 | 79 | ## Train 80 | 81 | All you need to do is: 82 | 83 | ``` python 84 | python train.py 85 | ``` 86 | 87 | 88 | 89 | ## Result 90 | 91 | From the left 92 | `I`, `Itegt`, `Ite`, `Ite_`, `Msgt`, `Ms`, `Ms_` 93 | 94 | * Epoch 2
95 | ![](./doc/epoch1.PNG) 96 | * Epoch 5
97 | ![](./doc/epoch5.PNG) 98 | * Epoch 10
99 | ![](./doc/epoch10.PNG) 100 | * Epoch 30
101 | ![](./doc/epoch30.PNG) 102 | * Epoch 50
103 | ![](./doc/epoch50.PNG) 104 | * Epoch 120
105 | ![](./doc/epoch120.PNG) 106 | 107 | These are not good enough for real task. I think the reason is lack of dataset and simplicity. 108 | But, it was a good experience for me to implement the paper. 109 | 110 | 111 | 112 | ## Issue 113 | 114 | If you are having a trouble to run this code, please use issue tab. Thank you. 115 | 116 | -------------------------------------------------------------------------------- /create_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import random 5 | import progressbar 6 | 7 | import numpy as np 8 | 9 | import matplotlib.pyplot as plt 10 | 11 | rand_color = lambda : (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) 12 | rand_pos = lambda a, b: (random.randint(a, b-1), random.randint(a, b-1)) 13 | 14 | target_size = 256 15 | imgs_per_back = 30 16 | 17 | backs = glob.glob('./dataset/backs/*.png') 18 | fonts = glob.glob('./dataset/font_mask/*.png') 19 | 20 | os.makedirs('./dataset/train/I', exist_ok=True) 21 | os.makedirs('./dataset/train/Itegt', exist_ok=True) 22 | os.makedirs('./dataset/train/Mm', exist_ok=True) 23 | os.makedirs('./dataset/train/Msgt', exist_ok=True) 24 | 25 | os.makedirs('./dataset/val/I', exist_ok=True) 26 | os.makedirs('./dataset/val/Itegt', exist_ok=True) 27 | os.makedirs('./dataset/val/Mm', exist_ok=True) 28 | os.makedirs('./dataset/val/Msgt', exist_ok=True) 29 | 30 | t_idx = len(os.listdir('./dataset/train/I')) 31 | v_idx = len(os.listdir('./dataset/val/I')) 32 | 33 | bar = progressbar.ProgressBar(maxval=len(backs)*imgs_per_back) 34 | bar.start() 35 | for back in backs: 36 | back_img = cv2.imread(back) 37 | bh, bw, _ = back_img.shape 38 | if bh < target_size or bw < target_size: 39 | back_img = cv2.resize(back_img, (target_size, target_size), interpolation=cv2.INTER_CUBIC) 40 | bh, bw, _ = back_img.shape 41 | 42 | for bi in range(imgs_per_back): 43 | sx, sy = random.randint(0, bw-target_size), random.randint(0, bh-target_size) 44 | 45 | Itegt = back_img[sy:sy+target_size, sx:sx+target_size, :].copy() 46 | I = Itegt.copy() 47 | Mm = np.zeros_like(I) 48 | Msgt = np.zeros_like(I) 49 | 50 | hist = [] 51 | for font in random.sample(fonts, random.randint(2, 4)): 52 | font_img = cv2.imread(font) 53 | mask_img = np.ones_like(font_img, dtype=np.uint8)*255 54 | 55 | height, width, _ = font_img.shape 56 | 57 | angle = random.randint(-30, +30) 58 | fs = random.randint(90, 120) 59 | ratio = fs / height - 0.2 60 | 61 | matrix = cv2.getRotationMatrix2D((width/2, height/2), angle, ratio) 62 | font_rot = cv2.warpAffine(font_img, matrix, (width, height), cv2.INTER_CUBIC) 63 | mask_rot = cv2.warpAffine(mask_img, matrix, (width, height), cv2.INTER_CUBIC) 64 | 65 | h, w, _ = font_rot.shape 66 | 67 | font_in_I = np.zeros_like(I) 68 | mask_in_I = np.zeros_like(I) 69 | 70 | allow = 0 71 | while True: 72 | sx, sy = rand_pos(0, target_size-w) 73 | 74 | done = True 75 | for sx_, sy_ in hist: 76 | if (sx_ - sx)**2 + (sy_ - sy)**2 < (fs * ratio)**2 - allow: 77 | done = False 78 | break 79 | allow += 5 80 | 81 | if done: 82 | hist.append([sx, sy]) 83 | break 84 | 85 | font_in_I[sy:sy+h, sx:sx+w, :] = font_rot 86 | mask_in_I[sy:sy+h, sx:sx+w, :] = mask_rot 87 | 88 | font_in_I[font_in_I > 30] = 255 89 | mask_in_I[mask_in_I > 30] = 255 90 | 91 | I = cv2.bitwise_and(I, 255-font_in_I) 92 | I = cv2.bitwise_or(I, (font_in_I // 255 * rand_color()).astype(np.uint8)) 93 | 94 | Mm = cv2.bitwise_or(Mm, mask_in_I) 95 | Msgt = cv2.bitwise_or(Msgt, font_in_I) 96 | 97 | if bi < imgs_per_back*0.8: 98 | cv2.imwrite(f'dataset/train/I/{t_idx}.png', I) 99 | cv2.imwrite(f'dataset/train/Itegt/{t_idx}.png', Itegt) 100 | cv2.imwrite(f'dataset/train/Mm/{t_idx}.png', Mm) 101 | cv2.imwrite(f'dataset/train/Msgt/{t_idx}.png', Msgt) 102 | t_idx += 1 103 | else: 104 | cv2.imwrite(f'dataset/val/I/{v_idx}.png', I) 105 | cv2.imwrite(f'dataset/val/Itegt/{v_idx}.png', Itegt) 106 | cv2.imwrite(f'dataset/val/Mm/{v_idx}.png', Mm) 107 | cv2.imwrite(f'dataset/val/Msgt/{v_idx}.png', Msgt) 108 | v_idx += 1 109 | 110 | bar.update(t_idx + v_idx) 111 | bar.finish() -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os, cv2 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | def mat_to_tensor(mat): 8 | mat = mat.transpose((2, 0, 1)) 9 | tensor = torch.Tensor(mat) 10 | return tensor 11 | 12 | def tensor_to_mat(tensor): 13 | mat = tensor.detach().cpu().numpy() 14 | mat = mat.transpose((0, 2, 3, 1)) 15 | return mat 16 | 17 | def preprocess_image(img, target_shape: tuple): 18 | img = cv2.resize(img, target_shape, interpolation=cv2.INTER_CUBIC).astype(np.float32) 19 | img = img / 255. 20 | if len(img.shape) == 2: 21 | img = img.reshape(*img.shape, 1) 22 | 23 | return img 24 | 25 | def postprocess_image(img): 26 | img = img * 255 27 | img = np.clip(img, 0, 255) 28 | return img.astype(np.uint8) 29 | 30 | class CustomDataset(Dataset): 31 | def __init__(self, 32 | data_dir, 33 | set_name="train", 34 | target_size=(256, 256)): 35 | 36 | super().__init__() 37 | 38 | self.root_dir = os.path.join(data_dir, set_name) 39 | self.target_size = target_size 40 | 41 | self.I_dir = os.path.join(self.root_dir, "I") 42 | self.Itegt_dir = os.path.join(self.root_dir, "Itegt") 43 | self.Mm_dir = os.path.join(self.root_dir, "Mm") 44 | self.Msgt_dir = os.path.join(self.root_dir, "Msgt") 45 | 46 | self.datas = os.listdir(self.I_dir) 47 | 48 | def __len__(self): 49 | return len(self.datas) 50 | 51 | def __getitem__(self, idx): 52 | img_name = self.datas[idx] 53 | 54 | I = cv2.imread(os.path.join(self.I_dir, img_name)) 55 | Itegt = cv2.imread(os.path.join(self.Itegt_dir, img_name)) 56 | Mm = cv2.imread(os.path.join(self.Mm_dir, img_name), cv2.IMREAD_GRAYSCALE) 57 | Msgt = cv2.imread(os.path.join(self.Msgt_dir, img_name), cv2.IMREAD_GRAYSCALE) 58 | 59 | I = mat_to_tensor(preprocess_image(I, self.target_size)) 60 | Itegt = mat_to_tensor(preprocess_image(Itegt, self.target_size)) 61 | Mm = mat_to_tensor(preprocess_image(Mm, self.target_size)) 62 | Msgt = mat_to_tensor(preprocess_image(Msgt, self.target_size)) 63 | 64 | return I, Itegt, Mm, Msgt 65 | 66 | 67 | if __name__ == "__main__": 68 | ds = CustomDataset('dataset', 'train') 69 | 70 | I, Itegt, Mm, Ms = ds.__getitem__(0) 71 | print(f"Dataset length : {len(ds)}") 72 | print(f"I shape : {I.shape}") 73 | print(f"Itegt shape : {Itegt.shape}") 74 | print(f"Mm shape : {Mm.shape}") 75 | print(f"Ms shape : {Ms.shape}") 76 | -------------------------------------------------------------------------------- /doc/Thumbs.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/Thumbs.db -------------------------------------------------------------------------------- /doc/back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/back.png -------------------------------------------------------------------------------- /doc/dataset_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/dataset_example.png -------------------------------------------------------------------------------- /doc/epoch1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/epoch1.PNG -------------------------------------------------------------------------------- /doc/epoch10.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/epoch10.PNG -------------------------------------------------------------------------------- /doc/epoch120.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/epoch120.PNG -------------------------------------------------------------------------------- /doc/epoch30.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/epoch30.PNG -------------------------------------------------------------------------------- /doc/epoch5.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/epoch5.PNG -------------------------------------------------------------------------------- /doc/epoch50.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/epoch50.PNG -------------------------------------------------------------------------------- /doc/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/model.png -------------------------------------------------------------------------------- /doc/text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/text.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def TSDLoss(Mgt, Ms, Ms_, r=10): 6 | return torch.mean(torch.abs(Ms-Mgt) + r * torch.abs(Ms_-Mgt)) 7 | 8 | def TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_, rm=5, rs=5, rr=10): 9 | 10 | Mw = torch.ones_like(Mm) + rm * Mm + rs * Ms 11 | Mw_ = torch.ones_like(Mm) + rm * Mm + rs * Ms_ 12 | 13 | Ltrg = torch.mean(torch.abs(torch.mul(Ite, Mw) - torch.mul(Itegt, Mw)) + \ 14 | rr * torch.abs(torch.mul(Ite_, Mw_) - torch.mul(Itegt, Mw_))) 15 | 16 | return Ltrg 17 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # dis_conv 6 | # (https://github.com/JiahuiYu/generative_inpainting/blob/3a5324373ba52c68c79587ca183bc10b9e57b783/inpaint_ops.py#L84) 7 | class _dis_conv(nn.Module): 8 | 9 | def __init__(self, in_channels, out_channels, kernel_size=5, stride=2, padding=2): 10 | super().__init__() 11 | 12 | self._conv = nn.Sequential( 13 | nn.utils.spectral_norm( 14 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 15 | ), 16 | nn.LeakyReLU(inplace=True) 17 | ) 18 | 19 | # weight initialization 20 | def weight_init(m): 21 | if isinstance(m, nn.Conv2d): 22 | # nn.utils.spectral_norm(m.weight) 23 | nn.init.zeros_(m.bias) 24 | 25 | self.apply(weight_init) 26 | 27 | def forward(self, x): 28 | return self._conv(x) 29 | 30 | # weights are fixed to one, bias to zero 31 | class _one_conv(nn.Module): 32 | def __init__(self, in_channels, out_channels, kernel_size=5, stride=2, padding=2): 33 | super().__init__() 34 | 35 | self._conv = nn.Sequential( 36 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 37 | ) 38 | 39 | # weight initialization 40 | def weight_init(m): 41 | if isinstance(m, nn.Conv2d): 42 | nn.init.ones_(m.weight) 43 | nn.init.zeros_(m.bias) 44 | m.weight.requires_grad = False 45 | m.bias.requires_grad = False 46 | 47 | self.apply(weight_init) 48 | 49 | def forward(self, x): 50 | return self._conv(x) 51 | 52 | class _double_conv2d(nn.Module): 53 | 54 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, mid_channels=None): 55 | super().__init__() 56 | 57 | if not mid_channels: 58 | mid_channels = out_channels 59 | 60 | self.double_conv = nn.Sequential( 61 | nn.Conv2d(in_channels, mid_channels, kernel_size=kernel_size, padding=padding), 62 | nn.BatchNorm2d(mid_channels), 63 | nn.ReLU(inplace=True), 64 | 65 | nn.Conv2d(mid_channels, out_channels, kernel_size=kernel_size, padding=padding), 66 | nn.BatchNorm2d(out_channels), 67 | nn.ReLU(inplace=True) 68 | ) 69 | 70 | # weight initialization 71 | def weight_init(m): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu')) 74 | nn.init.zeros_(m.bias) 75 | 76 | self.apply(weight_init) 77 | 78 | def forward(self, x): 79 | return self.double_conv(x) 80 | 81 | 82 | class _down_conv2d(nn.Module): 83 | 84 | def __init__(self, 85 | in_channels, 86 | out_channels, 87 | kernel_size): 88 | 89 | super().__init__() 90 | 91 | self.seq_model = nn.Sequential( 92 | nn.MaxPool2d(2), 93 | _double_conv2d(in_channels, out_channels) 94 | ) 95 | 96 | 97 | def forward(self, x): 98 | return self.seq_model(x) 99 | 100 | 101 | class _up_conv2d(nn.Module): 102 | 103 | def __init__(self, 104 | in_channels, 105 | out_channels, 106 | kernel_size): 107 | 108 | super().__init__() 109 | 110 | self.conv_t = nn.ConvTranspose2d(in_channels, in_channels//2, 2, 2) 111 | self.conv = _double_conv2d(in_channels, out_channels) 112 | 113 | # x1 : input, x2 : matching down_conv2d output 114 | def forward(self, x1, x2): 115 | x1 = self.conv_t(x1) 116 | 117 | diffY = x2.size()[2] - x1.size()[2] 118 | diffX = x2.size()[3] - x1.size()[3] 119 | 120 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 121 | diffY // 2, diffY - diffY // 2]) 122 | 123 | x = torch.cat([x2, x1], dim=1) 124 | return self.conv(x) 125 | 126 | 127 | class _final_conv2d(nn.Module): 128 | 129 | def __init__(self, 130 | in_channels, 131 | out_channels, 132 | kernel_size): 133 | 134 | super().__init__() 135 | 136 | self.conv = nn.Conv2d(in_channels, out_channels, 1, 1) 137 | 138 | def forward(self, x): 139 | return self.conv(x) -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from modules import \ 6 | _double_conv2d, _down_conv2d, _up_conv2d, _final_conv2d, _dis_conv, _one_conv 7 | 8 | from losses import TSDLoss, TRGLoss 9 | 10 | # Text Stroke Detection (GD in paper) 11 | class TSDNet(nn.Module): 12 | def __init__(self): 13 | super(TSDNet, self).__init__() 14 | 15 | self.inc = _double_conv2d(4, 16, 3) 16 | self.down1 = _down_conv2d(16, 32, 3) 17 | self.down2 = _down_conv2d(32, 64, 3) 18 | self.down3 = _down_conv2d(64, 128, 3) 19 | 20 | self.up1 = _up_conv2d(128, 64, 3) 21 | self.up2 = _up_conv2d(64, 32, 3) 22 | self.up3 = _up_conv2d(32, 16, 3) 23 | 24 | self.outc = _final_conv2d(16, 1, 3) 25 | 26 | def forward(self, Igt, M): 27 | x = torch.cat([Igt, M], dim=1) 28 | x1 = self.inc(x) 29 | x2 = self.down1(x1) 30 | x3 = self.down2(x2) 31 | x4 = self.down3(x3) 32 | 33 | x = self.up1(x4, x3) 34 | x = self.up2(x, x2) 35 | x = self.up3(x, x1) 36 | 37 | M = self.outc(x) 38 | return M 39 | 40 | # Text Removal Generation (GR, GR' in paper) 41 | class TRGNet(nn.Module): 42 | def __init__(self): 43 | super(TRGNet, self).__init__() 44 | 45 | self.inc = _double_conv2d(5, 16, 5, 2) 46 | self.down1 = _down_conv2d(16, 32, 3) 47 | self.down2 = _down_conv2d(32, 64, 3) 48 | self.down3 = _down_conv2d(64, 128, 3) 49 | 50 | self.mid_layer = _double_conv2d(128, 128, 3) 51 | 52 | self.up1 = _up_conv2d(128, 64, 3) 53 | self.up2 = _up_conv2d(64, 32, 3) 54 | self.up3 = _up_conv2d(32, 16, 3) 55 | 56 | self.outc = _final_conv2d(16, 3, 3) 57 | 58 | def forward(self, Igt, M, Ms): 59 | x = torch.cat([Igt, M, Ms], dim=1) 60 | x1 = self.inc(x) 61 | 62 | x2 = self.down1(x1) 63 | x3 = self.down2(x2) 64 | x4 = self.down3(x3) 65 | 66 | x4 = torch.add(self.mid_layer(x4), x4) 67 | 68 | x = self.up1(x4, x3) 69 | x = self.up2(x, x2) 70 | x = self.up3(x, x1) 71 | 72 | M = self.outc(x) 73 | return M 74 | 75 | # Text Stroke Detection _ (G'D in paper) 76 | class TSDNet_(nn.Module): 77 | def __init__(self): 78 | super(TSDNet_, self).__init__() 79 | 80 | self.inc = _double_conv2d(5, 16, 3) 81 | self.down1 = _down_conv2d(16, 32, 3) 82 | self.down2 = _down_conv2d(32, 64, 3) 83 | self.down3 = _down_conv2d(64, 128, 3) 84 | 85 | self.up1 = _up_conv2d(128, 64, 3) 86 | self.up2 = _up_conv2d(64, 32, 3) 87 | self.up3 = _up_conv2d(32, 16, 3) 88 | 89 | self.outc = _final_conv2d(16, 1, 3) 90 | 91 | def forward(self, Ite, M, Ms): 92 | x = torch.cat([Ite, M, Ms], dim=1) 93 | x1 = self.inc(x) 94 | x2 = self.down1(x1) 95 | x3 = self.down2(x2) 96 | x4 = self.down3(x3) 97 | 98 | x = self.up1(x4, x3) 99 | x = self.up2(x, x2) 100 | x = self.up3(x, x1) 101 | 102 | M = self.outc(x) 103 | return M 104 | 105 | # weighted patch based discriminator (D, Dm in paper) 106 | # build_sn_patch_gan_discriminator 107 | # (https://github.com/JiahuiYu/generative_inpainting/blob/master/inpaint_model.py) 108 | class Discriminator(nn.Module): 109 | def __init__(self): 110 | 111 | super(Discriminator, self).__init__() 112 | 113 | self.Dm = nn.Sequential( 114 | _one_conv(1, 1, 5, 2, 2), 115 | nn.Sigmoid(), 116 | _one_conv(1, 1, 5, 2, 2), 117 | nn.Sigmoid(), 118 | _one_conv(1, 1, 5, 2, 2), 119 | nn.Sigmoid(), 120 | _one_conv(1, 1, 5, 2, 2), 121 | nn.Sigmoid(), 122 | _one_conv(1, 1, 5, 2, 2), 123 | nn.Sigmoid() 124 | ) 125 | 126 | self.D = nn.Sequential( 127 | _dis_conv(3, 64, 5, 2, 2), 128 | _dis_conv(64, 128, 5, 2, 2), 129 | _dis_conv(128, 256, 5, 2, 2), 130 | _dis_conv(256, 256, 5, 2, 2), 131 | _dis_conv(256, 256, 5, 2, 2) 132 | ) 133 | 134 | self.pool = nn.AvgPool2d(8) 135 | self.linear = nn.Linear(256, 1) 136 | # self.sigmoid = nn.Sigmoid() 137 | 138 | def forward(self, Mm, Ite_): 139 | mi = self.Dm(Mm) 140 | di = self.D(Ite_) 141 | 142 | y = torch.mul(mi, di) 143 | # y = self.pool(y) 144 | # y = self.linear(y.view(-1, 256)) 145 | return y 146 | 147 | 148 | class STRNet(nn.Module): 149 | def __init__(self): 150 | 151 | super(STRNet, self).__init__() 152 | 153 | self.tsdnet = TSDNet() 154 | self.trgnet = TRGNet() 155 | self.tsdnet_ = TSDNet_() 156 | self.trgnet_ = TRGNet() 157 | 158 | self.discrim = Discriminator() 159 | 160 | def forward(self, I, Mm): 161 | Ms = self.tsdnet(I, Mm) 162 | Ite = self.trgnet(I, Mm, Ms) 163 | Ms_ = self.tsdnet_(Ite, Mm, Ms) 164 | Ite_ = self.trgnet_(Ite, Mm, Ms) 165 | 166 | return Ms, Ite, Ms_, Ite_ 167 | 168 | 169 | if __name__ == "__main__": 170 | from torch.optim import Adam 171 | 172 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 173 | # device='cpu' 174 | print(device) 175 | 176 | # I : input image 177 | # Itegt: input image 178 | # M : Text Region mask 179 | # Ms : Text Stroke mask (from tsdnet) 180 | # 181 | 182 | I = torch.randn((2, 3, 256, 256)).to(device) 183 | print(f"I shape\n : {I.shape}") 184 | 185 | Itegt = torch.randn((2, 3, 256, 256)).to(device) 186 | print(f"Itegt shape\n : {Itegt.shape}") 187 | 188 | Mm = torch.randn((2, 1, 256, 256)).to(device) 189 | print(f"Mm shape\n : {Mm.shape}") 190 | 191 | Msgt = torch.randn((2, 1, 256, 256)).to(device) 192 | print(f"Mgt shape\n : {Msgt.shape}") 193 | 194 | One = torch.ones((2, 1)).to(device) 195 | Zero = torch.zeros((2, 1)).to(device) 196 | 197 | model = STRNet().to(device) 198 | 199 | model_optim = Adam(model.parameters(), 0.0001) 200 | discrim_optim = Adam(model.discrim.parameters(), 0.0001) 201 | bce_loss = nn.BCEWithLogitsLoss() 202 | 203 | Ms, Ite, Ms_, Ite_ = model.forward(I, Mm) 204 | 205 | Ltsd = TSDLoss(Msgt, Ms, Ms_) 206 | Ltrg = TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_) 207 | # Lgsn = -bce_loss(model.discrim(Mm, Ite_), One) 208 | Lgsn = -torch.mean(model.discrim(Mm, Ite_)) 209 | 210 | total_loss = Ltsd + Ltrg + Lgsn 211 | 212 | model_optim.zero_grad() 213 | total_loss.backward() 214 | model_optim.step() 215 | 216 | Ms, Ite, Ms_, Ite_ = model.forward(I, Mm) 217 | # Ldsn = F.relu(1-bce_loss(model.discrim(Mm, Itegt), One)) + \ 218 | # F.relu(1+bce_loss(model.discrim(Mm, Ite_), Zero)) 219 | Ldsn = torch.mean(F.relu(1-model.discrim(Mm, Itegt))) + \ 220 | torch.mean(F.relu(1+model.discrim(Mm, Ite_))) 221 | 222 | discrim_optim.zero_grad() 223 | Ldsn.backward() 224 | discrim_optim.step() 225 | 226 | 227 | # Igt = torch.randn((2, 3, 256, 256)).to(device) 228 | # print(f"Igt shape\n : {Igt.shape}") 229 | 230 | # M = torch.randn((2, 1, 256, 256)).to(device) 231 | # print(f"M shape\n : {M.shape}") 232 | 233 | # Mgt = torch.randn((2, 1, 256, 256)).to(device) 234 | # print(f"Mgt shape\n : {Mgt.shape}") 235 | 236 | # # models 237 | # tsdnet = TSDNet().to(device) 238 | # trgnet = TRGNet().to(device) 239 | # tsdnet_ = TSDNet_().to(device) 240 | # trgnet_ = TRGNet().to(device) 241 | 242 | # discriminator = Discriminator().to(device) 243 | 244 | # # optim 245 | # from torch.optim import Adam 246 | # total_optim = Adam(list(tsdnet.parameters()) + list(trgnet.parameters()) + 247 | # list(tsdnet_.parameters()) + list(trgnet_.parameters()), 0.0001) 248 | # total_optim.zero_grad() 249 | 250 | # discr_optim = Adam(discriminator.parameters()) 251 | # discr_optim.zero_grad() 252 | 253 | # # inference 254 | # Ms = tsdnet(Igt, M) 255 | # # print(f"tsdnet output Ms shape\n : {Ms.shape}") 256 | # Ite = trgnet(Igt, M, Ms) 257 | # # print(f"trgnet output Ite shape\n : {Ite.shape}") 258 | # Ms_ = tsdnet_(Ite, M, Ms) 259 | # # print(f"tsdnet_ output Ms_ shape\n : {Ms_.shape}") 260 | # Ite_ = trgnet_(Ite, M, Ms_) 261 | # # print(f"Final trgnet_ output Ite_ shape\n : {Ite.shape}") 262 | 263 | # # calculate loss 264 | # Lgsn = -discriminator.forward_with_loss(M, Ite) 265 | 266 | # from losses import TSDLoss, TRGLoss 267 | # Ltsd = TSDLoss(Mgt, Ms, Ms_) 268 | # Ltrg = TRGLoss(M, Ms, Ms_, Igt, Ite, Ite_) 269 | 270 | # total_loss = Ltsd + Ltrg + Lgsn 271 | 272 | # # train total model 273 | # total_loss.backward() 274 | # total_optim.step() 275 | # print(total_loss.detach().cpu().item()) 276 | 277 | # # train discriminator 278 | # Ms = tsdnet(Igt, M) 279 | # # print(f"tsdnet output Ms shape\n : {Ms.shape}") 280 | # Ite = trgnet(Igt, M, Ms) 281 | # # print(f"trgnet output Ite shape\n : {Ite.shape}") 282 | # Ms_ = tsdnet_(Ite, M, Ms) 283 | # # print(f"tsdnet_ output Ms_ shape\n : {Ms_.shape}") 284 | # Ite_ = trgnet_(Ite, M, Ms_) 285 | # # print(f"Final trgnet_ output Ite_ shape\n : {Ite.shape}") 286 | 287 | # Ldsn = discriminator.forward_with_loss(M, Igt) + discriminator.forward_with_loss(M, Ite_) 288 | # discr_loss = Ldsn 289 | 290 | # discr_loss.backward() 291 | # discr_optim.step() 292 | # print(discr_loss.detach().cpu().item()) 293 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | matplotlib 3 | numpy 4 | tqdm -------------------------------------------------------------------------------- /results/show/Thumbs.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/results/show/Thumbs.db -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os, argparse, time, tqdm, random, cv2 3 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | from torch.optim import Adam 12 | from torch import nn 13 | 14 | from dataset import CustomDataset, postprocess_image, tensor_to_mat 15 | from network import STRNet 16 | from losses import TSDLoss, TRGLoss 17 | 18 | random_seed = 123 19 | 20 | torch.manual_seed(random_seed) 21 | torch.cuda.manual_seed(random_seed) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = False 24 | np.random.seed(random_seed) 25 | random.seed(random_seed) 26 | 27 | def get_args(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("-d", "--data_path", default='dataset', help="data root path") 30 | 31 | parser.add_argument("-e", "--num_epochs", default=100, type=int, help="num epochs") 32 | parser.add_argument("-b", "--batch_size", default=16, type=int, help="batch size > 1") 33 | 34 | parser.add_argument("-n", "--num_workers", default=8, type=int, help="num_workers for DataLoader") 35 | parser.add_argument("-sn", "--show_num", default=4, type=int, help="show result images during training num") 36 | 37 | args = parser.parse_args() 38 | 39 | return args 40 | 41 | def load_weights_from_directory(model, weight_path) -> int: 42 | if weight_path.endswith('.pth'): 43 | wp = weight_path 44 | else: 45 | wps = sorted(os.listdir(weight_path), key=lambda x: int(x.split('_')[0])) 46 | if wps: 47 | wp = wps[-1] 48 | else: 49 | return 0 50 | 51 | print(f"Loading weights from {wp}...") 52 | model.load_state_dict(torch.load(os.path.join(weight_path, wp))) 53 | return int(wp.split('_')[0]) 54 | 55 | if __name__ == "__main__": 56 | 57 | args = get_args() 58 | 59 | ### Path 60 | model_path = "results" 61 | weight_path = os.path.join(model_path, "weights") 62 | show_path = os.path.join(model_path, "show") 63 | 64 | os.makedirs(model_path, exist_ok=True) 65 | os.makedirs(weight_path, exist_ok=True) 66 | os.makedirs(show_path, exist_ok=True) 67 | 68 | 69 | ### Hyperparameters 70 | epochs = args.num_epochs 71 | batch_size = args.batch_size 72 | if batch_size <= 1: 73 | raise "Batch size should bigger than 1 for batch normalization" 74 | 75 | num_workers = args.num_workers 76 | show_num = args.show_num 77 | 78 | ### DataLoader 79 | dataloader_params = {'batch_size': batch_size, 80 | 'shuffle': True, 81 | 'drop_last': True, 82 | 'num_workers': num_workers} 83 | 84 | train_data = CustomDataset(args.data_path, set_name="train") 85 | train_gen = DataLoader(train_data, **dataloader_params) 86 | 87 | dataloader_params = {'batch_size': 1, 88 | 'shuffle': True, 89 | 'drop_last': False, 90 | 'num_workers': num_workers} 91 | val_data = CustomDataset(args.data_path, set_name="val") 92 | val_gen = DataLoader(val_data, **dataloader_params) 93 | 94 | steps_per_epoch = len(train_gen) 95 | 96 | ### Model 97 | device = "cuda" if torch.cuda.is_available() else "cpu" 98 | print(f"Using {device}...") 99 | 100 | model = STRNet().to(device) 101 | 102 | # load best weight 103 | initial_epoch = load_weights_from_directory(model, weight_path) + 1 104 | print(f"Training start from epoch {initial_epoch}") 105 | 106 | # Train Setting 107 | model_optim = Adam(model.parameters(), 0.0001, (0.5, 0.9)) 108 | 109 | ### Train 110 | for epoch in range(initial_epoch, epochs): 111 | ## training 112 | train_loss = [] 113 | train_discrim_loss = [] 114 | 115 | model.train() 116 | pgbar = tqdm.tqdm(train_gen, total=len(train_gen)) 117 | pgbar.set_description(f"Epoch {epoch}/{epochs}") 118 | for I, Itegt, Mm, Msgt in pgbar: 119 | 120 | I, Itegt, Mm, Msgt = I.to(device), Itegt.to(device), Mm.to(device), Msgt.to(device) 121 | 122 | # train model 123 | Ms, Ite, Ms_, Ite_ = model.forward(I, Mm) 124 | 125 | Ltsd = TSDLoss(Msgt, Ms, Ms_) 126 | Ltrg = TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_) 127 | Lgsn = -torch.mean(model.discrim(Mm, Ite_)) 128 | 129 | total_loss = Ltsd + Ltrg + Lgsn 130 | 131 | model_optim.zero_grad() 132 | total_loss.backward() 133 | model_optim.step() 134 | 135 | # train discriminator 136 | Ms, Ite, Ms_, Ite_ = model.forward(I, Mm) 137 | Ldsn = torch.mean(F.relu(1-model.discrim(Mm, Itegt))) + \ 138 | torch.mean(F.relu(1+model.discrim(Mm, Ite_))) 139 | 140 | model_optim.zero_grad() 141 | Ldsn.backward() 142 | model_optim.step() 143 | 144 | ltsd = Ltsd.detach().cpu().item() 145 | ltrg = Ltrg.detach().cpu().item() 146 | lgsn = Lgsn.detach().cpu().item() 147 | train_loss.append(total_loss.detach().cpu().item()) 148 | train_discrim_loss.append(Ldsn.detach().cpu().item()) 149 | 150 | pgbar.set_postfix_str(f"total loss : {train_loss[-1]:.6f} ltsd : {ltsd:.6f} ltrg : {ltrg:.6f} lgsn : {lgsn:.6f} d_loss : {train_discrim_loss[-1]:.6f}") 151 | 152 | train_loss = sum(train_loss)/len(train_loss) 153 | 154 | ## validation 155 | val_loss = [] 156 | 157 | # will saved in show directory 158 | result_images = [] 159 | 160 | model.eval() 161 | pgbar = tqdm.tqdm(val_gen, total=len(val_gen)) 162 | pgbar.set_description("Validating...") 163 | for I, Itegt, Mm, Msgt in pgbar: 164 | 165 | I, Itegt, Mm, Msgt = I.to(device), Itegt.to(device), Mm.to(device), Msgt.to(device) 166 | 167 | # train model 168 | Ms, Ite, Ms_, Ite_ = model.forward(I, Mm) 169 | 170 | Ltsd = TSDLoss(Msgt, Ms, Ms_) 171 | Ltrg = TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_) 172 | Lgsn = -torch.mean(model.discrim(Mm, Ite_)) 173 | 174 | total_loss = Ltsd + Ltrg + Lgsn 175 | 176 | val_loss.append(total_loss.detach().cpu().item()) 177 | 178 | pgbar.set_postfix_str(f"loss : {sum(val_loss[-10:]) / len(val_loss[-10:]):.6f}") 179 | 180 | if len(result_images) < args.show_num: 181 | result_images.append([I.cpu(), Itegt.cpu(), Ite.cpu(), Ite_.cpu(), Msgt.cpu(), Ms.cpu(), Ms_.cpu()]) 182 | else: 183 | break 184 | 185 | val_loss = sum(val_loss) / len(val_loss) 186 | 187 | ## visualize 188 | fig, axs = plt.subplots(args.show_num, 1, figsize=(10, 2*args.show_num)) 189 | fig.suptitle("I, Itegt, Ite, Ite_, Msgt, Ms, Ms_]") 190 | for i, (I, Itegt, Ite, Ite_, Msgt, Ms, Ms_) in enumerate(result_images): 191 | 192 | I = postprocess_image(tensor_to_mat(I))[0] 193 | Itegt = postprocess_image(tensor_to_mat(Itegt))[0] 194 | Ite = postprocess_image(tensor_to_mat(Ite))[0] 195 | Ite_ = postprocess_image(tensor_to_mat(Ite_))[0] 196 | Msgt = postprocess_image(tensor_to_mat(Msgt))[0] 197 | Ms = postprocess_image(tensor_to_mat(Ms))[0] 198 | Ms_ = postprocess_image(tensor_to_mat(Ms_))[0] 199 | 200 | Msgt = cv2.cvtColor(Msgt, cv2.COLOR_GRAY2BGR) 201 | Ms = cv2.cvtColor(Ms, cv2.COLOR_GRAY2BGR) 202 | Ms_ = cv2.cvtColor(Ms_, cv2.COLOR_GRAY2BGR) 203 | 204 | axs[i].imshow(np.hstack([I, Itegt, Ite, Ite_, Msgt, Ms, Ms_])) 205 | axs[i].set_xticks([]) 206 | axs[i].set_yticks([]) 207 | 208 | fig.savefig(os.path.join(model_path, "show", f"epoch_{epoch}.png")) 209 | plt.close() 210 | 211 | print(f"train_loss : {train_loss}, val_loss : {val_loss}") 212 | print() 213 | time.sleep(0.2) 214 | 215 | torch.save(model.state_dict(), os.path.join(weight_path, f"{epoch}_train_{train_loss}_val_{val_loss}.pth")) 216 | --------------------------------------------------------------------------------