├── .gitignore
├── LICENSE
├── README.md
├── create_dataset.py
├── dataset.py
├── doc
├── Thumbs.db
├── back.png
├── dataset_example.png
├── epoch1.PNG
├── epoch10.PNG
├── epoch120.PNG
├── epoch30.PNG
├── epoch5.PNG
├── epoch50.PNG
├── model.png
└── text.png
├── losses.py
├── modules.py
├── network.py
├── requirements.txt
├── results
└── show
│ └── Thumbs.db
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # etc
132 | *.png
133 | *.jpg
134 | *.pth
135 |
136 | ./doc/*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 ZeroAct
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Scene Text Remover Pytorch Implementation
2 |
3 | This is a minimal implementation of [Scene text removal via cascaded text stroke detection and erasing](https://arxiv.org/pdf/2011.09768.pdf). This github repository is for studying on image in-painting for scene text erasing. Thank you :)
4 |
5 |
6 |
7 | ## Requirements
8 |
9 | Python 3.7 or later with all [requirements.txt](./requirements.txt) dependencies installed, including `torch>=1.6`. To install run:
10 |
11 | ```
12 | $ pip install -r requirements.txt
13 | ```
14 |
15 |
16 |
17 | ## Model Summary
18 |
19 | 
20 |
21 | This model has u-net sub modules.
22 | `Gd` detects text stroke image `Ms` with `I` and `M`. `G'd` detects more precise text stroke `M's`.
23 | Similarly, `Gr` generates text erased image `Ite`, and `G'r` generates more precise output `I'te`.
24 |
25 |
26 |
27 | ## Custom Dictionary
28 |
29 | Not to be confused, I renamed the names.
30 |
31 | `I` : Input Image (with text)
32 | `Mm` : Text area mask (`M` in the model)
33 | `Ms` : Text stroke mask; output of `Gd`
34 | `Ms_` : Text stroke mask; output of `G'd`
35 | `Msgt` : Text stroke mask ; ground truth
36 | `Ite` : Text erased image; output of `Gr`
37 | `Ite_` : Text erased image; output of `G'r`
38 | `Itegt`: Text erased image; ground truth
39 |
40 |
41 |
42 | ## Prepare Dataset
43 |
44 | You need to prepare background images in `backs` directory and text binary images in `font_mask` directory.
45 |
46 | 
47 | [part of background image sample, text binary image sample]
48 |
49 | Executing `python create_dataset.py` will automatically generate `I`, `Itegt`, `Mm`, `Msgt` data.
50 | (If you already have `I`, `Itegt`, `Mm`, `Msgt`, you can skip this section)
51 |
52 | ```
53 | ├─dataset
54 | │ ├─backs
55 | │ │ # background images
56 | │ └─font_mask
57 | │ │ # text binary images
58 | │ └─train
59 | │ │ └─I
60 | │ │ └─Itegt
61 | │ │ └─Mm
62 | │ │ └─Msgt
63 | │ └─val
64 | │ └─I
65 | │ └─Itegt
66 | │ └─Mm
67 | │ └─Msgt
68 | ```
69 |
70 | I generated my dataset with 709 background images and 2410 font mask.
71 | I used 17040 pairs for training and 4260 pairs for validation.
72 |
73 | 
74 |
75 | Thanks for helping me gathering background images [sina-Kim]([sina-Kim (github.com)](https://github.com/sina-Kim)).
76 |
77 |
78 |
79 | ## Train
80 |
81 | All you need to do is:
82 |
83 | ``` python
84 | python train.py
85 | ```
86 |
87 |
88 |
89 | ## Result
90 |
91 | From the left
92 | `I`, `Itegt`, `Ite`, `Ite_`, `Msgt`, `Ms`, `Ms_`
93 |
94 | * Epoch 2
95 | 
96 | * Epoch 5
97 | 
98 | * Epoch 10
99 | 
100 | * Epoch 30
101 | 
102 | * Epoch 50
103 | 
104 | * Epoch 120
105 | 
106 |
107 | These are not good enough for real task. I think the reason is lack of dataset and simplicity.
108 | But, it was a good experience for me to implement the paper.
109 |
110 |
111 |
112 | ## Issue
113 |
114 | If you are having a trouble to run this code, please use issue tab. Thank you.
115 |
116 |
--------------------------------------------------------------------------------
/create_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import glob
4 | import random
5 | import progressbar
6 |
7 | import numpy as np
8 |
9 | import matplotlib.pyplot as plt
10 |
11 | rand_color = lambda : (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
12 | rand_pos = lambda a, b: (random.randint(a, b-1), random.randint(a, b-1))
13 |
14 | target_size = 256
15 | imgs_per_back = 30
16 |
17 | backs = glob.glob('./dataset/backs/*.png')
18 | fonts = glob.glob('./dataset/font_mask/*.png')
19 |
20 | os.makedirs('./dataset/train/I', exist_ok=True)
21 | os.makedirs('./dataset/train/Itegt', exist_ok=True)
22 | os.makedirs('./dataset/train/Mm', exist_ok=True)
23 | os.makedirs('./dataset/train/Msgt', exist_ok=True)
24 |
25 | os.makedirs('./dataset/val/I', exist_ok=True)
26 | os.makedirs('./dataset/val/Itegt', exist_ok=True)
27 | os.makedirs('./dataset/val/Mm', exist_ok=True)
28 | os.makedirs('./dataset/val/Msgt', exist_ok=True)
29 |
30 | t_idx = len(os.listdir('./dataset/train/I'))
31 | v_idx = len(os.listdir('./dataset/val/I'))
32 |
33 | bar = progressbar.ProgressBar(maxval=len(backs)*imgs_per_back)
34 | bar.start()
35 | for back in backs:
36 | back_img = cv2.imread(back)
37 | bh, bw, _ = back_img.shape
38 | if bh < target_size or bw < target_size:
39 | back_img = cv2.resize(back_img, (target_size, target_size), interpolation=cv2.INTER_CUBIC)
40 | bh, bw, _ = back_img.shape
41 |
42 | for bi in range(imgs_per_back):
43 | sx, sy = random.randint(0, bw-target_size), random.randint(0, bh-target_size)
44 |
45 | Itegt = back_img[sy:sy+target_size, sx:sx+target_size, :].copy()
46 | I = Itegt.copy()
47 | Mm = np.zeros_like(I)
48 | Msgt = np.zeros_like(I)
49 |
50 | hist = []
51 | for font in random.sample(fonts, random.randint(2, 4)):
52 | font_img = cv2.imread(font)
53 | mask_img = np.ones_like(font_img, dtype=np.uint8)*255
54 |
55 | height, width, _ = font_img.shape
56 |
57 | angle = random.randint(-30, +30)
58 | fs = random.randint(90, 120)
59 | ratio = fs / height - 0.2
60 |
61 | matrix = cv2.getRotationMatrix2D((width/2, height/2), angle, ratio)
62 | font_rot = cv2.warpAffine(font_img, matrix, (width, height), cv2.INTER_CUBIC)
63 | mask_rot = cv2.warpAffine(mask_img, matrix, (width, height), cv2.INTER_CUBIC)
64 |
65 | h, w, _ = font_rot.shape
66 |
67 | font_in_I = np.zeros_like(I)
68 | mask_in_I = np.zeros_like(I)
69 |
70 | allow = 0
71 | while True:
72 | sx, sy = rand_pos(0, target_size-w)
73 |
74 | done = True
75 | for sx_, sy_ in hist:
76 | if (sx_ - sx)**2 + (sy_ - sy)**2 < (fs * ratio)**2 - allow:
77 | done = False
78 | break
79 | allow += 5
80 |
81 | if done:
82 | hist.append([sx, sy])
83 | break
84 |
85 | font_in_I[sy:sy+h, sx:sx+w, :] = font_rot
86 | mask_in_I[sy:sy+h, sx:sx+w, :] = mask_rot
87 |
88 | font_in_I[font_in_I > 30] = 255
89 | mask_in_I[mask_in_I > 30] = 255
90 |
91 | I = cv2.bitwise_and(I, 255-font_in_I)
92 | I = cv2.bitwise_or(I, (font_in_I // 255 * rand_color()).astype(np.uint8))
93 |
94 | Mm = cv2.bitwise_or(Mm, mask_in_I)
95 | Msgt = cv2.bitwise_or(Msgt, font_in_I)
96 |
97 | if bi < imgs_per_back*0.8:
98 | cv2.imwrite(f'dataset/train/I/{t_idx}.png', I)
99 | cv2.imwrite(f'dataset/train/Itegt/{t_idx}.png', Itegt)
100 | cv2.imwrite(f'dataset/train/Mm/{t_idx}.png', Mm)
101 | cv2.imwrite(f'dataset/train/Msgt/{t_idx}.png', Msgt)
102 | t_idx += 1
103 | else:
104 | cv2.imwrite(f'dataset/val/I/{v_idx}.png', I)
105 | cv2.imwrite(f'dataset/val/Itegt/{v_idx}.png', Itegt)
106 | cv2.imwrite(f'dataset/val/Mm/{v_idx}.png', Mm)
107 | cv2.imwrite(f'dataset/val/Msgt/{v_idx}.png', Msgt)
108 | v_idx += 1
109 |
110 | bar.update(t_idx + v_idx)
111 | bar.finish()
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os, cv2
2 | import numpy as np
3 |
4 | import torch
5 | from torch.utils.data import Dataset
6 |
7 | def mat_to_tensor(mat):
8 | mat = mat.transpose((2, 0, 1))
9 | tensor = torch.Tensor(mat)
10 | return tensor
11 |
12 | def tensor_to_mat(tensor):
13 | mat = tensor.detach().cpu().numpy()
14 | mat = mat.transpose((0, 2, 3, 1))
15 | return mat
16 |
17 | def preprocess_image(img, target_shape: tuple):
18 | img = cv2.resize(img, target_shape, interpolation=cv2.INTER_CUBIC).astype(np.float32)
19 | img = img / 255.
20 | if len(img.shape) == 2:
21 | img = img.reshape(*img.shape, 1)
22 |
23 | return img
24 |
25 | def postprocess_image(img):
26 | img = img * 255
27 | img = np.clip(img, 0, 255)
28 | return img.astype(np.uint8)
29 |
30 | class CustomDataset(Dataset):
31 | def __init__(self,
32 | data_dir,
33 | set_name="train",
34 | target_size=(256, 256)):
35 |
36 | super().__init__()
37 |
38 | self.root_dir = os.path.join(data_dir, set_name)
39 | self.target_size = target_size
40 |
41 | self.I_dir = os.path.join(self.root_dir, "I")
42 | self.Itegt_dir = os.path.join(self.root_dir, "Itegt")
43 | self.Mm_dir = os.path.join(self.root_dir, "Mm")
44 | self.Msgt_dir = os.path.join(self.root_dir, "Msgt")
45 |
46 | self.datas = os.listdir(self.I_dir)
47 |
48 | def __len__(self):
49 | return len(self.datas)
50 |
51 | def __getitem__(self, idx):
52 | img_name = self.datas[idx]
53 |
54 | I = cv2.imread(os.path.join(self.I_dir, img_name))
55 | Itegt = cv2.imread(os.path.join(self.Itegt_dir, img_name))
56 | Mm = cv2.imread(os.path.join(self.Mm_dir, img_name), cv2.IMREAD_GRAYSCALE)
57 | Msgt = cv2.imread(os.path.join(self.Msgt_dir, img_name), cv2.IMREAD_GRAYSCALE)
58 |
59 | I = mat_to_tensor(preprocess_image(I, self.target_size))
60 | Itegt = mat_to_tensor(preprocess_image(Itegt, self.target_size))
61 | Mm = mat_to_tensor(preprocess_image(Mm, self.target_size))
62 | Msgt = mat_to_tensor(preprocess_image(Msgt, self.target_size))
63 |
64 | return I, Itegt, Mm, Msgt
65 |
66 |
67 | if __name__ == "__main__":
68 | ds = CustomDataset('dataset', 'train')
69 |
70 | I, Itegt, Mm, Ms = ds.__getitem__(0)
71 | print(f"Dataset length : {len(ds)}")
72 | print(f"I shape : {I.shape}")
73 | print(f"Itegt shape : {Itegt.shape}")
74 | print(f"Mm shape : {Mm.shape}")
75 | print(f"Ms shape : {Ms.shape}")
76 |
--------------------------------------------------------------------------------
/doc/Thumbs.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/Thumbs.db
--------------------------------------------------------------------------------
/doc/back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/back.png
--------------------------------------------------------------------------------
/doc/dataset_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/dataset_example.png
--------------------------------------------------------------------------------
/doc/epoch1.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/epoch1.PNG
--------------------------------------------------------------------------------
/doc/epoch10.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/epoch10.PNG
--------------------------------------------------------------------------------
/doc/epoch120.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/epoch120.PNG
--------------------------------------------------------------------------------
/doc/epoch30.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/epoch30.PNG
--------------------------------------------------------------------------------
/doc/epoch5.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/epoch5.PNG
--------------------------------------------------------------------------------
/doc/epoch50.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/epoch50.PNG
--------------------------------------------------------------------------------
/doc/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/model.png
--------------------------------------------------------------------------------
/doc/text.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/doc/text.png
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | def TSDLoss(Mgt, Ms, Ms_, r=10):
6 | return torch.mean(torch.abs(Ms-Mgt) + r * torch.abs(Ms_-Mgt))
7 |
8 | def TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_, rm=5, rs=5, rr=10):
9 |
10 | Mw = torch.ones_like(Mm) + rm * Mm + rs * Ms
11 | Mw_ = torch.ones_like(Mm) + rm * Mm + rs * Ms_
12 |
13 | Ltrg = torch.mean(torch.abs(torch.mul(Ite, Mw) - torch.mul(Itegt, Mw)) + \
14 | rr * torch.abs(torch.mul(Ite_, Mw_) - torch.mul(Itegt, Mw_)))
15 |
16 | return Ltrg
17 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | # dis_conv
6 | # (https://github.com/JiahuiYu/generative_inpainting/blob/3a5324373ba52c68c79587ca183bc10b9e57b783/inpaint_ops.py#L84)
7 | class _dis_conv(nn.Module):
8 |
9 | def __init__(self, in_channels, out_channels, kernel_size=5, stride=2, padding=2):
10 | super().__init__()
11 |
12 | self._conv = nn.Sequential(
13 | nn.utils.spectral_norm(
14 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
15 | ),
16 | nn.LeakyReLU(inplace=True)
17 | )
18 |
19 | # weight initialization
20 | def weight_init(m):
21 | if isinstance(m, nn.Conv2d):
22 | # nn.utils.spectral_norm(m.weight)
23 | nn.init.zeros_(m.bias)
24 |
25 | self.apply(weight_init)
26 |
27 | def forward(self, x):
28 | return self._conv(x)
29 |
30 | # weights are fixed to one, bias to zero
31 | class _one_conv(nn.Module):
32 | def __init__(self, in_channels, out_channels, kernel_size=5, stride=2, padding=2):
33 | super().__init__()
34 |
35 | self._conv = nn.Sequential(
36 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
37 | )
38 |
39 | # weight initialization
40 | def weight_init(m):
41 | if isinstance(m, nn.Conv2d):
42 | nn.init.ones_(m.weight)
43 | nn.init.zeros_(m.bias)
44 | m.weight.requires_grad = False
45 | m.bias.requires_grad = False
46 |
47 | self.apply(weight_init)
48 |
49 | def forward(self, x):
50 | return self._conv(x)
51 |
52 | class _double_conv2d(nn.Module):
53 |
54 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, mid_channels=None):
55 | super().__init__()
56 |
57 | if not mid_channels:
58 | mid_channels = out_channels
59 |
60 | self.double_conv = nn.Sequential(
61 | nn.Conv2d(in_channels, mid_channels, kernel_size=kernel_size, padding=padding),
62 | nn.BatchNorm2d(mid_channels),
63 | nn.ReLU(inplace=True),
64 |
65 | nn.Conv2d(mid_channels, out_channels, kernel_size=kernel_size, padding=padding),
66 | nn.BatchNorm2d(out_channels),
67 | nn.ReLU(inplace=True)
68 | )
69 |
70 | # weight initialization
71 | def weight_init(m):
72 | if isinstance(m, nn.Conv2d):
73 | nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))
74 | nn.init.zeros_(m.bias)
75 |
76 | self.apply(weight_init)
77 |
78 | def forward(self, x):
79 | return self.double_conv(x)
80 |
81 |
82 | class _down_conv2d(nn.Module):
83 |
84 | def __init__(self,
85 | in_channels,
86 | out_channels,
87 | kernel_size):
88 |
89 | super().__init__()
90 |
91 | self.seq_model = nn.Sequential(
92 | nn.MaxPool2d(2),
93 | _double_conv2d(in_channels, out_channels)
94 | )
95 |
96 |
97 | def forward(self, x):
98 | return self.seq_model(x)
99 |
100 |
101 | class _up_conv2d(nn.Module):
102 |
103 | def __init__(self,
104 | in_channels,
105 | out_channels,
106 | kernel_size):
107 |
108 | super().__init__()
109 |
110 | self.conv_t = nn.ConvTranspose2d(in_channels, in_channels//2, 2, 2)
111 | self.conv = _double_conv2d(in_channels, out_channels)
112 |
113 | # x1 : input, x2 : matching down_conv2d output
114 | def forward(self, x1, x2):
115 | x1 = self.conv_t(x1)
116 |
117 | diffY = x2.size()[2] - x1.size()[2]
118 | diffX = x2.size()[3] - x1.size()[3]
119 |
120 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
121 | diffY // 2, diffY - diffY // 2])
122 |
123 | x = torch.cat([x2, x1], dim=1)
124 | return self.conv(x)
125 |
126 |
127 | class _final_conv2d(nn.Module):
128 |
129 | def __init__(self,
130 | in_channels,
131 | out_channels,
132 | kernel_size):
133 |
134 | super().__init__()
135 |
136 | self.conv = nn.Conv2d(in_channels, out_channels, 1, 1)
137 |
138 | def forward(self, x):
139 | return self.conv(x)
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from modules import \
6 | _double_conv2d, _down_conv2d, _up_conv2d, _final_conv2d, _dis_conv, _one_conv
7 |
8 | from losses import TSDLoss, TRGLoss
9 |
10 | # Text Stroke Detection (GD in paper)
11 | class TSDNet(nn.Module):
12 | def __init__(self):
13 | super(TSDNet, self).__init__()
14 |
15 | self.inc = _double_conv2d(4, 16, 3)
16 | self.down1 = _down_conv2d(16, 32, 3)
17 | self.down2 = _down_conv2d(32, 64, 3)
18 | self.down3 = _down_conv2d(64, 128, 3)
19 |
20 | self.up1 = _up_conv2d(128, 64, 3)
21 | self.up2 = _up_conv2d(64, 32, 3)
22 | self.up3 = _up_conv2d(32, 16, 3)
23 |
24 | self.outc = _final_conv2d(16, 1, 3)
25 |
26 | def forward(self, Igt, M):
27 | x = torch.cat([Igt, M], dim=1)
28 | x1 = self.inc(x)
29 | x2 = self.down1(x1)
30 | x3 = self.down2(x2)
31 | x4 = self.down3(x3)
32 |
33 | x = self.up1(x4, x3)
34 | x = self.up2(x, x2)
35 | x = self.up3(x, x1)
36 |
37 | M = self.outc(x)
38 | return M
39 |
40 | # Text Removal Generation (GR, GR' in paper)
41 | class TRGNet(nn.Module):
42 | def __init__(self):
43 | super(TRGNet, self).__init__()
44 |
45 | self.inc = _double_conv2d(5, 16, 5, 2)
46 | self.down1 = _down_conv2d(16, 32, 3)
47 | self.down2 = _down_conv2d(32, 64, 3)
48 | self.down3 = _down_conv2d(64, 128, 3)
49 |
50 | self.mid_layer = _double_conv2d(128, 128, 3)
51 |
52 | self.up1 = _up_conv2d(128, 64, 3)
53 | self.up2 = _up_conv2d(64, 32, 3)
54 | self.up3 = _up_conv2d(32, 16, 3)
55 |
56 | self.outc = _final_conv2d(16, 3, 3)
57 |
58 | def forward(self, Igt, M, Ms):
59 | x = torch.cat([Igt, M, Ms], dim=1)
60 | x1 = self.inc(x)
61 |
62 | x2 = self.down1(x1)
63 | x3 = self.down2(x2)
64 | x4 = self.down3(x3)
65 |
66 | x4 = torch.add(self.mid_layer(x4), x4)
67 |
68 | x = self.up1(x4, x3)
69 | x = self.up2(x, x2)
70 | x = self.up3(x, x1)
71 |
72 | M = self.outc(x)
73 | return M
74 |
75 | # Text Stroke Detection _ (G'D in paper)
76 | class TSDNet_(nn.Module):
77 | def __init__(self):
78 | super(TSDNet_, self).__init__()
79 |
80 | self.inc = _double_conv2d(5, 16, 3)
81 | self.down1 = _down_conv2d(16, 32, 3)
82 | self.down2 = _down_conv2d(32, 64, 3)
83 | self.down3 = _down_conv2d(64, 128, 3)
84 |
85 | self.up1 = _up_conv2d(128, 64, 3)
86 | self.up2 = _up_conv2d(64, 32, 3)
87 | self.up3 = _up_conv2d(32, 16, 3)
88 |
89 | self.outc = _final_conv2d(16, 1, 3)
90 |
91 | def forward(self, Ite, M, Ms):
92 | x = torch.cat([Ite, M, Ms], dim=1)
93 | x1 = self.inc(x)
94 | x2 = self.down1(x1)
95 | x3 = self.down2(x2)
96 | x4 = self.down3(x3)
97 |
98 | x = self.up1(x4, x3)
99 | x = self.up2(x, x2)
100 | x = self.up3(x, x1)
101 |
102 | M = self.outc(x)
103 | return M
104 |
105 | # weighted patch based discriminator (D, Dm in paper)
106 | # build_sn_patch_gan_discriminator
107 | # (https://github.com/JiahuiYu/generative_inpainting/blob/master/inpaint_model.py)
108 | class Discriminator(nn.Module):
109 | def __init__(self):
110 |
111 | super(Discriminator, self).__init__()
112 |
113 | self.Dm = nn.Sequential(
114 | _one_conv(1, 1, 5, 2, 2),
115 | nn.Sigmoid(),
116 | _one_conv(1, 1, 5, 2, 2),
117 | nn.Sigmoid(),
118 | _one_conv(1, 1, 5, 2, 2),
119 | nn.Sigmoid(),
120 | _one_conv(1, 1, 5, 2, 2),
121 | nn.Sigmoid(),
122 | _one_conv(1, 1, 5, 2, 2),
123 | nn.Sigmoid()
124 | )
125 |
126 | self.D = nn.Sequential(
127 | _dis_conv(3, 64, 5, 2, 2),
128 | _dis_conv(64, 128, 5, 2, 2),
129 | _dis_conv(128, 256, 5, 2, 2),
130 | _dis_conv(256, 256, 5, 2, 2),
131 | _dis_conv(256, 256, 5, 2, 2)
132 | )
133 |
134 | self.pool = nn.AvgPool2d(8)
135 | self.linear = nn.Linear(256, 1)
136 | # self.sigmoid = nn.Sigmoid()
137 |
138 | def forward(self, Mm, Ite_):
139 | mi = self.Dm(Mm)
140 | di = self.D(Ite_)
141 |
142 | y = torch.mul(mi, di)
143 | # y = self.pool(y)
144 | # y = self.linear(y.view(-1, 256))
145 | return y
146 |
147 |
148 | class STRNet(nn.Module):
149 | def __init__(self):
150 |
151 | super(STRNet, self).__init__()
152 |
153 | self.tsdnet = TSDNet()
154 | self.trgnet = TRGNet()
155 | self.tsdnet_ = TSDNet_()
156 | self.trgnet_ = TRGNet()
157 |
158 | self.discrim = Discriminator()
159 |
160 | def forward(self, I, Mm):
161 | Ms = self.tsdnet(I, Mm)
162 | Ite = self.trgnet(I, Mm, Ms)
163 | Ms_ = self.tsdnet_(Ite, Mm, Ms)
164 | Ite_ = self.trgnet_(Ite, Mm, Ms)
165 |
166 | return Ms, Ite, Ms_, Ite_
167 |
168 |
169 | if __name__ == "__main__":
170 | from torch.optim import Adam
171 |
172 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
173 | # device='cpu'
174 | print(device)
175 |
176 | # I : input image
177 | # Itegt: input image
178 | # M : Text Region mask
179 | # Ms : Text Stroke mask (from tsdnet)
180 | #
181 |
182 | I = torch.randn((2, 3, 256, 256)).to(device)
183 | print(f"I shape\n : {I.shape}")
184 |
185 | Itegt = torch.randn((2, 3, 256, 256)).to(device)
186 | print(f"Itegt shape\n : {Itegt.shape}")
187 |
188 | Mm = torch.randn((2, 1, 256, 256)).to(device)
189 | print(f"Mm shape\n : {Mm.shape}")
190 |
191 | Msgt = torch.randn((2, 1, 256, 256)).to(device)
192 | print(f"Mgt shape\n : {Msgt.shape}")
193 |
194 | One = torch.ones((2, 1)).to(device)
195 | Zero = torch.zeros((2, 1)).to(device)
196 |
197 | model = STRNet().to(device)
198 |
199 | model_optim = Adam(model.parameters(), 0.0001)
200 | discrim_optim = Adam(model.discrim.parameters(), 0.0001)
201 | bce_loss = nn.BCEWithLogitsLoss()
202 |
203 | Ms, Ite, Ms_, Ite_ = model.forward(I, Mm)
204 |
205 | Ltsd = TSDLoss(Msgt, Ms, Ms_)
206 | Ltrg = TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_)
207 | # Lgsn = -bce_loss(model.discrim(Mm, Ite_), One)
208 | Lgsn = -torch.mean(model.discrim(Mm, Ite_))
209 |
210 | total_loss = Ltsd + Ltrg + Lgsn
211 |
212 | model_optim.zero_grad()
213 | total_loss.backward()
214 | model_optim.step()
215 |
216 | Ms, Ite, Ms_, Ite_ = model.forward(I, Mm)
217 | # Ldsn = F.relu(1-bce_loss(model.discrim(Mm, Itegt), One)) + \
218 | # F.relu(1+bce_loss(model.discrim(Mm, Ite_), Zero))
219 | Ldsn = torch.mean(F.relu(1-model.discrim(Mm, Itegt))) + \
220 | torch.mean(F.relu(1+model.discrim(Mm, Ite_)))
221 |
222 | discrim_optim.zero_grad()
223 | Ldsn.backward()
224 | discrim_optim.step()
225 |
226 |
227 | # Igt = torch.randn((2, 3, 256, 256)).to(device)
228 | # print(f"Igt shape\n : {Igt.shape}")
229 |
230 | # M = torch.randn((2, 1, 256, 256)).to(device)
231 | # print(f"M shape\n : {M.shape}")
232 |
233 | # Mgt = torch.randn((2, 1, 256, 256)).to(device)
234 | # print(f"Mgt shape\n : {Mgt.shape}")
235 |
236 | # # models
237 | # tsdnet = TSDNet().to(device)
238 | # trgnet = TRGNet().to(device)
239 | # tsdnet_ = TSDNet_().to(device)
240 | # trgnet_ = TRGNet().to(device)
241 |
242 | # discriminator = Discriminator().to(device)
243 |
244 | # # optim
245 | # from torch.optim import Adam
246 | # total_optim = Adam(list(tsdnet.parameters()) + list(trgnet.parameters()) +
247 | # list(tsdnet_.parameters()) + list(trgnet_.parameters()), 0.0001)
248 | # total_optim.zero_grad()
249 |
250 | # discr_optim = Adam(discriminator.parameters())
251 | # discr_optim.zero_grad()
252 |
253 | # # inference
254 | # Ms = tsdnet(Igt, M)
255 | # # print(f"tsdnet output Ms shape\n : {Ms.shape}")
256 | # Ite = trgnet(Igt, M, Ms)
257 | # # print(f"trgnet output Ite shape\n : {Ite.shape}")
258 | # Ms_ = tsdnet_(Ite, M, Ms)
259 | # # print(f"tsdnet_ output Ms_ shape\n : {Ms_.shape}")
260 | # Ite_ = trgnet_(Ite, M, Ms_)
261 | # # print(f"Final trgnet_ output Ite_ shape\n : {Ite.shape}")
262 |
263 | # # calculate loss
264 | # Lgsn = -discriminator.forward_with_loss(M, Ite)
265 |
266 | # from losses import TSDLoss, TRGLoss
267 | # Ltsd = TSDLoss(Mgt, Ms, Ms_)
268 | # Ltrg = TRGLoss(M, Ms, Ms_, Igt, Ite, Ite_)
269 |
270 | # total_loss = Ltsd + Ltrg + Lgsn
271 |
272 | # # train total model
273 | # total_loss.backward()
274 | # total_optim.step()
275 | # print(total_loss.detach().cpu().item())
276 |
277 | # # train discriminator
278 | # Ms = tsdnet(Igt, M)
279 | # # print(f"tsdnet output Ms shape\n : {Ms.shape}")
280 | # Ite = trgnet(Igt, M, Ms)
281 | # # print(f"trgnet output Ite shape\n : {Ite.shape}")
282 | # Ms_ = tsdnet_(Ite, M, Ms)
283 | # # print(f"tsdnet_ output Ms_ shape\n : {Ms_.shape}")
284 | # Ite_ = trgnet_(Ite, M, Ms_)
285 | # # print(f"Final trgnet_ output Ite_ shape\n : {Ite.shape}")
286 |
287 | # Ldsn = discriminator.forward_with_loss(M, Igt) + discriminator.forward_with_loss(M, Ite_)
288 | # discr_loss = Ldsn
289 |
290 | # discr_loss.backward()
291 | # discr_optim.step()
292 | # print(discr_loss.detach().cpu().item())
293 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | matplotlib
3 | numpy
4 | tqdm
--------------------------------------------------------------------------------
/results/show/Thumbs.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZeroAct/SceneTextRemover-pytorch/788e49aa5a0b8d9004bb641c2538eca5f83e3c31/results/show/Thumbs.db
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os, argparse, time, tqdm, random, cv2
3 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.utils.data import DataLoader
11 | from torch.optim import Adam
12 | from torch import nn
13 |
14 | from dataset import CustomDataset, postprocess_image, tensor_to_mat
15 | from network import STRNet
16 | from losses import TSDLoss, TRGLoss
17 |
18 | random_seed = 123
19 |
20 | torch.manual_seed(random_seed)
21 | torch.cuda.manual_seed(random_seed)
22 | torch.backends.cudnn.deterministic = True
23 | torch.backends.cudnn.benchmark = False
24 | np.random.seed(random_seed)
25 | random.seed(random_seed)
26 |
27 | def get_args():
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("-d", "--data_path", default='dataset', help="data root path")
30 |
31 | parser.add_argument("-e", "--num_epochs", default=100, type=int, help="num epochs")
32 | parser.add_argument("-b", "--batch_size", default=16, type=int, help="batch size > 1")
33 |
34 | parser.add_argument("-n", "--num_workers", default=8, type=int, help="num_workers for DataLoader")
35 | parser.add_argument("-sn", "--show_num", default=4, type=int, help="show result images during training num")
36 |
37 | args = parser.parse_args()
38 |
39 | return args
40 |
41 | def load_weights_from_directory(model, weight_path) -> int:
42 | if weight_path.endswith('.pth'):
43 | wp = weight_path
44 | else:
45 | wps = sorted(os.listdir(weight_path), key=lambda x: int(x.split('_')[0]))
46 | if wps:
47 | wp = wps[-1]
48 | else:
49 | return 0
50 |
51 | print(f"Loading weights from {wp}...")
52 | model.load_state_dict(torch.load(os.path.join(weight_path, wp)))
53 | return int(wp.split('_')[0])
54 |
55 | if __name__ == "__main__":
56 |
57 | args = get_args()
58 |
59 | ### Path
60 | model_path = "results"
61 | weight_path = os.path.join(model_path, "weights")
62 | show_path = os.path.join(model_path, "show")
63 |
64 | os.makedirs(model_path, exist_ok=True)
65 | os.makedirs(weight_path, exist_ok=True)
66 | os.makedirs(show_path, exist_ok=True)
67 |
68 |
69 | ### Hyperparameters
70 | epochs = args.num_epochs
71 | batch_size = args.batch_size
72 | if batch_size <= 1:
73 | raise "Batch size should bigger than 1 for batch normalization"
74 |
75 | num_workers = args.num_workers
76 | show_num = args.show_num
77 |
78 | ### DataLoader
79 | dataloader_params = {'batch_size': batch_size,
80 | 'shuffle': True,
81 | 'drop_last': True,
82 | 'num_workers': num_workers}
83 |
84 | train_data = CustomDataset(args.data_path, set_name="train")
85 | train_gen = DataLoader(train_data, **dataloader_params)
86 |
87 | dataloader_params = {'batch_size': 1,
88 | 'shuffle': True,
89 | 'drop_last': False,
90 | 'num_workers': num_workers}
91 | val_data = CustomDataset(args.data_path, set_name="val")
92 | val_gen = DataLoader(val_data, **dataloader_params)
93 |
94 | steps_per_epoch = len(train_gen)
95 |
96 | ### Model
97 | device = "cuda" if torch.cuda.is_available() else "cpu"
98 | print(f"Using {device}...")
99 |
100 | model = STRNet().to(device)
101 |
102 | # load best weight
103 | initial_epoch = load_weights_from_directory(model, weight_path) + 1
104 | print(f"Training start from epoch {initial_epoch}")
105 |
106 | # Train Setting
107 | model_optim = Adam(model.parameters(), 0.0001, (0.5, 0.9))
108 |
109 | ### Train
110 | for epoch in range(initial_epoch, epochs):
111 | ## training
112 | train_loss = []
113 | train_discrim_loss = []
114 |
115 | model.train()
116 | pgbar = tqdm.tqdm(train_gen, total=len(train_gen))
117 | pgbar.set_description(f"Epoch {epoch}/{epochs}")
118 | for I, Itegt, Mm, Msgt in pgbar:
119 |
120 | I, Itegt, Mm, Msgt = I.to(device), Itegt.to(device), Mm.to(device), Msgt.to(device)
121 |
122 | # train model
123 | Ms, Ite, Ms_, Ite_ = model.forward(I, Mm)
124 |
125 | Ltsd = TSDLoss(Msgt, Ms, Ms_)
126 | Ltrg = TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_)
127 | Lgsn = -torch.mean(model.discrim(Mm, Ite_))
128 |
129 | total_loss = Ltsd + Ltrg + Lgsn
130 |
131 | model_optim.zero_grad()
132 | total_loss.backward()
133 | model_optim.step()
134 |
135 | # train discriminator
136 | Ms, Ite, Ms_, Ite_ = model.forward(I, Mm)
137 | Ldsn = torch.mean(F.relu(1-model.discrim(Mm, Itegt))) + \
138 | torch.mean(F.relu(1+model.discrim(Mm, Ite_)))
139 |
140 | model_optim.zero_grad()
141 | Ldsn.backward()
142 | model_optim.step()
143 |
144 | ltsd = Ltsd.detach().cpu().item()
145 | ltrg = Ltrg.detach().cpu().item()
146 | lgsn = Lgsn.detach().cpu().item()
147 | train_loss.append(total_loss.detach().cpu().item())
148 | train_discrim_loss.append(Ldsn.detach().cpu().item())
149 |
150 | pgbar.set_postfix_str(f"total loss : {train_loss[-1]:.6f} ltsd : {ltsd:.6f} ltrg : {ltrg:.6f} lgsn : {lgsn:.6f} d_loss : {train_discrim_loss[-1]:.6f}")
151 |
152 | train_loss = sum(train_loss)/len(train_loss)
153 |
154 | ## validation
155 | val_loss = []
156 |
157 | # will saved in show directory
158 | result_images = []
159 |
160 | model.eval()
161 | pgbar = tqdm.tqdm(val_gen, total=len(val_gen))
162 | pgbar.set_description("Validating...")
163 | for I, Itegt, Mm, Msgt in pgbar:
164 |
165 | I, Itegt, Mm, Msgt = I.to(device), Itegt.to(device), Mm.to(device), Msgt.to(device)
166 |
167 | # train model
168 | Ms, Ite, Ms_, Ite_ = model.forward(I, Mm)
169 |
170 | Ltsd = TSDLoss(Msgt, Ms, Ms_)
171 | Ltrg = TRGLoss(Mm, Ms, Ms_, Itegt, Ite, Ite_)
172 | Lgsn = -torch.mean(model.discrim(Mm, Ite_))
173 |
174 | total_loss = Ltsd + Ltrg + Lgsn
175 |
176 | val_loss.append(total_loss.detach().cpu().item())
177 |
178 | pgbar.set_postfix_str(f"loss : {sum(val_loss[-10:]) / len(val_loss[-10:]):.6f}")
179 |
180 | if len(result_images) < args.show_num:
181 | result_images.append([I.cpu(), Itegt.cpu(), Ite.cpu(), Ite_.cpu(), Msgt.cpu(), Ms.cpu(), Ms_.cpu()])
182 | else:
183 | break
184 |
185 | val_loss = sum(val_loss) / len(val_loss)
186 |
187 | ## visualize
188 | fig, axs = plt.subplots(args.show_num, 1, figsize=(10, 2*args.show_num))
189 | fig.suptitle("I, Itegt, Ite, Ite_, Msgt, Ms, Ms_]")
190 | for i, (I, Itegt, Ite, Ite_, Msgt, Ms, Ms_) in enumerate(result_images):
191 |
192 | I = postprocess_image(tensor_to_mat(I))[0]
193 | Itegt = postprocess_image(tensor_to_mat(Itegt))[0]
194 | Ite = postprocess_image(tensor_to_mat(Ite))[0]
195 | Ite_ = postprocess_image(tensor_to_mat(Ite_))[0]
196 | Msgt = postprocess_image(tensor_to_mat(Msgt))[0]
197 | Ms = postprocess_image(tensor_to_mat(Ms))[0]
198 | Ms_ = postprocess_image(tensor_to_mat(Ms_))[0]
199 |
200 | Msgt = cv2.cvtColor(Msgt, cv2.COLOR_GRAY2BGR)
201 | Ms = cv2.cvtColor(Ms, cv2.COLOR_GRAY2BGR)
202 | Ms_ = cv2.cvtColor(Ms_, cv2.COLOR_GRAY2BGR)
203 |
204 | axs[i].imshow(np.hstack([I, Itegt, Ite, Ite_, Msgt, Ms, Ms_]))
205 | axs[i].set_xticks([])
206 | axs[i].set_yticks([])
207 |
208 | fig.savefig(os.path.join(model_path, "show", f"epoch_{epoch}.png"))
209 | plt.close()
210 |
211 | print(f"train_loss : {train_loss}, val_loss : {val_loss}")
212 | print()
213 | time.sleep(0.2)
214 |
215 | torch.save(model.state_dict(), os.path.join(weight_path, f"{epoch}_train_{train_loss}_val_{val_loss}.pth"))
216 |
--------------------------------------------------------------------------------