├── pytorch_fid ├── __init__.py ├── __main__.py ├── fid_score.py └── inception.py ├── arial.ttf ├── demo_img ├── imgs │ ├── 1.JPG │ ├── 2.JPG │ ├── 3.JPG │ ├── 4.JPG │ ├── 001.png │ ├── 002.png │ ├── 003.png │ ├── 004.png │ ├── 005.png │ ├── 006.png │ ├── 007.png │ ├── 008.png │ ├── 100_0.jpg │ ├── 100_3.jpg │ ├── 100_4.jpg │ ├── 102_1.jpg │ └── 103_1.jpg └── i_t.txt ├── eval_2k.sh ├── eval_scene.sh ├── configs ├── erase-train.py └── mostel-train.py ├── standard_text.py ├── requirements.txt ├── eval_utils.py ├── .gitignore ├── rec_model.py ├── README.md ├── evaluation.py ├── predict.py ├── train_erase.py ├── loss.py ├── eval_real.py ├── rec_utils.py ├── tps_spatial_transformer.py ├── datagen.py ├── model_erase.py ├── train.py ├── rec_dataset.py ├── model.py └── rec_modules.py /pytorch_fid/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.1' 2 | -------------------------------------------------------------------------------- /arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/arial.ttf -------------------------------------------------------------------------------- /demo_img/imgs/1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/1.JPG -------------------------------------------------------------------------------- /demo_img/imgs/2.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/2.JPG -------------------------------------------------------------------------------- /demo_img/imgs/3.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/3.JPG -------------------------------------------------------------------------------- /demo_img/imgs/4.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/4.JPG -------------------------------------------------------------------------------- /demo_img/imgs/001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/001.png -------------------------------------------------------------------------------- /demo_img/imgs/002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/002.png -------------------------------------------------------------------------------- /demo_img/imgs/003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/003.png -------------------------------------------------------------------------------- /demo_img/imgs/004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/004.png -------------------------------------------------------------------------------- /demo_img/imgs/005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/005.png -------------------------------------------------------------------------------- /demo_img/imgs/006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/006.png -------------------------------------------------------------------------------- /demo_img/imgs/007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/007.png -------------------------------------------------------------------------------- /demo_img/imgs/008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/008.png -------------------------------------------------------------------------------- /demo_img/imgs/100_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/100_0.jpg -------------------------------------------------------------------------------- /demo_img/imgs/100_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/100_3.jpg -------------------------------------------------------------------------------- /demo_img/imgs/100_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/100_4.jpg -------------------------------------------------------------------------------- /demo_img/imgs/102_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/102_1.jpg -------------------------------------------------------------------------------- /demo_img/imgs/103_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqqyd/MOSTEL/HEAD/demo_img/imgs/103_1.jpg -------------------------------------------------------------------------------- /pytorch_fid/__main__.py: -------------------------------------------------------------------------------- 1 | import pytorch_fid.fid_score 2 | 3 | pytorch_fid.fid_score.main() 4 | -------------------------------------------------------------------------------- /eval_2k.sh: -------------------------------------------------------------------------------- 1 | rm -r tmp-eval-result 2 | python predict.py --config $1 --input_dir datasets/evaluation/Tamper-Syn2k/i_s/ --save_dir tmp-eval-result --checkpoint $2 --slm 3 | python evaluation.py --gt_path datasets/evaluation/Tamper-Syn2k/t_f/ --target_path tmp-eval-result/ 4 | rm -r tmp-eval-result 5 | -------------------------------------------------------------------------------- /eval_scene.sh: -------------------------------------------------------------------------------- 1 | rm -r tmp-eval-result 2 | python predict.py --config $1 --input_dir datasets/evaluation/Tamper-Scene/i_s/ --save_dir tmp-eval-result --checkpoint $2 --slm 3 | python eval_real.py --saved_model models/TPS-ResNet-BiLSTM-Attn.pth --gt_file datasets/evaluation/Tamper-Scene/i_t.txt --image_folder tmp-eval-result 4 | rm -r tmp-eval-result 5 | -------------------------------------------------------------------------------- /demo_img/i_t.txt: -------------------------------------------------------------------------------- 1 | 001.png David 2 | 002.png typhoon 3 | 003.png Prospect 4 | 004.png play 5 | 005.png camera 6 | 006.png foods 7 | 007.png cats 8 | 008.png Custom 9 | 100_0.jpg Highroad 10 | 100_3.jpg but 11 | 100_4.jpg Greenhouse 12 | 102_1.jpg REFINE'S 13 | 103_1.jpg CANDLE 14 | 1.JPG SNMCOws 15 | 2.JPG MsoXe 16 | 3.JPG CSmoasfa 17 | 4.JPG H2014 18 | -------------------------------------------------------------------------------- /configs/erase-train.py: -------------------------------------------------------------------------------- 1 | # Loss 2 | lb = 1. 3 | lb_mask = 1. 4 | lb_beta = 10. 5 | 6 | # Train 7 | learning_rate = 1e-4 8 | decay_rate = 0.9 9 | beta1 = 0.9 10 | beta2 = 0.999 11 | max_iter = 100000 12 | write_log_interval = 50 13 | save_ckpt_interval = 50000 14 | gen_example_interval = 50000 15 | task_name = 'erase-train' 16 | checkpoint_savedir = 'output/' + task_name + '/' # dont forget '/' 17 | ckpt_path = 'None' 18 | vgg19_weights = 'models/vgg19-dcbb9e9d.pth' 19 | 20 | # data 21 | batch_size = 64 22 | num_workers = 8 23 | data_shape = [64, 256] 24 | data_dir = ['datasets/training/EnsText-patch'] 25 | i_s_dir = 'i_s' 26 | t_b_dir = 't_b' 27 | mask_s_dir = 'mask_s' 28 | example_data_dir = 'demo_img/imgs' 29 | example_result_dir = checkpoint_savedir + 'val_visualization' 30 | 31 | # predict 32 | predict_ckpt_path = None 33 | predict_data_dir = None 34 | predict_result_dir = checkpoint_savedir + 'pred_result' 35 | -------------------------------------------------------------------------------- /standard_text.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image, ImageDraw, ImageFont 4 | 5 | 6 | class Std_Text(object): 7 | def __init__(self, font_path): 8 | self.height = 64 9 | self.max_width = 960 10 | self.border_width = 5 11 | self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 12 | font_height = self.get_valid_height(font_path) 13 | self.font = ImageFont.truetype(font_path, font_height) 14 | 15 | def get_valid_height(self, font_path): 16 | font = ImageFont.truetype(font_path, self.height - 4) 17 | _, font_height = font.getsize(self.char_list) 18 | if font_height <= self.height - 4: 19 | return self.height - 4 20 | else: 21 | return int((self.height - 4)**2 / font_height) 22 | 23 | def draw_text(self, text): 24 | assert len(text) != 0 25 | 26 | char_x = self.border_width 27 | bg = Image.new("RGB", (self.max_width, self.height), color=(127, 127, 127)) 28 | draw = ImageDraw.Draw(bg) 29 | for char in text: 30 | draw.text((char_x, 2), char, fill=(0, 0, 0), font=self.font) 31 | char_size = self.font.getsize(char)[0] 32 | char_x += char_size 33 | 34 | canvas = np.array(bg).astype(np.uint8) 35 | char_x += self.border_width 36 | canvas = canvas[:, :char_x, :] 37 | 38 | return canvas 39 | 40 | 41 | def main(): 42 | font_path = 'arial.ttf' 43 | std_text = Std_Text(font_path) 44 | 45 | tmp = std_text.draw_text('qwertyuiopasdfghjklzxcvbnm') 46 | print(tmp.shape) 47 | cv2.imwrite('tmp.jpg', tmp) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | addict==2.4.0 3 | Augmentor==0.2.10 4 | brotlipy==0.7.0 5 | cachetools==5.3.0 6 | certifi==2022.12.7 7 | cffi==1.15.1 8 | charset-normalizer==2.0.4 9 | click==8.1.3 10 | cryptography==38.0.4 11 | cycler==0.11.0 12 | Cython==0.29.33 13 | editdistance==0.6.2 14 | Flask==2.2.2 15 | flit_core==3.6.0 16 | fonttools==4.38.0 17 | future==0.18.3 18 | gevent==22.10.2 19 | gevent-websocket==0.10.1 20 | google-auth==2.16.0 21 | google-auth-oauthlib==0.4.6 22 | greenlet==2.0.2 23 | grpcio==1.51.1 24 | idna==2.10 25 | imageio==2.25.0 26 | imgaug==0.4.0 27 | importlib-metadata==4.11.3 28 | itsdangerous==2.1.2 29 | Jinja2==3.1.2 30 | joblib==1.2.0 31 | kiwisolver==1.4.4 32 | lmdb==1.4.0 33 | Markdown==3.4.1 34 | MarkupSafe==2.1.2 35 | matplotlib==3.5.3 36 | natsort==8.2.0 37 | networkx==2.6.3 38 | nltk==3.8.1 39 | numpy==1.21.6 40 | oauthlib==3.2.2 41 | opencv-python==4.7.0.68 42 | packaging==23.0 43 | Pillow==9.4.0 44 | pip==22.3.1 45 | pluggy==1.0.0 46 | prettytable==3.6.0 47 | protobuf==3.20.3 48 | pyasn1==0.4.8 49 | pyasn1-modules==0.2.8 50 | pycparser==2.20 51 | pygame==2.1.2 52 | pyOpenSSL==19.1.0 53 | pyparsing==3.0.9 54 | PySocks==1.7.1 55 | python-dateutil==2.8.2 56 | PyWavelets==1.3.0 57 | PyYAML==6.0 58 | regex==2022.10.31 59 | requests==2.28.1 60 | requests-oauthlib==1.3.1 61 | rsa==4.9 62 | ruamel.yaml==0.17.21 63 | ruamel.yaml.clib==0.2.6 64 | scikit-image==0.19.3 65 | scipy==1.7.3 66 | setuptools==65.6.3 67 | shapely==2.0.0 68 | six==1.16.0 69 | sortedcontainers==2.4.0 70 | tensorboard==2.11.2 71 | tensorboard-data-server==0.6.1 72 | tensorboard-plugin-wit==1.8.1 73 | terminaltables==3.1.10 74 | tifffile==2021.11.2 75 | toolz==0.12.0 76 | tqdm==4.64.1 77 | typing_extensions==4.4.0 78 | urllib3==1.25.11 79 | wcwidth==0.2.6 80 | Werkzeug==2.2.2 81 | wheel==0.35.1 82 | yapf==0.32.0 83 | zipp==3.11.0 84 | zope.event==4.6 85 | zope.interface==5.5.2 86 | zstandard==0.18.0 87 | -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Module providing functionality surrounding gaussian function. 3 | """ 4 | import os 5 | import numpy as np 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torchvision.transforms import Compose, ToTensor, Resize 9 | 10 | 11 | def fspecial_gauss(size, sigma): 12 | """Function to mimic the 'fspecial' gaussian MATLAB function 13 | """ 14 | x, y = np.mgrid[-size // 2 + 1:size // 15 | 2 + 1, -size // 2 + 1:size // 2 + 1] 16 | g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2))) 17 | return g / g.sum() 18 | 19 | 20 | def CheckImageFile(filename): 21 | return any(filename.endswith(extention) for extention in [ 22 | '.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.bmp', '.BMP']) 23 | 24 | 25 | def ImageTransform(loadSize): 26 | return Compose([ 27 | Resize(size=loadSize, interpolation=Image.BICUBIC), 28 | ToTensor(), 29 | ]) 30 | 31 | 32 | class devdata(Dataset): 33 | def __init__(self, dataRoot, gtRoot, loadSize=512): 34 | super(devdata, self).__init__() 35 | self.imageFiles = [os.path.join(dataRoot, filename) for filename 36 | in os.listdir(dataRoot) if CheckImageFile(filename)] 37 | self.gtFiles = [os.path.join(gtRoot, filename) for filename 38 | in os.listdir(dataRoot) if CheckImageFile(filename)] 39 | self.loadSize = loadSize 40 | 41 | def __getitem__(self, index): 42 | img = Image.open(self.imageFiles[index]) 43 | gt = Image.open(self.gtFiles[index]) 44 | to_scale = gt.size 45 | inputImage = ImageTransform(to_scale)(img.convert('RGB')) 46 | groundTruth = ImageTransform(to_scale)(gt.convert('RGB')) 47 | path = self.imageFiles[index].split('/')[-1] 48 | 49 | return inputImage, groundTruth, path 50 | 51 | def __len__(self): 52 | return len(self.imageFiles) 53 | -------------------------------------------------------------------------------- /configs/mostel-train.py: -------------------------------------------------------------------------------- 1 | # Loss 2 | lb = 1. 3 | lb_mask = 1. 4 | lb_beta = 10. 5 | lf = 1. 6 | lf_theta_1 = 10. 7 | lf_theta_2 = 1. 8 | lf_theta_3 = 500. 9 | lf_mask = 10. 10 | lf_rec = 0.1 11 | 12 | # Recognizer 13 | with_recognizer = True 14 | use_rgb = True 15 | train_recognizer = True 16 | rec_lr_weight = 1. 17 | 18 | # StyleAug 19 | vflip_rate = 0.5 20 | hflip_rate = 0.5 21 | angle_range = [(-15, -5), (5, 15)] 22 | 23 | # Train 24 | learning_rate = 5e-5 25 | decay_rate = 0.9 26 | beta1 = 0.9 27 | beta2 = 0.999 28 | max_iter = 300000 29 | write_log_interval = 50 30 | save_ckpt_interval = 50000 31 | gen_example_interval = 50000 32 | task_name = 'mostel-train' 33 | checkpoint_savedir = 'output/' + task_name + '/' # dont forget '/' 34 | ckpt_path = 'None' 35 | inpaint_ckpt_path = 'models/erase_pretrain.pth' 36 | vgg19_weights = 'models/vgg19-dcbb9e9d.pth' 37 | rec_ckpt_path = 'models/recognizer_pretrain.pth' 38 | 39 | # data 40 | batch_size = 16 41 | real_bs = 2 42 | with_real_data = True if real_bs > 0 else False 43 | num_workers = 8 44 | data_shape = [64, 256] 45 | data_dir = [ 46 | 'datasets/training/train-50k-1', 47 | 'datasets/training/train-50k-2', 48 | 'datasets/training/train-50k-3', 49 | ] 50 | real_data_dir = [ 51 | 'datasets/training/mlt2017-train-patch', 52 | 'datasets/training/mlt2017-val-patch', 53 | 'datasets/training/ic13-test-patch', 54 | ] 55 | i_s_dir = 'i_s' 56 | t_b_dir = 't_b' 57 | t_f_dir = 't_f' 58 | mask_t_dir = 'mask_t' 59 | mask_s_dir = 'mask_s' 60 | txt_dir = 'txt' 61 | font_path = 'arial.ttf' 62 | example_data_dir = 'demo_img/imgs' 63 | example_result_dir = checkpoint_savedir + 'val_visualization' 64 | 65 | # TPS 66 | TPS_ON = True 67 | num_control_points = 10 68 | stn_activation = 'tanh' 69 | tps_inputsize = data_shape 70 | tps_outputsize = data_shape 71 | tps_margins = (0.05, 0.05) 72 | 73 | # predict 74 | predict_ckpt_path = None 75 | predict_data_dir = None 76 | predict_result_dir = checkpoint_savedir + 'pred_result' 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | data/ 107 | data 108 | .vscode 109 | .idea 110 | .DS_Store 111 | 112 | # custom 113 | *.pkl 114 | *.pkl.json 115 | *.log.json 116 | work_dirs/ 117 | 118 | # Pytorch 119 | *.pth 120 | *.py~ 121 | *.sh~ 122 | -------------------------------------------------------------------------------- /rec_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from rec_modules import TPS_SpatialTransformerNetwork, ResNet_FeatureExtractor, BidirectionalLSTM, Attention 3 | import time 4 | 5 | class Rec_Model(nn.Module): 6 | def __init__(self, cfg): 7 | super(Rec_Model, self).__init__() 8 | """ TPS Transformation """ 9 | input_channel = 3 if cfg.use_rgb else 1 10 | self.Transformation = TPS_SpatialTransformerNetwork( 11 | F=20, I_size=(32, 100), I_r_size=(32, 100), I_channel_num=input_channel) # num_fiducial, imgH, imgW, input_channel 12 | 13 | """ FeatureExtraction """ 14 | self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, 512) # input_channel, output_channel 15 | self.FeatureExtraction_output = 512 # int(imgH/16-1) * 512 16 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 17 | 18 | """ BiLSTM Sequence modeling""" 19 | self.SequenceModeling = nn.Sequential( 20 | BidirectionalLSTM(self.FeatureExtraction_output, 256, 256), # hidden_size, hidden_size 21 | BidirectionalLSTM(256, 256, 256)) # hidden_size, hidden_size, hidden_size 22 | self.SequenceModeling_output = 256 23 | 24 | """ Prediction """ 25 | self.Prediction = Attention(self.SequenceModeling_output, 256, 38) # hidden_size, num_class 26 | 27 | 28 | def forward(self, input, text, is_train=True): 29 | """ TPS Transformation stage """ 30 | input = self.Transformation(input) 31 | 32 | """ Feature extraction stage """ 33 | visual_feature = self.FeatureExtraction(input) 34 | visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2).contiguous()) # [b, c, h, w] -> [b, w, c, h] 35 | visual_feature = visual_feature.squeeze(3) 36 | 37 | """ BiLSTM Sequence modeling stage """ 38 | contextual_feature = self.SequenceModeling(visual_feature) 39 | 40 | """ Attention Prediction stage """ 41 | prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=34) # batch_max_length 42 | 43 | return prediction 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Exploring Stroke-Level Modifications for Scene Text Editing 2 | 3 | ## Introduction 4 | This is a pytorch implementation for paper [MOSTEL](https://arxiv.org/abs/2212.01982). It edits scene text at stroke level and can be trained using both labeled synthetic images and unpaired real scene text images. 5 | 6 | ## ToDo List 7 | 8 | - [x] Release code 9 | - [x] Release evaluation datasets 10 | - [x] Document for Installation 11 | - [x] Trained models 12 | - [x] Document for training and testing 13 | 14 | 15 | ## Installation 16 | 17 | ### Requirements 18 | - Python==3.7 19 | - Pytorch==1.7.1 20 | - CUDA==10.1 21 | 22 | ```bash 23 | https://github.com/qqqyd/MOSTEL.git 24 | cd MOSTEL/ 25 | 26 | conda create --name MOSTEL python=3.7 -y 27 | conda activate MOSTEL 28 | pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 -f https://download.pytorch.org/whl/torch_stable.html 29 | pip install mmcv-full==1.6.0 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.7/index.html 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | ## Training 34 | Prepare the datasets and put them in ```datasets/```. Our training data uses synthetic data generated by [SRNet-Datagen](https://github.com/youdao-ai/SRNet-Datagen) and real scene text datasets. You can download our datasets [here](https://rec.ustc.edu.cn/share/81ddc320-b05b-11ed-b4bc-5f690f426d88)(password: t6bq) or [OneDrive](https://mailustceducn-my.sharepoint.com/:f:/g/personal/qqqyd_mail_ustc_edu_cn/ElyDQHgge-VNhpwmwXm2YmIBzNGy-1IxH_xq7mzAsWEk0A)(password: t6bq). 35 | 36 | To get better performance, Background Reconstruction Module can be pre-trained on [SCUT-EnsText](https://github.com/HCIILAB/SCUT-EnsText), and recognizer can be pre-trained on 50k synthetic data generated by [SRNet-Datagen](https://github.com/youdao-ai/SRNet-Datagen). You can also use our [models](https://rec.ustc.edu.cn/share/56198940-b05c-11ed-a0b3-69e6f6e19d65)(password: 85b5) or [OneDrive](https://mailustceducn-my.sharepoint.com/:f:/g/personal/qqqyd_mail_ustc_edu_cn/En6jxV64oUxAlEOC76Nwo0EBcdHNgZzxt5pZULX5_JGrSA)(password: 85b5). 37 | 38 | 39 | ```bash 40 | python train.py --config configs/mostel-train.py 41 | ``` 42 | 43 | ## Testing and evaluation 44 | Prepare the models and put them in ```models/```. You can download our models [here](https://rec.ustc.edu.cn/share/56198940-b05c-11ed-a0b3-69e6f6e19d65)(password: 85b5) or [OneDrive](https://mailustceducn-my.sharepoint.com/:f:/g/personal/qqqyd_mail_ustc_edu_cn/En6jxV64oUxAlEOC76Nwo0EBcdHNgZzxt5pZULX5_JGrSA)(password: 85b5). 45 | 46 | Generating the predicted results using following commands: 47 | ```bash 48 | python predict.py --config configs/mostel-train.py --input_dir datasets/evaluation/Tamper-Syn2k/i_s/ --save_dir results-syn2k --checkpoint models/mostel.pth --slm 49 | python predict.py --config configs/mostel-train.py --input_dir datasets/evaluation/Tamper-Scene/i_s/ --save_dir results-scene --checkpoint models/mostel.pth --slm 50 | ``` 51 | 52 | For synthetic data, the evaluation metrics are MSE, PSNR, SSIM and FID. 53 | ```bash 54 | python evaluation.py --gt_path datasets/evaluation/Tamper-Syn2k/t_f/ --target_path results-syn2k/ 55 | ``` 56 | For real data, the evaluation metric is recognition accuracy. 57 | ```bash 58 | python eval_real.py --saved_model models/TPS-ResNet-BiLSTM-Attn.pth --gt_file datasets/evaluation/Tamper-Scene/i_t.txt --image_folder results-scene/ 59 | ``` 60 | 61 | Or you can use ```eval_2k.sh``` and ```eval_scene.sh``` for testing and evaluation. 62 | ```bash 63 | bash eval_2k.sh configs/mostel-train.py models/mostel.pth 64 | bash eval_scene.sh configs/mostel-train.py models/mostel.pth 65 | ``` 66 | 67 | In our experiments, we found that SLM will improve the quantitative performance while leaving some text outline traces, which is not good for visualization. You can add ```--dilate``` for better visualization when generating predicted results. 68 | 69 | ## Citing the related works 70 | 71 | If you find our method useful for your research, please cite 72 | 73 | @inproceedings{qu2023exploring, 74 | title={Exploring stroke-level modifications for scene text editing}, 75 | author={Qu, Yadong and Tan, Qingfeng and Xie, Hongtao and Xu, Jianjun and Wang, Yuxin and Zhang, Yongdong}, 76 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 77 | volume={37}, 78 | number={2}, 79 | pages={2119--2127}, 80 | year={2023} 81 | } 82 | 83 | ## References 84 | 85 | [Niwhskal/SRNet](https://github.com/Niwhskal/SRNet) 86 | 87 | [youdao-ai/SRNet-Datagen](https://github.com/youdao-ai/SRNet-Datagen) 88 | 89 | [clovaai/deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark) 90 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse 3 | import torch 4 | import numpy as np 5 | from torch.utils.data import DataLoader 6 | from scipy import signal, ndimage 7 | from tqdm import tqdm 8 | from pytorch_fid import fid_score 9 | from eval_utils import devdata, fspecial_gauss 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--target_path', type=str, default='', help='results') 14 | parser.add_argument('--gt_path', type=str, default='', help='labels') 15 | parser.add_argument('--no_fid', action='store_true', default=False) 16 | args = parser.parse_args() 17 | img_path = args.target_path 18 | gt_path = args.gt_path 19 | 20 | sum_psnr = 0 21 | sum_ssim = 0 22 | sum_mse = 0 23 | count = 0 24 | sum_time = 0.0 25 | l1_loss = 0 26 | 27 | 28 | def ssim(img1, img2, cs_map=False): 29 | """Return the Structural Similarity Map corresponding to input images img1 30 | and img2 (images are assumed to be uint8) 31 | 32 | This function attempts to mimic precisely the functionality of ssim.m a 33 | MATLAB provided by the author's of SSIM 34 | https://ece.uwaterloo.ca/~z70wang/research/ssim/ssim_index.m 35 | """ 36 | img1 = img1.astype(float) 37 | img2 = img2.astype(float) 38 | 39 | size = min(img1.shape[0], 11) 40 | sigma = 1.5 41 | window = fspecial_gauss(size, sigma) 42 | K1 = 0.01 43 | K2 = 0.03 44 | L = 255 # bitdepth of image 45 | C1 = (K1 * L) ** 2 46 | C2 = (K2 * L) ** 2 47 | mu1 = signal.fftconvolve(img1, window, mode='valid') 48 | mu2 = signal.fftconvolve(img2, window, mode='valid') 49 | mu1_sq = mu1 * mu1 50 | mu2_sq = mu2 * mu2 51 | mu1_mu2 = mu1 * mu2 52 | sigma1_sq = signal.fftconvolve(img1 * img1, window, mode='valid') - mu1_sq 53 | sigma2_sq = signal.fftconvolve(img2 * img2, window, mode='valid') - mu2_sq 54 | sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') - mu1_mu2 55 | if cs_map: 56 | return (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)), 57 | (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)) 58 | else: 59 | return ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 60 | (sigma1_sq + sigma2_sq + C2)) 61 | 62 | 63 | def msssim(img1, img2): 64 | """This function implements Multi-Scale Structural Similarity (MSSSIM) Image 65 | Quality Assessment according to Z. Wang's "Multi-scale structural similarity 66 | for image quality assessment" Invited Paper, IEEE Asilomar Conference on 67 | Signals, Systems and Computers, Nov. 2003 68 | 69 | Author's MATLAB implementation:- 70 | http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 71 | """ 72 | level = 5 73 | weight = np.array([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 74 | downsample_filter = np.ones((2, 2)) / 4.0 75 | mssim = np.array([]) 76 | mcs = np.array([]) 77 | for l in range(level): 78 | ssim_map, cs_map = ssim(img1, img2, cs_map=True) 79 | mssim = np.append(mssim, ssim_map.mean()) 80 | mcs = np.append(mcs, cs_map.mean()) 81 | filtered_im1 = ndimage.filters.convolve(img1, downsample_filter, 82 | mode='reflect') 83 | filtered_im2 = ndimage.filters.convolve(img2, downsample_filter, 84 | mode='reflect') 85 | im1 = filtered_im1[:: 2, :: 2] 86 | im2 = filtered_im2[:: 2, :: 2] 87 | 88 | # Note: Remove the negative and add it later to avoid NaN in exponential. 89 | sign_mcs = np.sign(mcs[0: level - 1]) 90 | sign_mssim = np.sign(mssim[level - 1]) 91 | mcs_power = np.power(np.abs(mcs[0: level - 1]), weight[0: level - 1]) 92 | mssim_power = np.power(np.abs(mssim[level - 1]), weight[level - 1]) 93 | return np.prod(sign_mcs * mcs_power) * sign_mssim * mssim_power 94 | 95 | 96 | imgData = devdata(dataRoot=img_path, gtRoot=gt_path) 97 | data_loader = DataLoader( 98 | imgData, 99 | batch_size=1, 100 | shuffle=False, 101 | num_workers=0, 102 | drop_last=False) 103 | 104 | for idx, (img, lbl, path) in tqdm(enumerate(data_loader), total=len(data_loader)): 105 | mse = ((lbl - img)**2).mean() 106 | sum_mse += mse 107 | if mse == 0: 108 | continue 109 | count += 1 110 | 111 | psnr = 10 * math.log10(1 / mse) 112 | sum_psnr += psnr 113 | 114 | R = lbl[0, 0, :, :] 115 | G = lbl[0, 1, :, :] 116 | B = lbl[0, 2, :, :] 117 | YGT = .299 * R + .587 * G + .114 * B 118 | R = img[0, 0, :, :] 119 | G = img[0, 1, :, :] 120 | B = img[0, 2, :, :] 121 | YBC = .299 * R + .587 * G + .114 * B 122 | mssim = msssim(np.array(YGT * 255), np.array(YBC * 255)) 123 | sum_ssim += mssim 124 | 125 | print('PSNR:', sum_psnr / count) 126 | print('SSIM:', sum_ssim / count) 127 | print('MSE:', sum_mse.item() / count) 128 | 129 | if not args.no_fid: 130 | batch_size = 1 131 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 132 | dims = 2048 133 | fid_value = fid_score.calculate_fid_given_paths([str(gt_path), str(img_path)], batch_size, device, dims) 134 | print('FID:', fid_value) 135 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import argparse 4 | import math 5 | import torch 6 | import cv2 7 | import time 8 | import torchvision.transforms.functional as F 9 | import numpy as np 10 | from tqdm import tqdm 11 | from mmcv import Config 12 | from torch.utils.data import DataLoader 13 | from model import Generator 14 | from datagen import custom_dataset 15 | 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | 19 | def test_speed(G, i_t, i_s): 20 | num = 50 21 | start_time = time.time() 22 | for _ in range(num): 23 | tmp = G(i_t, i_s) 24 | time_cost = (time.time() - start_time) / num 25 | return time_cost 26 | 27 | 28 | class MyDilate(): 29 | def __init__(self) -> None: 30 | tmp_distance = 3 31 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (tmp_distance, tmp_distance)) # MORPH_RECT MORPH_CROSS MORPH_ELLIPSE 32 | self.kernel = kernel 33 | self.iterations = 1 34 | 35 | def __call__(self, img, binary=True): 36 | img = img * 255 37 | if binary: 38 | ret, img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY) 39 | dilate_img = cv2.morphologyEx(img, cv2.MORPH_DILATE, self.kernel, iterations=self.iterations) 40 | ret, dilate_img = cv2.threshold(dilate_img, 127, 255, cv2.THRESH_BINARY) 41 | dilate_img = dilate_img[:, :, np.newaxis] / 255 42 | 43 | return dilate_img 44 | 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--config', type=str) 49 | parser.add_argument('--input_dir', help='Directory containing xxx_i_s and xxx_i_t with same prefix') 50 | parser.add_argument('--save_dir', help='Directory to save result') 51 | parser.add_argument('--checkpoint', help='checkpoint') 52 | parser.add_argument('--i_t_name', default='i_t.txt') 53 | parser.add_argument('--vis', action='store_true', default=False) 54 | parser.add_argument('--slm', action='store_true', default=False) 55 | parser.add_argument('--speed', action='store_true', default=False) 56 | parser.add_argument('--dilate', action='store_true', default=False) 57 | args = parser.parse_args() 58 | 59 | assert args.input_dir is not None 60 | assert args.save_dir is not None 61 | assert args.checkpoint is not None 62 | if not os.path.exists(args.save_dir): 63 | os.makedirs(args.save_dir) 64 | 65 | cfg = Config.fromfile(args.config) 66 | G = Generator(cfg, in_channels=3).to(device) 67 | checkpoint = torch.load(args.checkpoint) 68 | G.load_state_dict(checkpoint['generator']) 69 | print('Model loaded: {}'.format(args.checkpoint)) 70 | 71 | batch_size = 256 if not args.speed else 1 72 | eval_data = custom_dataset(cfg, data_dir=args.input_dir, i_t_name=args.i_t_name, mode='eval') 73 | eval_loader = DataLoader( 74 | dataset=eval_data, 75 | batch_size=batch_size, 76 | num_workers=16, 77 | shuffle=False, 78 | drop_last=False) 79 | eval_iter = iter(eval_loader) 80 | 81 | G.eval() 82 | total_fps = [] 83 | if args.dilate: 84 | mydilate = MyDilate() 85 | 86 | with torch.no_grad(): 87 | for step in tqdm(range(len(eval_data)), total=math.ceil(len(eval_data)/batch_size)): 88 | try: 89 | inp = eval_iter.next() 90 | except StopIteration: 91 | break 92 | i_t = inp[0].to(device) 93 | i_s = inp[1].to(device) 94 | name_list = inp[2] 95 | 96 | gen_o_b_ori, gen_o_b, gen_o_f, gen_x_t_tps, gen_o_mask_s, gen_o_mask_t = G(i_t, i_s) 97 | 98 | if args.speed: 99 | time_cost = test_speed(G, i_t, i_s) 100 | total_fps.append(1 / time_cost) 101 | print('Params: %s, Inference speed: %fms, FPS: %f, %f' % ( 102 | str(sum(p.numel() for p in G.parameters() if p.requires_grad)), 103 | time_cost * 1000, 1 / time_cost, sum(total_fps) / len(total_fps))) 104 | 105 | gen_o_b_ori = gen_o_b_ori * 255 106 | gen_o_b = gen_o_b * 255 107 | gen_o_f = gen_o_f * 255 108 | gen_x_t_tps = gen_x_t_tps * 255 109 | 110 | for tmp_idx in range(gen_o_f.shape[0]): 111 | name = str(name_list[tmp_idx]) 112 | name, suffix = name.split('.') 113 | 114 | o_mask_s = gen_o_mask_s[tmp_idx].detach().to('cpu').numpy().transpose(1, 2, 0) 115 | o_mask_t = gen_o_mask_t[tmp_idx].detach().to('cpu').numpy().transpose(1, 2, 0) 116 | o_b_ori = gen_o_b_ori[tmp_idx].detach().to('cpu').numpy().transpose(1, 2, 0) 117 | o_b = gen_o_b[tmp_idx].detach().to('cpu').numpy().transpose(1, 2, 0) 118 | o_f = gen_o_f[tmp_idx].detach().to('cpu').numpy().transpose(1, 2, 0) 119 | x_t_tps = gen_x_t_tps[tmp_idx].detach().to('cpu').numpy().transpose(1, 2, 0) 120 | 121 | ori_o_mask_s = o_mask_s 122 | if args.dilate: 123 | tmp_i_s = (i_s * 255)[tmp_idx].detach().to('cpu').numpy().transpose(1, 2, 0) 124 | o_mask_s = mydilate(o_mask_s) 125 | o_b = o_mask_s * o_b_ori + (1 - o_mask_s) * tmp_i_s 126 | 127 | if args.slm: 128 | alpha = 0.5 129 | o_f = o_mask_t * o_f + (1 - o_mask_t) * (alpha * o_b + (1 - alpha) * o_f) 130 | 131 | if args.vis: 132 | cv2.imwrite(os.path.join(args.save_dir, name + '_o_f.' + suffix), o_f[:, :, ::-1]) 133 | cv2.imwrite(os.path.join(args.save_dir, name + '_o_b.' + suffix), o_b[:, :, ::-1]) 134 | cv2.imwrite(os.path.join(args.save_dir, name + '_o_b_ori.' + suffix), o_b_ori[:, :, ::-1]) 135 | cv2.imwrite(os.path.join(args.save_dir, name + '_o_mask_s.' + suffix), o_mask_s * 255) 136 | cv2.imwrite(os.path.join(args.save_dir, name + '_o_mask_t.' + suffix), o_mask_t * 255) 137 | cv2.imwrite(os.path.join(args.save_dir, name + '_x_t_tps.' + suffix), x_t_tps[:, :, ::-1]) 138 | else: 139 | cv2.imwrite(os.path.join(args.save_dir, name + '.' + suffix), o_f[:, :, ::-1]) 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /train_erase.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import numpy as np 5 | import cv2 6 | import torch 7 | from mmcv import Config 8 | from tqdm import tqdm 9 | from torch.utils.data import DataLoader 10 | from loss import build_generator_erase_loss, build_discriminator_loss 11 | from datagen import erase_dataset 12 | from model_erase import Generator, Discriminator 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | 16 | def requires_grad(model, flag=True): 17 | for p in model.parameters(): 18 | p.requires_grad = flag 19 | 20 | 21 | def get_logger(cfg, log_filename='log.txt', log_level=logging.INFO): 22 | logger = logging.getLogger(log_filename) 23 | logger.setLevel(log_level) 24 | formatter = logging.Formatter('[%(asctime)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 25 | if not os.path.exists(cfg.checkpoint_savedir): 26 | os.makedirs(cfg.checkpoint_savedir) 27 | fh = logging.FileHandler(os.path.join(cfg.checkpoint_savedir, log_filename)) 28 | fh.setLevel(log_level) 29 | fh.setFormatter(formatter) 30 | ch = logging.StreamHandler() 31 | ch.setLevel(log_level) 32 | ch.setFormatter(formatter) 33 | logger.addHandler(ch) 34 | logger.addHandler(fh) 35 | 36 | return logger 37 | 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--config', type=str) 42 | args = parser.parse_args() 43 | cfg = Config.fromfile(args.config) 44 | 45 | logger = get_logger(cfg) 46 | logger.info('Config path: {}'.format(args.config)) 47 | writer = SummaryWriter(cfg.checkpoint_savedir + 'tensorboard/') 48 | 49 | train_data = erase_dataset(cfg, mode='train') 50 | train_loader = DataLoader( 51 | dataset=train_data, 52 | batch_size=cfg.batch_size, 53 | shuffle=True, 54 | num_workers=cfg.num_workers, 55 | pin_memory=True, 56 | drop_last=True) 57 | eval_data = erase_dataset(cfg, data_dir=cfg.example_data_dir, mode='eval') 58 | eval_loader = DataLoader( 59 | dataset=eval_data, 60 | batch_size=1, 61 | shuffle=False) 62 | 63 | G = Generator(cfg, in_channels=3).cuda() 64 | D1 = Discriminator(cfg, in_channels=6).cuda() 65 | G_solver = torch.optim.Adam(G.parameters(), lr=cfg.learning_rate, betas=(cfg.beta1, cfg.beta2)) 66 | D1_solver = torch.optim.Adam(D1.parameters(), lr=cfg.learning_rate, betas=(cfg.beta1, cfg.beta2)) 67 | 68 | try: 69 | checkpoint = torch.load(cfg.ckpt_path) 70 | G.load_state_dict(checkpoint['generator']) 71 | D1.load_state_dict(checkpoint['discriminator1']) 72 | G_solver.load_state_dict(checkpoint['g_optimizer']) 73 | D1_solver.load_state_dict(checkpoint['d1_optimizer']) 74 | logger.info('Model loaded: {}'.format(cfg.ckpt_path)) 75 | except FileNotFoundError: 76 | logger.info('Model not found') 77 | 78 | requires_grad(G, False) 79 | requires_grad(D1, True) 80 | 81 | trainiter = iter(train_loader) 82 | for step in tqdm(range(cfg.max_iter)): 83 | D1_solver.zero_grad() 84 | if ((step + 1) % cfg.save_ckpt_interval == 0): 85 | torch.save( 86 | { 87 | 'generator': G.state_dict(), 88 | 'discriminator1': D1.state_dict(), 89 | 'g_optimizer': G_solver.state_dict(), 90 | 'd1_optimizer': D1_solver.state_dict(), 91 | }, 92 | cfg.checkpoint_savedir + f'train_step-{step+1}.model', 93 | ) 94 | 95 | try: 96 | i_s, t_b, mask_s = trainiter.next() 97 | except StopIteration: 98 | trainiter = iter(train_loader) 99 | i_s, t_b, mask_s = trainiter.next() 100 | 101 | i_s = i_s.cuda() 102 | t_b = t_b.cuda() 103 | mask_s = mask_s.cuda() 104 | labels = [t_b, mask_s] 105 | 106 | o_b_ori, o_b, o_mask_s = G(i_s) 107 | i_db_true = torch.cat((t_b, i_s), dim=1) 108 | i_db_pred = torch.cat((o_b, i_s), dim=1) 109 | o_db_true = D1(i_db_true) 110 | o_db_pred = D1(i_db_pred) 111 | 112 | db_loss = build_discriminator_loss(o_db_true, o_db_pred) 113 | db_loss.backward() 114 | D1_solver.step() 115 | 116 | # Train generator 117 | requires_grad(G, True) 118 | requires_grad(D1, False) 119 | G_solver.zero_grad() 120 | 121 | o_b_ori, o_b, o_mask_s = G(i_s) 122 | i_db_pred = torch.cat((o_b, i_s), dim=1) 123 | o_db_pred = D1(i_db_pred) 124 | 125 | out_g = [o_b, o_mask_s] 126 | out_d = o_db_pred 127 | 128 | g_loss, metrics = build_generator_erase_loss(cfg, out_g, out_d, labels) 129 | g_loss.backward() 130 | G_solver.step() 131 | 132 | requires_grad(G, False) 133 | requires_grad(D1, True) 134 | 135 | if ((step + 1) % cfg.write_log_interval == 0): 136 | loss_str = 'Iter: {}/{} | Gen:{:<10.6f} | D_bg:{:<10.6f} | G_lr:{} | D_lr:{}'.format( 137 | step + 1, cfg.max_iter, 138 | g_loss.item(), 139 | db_loss.item(), 140 | G_solver.param_groups[0]['lr'], 141 | D1_solver.param_groups[0]['lr']) 142 | writer.add_scalar('main/G_loss', g_loss.item(), step) 143 | writer.add_scalar('main/db_loss', db_loss.item(), step) 144 | 145 | logger.info(loss_str) 146 | for name, metric in metrics.items(): 147 | loss_str = ' | '.join(['{:<7}: {:<10.6f}'.format(sub_name, sub_metric) for sub_name, sub_metric in metric.items()]) 148 | for sub_name, sub_metric in metric.items(): 149 | writer.add_scalar(name + '/' + sub_name, sub_metric, step) 150 | logger.info(loss_str) 151 | 152 | if ((step + 1) % cfg.gen_example_interval == 0): 153 | savedir = os.path.join(cfg.example_result_dir, 'iter-' + str(step + 1).zfill(len(str(cfg.max_iter)))) 154 | with torch.no_grad(): 155 | for inp in eval_loader: 156 | i_s = inp[0].cuda() 157 | name = str(inp[1][0]) 158 | name, suffix = name.split('.') 159 | 160 | G.eval() 161 | o_b_ori, o_b, o_mask_s = G(i_s) 162 | G.train() 163 | 164 | if not os.path.exists(savedir): 165 | os.makedirs(savedir) 166 | o_mask_s = o_mask_s.detach().squeeze(0).to('cpu').numpy().transpose(1, 2, 0) 167 | o_b = o_b.detach().squeeze(0).to('cpu').numpy().transpose(1, 2, 0) 168 | cv2.imwrite(os.path.join(savedir, name + '_o_mask_s.' + suffix), o_mask_s * 255) 169 | cv2.imwrite(os.path.join(savedir, name + '_o_b.' + suffix), o_b[:, :, ::-1] * 255) 170 | 171 | 172 | if __name__ == '__main__': 173 | main() 174 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | gpu_num = torch.cuda.device_count() 5 | epsilon = 1e-8 6 | 7 | def build_discriminator_loss(x_true, x_fake): 8 | d_loss = -torch.mean(torch.log(torch.clamp(x_true, epsilon, 1.0)) + 9 | torch.log(torch.clamp(1.0 - x_fake, epsilon, 1.0))) 10 | return d_loss 11 | 12 | 13 | def build_dice_loss(x_t, x_o): 14 | iflat = x_o.view(-1) 15 | tflat = x_t.view(-1) 16 | intersection = (iflat * tflat).sum() 17 | return 1. - torch.mean((2. * intersection + epsilon) / (iflat.sum() + tflat.sum() + epsilon)) 18 | 19 | 20 | def build_l1_loss(x_t, x_o): 21 | return torch.mean(torch.abs(x_t - x_o)) 22 | 23 | 24 | def build_l2_loss(x_t, x_o): 25 | return torch.mean((x_t - x_o) ** 2) 26 | 27 | 28 | def build_perceptual_loss(x): 29 | l = [] 30 | for i, f in enumerate(x): 31 | l.append(build_l1_loss(f[0], f[1])) 32 | l = torch.stack(l, dim=0) 33 | l = l.sum() 34 | return l 35 | 36 | 37 | def build_gram_matrix(x): 38 | x_shape = x.shape 39 | c, h, w = x_shape[1], x_shape[2], x_shape[3] 40 | matrix = x.view((-1, c, h * w)) 41 | matrix1 = torch.transpose(matrix, 1, 2) 42 | gram = torch.matmul(matrix, matrix1) / (h * w * c) 43 | return gram 44 | 45 | 46 | def build_style_loss(x): 47 | l = [] 48 | for i, f in enumerate(x): 49 | f_shape = f[0].shape[0] * f[0].shape[1] * f[0].shape[2] 50 | f_norm = 1. / f_shape 51 | gram_true = build_gram_matrix(f[0]) 52 | gram_pred = build_gram_matrix(f[1]) 53 | l.append(f_norm * (build_l1_loss(gram_true, gram_pred))) 54 | l = torch.stack(l, dim=0) 55 | l = l.sum() 56 | return l 57 | 58 | 59 | def build_vgg_loss(x): 60 | splited = [] 61 | for i, f in enumerate(x): 62 | splited.append(torch.chunk(f, 2)) 63 | l_per = build_perceptual_loss(splited) 64 | l_style = build_style_loss(splited) 65 | return l_per, l_style 66 | 67 | 68 | def build_gan_loss(x_pred): 69 | gen_loss = -torch.mean(torch.log(torch.clamp(x_pred, epsilon, 1.0))) 70 | return gen_loss 71 | 72 | 73 | def build_recognizer_loss(preds, target): 74 | loss = F.cross_entropy(preds, target, ignore_index=0) 75 | return loss 76 | 77 | 78 | def build_generator_loss(cfg, out_g, out_d, out_vgg, labels): 79 | if cfg.with_recognizer: 80 | o_b, o_f, o_mask_s, o_mask_t, rec_preds = out_g 81 | t_b, t_f, mask_t, mask_s, rec_target = labels 82 | else: 83 | o_b, o_f, o_mask_s, o_mask_t = out_g 84 | t_b, t_f, mask_t, mask_s = labels 85 | o_db_pred, o_df_pred = out_d 86 | o_vgg = out_vgg 87 | 88 | # Background Inpainting module loss 89 | l_b_gan = build_gan_loss(o_db_pred) 90 | l_b_l2 = cfg.lb_beta * build_l2_loss(t_b, o_b) 91 | l_b_mask = cfg.lb_mask * build_dice_loss(mask_s, o_mask_s) 92 | l_b = l_b_gan + l_b_l2 + l_b_mask 93 | 94 | l_f_gan = build_gan_loss(o_df_pred) 95 | l_f_l2 = cfg.lf_theta_1 * build_l2_loss(t_f, o_f) 96 | l_f_vgg_per, l_f_vgg_style = build_vgg_loss(o_vgg) 97 | l_f_vgg_per = cfg.lf_theta_2 * l_f_vgg_per 98 | l_f_vgg_style = cfg.lf_theta_3 * l_f_vgg_style 99 | l_f_mask = cfg.lf_mask * build_dice_loss(mask_t, o_mask_t) 100 | if cfg.with_recognizer: 101 | l_f_rec = cfg.lf_rec * build_recognizer_loss(rec_preds.view(-1, rec_preds.shape[-1]), rec_target.contiguous().view(-1)) 102 | l_f = l_f_gan + l_f_vgg_per + l_f_vgg_style + l_f_l2 + l_f_mask + l_f_rec 103 | else: 104 | l_f = l_f_gan + l_f_vgg_per + l_f_vgg_style + l_f_l2 + l_f_mask 105 | l = cfg.lb * l_b + cfg.lf * l_f 106 | 107 | metrics = {} 108 | metrics['l_b'] = {} 109 | metrics['l_b']['l_b'] = l_b 110 | metrics['l_b']['l_b_gan'] = l_b_gan 111 | metrics['l_b']['l_b_l2'] = l_b_l2 112 | metrics['l_b']['l_b_mask'] = l_b_mask 113 | metrics['l_f'] = {} 114 | metrics['l_f']['l_f'] = l_f 115 | metrics['l_f']['l_f_gan'] = l_f_gan 116 | metrics['l_f']['l_f_l2'] = l_f_l2 117 | metrics['l_f']['l_f_vgg_per'] = l_f_vgg_per 118 | metrics['l_f']['l_f_vgg_style'] = l_f_vgg_style 119 | metrics['l_f']['l_f_mask'] = l_f_mask 120 | if cfg.with_recognizer: 121 | metrics['l_f']['l_f_rec'] = l_f_rec 122 | 123 | return l, metrics 124 | 125 | 126 | def build_generator_loss_with_real(cfg, out_g, out_d, out_vgg, labels): 127 | if cfg.with_recognizer: 128 | o_b, o_f, o_mask_s, o_mask_t, rec_preds = out_g 129 | t_b, t_f, mask_t, mask_s, rec_target = labels 130 | else: 131 | o_b, o_f, o_mask_s, o_mask_t = out_g 132 | t_b, t_f, mask_t, mask_s = labels 133 | o_db_pred, o_df_pred = out_d 134 | o_vgg = out_vgg 135 | 136 | synth_bs = (cfg.batch_size - cfg.real_bs) // gpu_num 137 | # Background Inpainting module loss 138 | l_b_gan = build_gan_loss(o_db_pred) 139 | l_b_l2 = cfg.lb_beta * build_l2_loss(t_b[:synth_bs], o_b[:synth_bs]) 140 | l_b_mask = cfg.lb_mask * build_dice_loss(mask_s[:synth_bs], o_mask_s[:synth_bs]) 141 | l_b = l_b_gan + l_b_l2 + l_b_mask 142 | 143 | l_f_gan = build_gan_loss(o_df_pred) 144 | l_f_l2 = cfg.lf_theta_1 * build_l2_loss(t_f, o_f) 145 | l_f_vgg_per, l_f_vgg_style = build_vgg_loss(o_vgg) 146 | l_f_vgg_per = cfg.lf_theta_2 * l_f_vgg_per 147 | l_f_vgg_style = cfg.lf_theta_3 * l_f_vgg_style 148 | l_f_mask = cfg.lf_mask * build_dice_loss(mask_t[:synth_bs], o_mask_t[:synth_bs]) 149 | if cfg.with_recognizer: 150 | l_f_rec = cfg.lf_rec * build_recognizer_loss(rec_preds.view(-1, rec_preds.shape[-1]), rec_target.contiguous().view(-1)) 151 | l_f = l_f_gan + l_f_vgg_per + l_f_vgg_style + l_f_l2 + l_f_mask + l_f_rec 152 | else: 153 | l_f = l_f_gan + l_f_vgg_per + l_f_vgg_style + l_f_l2 + l_f_mask 154 | l = cfg.lb * l_b + cfg.lf * l_f 155 | 156 | metrics = {} 157 | metrics['l_b'] = {} 158 | metrics['l_b']['l_b'] = l_b 159 | metrics['l_b']['l_b_gan'] = l_b_gan 160 | metrics['l_b']['l_b_l2'] = l_b_l2 161 | metrics['l_b']['l_b_mask'] = l_b_mask 162 | metrics['l_f'] = {} 163 | metrics['l_f']['l_f'] = l_f 164 | metrics['l_f']['l_f_gan'] = l_f_gan 165 | metrics['l_f']['l_f_l2'] = l_f_l2 166 | metrics['l_f']['l_f_vgg_per'] = l_f_vgg_per 167 | metrics['l_f']['l_f_vgg_style'] = l_f_vgg_style 168 | metrics['l_f']['l_f_mask'] = l_f_mask 169 | if cfg.with_recognizer: 170 | metrics['l_f']['l_f_rec'] = l_f_rec 171 | 172 | return l, metrics 173 | 174 | def build_generator_erase_loss(cfg, out_g, out_d, labels): 175 | o_b, o_mask_s = out_g 176 | t_b, mask_s = labels 177 | o_db_pred = out_d 178 | 179 | l_b_gan = build_gan_loss(o_db_pred) 180 | l_b_l2 = cfg.lb_beta * build_l2_loss(t_b, o_b) 181 | l_b_mask = cfg.lb_mask * build_dice_loss(mask_s, o_mask_s) 182 | l_b = l_b_gan + l_b_l2 + l_b_mask 183 | l = cfg.lb * l_b 184 | 185 | metrics = {} 186 | metrics['l_b'] = {} 187 | metrics['l_b']['l_b'] = l_b 188 | metrics['l_b']['l_b_gan'] = l_b_gan 189 | metrics['l_b']['l_b_l2'] = l_b_l2 190 | metrics['l_b']['l_b_mask'] = l_b_mask 191 | 192 | return l, metrics -------------------------------------------------------------------------------- /eval_real.py: -------------------------------------------------------------------------------- 1 | import string 2 | import argparse 3 | import torch 4 | import torch.backends.cudnn as cudnn 5 | import torch.utils.data 6 | import torch.nn.functional as F 7 | from nltk.metrics.distance import edit_distance 8 | from rec_utils import AttnLabelConverter 9 | from rec_dataset import RawDataset, AlignCollate 10 | from rec_model import Rec_Model 11 | 12 | 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | def main(opt): 16 | """ model configuration """ 17 | converter = AttnLabelConverter(opt.character) 18 | opt.num_class = len(converter.character) 19 | if opt.use_rgb: 20 | opt.input_channel = 3 21 | model = Rec_Model(opt) 22 | 23 | # load model 24 | print('Loading pretrained model from %s' % opt.saved_model) 25 | rec_state_dict = torch.load(opt.saved_model, map_location=device) 26 | if len(rec_state_dict) == 1: 27 | rec_state_dict = rec_state_dict['recognizer'] 28 | rec_state_dict = {k.replace('module.', ''): v for k, v in rec_state_dict.items()} 29 | model.load_state_dict(rec_state_dict) 30 | model = torch.nn.DataParallel(model).to(device) 31 | 32 | # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo 33 | AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 34 | demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset 35 | demo_loader = torch.utils.data.DataLoader( 36 | demo_data, batch_size=opt.batch_size, 37 | shuffle=False, 38 | num_workers=int(opt.workers), 39 | collate_fn=AlignCollate_demo, pin_memory=True) 40 | 41 | n_correct = 0 42 | norm_ED = 0 43 | length_of_data = 0 44 | 45 | # predict 46 | model.eval() 47 | with torch.no_grad(): 48 | for image_tensors, image_path_list, image_gts in demo_loader: 49 | batch_size = image_tensors.size(0) 50 | image = image_tensors.to(device) 51 | # For max length prediction 52 | length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) 53 | text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) 54 | 55 | if 'CTC' in opt.Prediction: 56 | preds = model(image, text_for_pred) 57 | 58 | # Select max probabilty (greedy decoding) then decode index to character 59 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 60 | _, preds_index = preds.max(2) 61 | # preds_index = preds_index.view(-1) 62 | preds_str = converter.decode(preds_index, preds_size) 63 | else: 64 | preds = model(image, text_for_pred, is_train=False) 65 | 66 | # select max probabilty (greedy decoding) then decode index to character 67 | _, preds_index = preds.max(2) 68 | preds_str = converter.decode(preds_index, length_for_pred) 69 | 70 | preds_prob = F.softmax(preds, dim=2) 71 | preds_max_prob, _ = preds_prob.max(dim=2) 72 | for img_name, pred, pred_max_prob, gt in zip(image_path_list, preds_str, preds_max_prob, image_gts): 73 | img_name = img_name.split('/')[-1] 74 | if 'Attn' in opt.Prediction: 75 | pred_EOS = pred.find('[s]') 76 | pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) 77 | pred_max_prob = pred_max_prob[:pred_EOS] 78 | 79 | # calculate confidence score (= multiply of pred_max_prob) 80 | confidence_score = pred_max_prob.cumprod(dim=0)[-1] 81 | 82 | if not opt.sensitive: 83 | pred = pred.lower() 84 | gt = gt.lower() 85 | length_of_data += 1 86 | if pred == gt: 87 | n_correct += 1 88 | if len(gt) == 0 or len(pred) == 0: 89 | norm_ED += 0 90 | elif len(gt) > len(pred): 91 | norm_ED += 1 - edit_distance(pred, gt) / len(gt) 92 | else: 93 | norm_ED += 1 - edit_distance(pred, gt) / len(pred) 94 | 95 | accuracy = n_correct / float(length_of_data) * 100 96 | norm_ED = norm_ED / float(length_of_data) 97 | 98 | print(f'{opt.image_folder}: Total {length_of_data}\t Acc {accuracy:0.3f}\t normalized_ED {norm_ED:0.3f}') 99 | 100 | 101 | if __name__ == '__main__': 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('--image_folder', help='path to image_folder which contains text images') 104 | parser.add_argument('--gt_file', help='path to gt_file') 105 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 106 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 107 | parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation") 108 | """ Data processing """ 109 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 110 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 111 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 112 | parser.add_argument('--use_rgb', action='store_true', help='use rgb input') 113 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 114 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 115 | parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') 116 | """ Model Architecture """ 117 | parser.add_argument('--Transformation', type=str, default='TPS', help='Transformation stage. None|TPS') 118 | parser.add_argument('--FeatureExtraction', type=str, default='ResNet', help='FeatureExtraction stage. VGG|RCNN|ResNet') 119 | parser.add_argument('--SequenceModeling', type=str, default='BiLSTM', help='SequenceModeling stage. None|BiLSTM') 120 | parser.add_argument('--Prediction', type=str, default='Attn', help='Prediction stage. CTC|Attn') 121 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 122 | parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') 123 | parser.add_argument('--output_channel', type=int, default=512, 124 | help='the number of output channel of Feature extractor') 125 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 126 | parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') 127 | 128 | opt = parser.parse_args() 129 | 130 | """ vocab / character number configuration """ 131 | if opt.sensitive: 132 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 133 | 134 | cudnn.benchmark = True 135 | cudnn.deterministic = True 136 | opt.num_gpu = torch.cuda.device_count() 137 | 138 | main(opt) 139 | 140 | -------------------------------------------------------------------------------- /rec_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 3 | 4 | 5 | class CTCLabelConverter(object): 6 | """ Convert between text-label and text-index """ 7 | 8 | def __init__(self, character): 9 | # character (str): set of the possible characters. 10 | dict_character = list(character) 11 | 12 | self.dict = {} 13 | for i, char in enumerate(dict_character): 14 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss 15 | self.dict[char] = i + 1 16 | 17 | self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) 18 | 19 | def encode(self, text, batch_max_length=25): 20 | """convert text-label into text-index. 21 | input: 22 | text: text labels of each image. [batch_size] 23 | batch_max_length: max length of text label in the batch. 25 by default 24 | 25 | output: 26 | text: text index for CTCLoss. [batch_size, batch_max_length] 27 | length: length of each text. [batch_size] 28 | """ 29 | length = [len(s) for s in text] 30 | 31 | # The index used for padding (=0) would not affect the CTC loss calculation. 32 | batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) 33 | for i, t in enumerate(text): 34 | text = list(t) 35 | text = [self.dict[char] for char in text] 36 | batch_text[i][:len(text)] = torch.LongTensor(text) 37 | return (batch_text.to(device), torch.IntTensor(length).to(device)) 38 | 39 | def decode(self, text_index, length): 40 | """ convert text-index into text-label. """ 41 | texts = [] 42 | for index, l in enumerate(length): 43 | t = text_index[index, :] 44 | 45 | char_list = [] 46 | for i in range(l): 47 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. 48 | char_list.append(self.character[t[i]]) 49 | text = ''.join(char_list) 50 | 51 | texts.append(text) 52 | return texts 53 | 54 | 55 | class CTCLabelConverterForBaiduWarpctc(object): 56 | """ Convert between text-label and text-index for baidu warpctc """ 57 | 58 | def __init__(self, character): 59 | # character (str): set of the possible characters. 60 | dict_character = list(character) 61 | 62 | self.dict = {} 63 | for i, char in enumerate(dict_character): 64 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss 65 | self.dict[char] = i + 1 66 | 67 | self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) 68 | 69 | def encode(self, text, batch_max_length=25): 70 | """convert text-label into text-index. 71 | input: 72 | text: text labels of each image. [batch_size] 73 | output: 74 | text: concatenated text index for CTCLoss. 75 | [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] 76 | length: length of each text. [batch_size] 77 | """ 78 | length = [len(s) for s in text] 79 | text = ''.join(text) 80 | text = [self.dict[char] for char in text] 81 | 82 | return (torch.IntTensor(text), torch.IntTensor(length)) 83 | 84 | def decode(self, text_index, length): 85 | """ convert text-index into text-label. """ 86 | texts = [] 87 | index = 0 88 | for l in length: 89 | t = text_index[index:index + l] 90 | 91 | char_list = [] 92 | for i in range(l): 93 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. 94 | char_list.append(self.character[t[i]]) 95 | text = ''.join(char_list) 96 | 97 | texts.append(text) 98 | index += l 99 | return texts 100 | 101 | 102 | class AttnLabelConverter(object): 103 | """ Convert between text-label and text-index """ 104 | 105 | def __init__(self, character): 106 | # character (str): set of the possible characters. 107 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. 108 | list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] 109 | list_character = list(character) 110 | self.character = list_token + list_character 111 | 112 | self.dict = {} 113 | for i, char in enumerate(self.character): 114 | # print(i, char) 115 | self.dict[char] = i 116 | 117 | def encode(self, text, batch_max_length=25): 118 | """ convert text-label into text-index. 119 | input: 120 | text: text labels of each image. [batch_size] 121 | batch_max_length: max length of text label in the batch. 25 by default 122 | 123 | output: 124 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. 125 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. 126 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] 127 | """ 128 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. 129 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting 130 | batch_max_length += 1 131 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. 132 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) 133 | for i, t in enumerate(text): 134 | text = list(t) 135 | text.append('[s]') 136 | text = [self.dict[char] for char in text] 137 | try: 138 | batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token 139 | except: 140 | print(text) 141 | print(len(text)) 142 | print(len(batch_text[i])) 143 | # return (batch_text.to(device), torch.IntTensor(length).to(device)) 144 | return (batch_text, torch.IntTensor(length)) 145 | 146 | def decode(self, text_index, length): 147 | """ convert text-index into text-label. """ 148 | texts = [] 149 | for index, l in enumerate(length): 150 | text = ''.join([self.character[i] for i in text_index[index, :]]) 151 | texts.append(text) 152 | return texts 153 | 154 | 155 | class Averager(object): 156 | """Compute average for torch.Tensor, used for loss average.""" 157 | 158 | def __init__(self): 159 | self.reset() 160 | 161 | def add(self, v): 162 | count = v.data.numel() 163 | v = v.data.sum() 164 | self.n_count += count 165 | self.sum += v 166 | 167 | def reset(self): 168 | self.n_count = 0 169 | self.sum = 0 170 | 171 | def val(self): 172 | res = 0 173 | if self.n_count != 0: 174 | res = self.sum / float(self.n_count) 175 | return res 176 | -------------------------------------------------------------------------------- /tps_spatial_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class TPSSpatialTransformer(nn.Module): 8 | def __init__(self, output_image_size=None, num_control_points=None, margins=None): 9 | # margins: (x, y) x, y in [0, 1) 10 | super(TPSSpatialTransformer, self).__init__() 11 | self.output_image_size = output_image_size 12 | self.num_control_points = num_control_points 13 | self.margins = margins 14 | 15 | # self.source_ctrl_points = torch.Tensor([ 16 | # [-1, -1], 17 | # [-0.5, -1], 18 | # [0, -1], 19 | # [0.5, -1], 20 | # [1, -1], 21 | # [-1, 1], 22 | # [-0.5, 1], 23 | # [0, 1], 24 | # [0.5, 1], 25 | # [1, 1]]) 26 | self.source_ctrl_points = self.build_output_control_points(num_control_points, margins) 27 | 28 | def build_output_control_points(self, num_control_points, margins): 29 | margin_x, margin_y = margins 30 | num_ctrl_pts_per_side = num_control_points // 2 31 | ctrl_pts_x = np.linspace(-1.0 + margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) 32 | ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * (-1.0 + margin_y) 33 | ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) 34 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 35 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 36 | output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 37 | output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr) 38 | return output_ctrl_pts 39 | 40 | def b_inv(self, b_mat): 41 | eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat) 42 | b_inv, _ = torch.solve(eye, b_mat) 43 | return b_inv 44 | 45 | def _repeat(self, x, n_repeats): 46 | rep = torch.unsqueeze(torch.ones(n_repeats), 1).transpose(0, 1) 47 | x = torch.matmul(x.reshape(-1, 1).int(), rep.int()) 48 | return x.reshape(-1) 49 | 50 | def _interpolate(self, im, x, y): 51 | # constants 52 | num_batch, height, width, channels = im.shape 53 | 54 | x = x.float() 55 | y = y.float() 56 | out_height, out_width = self.output_image_size 57 | height_f = torch.tensor(height, dtype=torch.float32) 58 | width_f = torch.tensor(width, dtype=torch.float32) 59 | zero = torch.tensor(0, dtype=torch.int32) 60 | max_y = torch.tensor(height - 1, dtype=torch.int32) 61 | max_x = torch.tensor(width - 1, dtype=torch.int32) 62 | 63 | # scale indices from [-1, 1] to [0, width/height] 64 | x = (x + 1.0) * (width_f) / 2.0 65 | y = (y + 1.0) * (height_f) / 2.0 66 | 67 | # do sampling 68 | x0 = torch.floor(x).int() 69 | x1 = x0 + 1 70 | y0 = torch.floor(y).int() 71 | y1 = y0 + 1 72 | 73 | x0 = torch.clamp(x0, min=zero, max=max_x) 74 | x1 = torch.clamp(x1, min=zero, max=max_x) 75 | y0 = torch.clamp(y0, min=zero, max=max_y) 76 | y1 = torch.clamp(y1, min=zero, max=max_y) 77 | 78 | dim2 = width 79 | dim1 = width * height 80 | # base = _repeat(torch.range(0, num_batch-1)*dim1, out_height*out_width) 81 | base = self._repeat(torch.arange(0, num_batch) * dim1, out_height * out_width).cuda() 82 | 83 | base_y0 = base + y0 * dim2 84 | base_y1 = base + y1 * dim2 85 | idx_a = base_y0 + x0 86 | idx_b = base_y1 + x0 87 | idx_c = base_y0 + x1 88 | idx_d = base_y1 + x1 89 | 90 | # use indices to lookup pixels in the flat image and restore channels dim 91 | im_flat = im.reshape(-1, channels) 92 | im_flat = im_flat.float() 93 | 94 | # tmp = idx_a.unsqueeze(1).long() 95 | idx_a = idx_a.unsqueeze(1).long() 96 | idx_b = idx_b.unsqueeze(1).long() 97 | idx_c = idx_c.unsqueeze(1).long() 98 | idx_d = idx_d.unsqueeze(1).long() 99 | if channels != 1: 100 | tmp_idx_a = idx_a.long() 101 | tmp_idx_b = idx_b.long() 102 | tmp_idx_c = idx_c.long() 103 | tmp_idx_d = idx_d.long() 104 | for i in range(channels - 1): 105 | idx_a = torch.cat((idx_a, tmp_idx_a), 1) 106 | idx_b = torch.cat((idx_b, tmp_idx_b), 1) 107 | idx_c = torch.cat((idx_c, tmp_idx_c), 1) 108 | idx_d = torch.cat((idx_d, tmp_idx_d), 1) 109 | 110 | Ia = torch.gather(im_flat, 0, idx_a) 111 | Ib = torch.gather(im_flat, 0, idx_b) 112 | Ic = torch.gather(im_flat, 0, idx_c) 113 | Id = torch.gather(im_flat, 0, idx_d) 114 | 115 | # and finally calculate interpolated values 116 | x0_f = x0.float() 117 | x1_f = x1.float() 118 | y0_f = y0.float() 119 | y1_f = y1.float() 120 | wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1) 121 | wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1) 122 | wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1) 123 | wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1) 124 | 125 | output = wa * Ia + wb * Ib + wc * Ic + wd * Id 126 | return output 127 | 128 | def solve_system(self, target_ctrl_points): 129 | """Thin Plate Spline Spatial Transformer layer 130 | TPS control points are arranged in arbitrary positions given by `coord`. 131 | coord : float Tensor [num_batch, num_point, 2] 132 | Relative coordinate of the control points. 133 | vec : float Tensor [num_batch, num_point, 2] 134 | The vector on the control points. 135 | """ 136 | coord = target_ctrl_points 137 | num_batch = coord.shape[0] 138 | num_point = self.num_control_points 139 | 140 | vec = torch.ones(num_batch)[:, None, None] * self.source_ctrl_points 141 | vec = vec.reshape(num_batch, num_point, 2).cuda() 142 | p = torch.cat([torch.ones([num_batch, num_point, 1]).cuda(), coord], 2) # [bn, pn, 3] 143 | 144 | p_1 = torch.reshape(p, [num_batch, -1, 1, 3]) # [bn, pn, 1, 3] 145 | p_2 = torch.reshape(p, [num_batch, 1, -1, 3]) # [bn, 1, pn, 3] 146 | d = p_1 - p_2 # [bn, pn, pn, 3] 147 | d2 = torch.sum(torch.pow(d, 2), 3) # [bn, pn, pn] 148 | r = d2 * torch.log(d2 + 1e-6) # [bn, pn, pn] 149 | 150 | W_0 = torch.cat([p, r], 2) # [bn, pn, 3+pn] 151 | W_1 = torch.cat([torch.zeros([num_batch, 3, 3]).cuda(), torch.transpose(p, 2, 1)], 2) # [bn, 3, pn+3] 152 | W = torch.cat([W_0, W_1], 1) # [bn, pn+3, pn+3] 153 | W_inv = self.b_inv(W) 154 | 155 | tp = F.pad(vec, (0, 0, 0, 3)) 156 | 157 | tp = tp.squeeze(1) # [bn, pn+3, 2] 158 | T = torch.matmul(W_inv, tp) # [bn, pn+3, 2] 159 | T = torch.transpose(T, 2, 1) # [bn, 2, pn+3] 160 | 161 | return T 162 | 163 | def _meshgrid(self, height, width, coord): 164 | x_t = torch.linspace(-1.0, 1.0, steps=width).reshape(1, width).expand(height, width) 165 | y_t = torch.linspace(-1.0, 1.0, steps=height).reshape(height, 1).expand(height, width) 166 | x_t_flat = x_t.reshape(1, 1, -1).cuda() 167 | y_t_flat = y_t.reshape(1, 1, -1).cuda() 168 | 169 | num_batch = coord.shape[0] 170 | px = torch.unsqueeze(coord[:, :, 0], 2) # [bn, pn, 1] 171 | py = torch.unsqueeze(coord[:, :, 1], 2) # [bn, pn, 1] 172 | 173 | d2 = torch.pow(x_t_flat - px, 2) + torch.pow(y_t_flat - py, 2) 174 | 175 | r = d2 * torch.log(d2 + 1e-6) # [bn, pn, h*w] 176 | x_t_flat_g = x_t_flat.expand(num_batch, x_t_flat.shape[1], x_t_flat.shape[2]) 177 | y_t_flat_g = y_t_flat.expand(num_batch, y_t_flat.shape[1], y_t_flat.shape[2]) 178 | 179 | grid = torch.cat((torch.ones(x_t_flat_g.shape).cuda(), x_t_flat_g, y_t_flat_g, r), 1) 180 | return grid 181 | 182 | def forward(self, input_dim, coord): 183 | T = self.solve_system(coord) 184 | input_dim = input_dim.permute(0, 2, 3, 1) 185 | num_batch, height, width, num_channels = input_dim.shape 186 | 187 | # grid of (x_t, y_t, 1), eq (1) in ref [1] 188 | out_height, out_width = self.output_image_size 189 | grid = self._meshgrid(out_height, out_width, coord) # [2, h*w] 190 | # transform A x (1, x_t, y_t, r1, r2, ..., rn) -> (x_s, y_s) 191 | # [bn, 2, pn+3] x [bn, pn+3, h*w] -> [bn, 2, h*w] 192 | T_g = torch.matmul(T, grid) 193 | x_s = torch.unsqueeze(T_g[:, 0, :], 1) 194 | y_s = torch.unsqueeze(T_g[:, 1, :], 1) 195 | x_s_flat = x_s.reshape(-1) 196 | y_s_flat = y_s.reshape(-1) 197 | 198 | input_transformed = self._interpolate(input_dim, x_s_flat, y_s_flat) 199 | 200 | output = input_transformed.reshape(num_batch, out_height, out_width, num_channels) 201 | output = output.permute(0, 3, 1, 2) 202 | return output, None 203 | 204 | def point_transform(point, T, coord): 205 | point = torch.Tensor(point.reshape([1, 1, 2])) 206 | d2 = torch.sum(torch.pow(point - coord, 2), 2) 207 | r = d2 * torch.log(d2 + 1e-6) 208 | q = torch.Tensor(np.array([[1, point[0, 0, 0], point[0, 0, 1]]])) 209 | x = torch.cat([q, r], 1) 210 | point_T = torch.matmul(T, torch.transpose(x.unsqueeze(1), 2, 1)) 211 | return point_T 212 | -------------------------------------------------------------------------------- /datagen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import itertools 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | from torch.utils.data.sampler import Sampler 9 | from PIL import Image 10 | import standard_text 11 | 12 | 13 | class TwoStreamBatchSampler(Sampler): 14 | """Iterate two sets of indices 15 | 16 | An 'epoch' is one iteration through the primary indices. 17 | During the epoch, the secondary indices are iterated through 18 | as many times as needed. 19 | """ 20 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 21 | self.primary_indices = primary_indices 22 | self.secondary_indices = secondary_indices 23 | self.secondary_batch_size = secondary_batch_size 24 | self.primary_batch_size = batch_size - secondary_batch_size 25 | 26 | assert len(self.primary_indices) >= self.primary_batch_size > 0 27 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 28 | 29 | def __iter__(self): 30 | primary_iter = iterate_once(self.primary_indices) 31 | secondary_iter = iterate_eternally(self.secondary_indices) 32 | return ( 33 | primary_batch + secondary_batch 34 | for (primary_batch, secondary_batch) 35 | in zip(grouper(primary_iter, self.primary_batch_size), grouper(secondary_iter, self.secondary_batch_size)) 36 | ) 37 | 38 | def __len__(self): 39 | return len(self.primary_indices) // self.primary_batch_size 40 | 41 | 42 | def iterate_once(iterable): 43 | return np.random.permutation(iterable) 44 | 45 | 46 | def iterate_eternally(indices): 47 | def infinite_shuffles(): 48 | while True: 49 | yield np.random.permutation(indices) 50 | return itertools.chain.from_iterable(infinite_shuffles()) 51 | 52 | 53 | def grouper(iterable, n): 54 | # Collect data into fixed-length chunks or blocks 55 | # grouper('ABCDEFG', 3) --> ABC DEF" 56 | args = [iter(iterable)] * n 57 | return zip(*args) 58 | 59 | 60 | class custom_dataset(Dataset): 61 | def __init__(self, cfg, data_dir=None, i_t_name='i_t.txt', mode='train', with_real_data=False): 62 | self.cfg = cfg 63 | self.mode = mode 64 | self.transform = transforms.Compose([ 65 | transforms.Resize(cfg.data_shape), 66 | transforms.ToTensor(), 67 | ]) 68 | self.std_text = standard_text.Std_Text(cfg.font_path) 69 | 70 | if(self.mode == 'train'): 71 | self.data_dir = cfg.data_dir 72 | if isinstance(self.data_dir, str): 73 | self.data_dir = [self.data_dir] 74 | assert isinstance(self.data_dir, list) 75 | 76 | self.name_list = [] 77 | self.i_t_list = {} 78 | for tmp_data_dir in self.data_dir: 79 | tmp_dataset_name = tmp_data_dir.rsplit('/', 1)[-1] 80 | with open(os.path.join(tmp_data_dir, i_t_name), 'r') as f: 81 | lines = f.readlines() 82 | self.name_list += [os.path.join(tmp_data_dir, '{}', line.strip().split()[0]) for line in lines] 83 | for line in lines: 84 | tmp_key, tmp_val = line.strip().split() 85 | self.i_t_list[tmp_dataset_name + '_' + tmp_key] = tmp_val 86 | 87 | self.len_synth = len(self.name_list) 88 | assert self.len_synth == len(self.i_t_list) 89 | 90 | if with_real_data: 91 | self.real_data_dir = cfg.real_data_dir 92 | if isinstance(self.real_data_dir, str): 93 | self.real_data_dir = [self.real_data_dir] 94 | assert isinstance(self.real_data_dir, list) 95 | 96 | self.real_name_list = [] 97 | self.real_i_t_list = {} 98 | for tmp_data_dir in self.real_data_dir: 99 | tmp_dataset_name = tmp_data_dir.rsplit('/', 1)[-1] 100 | with open(os.path.join(tmp_data_dir, i_t_name), 'r') as f: 101 | lines = f.readlines() 102 | self.real_name_list += [os.path.join(tmp_data_dir, '{}', line.strip().split()[0]) for line in lines] 103 | for line in lines: 104 | tmp_key, tmp_val = line.strip().split() 105 | self.real_i_t_list[tmp_dataset_name + '_' + tmp_key] = tmp_val 106 | 107 | self.len_real = len(self.real_name_list) 108 | assert self.len_real == len(self.real_i_t_list) 109 | self.name_list += self.real_name_list 110 | else: 111 | assert data_dir is not None 112 | self.data_dir = data_dir 113 | with open(os.path.join(data_dir, '../' + i_t_name), 'r') as f: 114 | lines = f.readlines() 115 | self.name_list = [line.strip().split()[0] for line in lines] 116 | self.i_t_list = {line.strip().split()[0]: line.strip().split()[1] for line in lines} 117 | 118 | def custom_len(self): 119 | return self.len_synth, self.len_real 120 | 121 | def __len__(self): 122 | return len(self.name_list) 123 | 124 | def __getitem__(self, idx): 125 | img_name = self.name_list[idx] 126 | if self.mode == 'train': 127 | if idx < self.len_synth: 128 | _, tmp_dataset_name, _, tmp_key = img_name.rsplit('/', 3) 129 | tmp_text = self.i_t_list[tmp_dataset_name + '_' + tmp_key] 130 | i_t = self.std_text.draw_text(tmp_text) 131 | i_t = Image.fromarray(np.uint8(i_t)) 132 | i_s = Image.open(img_name.format(self.cfg.i_s_dir)) 133 | if i_s.mode != 'RGB': 134 | i_s = i_s.convert('RGB') 135 | t_b = Image.open(img_name.format(self.cfg.t_b_dir)) 136 | t_f = Image.open(img_name.format(self.cfg.t_f_dir)) 137 | mask_t = Image.open(img_name.format(self.cfg.mask_t_dir)) 138 | mask_s = Image.open(img_name.format(self.cfg.mask_s_dir)) 139 | with open(img_name.format(self.cfg.txt_dir)[:-4] + '.txt', 'r') as f: 140 | lines = f.readlines() 141 | text = lines[0].strip().split()[-1] 142 | text = re.sub("[^0-9a-zA-Z]+", "", text).lower() 143 | i_t = self.transform(i_t) 144 | i_s = self.transform(i_s) 145 | t_b = self.transform(t_b) 146 | t_f = self.transform(t_f) 147 | mask_t = self.transform(mask_t) 148 | mask_s = self.transform(mask_s) 149 | else: 150 | _, tmp_dataset_name, _, tmp_key = img_name.rsplit('/', 3) 151 | tmp_text = self.real_i_t_list[tmp_dataset_name + '_' + tmp_key] 152 | i_t = self.std_text.draw_text(tmp_text) 153 | i_t = Image.fromarray(np.uint8(i_t)) 154 | i_s = Image.open(img_name.format(self.cfg.i_s_dir)) 155 | if i_s.mode != 'RGB': 156 | i_s = i_s.convert('RGB') 157 | with open(img_name.format(self.cfg.txt_dir)[:-4] + '.txt', 'r') as f: 158 | lines = f.readlines() 159 | text = lines[0].strip().split()[-1] 160 | text = re.sub("[^0-9a-zA-Z]+", "", text).lower() 161 | i_t = self.transform(i_t) 162 | i_s = self.transform(i_s) 163 | t_f = i_s 164 | t_b = -1 * torch.ones([3] + self.cfg.data_shape) 165 | mask_t = -1 * torch.ones([1] + self.cfg.data_shape) 166 | mask_s = -1 * torch.ones([1] + self.cfg.data_shape) 167 | 168 | return [i_t, i_s, t_b, t_f, mask_t, mask_s, text] 169 | else: 170 | main_name = img_name 171 | i_s = Image.open(os.path.join(self.data_dir, img_name)) 172 | if i_s.mode != 'RGB': 173 | i_s = i_s.convert('RGB') 174 | tmp_text = self.i_t_list[img_name] 175 | i_t = self.std_text.draw_text(tmp_text) 176 | i_t = Image.fromarray(np.uint8(i_t)) 177 | i_s = self.transform(i_s) 178 | i_t = self.transform(i_t) 179 | 180 | return [i_t, i_s, main_name] 181 | 182 | 183 | class erase_dataset(Dataset): 184 | def __init__(self, cfg, data_dir=None, mode='train'): 185 | self.cfg = cfg 186 | self.mode = mode 187 | self.transform = transforms.Compose([ 188 | transforms.Resize(cfg.data_shape), 189 | transforms.ToTensor() 190 | ]) 191 | if(self.mode == 'train'): 192 | self.data_dir = cfg.data_dir 193 | if isinstance(self.data_dir, str): 194 | self.data_dir = [self.data_dir] 195 | assert isinstance(self.data_dir, list) 196 | self.name_list = [] 197 | for tmp_data_dir in self.data_dir: 198 | self.name_list += [os.path.join(tmp_data_dir, '{}', filename) for filename in os.listdir(os.path.join(tmp_data_dir, cfg.i_s_dir))] 199 | else: 200 | assert data_dir is not None 201 | self.data_dir = data_dir 202 | self.name_list = os.listdir(data_dir) 203 | 204 | def __len__(self): 205 | return len(self.name_list) 206 | 207 | def __getitem__(self, idx): 208 | img_name = self.name_list[idx] 209 | if self.mode == 'train': 210 | i_s = Image.open(img_name.format(self.cfg.i_s_dir)) 211 | t_b = Image.open(img_name.format(self.cfg.t_b_dir)) 212 | mask_s = Image.open(img_name.format(self.cfg.mask_s_dir)) 213 | i_s = self.transform(i_s) 214 | t_b = self.transform(t_b) 215 | mask_s = self.transform(mask_s) 216 | 217 | return [i_s, t_b, mask_s] 218 | else: 219 | main_name = img_name 220 | i_s = Image.open(os.path.join(self.data_dir, img_name)) 221 | if i_s.mode != 'RGB': 222 | i_s = i_s.convert('RGB') 223 | i_s = self.transform(i_s) 224 | 225 | return [i_s, main_name] 226 | -------------------------------------------------------------------------------- /pytorch_fid/fid_score.py: -------------------------------------------------------------------------------- 1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 2 | 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | 7 | When run as a stand-alone program, it compares the distribution of 8 | images that are stored as PNG/JPEG at a specified location with a 9 | distribution given by summary statistics (in pickle format). 10 | 11 | The FID is calculated by assuming that X_1 and X_2 are the activations of 12 | the pool_3 layer of the inception net for generated samples and real world 13 | samples respectively. 14 | 15 | See --help to see further details. 16 | 17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 18 | of Tensorflow 19 | 20 | Copyright 2018 Institute of Bioinformatics, JKU Linz 21 | 22 | Licensed under the Apache License, Version 2.0 (the "License"); 23 | you may not use this file except in compliance with the License. 24 | You may obtain a copy of the License at 25 | 26 | http://www.apache.org/licenses/LICENSE-2.0 27 | 28 | Unless required by applicable law or agreed to in writing, software 29 | distributed under the License is distributed on an "AS IS" BASIS, 30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 31 | See the License for the specific language governing permissions and 32 | limitations under the License. 33 | """ 34 | import os 35 | import pathlib 36 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 37 | 38 | import numpy as np 39 | import torch 40 | import torchvision.transforms as TF 41 | from PIL import Image 42 | from scipy import linalg 43 | from torch.nn.functional import adaptive_avg_pool2d 44 | 45 | try: 46 | from tqdm import tqdm 47 | except ImportError: 48 | # If tqdm is not available, provide a mock version of it 49 | def tqdm(x): 50 | return x 51 | 52 | from pytorch_fid.inception import InceptionV3 53 | 54 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 55 | parser.add_argument('--batch-size', type=int, default=50, 56 | help='Batch size to use') 57 | parser.add_argument('--num-workers', type=int, 58 | help=('Number of processes to use for data loading. ' 59 | 'Defaults to `min(8, num_cpus)`')) 60 | parser.add_argument('--device', type=str, default=None, 61 | help='Device to use. Like cuda, cuda:0 or cpu') 62 | parser.add_argument('--dims', type=int, default=2048, 63 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 64 | help=('Dimensionality of Inception features to use. ' 65 | 'By default, uses pool3 features')) 66 | parser.add_argument('path', type=str, nargs=2, 67 | help=('Paths to the generated images or ' 68 | 'to .npz statistic files')) 69 | 70 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 71 | 'tif', 'tiff', 'webp'} 72 | 73 | 74 | class ImagePathDataset(torch.utils.data.Dataset): 75 | def __init__(self, files, transforms=None): 76 | self.files = files 77 | self.transforms = transforms 78 | 79 | def __len__(self): 80 | return len(self.files) 81 | 82 | def __getitem__(self, i): 83 | path = self.files[i] 84 | img = Image.open(path).convert('RGB') 85 | if self.transforms is not None: 86 | img = self.transforms(img) 87 | return img 88 | 89 | 90 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', 91 | num_workers=1): 92 | """Calculates the activations of the pool_3 layer for all images. 93 | 94 | Params: 95 | -- files : List of image files paths 96 | -- model : Instance of inception model 97 | -- batch_size : Batch size of images for the model to process at once. 98 | Make sure that the number of samples is a multiple of 99 | the batch size, otherwise some samples are ignored. This 100 | behavior is retained to match the original FID score 101 | implementation. 102 | -- dims : Dimensionality of features returned by Inception 103 | -- device : Device to run calculations 104 | -- num_workers : Number of parallel dataloader workers 105 | 106 | Returns: 107 | -- A numpy array of dimension (num images, dims) that contains the 108 | activations of the given tensor when feeding inception with the 109 | query tensor. 110 | """ 111 | model.eval() 112 | 113 | if batch_size > len(files): 114 | print(('Warning: batch size is bigger than the data size. ' 115 | 'Setting batch size to data size')) 116 | batch_size = len(files) 117 | 118 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 119 | dataloader = torch.utils.data.DataLoader(dataset, 120 | batch_size=batch_size, 121 | shuffle=False, 122 | drop_last=False, 123 | num_workers=num_workers) 124 | 125 | pred_arr = np.empty((len(files), dims)) 126 | 127 | start_idx = 0 128 | 129 | for batch in tqdm(dataloader): 130 | batch = batch.to(device) 131 | 132 | with torch.no_grad(): 133 | pred = model(batch)[0] 134 | 135 | # If model output is not scalar, apply global spatial average pooling. 136 | # This happens if you choose a dimensionality not equal 2048. 137 | if pred.size(2) != 1 or pred.size(3) != 1: 138 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 139 | 140 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 141 | 142 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 143 | 144 | start_idx = start_idx + pred.shape[0] 145 | 146 | return pred_arr 147 | 148 | 149 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 150 | """Numpy implementation of the Frechet Distance. 151 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 152 | and X_2 ~ N(mu_2, C_2) is 153 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 154 | 155 | Stable version by Dougal J. Sutherland. 156 | 157 | Params: 158 | -- mu1 : Numpy array containing the activations of a layer of the 159 | inception net (like returned by the function 'get_predictions') 160 | for generated samples. 161 | -- mu2 : The sample mean over activations, precalculated on an 162 | representative data set. 163 | -- sigma1: The covariance matrix over activations for generated samples. 164 | -- sigma2: The covariance matrix over activations, precalculated on an 165 | representative data set. 166 | 167 | Returns: 168 | -- : The Frechet Distance. 169 | """ 170 | 171 | mu1 = np.atleast_1d(mu1) 172 | mu2 = np.atleast_1d(mu2) 173 | 174 | sigma1 = np.atleast_2d(sigma1) 175 | sigma2 = np.atleast_2d(sigma2) 176 | 177 | assert mu1.shape == mu2.shape, \ 178 | 'Training and test mean vectors have different lengths' 179 | assert sigma1.shape == sigma2.shape, \ 180 | 'Training and test covariances have different dimensions' 181 | 182 | diff = mu1 - mu2 183 | 184 | # Product might be almost singular 185 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 186 | if not np.isfinite(covmean).all(): 187 | msg = ('fid calculation produces singular product; ' 188 | 'adding %s to diagonal of cov estimates') % eps 189 | print(msg) 190 | offset = np.eye(sigma1.shape[0]) * eps 191 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 192 | 193 | # Numerical error might give slight imaginary component 194 | if np.iscomplexobj(covmean): 195 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 196 | m = np.max(np.abs(covmean.imag)) 197 | raise ValueError('Imaginary component {}'.format(m)) 198 | covmean = covmean.real 199 | 200 | tr_covmean = np.trace(covmean) 201 | 202 | return (diff.dot(diff) + np.trace(sigma1) 203 | + np.trace(sigma2) - 2 * tr_covmean) 204 | 205 | 206 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 207 | device='cpu', num_workers=1): 208 | """Calculation of the statistics used by the FID. 209 | Params: 210 | -- files : List of image files paths 211 | -- model : Instance of inception model 212 | -- batch_size : The images numpy array is split into batches with 213 | batch size batch_size. A reasonable batch size 214 | depends on the hardware. 215 | -- dims : Dimensionality of features returned by Inception 216 | -- device : Device to run calculations 217 | -- num_workers : Number of parallel dataloader workers 218 | 219 | Returns: 220 | -- mu : The mean over samples of the activations of the pool_3 layer of 221 | the inception model. 222 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 223 | the inception model. 224 | """ 225 | act = get_activations(files, model, batch_size, dims, device, num_workers) 226 | mu = np.mean(act, axis=0) 227 | sigma = np.cov(act, rowvar=False) 228 | return mu, sigma 229 | 230 | 231 | def compute_statistics_of_path(path, model, batch_size, dims, device, 232 | num_workers=1): 233 | if path.endswith('.npz'): 234 | with np.load(path) as f: 235 | m, s = f['mu'][:], f['sigma'][:] 236 | else: 237 | path = pathlib.Path(path) 238 | files = sorted([file for ext in IMAGE_EXTENSIONS 239 | for file in path.glob('*.{}'.format(ext))]) 240 | m, s = calculate_activation_statistics(files, model, batch_size, 241 | dims, device, num_workers) 242 | 243 | return m, s 244 | 245 | 246 | def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1): 247 | """Calculates the FID of two paths""" 248 | for p in paths: 249 | if not os.path.exists(p): 250 | raise RuntimeError('Invalid path: %s' % p) 251 | 252 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 253 | 254 | model = InceptionV3([block_idx]).to(device) 255 | 256 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 257 | dims, device, num_workers) 258 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 259 | dims, device, num_workers) 260 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 261 | 262 | return fid_value 263 | 264 | 265 | def main(): 266 | args = parser.parse_args() 267 | 268 | if args.device is None: 269 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 270 | else: 271 | device = torch.device(args.device) 272 | 273 | if args.num_workers is None: 274 | num_avail_cpus = len(os.sched_getaffinity(0)) 275 | num_workers = min(num_avail_cpus, 8) 276 | else: 277 | num_workers = args.num_workers 278 | 279 | fid_value = calculate_fid_given_paths(args.path, 280 | args.batch_size, 281 | device, 282 | args.dims, 283 | num_workers) 284 | print('FID: ', fid_value) 285 | 286 | 287 | if __name__ == '__main__': 288 | main() 289 | -------------------------------------------------------------------------------- /pytorch_fid/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 14 | # FID_WEIGHTS_URL = 'models/pt_inception-2015-12-05-6726825d.pth' 15 | 16 | 17 | class InceptionV3(nn.Module): 18 | """Pretrained InceptionV3 network returning feature maps""" 19 | 20 | # Index of default block of inception to return, 21 | # corresponds to output of final average pooling 22 | DEFAULT_BLOCK_INDEX = 3 23 | 24 | # Maps feature dimensionality to their output blocks indices 25 | BLOCK_INDEX_BY_DIM = { 26 | 64: 0, # First max pooling features 27 | 192: 1, # Second max pooling featurs 28 | 768: 2, # Pre-aux classifier features 29 | 2048: 3 # Final average pooling features 30 | } 31 | 32 | def __init__(self, 33 | output_blocks=(DEFAULT_BLOCK_INDEX,), 34 | resize_input=True, 35 | normalize_input=True, 36 | requires_grad=False, 37 | use_fid_inception=True): 38 | """Build pretrained InceptionV3 39 | 40 | Parameters 41 | ---------- 42 | output_blocks : list of int 43 | Indices of blocks to return features of. Possible values are: 44 | - 0: corresponds to output of first max pooling 45 | - 1: corresponds to output of second max pooling 46 | - 2: corresponds to output which is fed to aux classifier 47 | - 3: corresponds to output of final average pooling 48 | resize_input : bool 49 | If true, bilinearly resizes input to width and height 299 before 50 | feeding input to model. As the network without fully connected 51 | layers is fully convolutional, it should be able to handle inputs 52 | of arbitrary size, so resizing might not be strictly needed 53 | normalize_input : bool 54 | If true, scales the input from range (0, 1) to the range the 55 | pretrained Inception network expects, namely (-1, 1) 56 | requires_grad : bool 57 | If true, parameters of the model require gradients. Possibly useful 58 | for finetuning the network 59 | use_fid_inception : bool 60 | If true, uses the pretrained Inception model used in Tensorflow's 61 | FID implementation. If false, uses the pretrained Inception model 62 | available in torchvision. The FID Inception model has different 63 | weights and a slightly different structure from torchvision's 64 | Inception model. If you want to compute FID scores, you are 65 | strongly advised to set this parameter to true to get comparable 66 | results. 67 | """ 68 | super(InceptionV3, self).__init__() 69 | 70 | self.resize_input = resize_input 71 | self.normalize_input = normalize_input 72 | self.output_blocks = sorted(output_blocks) 73 | self.last_needed_block = max(output_blocks) 74 | 75 | assert self.last_needed_block <= 3, \ 76 | 'Last possible output block index is 3' 77 | 78 | self.blocks = nn.ModuleList() 79 | 80 | if use_fid_inception: 81 | inception = fid_inception_v3() 82 | else: 83 | inception = _inception_v3(pretrained=True) 84 | 85 | # Block 0: input to maxpool1 86 | block0 = [ 87 | inception.Conv2d_1a_3x3, 88 | inception.Conv2d_2a_3x3, 89 | inception.Conv2d_2b_3x3, 90 | nn.MaxPool2d(kernel_size=3, stride=2) 91 | ] 92 | self.blocks.append(nn.Sequential(*block0)) 93 | 94 | # Block 1: maxpool1 to maxpool2 95 | if self.last_needed_block >= 1: 96 | block1 = [ 97 | inception.Conv2d_3b_1x1, 98 | inception.Conv2d_4a_3x3, 99 | nn.MaxPool2d(kernel_size=3, stride=2) 100 | ] 101 | self.blocks.append(nn.Sequential(*block1)) 102 | 103 | # Block 2: maxpool2 to aux classifier 104 | if self.last_needed_block >= 2: 105 | block2 = [ 106 | inception.Mixed_5b, 107 | inception.Mixed_5c, 108 | inception.Mixed_5d, 109 | inception.Mixed_6a, 110 | inception.Mixed_6b, 111 | inception.Mixed_6c, 112 | inception.Mixed_6d, 113 | inception.Mixed_6e, 114 | ] 115 | self.blocks.append(nn.Sequential(*block2)) 116 | 117 | # Block 3: aux classifier to final avgpool 118 | if self.last_needed_block >= 3: 119 | block3 = [ 120 | inception.Mixed_7a, 121 | inception.Mixed_7b, 122 | inception.Mixed_7c, 123 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 124 | ] 125 | self.blocks.append(nn.Sequential(*block3)) 126 | 127 | for param in self.parameters(): 128 | param.requires_grad = requires_grad 129 | 130 | def forward(self, inp): 131 | """Get Inception feature maps 132 | 133 | Parameters 134 | ---------- 135 | inp : torch.autograd.Variable 136 | Input tensor of shape Bx3xHxW. Values are expected to be in 137 | range (0, 1) 138 | 139 | Returns 140 | ------- 141 | List of torch.autograd.Variable, corresponding to the selected output 142 | block, sorted ascending by index 143 | """ 144 | outp = [] 145 | x = inp 146 | 147 | if self.resize_input: 148 | x = F.interpolate(x, 149 | size=(299, 299), 150 | mode='bilinear', 151 | align_corners=False) 152 | 153 | if self.normalize_input: 154 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 155 | 156 | for idx, block in enumerate(self.blocks): 157 | x = block(x) 158 | if idx in self.output_blocks: 159 | outp.append(x) 160 | 161 | if idx == self.last_needed_block: 162 | break 163 | 164 | return outp 165 | 166 | 167 | def _inception_v3(*args, **kwargs): 168 | """Wraps `torchvision.models.inception_v3` 169 | 170 | Skips default weight inititialization if supported by torchvision version. 171 | See https://github.com/mseitzer/pytorch-fid/issues/28. 172 | """ 173 | try: 174 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 175 | except ValueError: 176 | # Just a caution against weird version strings 177 | version = (0,) 178 | 179 | if version >= (0, 6): 180 | kwargs['init_weights'] = False 181 | 182 | return torchvision.models.inception_v3(*args, **kwargs) 183 | 184 | 185 | def fid_inception_v3(): 186 | """Build pretrained Inception model for FID computation 187 | 188 | The Inception model for FID computation uses a different set of weights 189 | and has a slightly different structure than torchvision's Inception. 190 | 191 | This method first constructs torchvision's Inception and then patches the 192 | necessary parts that are different in the FID Inception model. 193 | """ 194 | inception = _inception_v3(num_classes=1008, 195 | aux_logits=False, 196 | pretrained=False) 197 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 198 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 199 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 200 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 201 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 202 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 203 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 204 | inception.Mixed_7b = FIDInceptionE_1(1280) 205 | inception.Mixed_7c = FIDInceptionE_2(2048) 206 | 207 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 208 | # state_dict = torch.load(FID_WEIGHTS_URL) 209 | 210 | inception.load_state_dict(state_dict) 211 | return inception 212 | 213 | 214 | class FIDInceptionA(torchvision.models.inception.InceptionA): 215 | """InceptionA block patched for FID computation""" 216 | def __init__(self, in_channels, pool_features): 217 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 218 | 219 | def forward(self, x): 220 | branch1x1 = self.branch1x1(x) 221 | 222 | branch5x5 = self.branch5x5_1(x) 223 | branch5x5 = self.branch5x5_2(branch5x5) 224 | 225 | branch3x3dbl = self.branch3x3dbl_1(x) 226 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 227 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 228 | 229 | # Patch: Tensorflow's average pool does not use the padded zero's in 230 | # its average calculation 231 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 232 | count_include_pad=False) 233 | branch_pool = self.branch_pool(branch_pool) 234 | 235 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 236 | return torch.cat(outputs, 1) 237 | 238 | 239 | class FIDInceptionC(torchvision.models.inception.InceptionC): 240 | """InceptionC block patched for FID computation""" 241 | def __init__(self, in_channels, channels_7x7): 242 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 243 | 244 | def forward(self, x): 245 | branch1x1 = self.branch1x1(x) 246 | 247 | branch7x7 = self.branch7x7_1(x) 248 | branch7x7 = self.branch7x7_2(branch7x7) 249 | branch7x7 = self.branch7x7_3(branch7x7) 250 | 251 | branch7x7dbl = self.branch7x7dbl_1(x) 252 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 253 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 254 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 255 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 256 | 257 | # Patch: Tensorflow's average pool does not use the padded zero's in 258 | # its average calculation 259 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 260 | count_include_pad=False) 261 | branch_pool = self.branch_pool(branch_pool) 262 | 263 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 264 | return torch.cat(outputs, 1) 265 | 266 | 267 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 268 | """First InceptionE block patched for FID computation""" 269 | def __init__(self, in_channels): 270 | super(FIDInceptionE_1, self).__init__(in_channels) 271 | 272 | def forward(self, x): 273 | branch1x1 = self.branch1x1(x) 274 | 275 | branch3x3 = self.branch3x3_1(x) 276 | branch3x3 = [ 277 | self.branch3x3_2a(branch3x3), 278 | self.branch3x3_2b(branch3x3), 279 | ] 280 | branch3x3 = torch.cat(branch3x3, 1) 281 | 282 | branch3x3dbl = self.branch3x3dbl_1(x) 283 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 284 | branch3x3dbl = [ 285 | self.branch3x3dbl_3a(branch3x3dbl), 286 | self.branch3x3dbl_3b(branch3x3dbl), 287 | ] 288 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 289 | 290 | # Patch: Tensorflow's average pool does not use the padded zero's in 291 | # its average calculation 292 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 293 | count_include_pad=False) 294 | branch_pool = self.branch_pool(branch_pool) 295 | 296 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 297 | return torch.cat(outputs, 1) 298 | 299 | 300 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 301 | """Second InceptionE block patched for FID computation""" 302 | def __init__(self, in_channels): 303 | super(FIDInceptionE_2, self).__init__(in_channels) 304 | 305 | def forward(self, x): 306 | branch1x1 = self.branch1x1(x) 307 | 308 | branch3x3 = self.branch3x3_1(x) 309 | branch3x3 = [ 310 | self.branch3x3_2a(branch3x3), 311 | self.branch3x3_2b(branch3x3), 312 | ] 313 | branch3x3 = torch.cat(branch3x3, 1) 314 | 315 | branch3x3dbl = self.branch3x3dbl_1(x) 316 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 317 | branch3x3dbl = [ 318 | self.branch3x3dbl_3a(branch3x3dbl), 319 | self.branch3x3dbl_3b(branch3x3dbl), 320 | ] 321 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 322 | 323 | # Patch: The FID Inception model uses max pooling instead of average 324 | # pooling. This is likely an error in this specific Inception 325 | # implementation, as other Inception models use average pooling here 326 | # (which matches the description in the paper). 327 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 328 | branch_pool = self.branch_pool(branch_pool) 329 | 330 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 331 | return torch.cat(outputs, 1) 332 | -------------------------------------------------------------------------------- /model_erase.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torchvision.models import vgg19 5 | 6 | 7 | class Conv_bn_block(torch.nn.Module): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__() 10 | self._conv = torch.nn.Conv2d(*args, **kwargs) 11 | self._bn = torch.nn.BatchNorm2d(kwargs['out_channels']) 12 | 13 | def forward(self, input): 14 | return F.relu(self._bn(self._conv(input))) 15 | 16 | 17 | class Res_block(torch.nn.Module): 18 | def __init__(self, in_channels): 19 | super().__init__() 20 | self._conv1 = torch.nn.Conv2d(in_channels, in_channels//4, kernel_size=1, stride=1) 21 | self._conv2 = torch.nn.Conv2d(in_channels//4, in_channels//4, kernel_size=3, stride=1, padding=1) 22 | self._conv3 = torch.nn.Conv2d(in_channels//4, in_channels, kernel_size=1, stride=1) 23 | self._bn = torch.nn.BatchNorm2d(in_channels) 24 | 25 | def forward(self, x): 26 | xin = x 27 | x = F.relu(self._conv1(x)) 28 | x = F.relu(self._conv2(x)) 29 | x = self._conv3(x) 30 | x = torch.add(xin, x) 31 | x = F.relu(self._bn(x)) 32 | 33 | return x 34 | 35 | 36 | class encoder_net(torch.nn.Module): 37 | def __init__(self, in_channels, get_feature_map=False): 38 | super().__init__() 39 | self.cnum = 32 40 | self.get_feature_map = get_feature_map 41 | self._conv1_1 = Conv_bn_block( 42 | in_channels=in_channels, 43 | out_channels=self.cnum, 44 | kernel_size=3, 45 | stride=1, 46 | padding=1) 47 | self._conv1_2 = Conv_bn_block( 48 | in_channels=self.cnum, 49 | out_channels=self.cnum, 50 | kernel_size=3, 51 | stride=1, 52 | padding=1) 53 | 54 | # -------------------------- 55 | self._pool1 = torch.nn.Conv2d( 56 | in_channels=self.cnum, 57 | out_channels=2*self.cnum, 58 | kernel_size=3, 59 | stride=2, 60 | padding=1) 61 | self._conv2_1 = Conv_bn_block( 62 | in_channels=2*self.cnum, 63 | out_channels=2*self.cnum, 64 | kernel_size=3, 65 | stride=1, 66 | padding=1) 67 | self._conv2_2 = Conv_bn_block( 68 | in_channels=2*self.cnum, 69 | out_channels=2*self.cnum, 70 | kernel_size=3, 71 | stride=1, 72 | padding=1) 73 | 74 | # --------------------------- 75 | self._pool2 = torch.nn.Conv2d( 76 | in_channels=2*self.cnum, 77 | out_channels=4*self.cnum, 78 | kernel_size=3, 79 | stride=2, 80 | padding=1) 81 | self._conv3_1 = Conv_bn_block( 82 | in_channels=4*self.cnum, 83 | out_channels=4*self.cnum, 84 | kernel_size=3, 85 | stride=1, 86 | padding=1) 87 | 88 | self._conv3_2 = Conv_bn_block( 89 | in_channels=4*self.cnum, 90 | out_channels=4*self.cnum, 91 | kernel_size=3, 92 | stride=1, 93 | padding=1) 94 | 95 | # --------------------------- 96 | self._pool3 = torch.nn.Conv2d( 97 | in_channels=4*self.cnum, 98 | out_channels=8*self.cnum, 99 | kernel_size=3, 100 | stride=2, 101 | padding=1) 102 | self._conv4_1 = Conv_bn_block( 103 | in_channels=8*self.cnum, 104 | out_channels=8*self.cnum, 105 | kernel_size=3, 106 | stride=1, 107 | padding=1) 108 | self._conv4_2 = Conv_bn_block( 109 | in_channels=8*self.cnum, 110 | out_channels=8*self.cnum, 111 | kernel_size=3, 112 | stride=1, 113 | padding=1) 114 | 115 | def forward(self, x): 116 | x = self._conv1_1(x) 117 | x = self._conv1_2(x) 118 | x = F.relu(self._pool1(x)) 119 | x = self._conv2_1(x) 120 | x = self._conv2_2(x) 121 | f1 = x 122 | x = F.relu(self._pool2(x)) 123 | x = self._conv3_1(x) 124 | x = self._conv3_2(x) 125 | f2 = x 126 | x = F.relu(self._pool3(x)) 127 | x = self._conv4_1(x) 128 | x = self._conv4_2(x) 129 | if self.get_feature_map: 130 | return x, [f2, f1] 131 | else: 132 | return x 133 | 134 | 135 | class build_res_block(torch.nn.Module): 136 | def __init__(self, in_channels): 137 | super().__init__() 138 | self._block1 = Res_block(in_channels) 139 | self._block2 = Res_block(in_channels) 140 | self._block3 = Res_block(in_channels) 141 | self._block4 = Res_block(in_channels) 142 | 143 | def forward(self, x): 144 | x = self._block1(x) 145 | x = self._block2(x) 146 | x = self._block3(x) 147 | x = self._block4(x) 148 | return x 149 | 150 | 151 | class decoder_net(torch.nn.Module): 152 | def __init__(self, in_channels, get_feature_map=False, mt=1, fn_mt=[1, 1, 1]): 153 | super().__init__() 154 | if isinstance(fn_mt, int): 155 | fn_mt = [fn_mt for _ in range(3)] 156 | assert isinstance(fn_mt, list) and len(fn_mt) == 3 157 | 158 | self.cnum = 32 159 | self.get_feature_map = get_feature_map 160 | self._conv1_1 = Conv_bn_block(in_channels=int(fn_mt[0] * in_channels), out_channels=8*self.cnum, kernel_size=3, stride=1, padding=1) 161 | self._conv1_2 = Conv_bn_block(in_channels=8*self.cnum, out_channels=8*self.cnum, kernel_size=3, stride=1, padding=1) 162 | 163 | # ----------------- 164 | self._deconv1 = torch.nn.ConvTranspose2d(8*self.cnum, 4*self.cnum, kernel_size=3, stride=2, padding=1, output_padding=1) 165 | self._conv2_1 = Conv_bn_block(in_channels=int(fn_mt[1]*mt*4*self.cnum), out_channels=4*self.cnum, kernel_size=3, stride=1, padding=1) 166 | self._conv2_2 = Conv_bn_block(in_channels=4*self.cnum, out_channels=4*self.cnum, kernel_size=3, stride=1, padding=1) 167 | 168 | # ----------------- 169 | self._deconv2 = torch.nn.ConvTranspose2d(4*self.cnum, 2*self.cnum, kernel_size=3, stride=2, padding=1, output_padding=1) 170 | self._conv3_1 = Conv_bn_block(in_channels=int(fn_mt[2]*mt*2*self.cnum), out_channels=2*self.cnum, kernel_size=3, stride=1, padding=1) 171 | self._conv3_2 = Conv_bn_block(in_channels=2*self.cnum, out_channels=2*self.cnum, kernel_size=3, stride=1, padding=1) 172 | 173 | # ---------------- 174 | self._deconv3 = torch.nn.ConvTranspose2d(2*self.cnum, self.cnum, kernel_size=3, stride=2, padding=1, output_padding=1) 175 | self._conv4_1 = Conv_bn_block(in_channels=self.cnum, out_channels=self.cnum, kernel_size=3, stride=1, padding=1) 176 | self._conv4_2 = Conv_bn_block(in_channels=self.cnum, out_channels=self.cnum, kernel_size=3, stride=1, padding=1) 177 | 178 | def forward(self, x, fuse=None, detach_flag=False): 179 | if fuse and fuse[0] is not None: 180 | if detach_flag: 181 | x = torch.cat((x, fuse[0].detach()), dim=1) 182 | else: 183 | x = torch.cat((x, fuse[0]), dim=1) 184 | x = self._conv1_1(x) 185 | x = self._conv1_2(x) 186 | f1 = x 187 | x = F.relu(self._deconv1(x)) 188 | if fuse and fuse[1] is not None: 189 | if detach_flag: 190 | x = torch.cat((x, fuse[1].detach()), dim=1) 191 | else: 192 | x = torch.cat((x, fuse[1]), dim=1) 193 | x = self._conv2_1(x) 194 | x = self._conv2_2(x) 195 | f2 = x 196 | x = F.relu(self._deconv2(x)) 197 | if fuse and fuse[2] is not None: 198 | if detach_flag: 199 | x = torch.cat((x, fuse[2].detach()), dim=1) 200 | else: 201 | x = torch.cat((x, fuse[2]), dim=1) 202 | x = self._conv3_1(x) 203 | x = self._conv3_2(x) 204 | f3 = x 205 | x = F.relu(self._deconv3(x)) 206 | x = self._conv4_1(x) 207 | x = self._conv4_2(x) 208 | if self.get_feature_map: 209 | return x, [f1, f2, f3] 210 | else: 211 | return x 212 | 213 | 214 | class PSPModule(torch.nn.Module): 215 | def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6), norm_layer=torch.nn.BatchNorm2d): 216 | super(PSPModule, self).__init__() 217 | self.stages = [] 218 | self.stages = torch.nn.ModuleList([self._make_stage(features, out_features, size, norm_layer) for size in sizes]) 219 | self.bottleneck = torch.nn.Sequential( 220 | torch.nn.Conv2d(features+len(sizes)*out_features, out_features, kernel_size=1, padding=0, dilation=1, bias=False), 221 | norm_layer(out_features), 222 | torch.nn.ReLU(), 223 | torch.nn.Dropout2d(0.1) 224 | ) 225 | 226 | def _make_stage(self, features, out_features, size, norm_layer): 227 | prior = torch.nn.AdaptiveAvgPool2d(output_size=(size, size)) 228 | conv = torch.nn.Conv2d(features, out_features, kernel_size=1, bias=False) 229 | bn = norm_layer(out_features) 230 | return torch.nn.Sequential(prior, conv, bn) 231 | 232 | def forward(self, feats): 233 | h, w = feats.size(2), feats.size(3) 234 | priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] + [feats] 235 | bottle = self.bottleneck(torch.cat(priors, 1)) 236 | return bottle 237 | 238 | 239 | class background_reconstruction_module(torch.nn.Module): 240 | def __init__(self, in_channels): 241 | super().__init__() 242 | self.cnum = 32 243 | self._encoder = encoder_net(in_channels, get_feature_map=True) 244 | self._res = build_res_block(8*self.cnum) 245 | self._decoder = decoder_net(8*self.cnum, get_feature_map=True, mt=2) 246 | self._out = torch.nn.Conv2d(self.cnum, 3, kernel_size=3, stride=1, padding=1) 247 | self._mask_s_decoder = decoder_net(8*self.cnum) 248 | self._mask_s_out = torch.nn.Conv2d(self.cnum, 1, kernel_size=3, stride=1, padding=1) 249 | self.ppm = PSPModule(8*self.cnum, out_features=8*self.cnum) 250 | 251 | def forward(self, x): 252 | x, f_encoder = self._encoder(x) 253 | x = self._res(x) 254 | x = self.ppm(x) 255 | mask_s = self._mask_s_decoder(x, fuse=None) 256 | mask_s_out = torch.sigmoid(self._mask_s_out(mask_s)) 257 | 258 | x, fs = self._decoder(x, fuse=[None] + f_encoder) 259 | x = torch.sigmoid(self._out(x)) 260 | 261 | return x, fs, mask_s_out 262 | 263 | 264 | class Generator(torch.nn.Module): 265 | def __init__(self, cfg, in_channels): 266 | super().__init__() 267 | self.cfg = cfg 268 | self.cnum = 32 269 | self.brm = background_reconstruction_module(in_channels) 270 | 271 | def forward(self, i_s): 272 | o_b, fuse, o_mask_s = self.brm(i_s) 273 | o_b_ori = o_b 274 | o_b = o_mask_s * o_b + (1 - o_mask_s) * i_s 275 | 276 | return o_b_ori, o_b, o_mask_s 277 | 278 | 279 | 280 | class Discriminator(torch.nn.Module): 281 | def __init__(self, cfg, in_channels): 282 | super().__init__() 283 | self.cfg = cfg 284 | self.cnum = 32 285 | self._conv1 = torch.nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1) 286 | self._conv2 = torch.nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) 287 | self._conv3 = torch.nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) 288 | self._conv4 = torch.nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) 289 | self._conv5 = torch.nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1) 290 | self._conv2_bn = torch.nn.BatchNorm2d(128) 291 | self._conv3_bn = torch.nn.BatchNorm2d(256) 292 | self._conv4_bn = torch.nn.BatchNorm2d(512) 293 | self._conv5_bn = torch.nn.BatchNorm2d(1) 294 | self.init_weights() 295 | 296 | def init_weights(self): 297 | for m in self.modules(): 298 | if isinstance(m, torch.nn.Conv2d): 299 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 300 | m.weight.data.normal_(0, math.sqrt(2. / n)) 301 | if m.bias is not None: 302 | m.bias.data.zero_() 303 | elif isinstance(m, torch.nn.BatchNorm2d): 304 | m.weight.data.fill_(1) 305 | m.bias.data.zero_() 306 | elif isinstance(m, torch.nn.Linear): 307 | m.weight.data.normal_(0, 0.001) 308 | m.bias.data.zero_() 309 | 310 | def forward(self, x): 311 | x = F.relu(self._conv1(x)) 312 | x = self._conv2(x) 313 | x = F.relu(self._conv2_bn(x)) 314 | x = self._conv3(x) 315 | x = F.relu(self._conv3_bn(x)) 316 | x = self._conv4(x) 317 | x = F.relu(self._conv4_bn(x)) 318 | x = self._conv5(x) 319 | x = self._conv5_bn(x) 320 | x = torch.sigmoid(x) 321 | 322 | return x 323 | 324 | 325 | class Vgg19(torch.nn.Module): 326 | def __init__(self, vgg19_weights): 327 | super(Vgg19, self).__init__() 328 | # features = list(vgg19(pretrained = True).features) 329 | vgg = vgg19(pretrained=False) 330 | params = torch.load(vgg19_weights) 331 | vgg.load_state_dict(params) 332 | features = list(vgg.features) 333 | self.features = torch.nn.ModuleList(features).eval() 334 | 335 | def forward(self, x): 336 | results = [] 337 | for ii, model in enumerate(self.features): 338 | x = model(x) 339 | 340 | if ii in {1, 6, 11, 20, 29}: 341 | results.append(x) 342 | return results 343 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import numpy as np 5 | import cv2 6 | import torch 7 | import torchvision.transforms.functional as F 8 | from mmcv import Config 9 | from tqdm import tqdm 10 | from torch.utils.data import DataLoader 11 | from loss import build_generator_loss, build_discriminator_loss, build_generator_loss_with_real 12 | from datagen import custom_dataset, TwoStreamBatchSampler 13 | from model import Generator, Discriminator, Vgg19 14 | from rec_model import Rec_Model 15 | from rec_utils import AttnLabelConverter 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | 19 | def requires_grad(model, flag=True): 20 | for p in model.parameters(): 21 | p.requires_grad = flag 22 | 23 | 24 | def rgb2grey(img): 25 | img = 0.299 * img[:, 0, :, :] + 0.587 * img[:, 1, :, :] + 0.114 * img[:, 2, :, :] 26 | img = img.unsqueeze(1) 27 | 28 | return img 29 | 30 | 31 | def get_logger(cfg, log_filename='log.txt', log_level=logging.INFO): 32 | logger = logging.getLogger(log_filename) 33 | logger.setLevel(log_level) 34 | formatter = logging.Formatter('[%(asctime)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 35 | if not os.path.exists(cfg.checkpoint_savedir): 36 | os.makedirs(cfg.checkpoint_savedir) 37 | fh = logging.FileHandler(os.path.join(cfg.checkpoint_savedir, log_filename)) 38 | fh.setLevel(log_level) 39 | fh.setFormatter(formatter) 40 | ch = logging.StreamHandler() 41 | ch.setLevel(log_level) 42 | ch.setFormatter(formatter) 43 | logger.addHandler(ch) 44 | logger.addHandler(fh) 45 | 46 | return logger 47 | 48 | 49 | def main(): 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('--config', type=str) 52 | args = parser.parse_args() 53 | cfg = Config.fromfile(args.config) 54 | gpu_num = torch.cuda.device_count() 55 | 56 | logger = get_logger(cfg) 57 | logger.info('Config path: {}'.format(args.config)) 58 | writer = SummaryWriter(cfg.checkpoint_savedir + 'tensorboard/') 59 | 60 | train_data = custom_dataset(cfg, mode='train', with_real_data=cfg.with_real_data) 61 | if cfg.with_real_data: 62 | len_synth, len_real = train_data.custom_len() 63 | synth_idxs = list(range(len_synth)) 64 | real_idxs = list(range(len_synth, len_synth + len_real)) 65 | batch_sampler = TwoStreamBatchSampler(synth_idxs, real_idxs, cfg.batch_size, cfg.real_bs) # default: shuffle = True, drop_last = True 66 | train_loader = DataLoader( 67 | dataset=train_data, 68 | batch_sampler=batch_sampler, 69 | num_workers=cfg.num_workers, 70 | pin_memory=True) 71 | else: 72 | train_loader = DataLoader( 73 | dataset=train_data, 74 | batch_size=cfg.batch_size, 75 | shuffle=True, 76 | num_workers=cfg.num_workers, 77 | pin_memory=True, 78 | drop_last=True) 79 | eval_data = custom_dataset(cfg, data_dir=cfg.example_data_dir, mode='eval') 80 | eval_loader = DataLoader( 81 | dataset=eval_data, 82 | batch_size=1, 83 | shuffle=False) 84 | 85 | G = Generator(cfg, in_channels=3).cuda() 86 | D1 = Discriminator(cfg, in_channels=6).cuda() 87 | D2 = Discriminator(cfg, in_channels=6).cuda() 88 | vgg_features = Vgg19(cfg.vgg19_weights).cuda() 89 | if cfg.with_recognizer: 90 | converter = AttnLabelConverter('0123456789abcdefghijklmnopqrstuvwxyz') 91 | Recognizer = Rec_Model(cfg) 92 | rec_state_dict = torch.load(cfg.rec_ckpt_path, map_location='cpu') 93 | if len(rec_state_dict) == 1: 94 | rec_state_dict = rec_state_dict['recognizer'] 95 | rec_state_dict = {k.replace('module.', ''): v for k, v in rec_state_dict.items()} 96 | Recognizer.cuda() 97 | Recognizer.load_state_dict(rec_state_dict) 98 | logger.info('Recognizer module loaded: {}'.format(cfg.rec_ckpt_path)) 99 | G_solver = torch.optim.Adam(G.parameters(), lr=cfg.learning_rate, betas=(cfg.beta1, cfg.beta2)) 100 | D1_solver = torch.optim.Adam(D1.parameters(), lr=cfg.learning_rate, betas=(cfg.beta1, cfg.beta2)) 101 | D2_solver = torch.optim.Adam(D2.parameters(), lr=cfg.learning_rate, betas=(cfg.beta1, cfg.beta2)) 102 | if cfg.with_recognizer and cfg.train_recognizer: 103 | Rec_solver = torch.optim.Adam(Recognizer.parameters(), lr=cfg.rec_lr_weight * cfg.learning_rate, betas=(cfg.beta1, cfg.beta2)) 104 | 105 | if os.path.exists(cfg.ckpt_path): 106 | checkpoint = torch.load(cfg.ckpt_path, map_location='cpu') 107 | G.load_state_dict(checkpoint['generator']) 108 | D1.load_state_dict(checkpoint['discriminator1']) 109 | D2.load_state_dict(checkpoint['discriminator2']) 110 | G_solver.load_state_dict(checkpoint['g_optimizer']) 111 | D1_solver.load_state_dict(checkpoint['d1_optimizer']) 112 | D2_solver.load_state_dict(checkpoint['d2_optimizer']) 113 | logger.info('Model loaded: {}'.format(cfg.ckpt_path)) 114 | else: 115 | logger.info('Model not found') 116 | if os.path.exists(cfg.inpaint_ckpt_path): 117 | checkpoint = torch.load(cfg.inpaint_ckpt_path, map_location='cpu') 118 | G.load_state_dict(checkpoint['generator'], strict=False) 119 | logger.info('Inpainting module loaded: {}'.format(cfg.inpaint_ckpt_path)) 120 | else: 121 | logger.info('Inpainting module not found') 122 | 123 | if gpu_num > 1: 124 | logger.info('Parallel Computing. Using {} GPUs.'.format(gpu_num)) 125 | G = torch.nn.DataParallel(G, device_ids=range(gpu_num)) 126 | D1 = torch.nn.DataParallel(D1, device_ids=range(gpu_num)) 127 | D2 = torch.nn.DataParallel(D2, device_ids=range(gpu_num)) 128 | vgg_features = torch.nn.DataParallel(vgg_features, device_ids=range(gpu_num)) 129 | if cfg.with_recognizer: 130 | Recognizer = torch.nn.DataParallel(Recognizer, device_ids=range(gpu_num)) 131 | 132 | # Train discriminator 133 | requires_grad(G, False) 134 | requires_grad(D1, True) 135 | requires_grad(D2, True) 136 | 137 | trainiter = iter(train_loader) 138 | for step in tqdm(range(cfg.max_iter)): 139 | D1_solver.zero_grad() 140 | D2_solver.zero_grad() 141 | 142 | if ((step + 1) % cfg.save_ckpt_interval == 0): 143 | torch.save( 144 | { 145 | 'generator': G.module.state_dict(), 146 | 'discriminator1': D1.module.state_dict(), 147 | 'discriminator2': D2.module.state_dict(), 148 | 'g_optimizer': G_solver.state_dict(), 149 | 'd1_optimizer': D1_solver.state_dict(), 150 | 'd2_optimizer': D2_solver.state_dict(), 151 | }, 152 | cfg.checkpoint_savedir + f'train_step-{step + 1}.model', 153 | ) 154 | if cfg.with_recognizer: 155 | torch.save({'recognizer': Recognizer.module.state_dict()}, cfg.checkpoint_savedir + 'best_recognizer.model') 156 | 157 | try: 158 | i_t, i_s, t_b, t_f, mask_t, mask_s, texts = trainiter.next() 159 | except StopIteration: 160 | trainiter = iter(train_loader) 161 | i_t, i_s, t_b, t_f, mask_t, mask_s, texts = trainiter.next() 162 | i_t = i_t.cuda() 163 | i_s = i_s.cuda() 164 | t_b = t_b.cuda() 165 | t_f = t_f.cuda() 166 | mask_t = mask_t.cuda() 167 | mask_s = mask_s.cuda() 168 | 169 | if cfg.with_recognizer: 170 | texts, texts_length = converter.encode(texts, batch_max_length=34) 171 | texts = texts.cuda() 172 | rec_target = texts[:, 1:] 173 | labels = [t_b, t_f, mask_t, mask_s, rec_target] 174 | else: 175 | labels = [t_b, t_f, mask_t, mask_s] 176 | 177 | o_b_ori, o_b, o_f, x_t_tps, o_mask_s, o_mask_t = G(i_t, i_s) 178 | 179 | if cfg.with_real_data: 180 | i_db_true = torch.cat((t_b[:(cfg.batch_size - cfg.real_bs) // gpu_num], i_s[:(cfg.batch_size - cfg.real_bs) // gpu_num]), dim=1) 181 | i_db_pred = torch.cat((o_b[:(cfg.batch_size - cfg.real_bs) // gpu_num], i_s[:(cfg.batch_size - cfg.real_bs) // gpu_num]), dim=1) 182 | else: 183 | i_db_true = torch.cat((t_b, i_s), dim=1) 184 | i_db_pred = torch.cat((o_b, i_s), dim=1) 185 | o_db_true = D1(i_db_true) 186 | o_db_pred = D1(i_db_pred) 187 | i_df_true = torch.cat((t_f, i_t), dim=1) 188 | i_df_pred = torch.cat((o_f, i_t), dim=1) 189 | o_df_true = D2(i_df_true) 190 | o_df_pred = D2(i_df_pred) 191 | 192 | db_loss = build_discriminator_loss(o_db_true, o_db_pred) 193 | df_loss = build_discriminator_loss(o_df_true, o_df_pred) 194 | db_loss.backward() 195 | df_loss.backward() 196 | D1_solver.step() 197 | D2_solver.step() 198 | 199 | # Train generator 200 | requires_grad(G, True) 201 | requires_grad(D1, False) 202 | requires_grad(D2, False) 203 | 204 | G_solver.zero_grad() 205 | if cfg.with_recognizer and cfg.train_recognizer: 206 | Rec_solver.zero_grad() 207 | o_b_ori, o_b, o_f, x_t_tps, o_mask_s, o_mask_t = G(i_t, i_s) 208 | 209 | if cfg.with_real_data: 210 | i_db_pred = torch.cat((o_b[:(cfg.batch_size - cfg.real_bs) // gpu_num], i_s[:(cfg.batch_size - cfg.real_bs) // gpu_num]), dim=1) 211 | else: 212 | i_db_pred = torch.cat((o_b, i_s), dim=1) 213 | i_df_pred = torch.cat((o_f, i_t), dim=1) 214 | o_db_pred = D1(i_db_pred) 215 | o_df_pred = D2(i_df_pred) 216 | i_vgg = torch.cat((t_f, o_f), dim=0) 217 | out_vgg = vgg_features(i_vgg) 218 | if cfg.with_recognizer: 219 | if cfg.use_rgb: 220 | tmp_o_f = o_f 221 | tmp_t_f = t_f 222 | else: 223 | tmp_o_f = rgb2grey(o_f) 224 | tmp_t_f = rgb2grey(t_f) 225 | rec_preds = Recognizer(tmp_o_f, texts[:, :-1], is_train=False) 226 | out_g = [o_b, o_f, o_mask_s, o_mask_t, rec_preds] 227 | else: 228 | out_g = [o_b, o_f, o_mask_s, o_mask_t] 229 | out_d = [o_db_pred, o_df_pred] 230 | 231 | if cfg.with_real_data: 232 | g_loss, metrics = build_generator_loss_with_real(cfg, out_g, out_d, out_vgg, labels) 233 | else: 234 | g_loss, metrics = build_generator_loss(cfg, out_g, out_d, out_vgg, labels) 235 | g_loss.backward() 236 | G_solver.step() 237 | if cfg.with_recognizer and cfg.train_recognizer: 238 | Rec_solver.step() 239 | 240 | requires_grad(G, False) 241 | requires_grad(D1, True) 242 | requires_grad(D2, True) 243 | 244 | if ((step + 1) % cfg.write_log_interval == 0): 245 | loss_str = 'Iter: {}/{} | Gen:{:<10.6f} | D_bg:{:<10.6f} | D_fus:{:<10.6f} | G_lr:{} | D_lr:{}'.format( 246 | step + 1, cfg.max_iter, 247 | g_loss.item(), 248 | db_loss.item(), 249 | df_loss.item(), 250 | G_solver.param_groups[0]['lr'], 251 | D1_solver.param_groups[0]['lr']) 252 | writer.add_scalar('main/G_loss', g_loss.item(), step) 253 | writer.add_scalar('main/db_loss', db_loss.item(), step) 254 | writer.add_scalar('main/df_loss', df_loss.item(), step) 255 | 256 | logger.info(loss_str) 257 | for name, metric in metrics.items(): 258 | loss_str = ' | '.join(['{:<7}: {:<10.6f}'.format(sub_name, sub_metric) for sub_name, sub_metric in metric.items()]) 259 | for sub_name, sub_metric in metric.items(): 260 | writer.add_scalar(name + '/' + sub_name, sub_metric, step) 261 | logger.info(loss_str) 262 | 263 | if ((step + 1) % cfg.gen_example_interval == 0): 264 | savedir = os.path.join(cfg.example_result_dir, 'iter-' + str(step + 1).zfill(len(str(cfg.max_iter)))) 265 | with torch.no_grad(): 266 | for inp in eval_loader: 267 | i_t = inp[0].cuda() 268 | i_s = inp[1].cuda() 269 | name = str(inp[2][0]) 270 | name, suffix = name.split('.') 271 | 272 | G.eval() 273 | o_b_ori, o_b, o_f, x_t_tps, o_mask_s, o_mask_t = G(i_t, i_s) 274 | G.train() 275 | 276 | if not os.path.exists(savedir): 277 | os.makedirs(savedir) 278 | o_mask_s = o_mask_s.detach().squeeze(0).to('cpu').numpy().transpose(1, 2, 0) 279 | o_mask_t = o_mask_t.detach().squeeze(0).to('cpu').numpy().transpose(1, 2, 0) 280 | x_t_tps = x_t_tps.detach().squeeze(0).to('cpu').numpy().transpose(1, 2, 0) 281 | o_b_ori = o_b_ori.detach().squeeze(0).to('cpu').numpy().transpose(1, 2, 0) 282 | o_b = o_b.detach().squeeze(0).to('cpu').numpy().transpose(1, 2, 0) 283 | o_f = o_f.detach().squeeze(0).to('cpu').numpy().transpose(1, 2, 0) 284 | cv2.imwrite(os.path.join(savedir, name + '_o_f.' + suffix), o_f[:, :, ::-1] * 255) 285 | cv2.imwrite(os.path.join(savedir, name + '_o_b_ori.' + suffix), o_b_ori[:, :, ::-1] * 255) 286 | cv2.imwrite(os.path.join(savedir, name + '_o_b.' + suffix), o_b[:, :, ::-1] * 255) 287 | cv2.imwrite(os.path.join(savedir, name + '_o_mask_s.' + suffix), o_mask_s * 255) 288 | cv2.imwrite(os.path.join(savedir, name + '_o_mask_t.' + suffix), o_mask_t * 255) 289 | cv2.imwrite(os.path.join(savedir, name + '_x_t_tps.' + suffix), x_t_tps[:, :, ::-1] * 255) 290 | 291 | 292 | if __name__ == '__main__': 293 | main() 294 | -------------------------------------------------------------------------------- /rec_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import six 5 | import math 6 | import lmdb 7 | import torch 8 | 9 | from natsort import natsorted 10 | from PIL import Image 11 | import numpy as np 12 | from torch.utils.data import Dataset, ConcatDataset, Subset 13 | from torch._utils import _accumulate 14 | import torchvision.transforms as transforms 15 | 16 | 17 | class Batch_Balanced_Dataset(object): 18 | 19 | def __init__(self, opt): 20 | """ 21 | Modulate the data ratio in the batch. 22 | For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5", 23 | the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST. 24 | """ 25 | log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') 26 | dashed_line = '-' * 80 27 | print(dashed_line) 28 | log.write(dashed_line + '\n') 29 | print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') 30 | log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n') 31 | assert len(opt.select_data) == len(opt.batch_ratio) 32 | 33 | _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 34 | self.data_loader_list = [] 35 | self.dataloader_iter_list = [] 36 | batch_size_list = [] 37 | Total_batch_size = 0 38 | for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): 39 | _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) 40 | print(dashed_line) 41 | log.write(dashed_line + '\n') 42 | _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d]) 43 | total_number_dataset = len(_dataset) 44 | log.write(_dataset_log) 45 | 46 | """ 47 | The total number of data can be modified with opt.total_data_usage_ratio. 48 | ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage. 49 | See 4.2 section in our paper. 50 | """ 51 | number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio)) 52 | dataset_split = [number_dataset, total_number_dataset - number_dataset] 53 | indices = range(total_number_dataset) 54 | _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) 55 | for offset, length in zip(_accumulate(dataset_split), dataset_split)] 56 | selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n' 57 | selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}' 58 | print(selected_d_log) 59 | log.write(selected_d_log + '\n') 60 | batch_size_list.append(str(_batch_size)) 61 | Total_batch_size += _batch_size 62 | 63 | _data_loader = torch.utils.data.DataLoader( 64 | _dataset, batch_size=_batch_size, 65 | shuffle=True, 66 | num_workers=int(opt.workers), 67 | collate_fn=_AlignCollate, pin_memory=True) 68 | self.data_loader_list.append(_data_loader) 69 | self.dataloader_iter_list.append(iter(_data_loader)) 70 | 71 | Total_batch_size_log = f'{dashed_line}\n' 72 | batch_size_sum = '+'.join(batch_size_list) 73 | Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n' 74 | Total_batch_size_log += f'{dashed_line}' 75 | opt.batch_size = Total_batch_size 76 | 77 | print(Total_batch_size_log) 78 | log.write(Total_batch_size_log + '\n') 79 | log.close() 80 | 81 | def get_batch(self): 82 | balanced_batch_images = [] 83 | balanced_batch_texts = [] 84 | 85 | for i, data_loader_iter in enumerate(self.dataloader_iter_list): 86 | try: 87 | image, text = data_loader_iter.next() 88 | balanced_batch_images.append(image) 89 | balanced_batch_texts += text 90 | except StopIteration: 91 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) 92 | image, text = self.dataloader_iter_list[i].next() 93 | balanced_batch_images.append(image) 94 | balanced_batch_texts += text 95 | except ValueError: 96 | pass 97 | 98 | balanced_batch_images = torch.cat(balanced_batch_images, 0) 99 | 100 | return balanced_batch_images, balanced_batch_texts 101 | 102 | 103 | def hierarchical_dataset(root, opt, select_data='/'): 104 | """ select_data='/' contains all sub-directory of root directory """ 105 | dataset_list = [] 106 | dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}' 107 | print(dataset_log) 108 | dataset_log += '\n' 109 | for dirpath, dirnames, filenames in os.walk(root+'/'): 110 | if not dirnames: 111 | select_flag = False 112 | for selected_d in select_data: 113 | if selected_d in dirpath: 114 | select_flag = True 115 | break 116 | 117 | if select_flag: 118 | dataset = LmdbDataset(dirpath, opt) 119 | sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}' 120 | print(sub_dataset_log) 121 | dataset_log += f'{sub_dataset_log}\n' 122 | dataset_list.append(dataset) 123 | 124 | concatenated_dataset = ConcatDataset(dataset_list) 125 | 126 | return concatenated_dataset, dataset_log 127 | 128 | 129 | class LmdbDataset(Dataset): 130 | 131 | def __init__(self, root, opt): 132 | 133 | self.root = root 134 | self.opt = opt 135 | self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) 136 | if not self.env: 137 | print('cannot create lmdb from %s' % (root)) 138 | sys.exit(0) 139 | 140 | with self.env.begin(write=False) as txn: 141 | nSamples = int(txn.get('num-samples'.encode())) 142 | self.nSamples = nSamples 143 | 144 | if self.opt.data_filtering_off: 145 | # for fast check or benchmark evaluation with no filtering 146 | self.filtered_index_list = [index + 1 for index in range(self.nSamples)] 147 | else: 148 | """ Filtering part 149 | If you want to evaluate IC15-2077 & CUTE datasets which have special character labels, 150 | use --data_filtering_off and only evaluate on alphabets and digits. 151 | see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L190-L192 152 | 153 | And if you want to evaluate them with the model trained with --sensitive option, 154 | use --sensitive and --data_filtering_off, 155 | see https://github.com/clovaai/deep-text-recognition-benchmark/blob/dff844874dbe9e0ec8c5a52a7bd08c7f20afe704/test.py#L137-L144 156 | """ 157 | self.filtered_index_list = [] 158 | for index in range(self.nSamples): 159 | index += 1 # lmdb starts with 1 160 | label_key = 'label-%09d'.encode() % index 161 | label = txn.get(label_key).decode('utf-8') 162 | 163 | if len(label) > self.opt.batch_max_length: 164 | # print(f'The length of the label is longer than max_length: length 165 | # {len(label)}, {label} in dataset {self.root}') 166 | continue 167 | 168 | # By default, images containing characters which are not in opt.character are filtered. 169 | # You can add [UNK] token to `opt.character` in utils.py instead of this filtering. 170 | out_of_char = f'[^{self.opt.character}]' 171 | if re.search(out_of_char, label.lower()): 172 | continue 173 | 174 | self.filtered_index_list.append(index) 175 | 176 | self.nSamples = len(self.filtered_index_list) 177 | 178 | def __len__(self): 179 | return self.nSamples 180 | 181 | def __getitem__(self, index): 182 | assert index <= len(self), 'index range error' 183 | index = self.filtered_index_list[index] 184 | 185 | with self.env.begin(write=False) as txn: 186 | label_key = 'label-%09d'.encode() % index 187 | label = txn.get(label_key).decode('utf-8') 188 | img_key = 'image-%09d'.encode() % index 189 | imgbuf = txn.get(img_key) 190 | 191 | buf = six.BytesIO() 192 | buf.write(imgbuf) 193 | buf.seek(0) 194 | try: 195 | if self.opt.use_rgb: 196 | img = Image.open(buf).convert('RGB') # for color image 197 | else: 198 | img = Image.open(buf).convert('L') 199 | 200 | except IOError: 201 | print(f'Corrupted image for {index}') 202 | # make dummy image and dummy label for corrupted image. 203 | if self.opt.use_rgb: 204 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 205 | else: 206 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 207 | label = '[dummy_label]' 208 | 209 | if not self.opt.sensitive: 210 | label = label.lower() 211 | 212 | # We only train and evaluate on alphanumerics (or pre-defined character set in train.py) 213 | out_of_char = f'[^{self.opt.character}]' 214 | label = re.sub(out_of_char, '', label) 215 | 216 | return (img, label) 217 | 218 | 219 | class RawDataset(Dataset): 220 | 221 | def __init__(self, root, opt): 222 | self.opt = opt 223 | 224 | with open(opt.gt_file) as f: 225 | lines = f.readlines() 226 | self.gt_list = {line.strip().split()[0]: line.strip().split()[1] for line in lines} 227 | 228 | self.image_path_list = [] 229 | for dirpath, dirnames, filenames in os.walk(root): 230 | for name in filenames: 231 | if name not in self.gt_list: 232 | continue 233 | _, ext = os.path.splitext(name) 234 | ext = ext.lower() 235 | if ext == '.jpg' or ext == '.jpeg' or ext == '.png': 236 | self.image_path_list.append(os.path.join(dirpath, name)) 237 | 238 | self.image_path_list = natsorted(self.image_path_list) 239 | self.nSamples = len(self.image_path_list) 240 | 241 | 242 | def __len__(self): 243 | return self.nSamples 244 | 245 | def __getitem__(self, index): 246 | 247 | try: 248 | if self.opt.use_rgb: 249 | img = Image.open(self.image_path_list[index]).convert('RGB') # for color image 250 | else: 251 | img = Image.open(self.image_path_list[index]).convert('L') 252 | 253 | except IOError: 254 | print(f'Corrupted image for {index}') 255 | # make dummy image and dummy label for corrupted image. 256 | if self.opt.use_rgb: 257 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 258 | else: 259 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 260 | gt = self.gt_list[self.image_path_list[index].split('/')[-1]] 261 | 262 | return (img, self.image_path_list[index], gt) 263 | 264 | 265 | class ResizeNormalize(object): 266 | 267 | def __init__(self, size, interpolation=Image.BICUBIC): 268 | self.size = size 269 | self.interpolation = interpolation 270 | self.toTensor = transforms.ToTensor() 271 | 272 | def __call__(self, img): 273 | img = img.resize(self.size, self.interpolation) 274 | img = self.toTensor(img) 275 | img.sub_(0.5).div_(0.5) 276 | return img 277 | 278 | 279 | class NormalizePAD(object): 280 | 281 | def __init__(self, max_size, PAD_type='right'): 282 | self.toTensor = transforms.ToTensor() 283 | self.max_size = max_size 284 | self.max_width_half = math.floor(max_size[2] / 2) 285 | self.PAD_type = PAD_type 286 | 287 | def __call__(self, img): 288 | img = self.toTensor(img) 289 | img.sub_(0.5).div_(0.5) 290 | c, h, w = img.size() 291 | Pad_img = torch.FloatTensor(*self.max_size).fill_(0) 292 | Pad_img[:, :, :w] = img # right pad 293 | if self.max_size[2] != w: # add border Pad 294 | Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) 295 | 296 | return Pad_img 297 | 298 | 299 | class AlignCollate(object): 300 | 301 | def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False): 302 | self.imgH = imgH 303 | self.imgW = imgW 304 | self.keep_ratio_with_pad = keep_ratio_with_pad 305 | 306 | def __call__(self, batch): 307 | batch = filter(lambda x: x is not None, batch) 308 | images, labels, gt = zip(*batch) 309 | 310 | if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper 311 | resized_max_w = self.imgW 312 | input_channel = 3 if images[0].mode == 'RGB' else 1 313 | transform = NormalizePAD((input_channel, self.imgH, resized_max_w)) 314 | 315 | resized_images = [] 316 | for image in images: 317 | w, h = image.size 318 | ratio = w / float(h) 319 | if math.ceil(self.imgH * ratio) > self.imgW: 320 | resized_w = self.imgW 321 | else: 322 | resized_w = math.ceil(self.imgH * ratio) 323 | 324 | resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC) 325 | resized_images.append(transform(resized_image)) 326 | # resized_image.save('./image_test/%d_test.jpg' % w) 327 | 328 | image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0) 329 | 330 | else: 331 | transform = ResizeNormalize((self.imgW, self.imgH)) 332 | image_tensors = [transform(image) for image in images] 333 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) 334 | 335 | return image_tensors, labels, gt 336 | 337 | 338 | def tensor2im(image_tensor, imtype=np.uint8): 339 | image_numpy = image_tensor.cpu().float().numpy() 340 | if image_numpy.shape[0] == 1: 341 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 342 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 343 | return image_numpy.astype(imtype) 344 | 345 | 346 | def save_image(image_numpy, image_path): 347 | image_pil = Image.fromarray(image_numpy) 348 | image_pil.save(image_path) 349 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import torch 4 | import torchvision.transforms.functional as tf 5 | import numpy as np 6 | from torchvision.models import vgg19 7 | from PIL import Image 8 | import torch.nn.functional as F 9 | from tps_spatial_transformer import TPSSpatialTransformer 10 | 11 | 12 | class Conv_bn_block(torch.nn.Module): 13 | def __init__(self, *args, **kwargs): 14 | super().__init__() 15 | self._conv = torch.nn.Conv2d(*args, **kwargs) 16 | self._bn = torch.nn.BatchNorm2d(kwargs['out_channels']) 17 | 18 | def forward(self, input): 19 | return F.relu(self._bn(self._conv(input))) 20 | 21 | 22 | class Res_block(torch.nn.Module): 23 | def __init__(self, in_channels): 24 | super().__init__() 25 | self._conv1 = torch.nn.Conv2d(in_channels, in_channels//4, kernel_size=1, stride=1) 26 | self._conv2 = torch.nn.Conv2d(in_channels//4, in_channels//4, kernel_size=3, stride=1, padding=1) 27 | self._conv3 = torch.nn.Conv2d(in_channels//4, in_channels, kernel_size=1, stride=1) 28 | self._bn = torch.nn.BatchNorm2d(in_channels) 29 | 30 | def forward(self, x): 31 | xin = x 32 | x = F.relu(self._conv1(x)) 33 | x = F.relu(self._conv2(x)) 34 | x = self._conv3(x) 35 | x = torch.add(xin, x) 36 | x = F.relu(self._bn(x)) 37 | 38 | return x 39 | 40 | 41 | class encoder_net(torch.nn.Module): 42 | def __init__(self, in_channels, get_feature_map=False): 43 | super().__init__() 44 | self.cnum = 32 45 | self.get_feature_map = get_feature_map 46 | self._conv1_1 = Conv_bn_block( 47 | in_channels=in_channels, 48 | out_channels=self.cnum, 49 | kernel_size=3, 50 | stride=1, 51 | padding=1) 52 | self._conv1_2 = Conv_bn_block( 53 | in_channels=self.cnum, 54 | out_channels=self.cnum, 55 | kernel_size=3, 56 | stride=1, 57 | padding=1) 58 | 59 | # -------------------------- 60 | self._pool1 = torch.nn.Conv2d( 61 | in_channels=self.cnum, 62 | out_channels=2*self.cnum, 63 | kernel_size=3, 64 | stride=2, 65 | padding=1) 66 | self._conv2_1 = Conv_bn_block( 67 | in_channels=2*self.cnum, 68 | out_channels=2*self.cnum, 69 | kernel_size=3, 70 | stride=1, 71 | padding=1) 72 | self._conv2_2 = Conv_bn_block( 73 | in_channels=2*self.cnum, 74 | out_channels=2*self.cnum, 75 | kernel_size=3, 76 | stride=1, 77 | padding=1) 78 | 79 | # --------------------------- 80 | self._pool2 = torch.nn.Conv2d( 81 | in_channels=2*self.cnum, 82 | out_channels=4*self.cnum, 83 | kernel_size=3, 84 | stride=2, 85 | padding=1) 86 | self._conv3_1 = Conv_bn_block( 87 | in_channels=4*self.cnum, 88 | out_channels=4*self.cnum, 89 | kernel_size=3, 90 | stride=1, 91 | padding=1) 92 | 93 | self._conv3_2 = Conv_bn_block( 94 | in_channels=4*self.cnum, 95 | out_channels=4*self.cnum, 96 | kernel_size=3, 97 | stride=1, 98 | padding=1) 99 | 100 | # --------------------------- 101 | self._pool3 = torch.nn.Conv2d( 102 | in_channels=4*self.cnum, 103 | out_channels=8*self.cnum, 104 | kernel_size=3, 105 | stride=2, 106 | padding=1) 107 | self._conv4_1 = Conv_bn_block( 108 | in_channels=8*self.cnum, 109 | out_channels=8*self.cnum, 110 | kernel_size=3, 111 | stride=1, 112 | padding=1) 113 | self._conv4_2 = Conv_bn_block( 114 | in_channels=8*self.cnum, 115 | out_channels=8*self.cnum, 116 | kernel_size=3, 117 | stride=1, 118 | padding=1) 119 | 120 | def forward(self, x): 121 | x = self._conv1_1(x) 122 | x = self._conv1_2(x) 123 | x = F.relu(self._pool1(x)) 124 | x = self._conv2_1(x) 125 | x = self._conv2_2(x) 126 | f1 = x 127 | x = F.relu(self._pool2(x)) 128 | x = self._conv3_1(x) 129 | x = self._conv3_2(x) 130 | f2 = x 131 | x = F.relu(self._pool3(x)) 132 | x = self._conv4_1(x) 133 | x = self._conv4_2(x) 134 | if self.get_feature_map: 135 | return x, [f2, f1] 136 | else: 137 | return x 138 | 139 | 140 | class build_res_block(torch.nn.Module): 141 | def __init__(self, in_channels): 142 | super().__init__() 143 | self._block1 = Res_block(in_channels) 144 | self._block2 = Res_block(in_channels) 145 | self._block3 = Res_block(in_channels) 146 | self._block4 = Res_block(in_channels) 147 | 148 | def forward(self, x): 149 | x = self._block1(x) 150 | x = self._block2(x) 151 | x = self._block3(x) 152 | x = self._block4(x) 153 | return x 154 | 155 | 156 | class decoder_net(torch.nn.Module): 157 | def __init__(self, in_channels, get_feature_map=False, mt=1, fn_mt=[1, 1, 1]): 158 | super().__init__() 159 | if isinstance(fn_mt, int): 160 | fn_mt = [fn_mt for _ in range(3)] 161 | assert isinstance(fn_mt, list) and len(fn_mt) == 3 162 | 163 | self.cnum = 32 164 | self.get_feature_map = get_feature_map 165 | self._conv1_1 = Conv_bn_block(in_channels=int(fn_mt[0] * in_channels), out_channels=8*self.cnum, kernel_size=3, stride=1, padding=1) 166 | self._conv1_2 = Conv_bn_block(in_channels=8*self.cnum, out_channels=8*self.cnum, kernel_size=3, stride=1, padding=1) 167 | 168 | # ----------------- 169 | self._deconv1 = torch.nn.ConvTranspose2d(8*self.cnum, 4*self.cnum, kernel_size=3, stride=2, padding=1, output_padding=1) 170 | self._conv2_1 = Conv_bn_block(in_channels=int(fn_mt[1]*mt*4*self.cnum), out_channels=4*self.cnum, kernel_size=3, stride=1, padding=1) 171 | self._conv2_2 = Conv_bn_block(in_channels=4*self.cnum, out_channels=4*self.cnum, kernel_size=3, stride=1, padding=1) 172 | 173 | # ----------------- 174 | self._deconv2 = torch.nn.ConvTranspose2d(4*self.cnum, 2*self.cnum, kernel_size=3, stride=2, padding=1, output_padding=1) 175 | self._conv3_1 = Conv_bn_block(in_channels=int(fn_mt[2]*mt*2*self.cnum), out_channels=2*self.cnum, kernel_size=3, stride=1, padding=1) 176 | self._conv3_2 = Conv_bn_block(in_channels=2*self.cnum, out_channels=2*self.cnum, kernel_size=3, stride=1, padding=1) 177 | 178 | # ---------------- 179 | self._deconv3 = torch.nn.ConvTranspose2d(2*self.cnum, self.cnum, kernel_size=3, stride=2, padding=1, output_padding=1) 180 | self._conv4_1 = Conv_bn_block(in_channels=self.cnum, out_channels=self.cnum, kernel_size=3, stride=1, padding=1) 181 | self._conv4_2 = Conv_bn_block(in_channels=self.cnum, out_channels=self.cnum, kernel_size=3, stride=1, padding=1) 182 | 183 | def forward(self, x, fuse=None, detach_flag=False): 184 | if fuse and fuse[0] is not None: 185 | if detach_flag: 186 | x = torch.cat((x, fuse[0].detach()), dim=1) 187 | else: 188 | x = torch.cat((x, fuse[0]), dim=1) 189 | x = self._conv1_1(x) 190 | x = self._conv1_2(x) 191 | f1 = x 192 | x = F.relu(self._deconv1(x)) 193 | if fuse and fuse[1] is not None: 194 | if detach_flag: 195 | x = torch.cat((x, fuse[1].detach()), dim=1) 196 | else: 197 | x = torch.cat((x, fuse[1]), dim=1) 198 | x = self._conv2_1(x) 199 | x = self._conv2_2(x) 200 | f2 = x 201 | x = F.relu(self._deconv2(x)) 202 | if fuse and fuse[2] is not None: 203 | if detach_flag: 204 | x = torch.cat((x, fuse[2].detach()), dim=1) 205 | else: 206 | x = torch.cat((x, fuse[2]), dim=1) 207 | x = self._conv3_1(x) 208 | x = self._conv3_2(x) 209 | f3 = x 210 | x = F.relu(self._deconv3(x)) 211 | x = self._conv4_1(x) 212 | x = self._conv4_2(x) 213 | if self.get_feature_map: 214 | return x, [f1, f2, f3] 215 | else: 216 | return x 217 | 218 | 219 | class PSPModule(torch.nn.Module): 220 | def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6), norm_layer=torch.nn.BatchNorm2d): 221 | super(PSPModule, self).__init__() 222 | self.stages = [] 223 | self.stages = torch.nn.ModuleList([self._make_stage(features, out_features, size, norm_layer) for size in sizes]) 224 | self.bottleneck = torch.nn.Sequential( 225 | torch.nn.Conv2d(features+len(sizes)*out_features, out_features, kernel_size=1, padding=0, dilation=1, bias=False), 226 | norm_layer(out_features), 227 | torch.nn.ReLU(), 228 | torch.nn.Dropout2d(0.1) 229 | ) 230 | 231 | def _make_stage(self, features, out_features, size, norm_layer): 232 | prior = torch.nn.AdaptiveAvgPool2d(output_size=(size, size)) 233 | conv = torch.nn.Conv2d(features, out_features, kernel_size=1, bias=False) 234 | bn = norm_layer(out_features) 235 | return torch.nn.Sequential(prior, conv, bn) 236 | 237 | def forward(self, feats): 238 | h, w = feats.size(2), feats.size(3) 239 | priors = [F.interpolate(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] + [feats] 240 | bottle = self.bottleneck(torch.cat(priors, 1)) 241 | return bottle 242 | 243 | 244 | class text_modification_module(torch.nn.Module): 245 | def __init__(self, cfg, in_channels, num_ctrlpoints, margins, stn_activation=None): 246 | super().__init__() 247 | self.cfg = cfg 248 | self.num_ctrlpoints = num_ctrlpoints 249 | self.margins = margins 250 | self.stn_activation = stn_activation 251 | self.cnum = 32 252 | self._t_encoder = encoder_net(in_channels) 253 | self._t_res = build_res_block(8*self.cnum) 254 | self._s_encoder = encoder_net(in_channels) 255 | self._s_res = build_res_block(8*self.cnum) 256 | self._mask_decoder = decoder_net(16*self.cnum, fn_mt=[1.5, 2, 2]) 257 | self._mask_out = torch.nn.Conv2d(self.cnum, 1, kernel_size=3, stride=1, padding=1) 258 | self._t_decoder = decoder_net(16*self.cnum, fn_mt=[1.5, 2, 2]) 259 | self._t_cbr = Conv_bn_block(in_channels=2*self.cnum, out_channels=2*self.cnum, kernel_size=3, stride=1, padding=1) 260 | self._t_out = torch.nn.Conv2d(2*self.cnum, 3, kernel_size=3, stride=1, padding=1) 261 | self.ppm = PSPModule(16*self.cnum, out_features=16*self.cnum) 262 | 263 | if cfg.TPS_ON: 264 | self.stn_fc1 = torch.nn.Sequential( 265 | torch.nn.Linear(8*32*256, 512), 266 | torch.nn.BatchNorm1d(512), 267 | torch.nn.ReLU(inplace=True)) 268 | self.stn_fc2 = torch.nn.Linear(512, num_ctrlpoints * 2) 269 | self.tps = TPSSpatialTransformer(output_image_size=cfg.tps_outputsize, num_control_points=num_ctrlpoints, margins=cfg.tps_margins) 270 | self.init_weights(self.stn_fc1) 271 | self.init_stn(self.stn_fc2, margins) 272 | 273 | def init_weights(self, module): 274 | for m in module.modules(): 275 | if isinstance(m, torch.nn.Conv2d): 276 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 277 | m.weight.data.normal_(0, math.sqrt(2. / n)) 278 | if m.bias is not None: 279 | m.bias.data.zero_() 280 | elif isinstance(m, torch.nn.BatchNorm2d): 281 | m.weight.data.fill_(1) 282 | m.bias.data.zero_() 283 | elif isinstance(m, torch.nn.Linear): 284 | m.weight.data.normal_(0, 0.001) 285 | m.bias.data.zero_() 286 | 287 | def init_stn(self, stn_fc2, margins=(0.01, 0.01)): 288 | margin = margins[0] 289 | sampling_num_per_side = int(self.num_ctrlpoints / 2) 290 | ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side) 291 | ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin 292 | ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin) 293 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 294 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 295 | ctrl_points = np.concatenate( 296 | [ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32) 297 | if self.stn_activation is None: 298 | pass 299 | elif self.stn_activation == 'sigmoid': 300 | ctrl_points = -np.log(1. / ctrl_points - 1.) 301 | elif self.stn_activation == 'tanh': 302 | ctrl_points = ctrl_points * 2 - 1 303 | ctrl_points = np.log((1 + ctrl_points) / (1 - ctrl_points)) / 2 304 | stn_fc2.weight.data.zero_() 305 | stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1) 306 | 307 | def forward(self, x_t, x_s, fuse): 308 | x_s = self._s_encoder(x_s) 309 | x_s = self._s_res(x_s) 310 | x_t_tps = x_t 311 | if self.cfg.TPS_ON: 312 | batch_size, _, h, w = x_s.size() 313 | ctrl_points = x_s.reshape(batch_size, -1) 314 | ctrl_points = self.stn_fc1(ctrl_points) 315 | ctrl_points = self.stn_fc2(0.1 * ctrl_points) 316 | if self.stn_activation == 'sigmoid': 317 | ctrl_points = F.sigmoid(ctrl_points) 318 | elif self.stn_activation == 'tanh': 319 | ctrl_points = torch.tanh(ctrl_points) 320 | 321 | ctrl_points = ctrl_points.view(-1, self.num_ctrlpoints, 2) 322 | x_t, _ = self.tps(x_t, ctrl_points) 323 | x_t_tps = x_t 324 | x_t = self._t_encoder(x_t) 325 | x_t = self._t_res(x_t) 326 | x = torch.cat((x_t, x_s), dim=1) 327 | x = self.ppm(x) 328 | 329 | mask_t = self._mask_decoder(x, fuse=fuse, detach_flag=True) 330 | mask_t_out = torch.sigmoid(self._mask_out(mask_t)) 331 | 332 | o_f = self._t_decoder(x, fuse=fuse, detach_flag=True) 333 | o_f = torch.cat((o_f, mask_t), dim=1) 334 | o_f = self._t_cbr(o_f) 335 | o_f_out = torch.sigmoid(self._t_out(o_f)) 336 | 337 | return mask_t_out, o_f_out, x_t_tps 338 | 339 | 340 | class background_reconstruction_module(torch.nn.Module): 341 | def __init__(self, in_channels): 342 | super().__init__() 343 | self.cnum = 32 344 | self._encoder = encoder_net(in_channels, get_feature_map=True) 345 | self._res = build_res_block(8*self.cnum) 346 | self._decoder = decoder_net(8*self.cnum, get_feature_map=True, mt=2) 347 | self._out = torch.nn.Conv2d(self.cnum, 3, kernel_size=3, stride=1, padding=1) 348 | self._mask_s_decoder = decoder_net(8*self.cnum) 349 | self._mask_s_out = torch.nn.Conv2d(self.cnum, 1, kernel_size=3, stride=1, padding=1) 350 | self.ppm = PSPModule(8*self.cnum, out_features=8*self.cnum) 351 | 352 | def forward(self, x): 353 | x, f_encoder = self._encoder(x) 354 | x = self._res(x) 355 | x = self.ppm(x) 356 | mask_s = self._mask_s_decoder(x, fuse=None) 357 | mask_s_out = torch.sigmoid(self._mask_s_out(mask_s)) 358 | 359 | x, fs = self._decoder(x, fuse=[None] + f_encoder) 360 | x = torch.sigmoid(self._out(x)) 361 | 362 | return x, fs, mask_s_out 363 | 364 | 365 | def random_transform(cfg, i_s): 366 | i_s_aug = i_s 367 | vflip_rate = cfg.vflip_rate 368 | hflip_rate = cfg.hflip_rate 369 | angle_range = cfg.angle_range 370 | if random.random() < hflip_rate: 371 | i_s_aug = tf.hflip(i_s_aug) 372 | if random.random() < vflip_rate: 373 | i_s_aug = tf.vflip(i_s_aug) 374 | if len(angle_range) > 0: 375 | angle = random.randint(*random.choice(angle_range)) 376 | i_s_aug = tf.rotate(i_s_aug, angle=angle, resample=Image.BILINEAR, expand=False) 377 | i_s_aug[:cfg.batch_size - cfg.real_bs] = i_s[:cfg.batch_size - cfg.real_bs] 378 | 379 | return i_s_aug 380 | 381 | 382 | class Generator(torch.nn.Module): 383 | def __init__(self, cfg, in_channels): 384 | super().__init__() 385 | self.cfg = cfg 386 | self.cnum = 32 387 | self.tmm = text_modification_module(cfg, in_channels, cfg.num_control_points, cfg.tps_margins, cfg.stn_activation) 388 | self.brm = background_reconstruction_module(in_channels) 389 | 390 | def forward(self, i_t, i_s): 391 | o_b, fuse, o_mask_s = self.brm(i_s) 392 | o_b_ori = o_b 393 | o_b = o_mask_s * o_b + (1 - o_mask_s) * i_s 394 | i_s_new = i_s * o_mask_s.detach() 395 | if self.training: 396 | i_s_new = random_transform(self.cfg, i_s_new) 397 | o_mask_t, o_f, x_t_tps = self.tmm(i_t, i_s_new, fuse=fuse) 398 | 399 | return o_b_ori, o_b, o_f, x_t_tps, o_mask_s, o_mask_t 400 | 401 | 402 | 403 | class Discriminator(torch.nn.Module): 404 | def __init__(self, cfg, in_channels): 405 | super().__init__() 406 | self.cfg = cfg 407 | self.cnum = 32 408 | self._conv1 = torch.nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1) 409 | self._conv2 = torch.nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) 410 | self._conv3 = torch.nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) 411 | self._conv4 = torch.nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) 412 | self._conv5 = torch.nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1) 413 | self._conv2_bn = torch.nn.BatchNorm2d(128) 414 | self._conv3_bn = torch.nn.BatchNorm2d(256) 415 | self._conv4_bn = torch.nn.BatchNorm2d(512) 416 | self._conv5_bn = torch.nn.BatchNorm2d(1) 417 | self.init_weights() 418 | 419 | def init_weights(self): 420 | for m in self.modules(): 421 | if isinstance(m, torch.nn.Conv2d): 422 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 423 | m.weight.data.normal_(0, math.sqrt(2. / n)) 424 | if m.bias is not None: 425 | m.bias.data.zero_() 426 | elif isinstance(m, torch.nn.BatchNorm2d): 427 | m.weight.data.fill_(1) 428 | m.bias.data.zero_() 429 | elif isinstance(m, torch.nn.Linear): 430 | m.weight.data.normal_(0, 0.001) 431 | m.bias.data.zero_() 432 | 433 | def forward(self, x): 434 | x = F.relu(self._conv1(x)) 435 | x = self._conv2(x) 436 | x = F.relu(self._conv2_bn(x)) 437 | x = self._conv3(x) 438 | x = F.relu(self._conv3_bn(x)) 439 | x = self._conv4(x) 440 | x = F.relu(self._conv4_bn(x)) 441 | x = self._conv5(x) 442 | x = self._conv5_bn(x) 443 | x = torch.sigmoid(x) 444 | 445 | return x 446 | 447 | 448 | class Vgg19(torch.nn.Module): 449 | def __init__(self, vgg19_weights): 450 | super(Vgg19, self).__init__() 451 | # features = list(vgg19(pretrained = True).features) 452 | vgg = vgg19(pretrained=False) 453 | params = torch.load(vgg19_weights) 454 | vgg.load_state_dict(params) 455 | features = list(vgg.features) 456 | self.features = torch.nn.ModuleList(features).eval() 457 | 458 | def forward(self, x): 459 | results = [] 460 | for ii, model in enumerate(self.features): 461 | x = model(x) 462 | 463 | if ii in {1, 6, 11, 20, 29}: 464 | results.append(x) 465 | return results 466 | -------------------------------------------------------------------------------- /rec_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | import time 7 | 8 | ## feature extraction 9 | 10 | class ResNet_FeatureExtractor(nn.Module): 11 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ 12 | 13 | def __init__(self, input_channel, output_channel=512): 14 | super(ResNet_FeatureExtractor, self).__init__() 15 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) 16 | 17 | def forward(self, input): 18 | return self.ConvNet(input) 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = self._conv3x3(inplanes, planes) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.conv2 = self._conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | def _conv3x3(self, in_planes, out_planes, stride=1): 34 | "3x3 convolution with padding" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 36 | padding=1, bias=False) 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | class ResNet(nn.Module): 56 | 57 | def __init__(self, input_channel, output_channel, block, layers): 58 | super(ResNet, self).__init__() 59 | 60 | self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] 61 | 62 | self.inplanes = int(output_channel / 8) 63 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), 64 | kernel_size=3, stride=1, padding=1, bias=False) 65 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 66 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, 67 | kernel_size=3, stride=1, padding=1, bias=False) 68 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 69 | self.relu = nn.ReLU(inplace=True) 70 | 71 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 72 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) 73 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ 74 | 0], kernel_size=3, stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 76 | 77 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 78 | self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) 79 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ 80 | 1], kernel_size=3, stride=1, padding=1, bias=False) 81 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 82 | 83 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 84 | self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) 85 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ 86 | 2], kernel_size=3, stride=1, padding=1, bias=False) 87 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 88 | 89 | self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) 90 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 91 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) 92 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 93 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 94 | 3], kernel_size=2, stride=1, padding=0, bias=False) 95 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 96 | 97 | def _make_layer(self, block, planes, blocks, stride=1): 98 | downsample = None 99 | if stride != 1 or self.inplanes != planes * block.expansion: 100 | downsample = nn.Sequential( 101 | nn.Conv2d(self.inplanes, planes * block.expansion, 102 | kernel_size=1, stride=stride, bias=False), 103 | nn.BatchNorm2d(planes * block.expansion), 104 | ) 105 | 106 | layers = [] 107 | layers.append(block(self.inplanes, planes, stride, downsample)) 108 | self.inplanes = planes * block.expansion 109 | for i in range(1, blocks): 110 | layers.append(block(self.inplanes, planes)) 111 | 112 | return nn.Sequential(*layers) 113 | 114 | def forward(self, x): 115 | x = self.conv0_1(x) 116 | x = self.bn0_1(x) 117 | x = self.relu(x) 118 | x = self.conv0_2(x) 119 | x = self.bn0_2(x) 120 | x = self.relu(x) 121 | 122 | x2 = self.maxpool1(x) 123 | x2 = self.layer1(x2) 124 | x2 = self.conv1(x2) 125 | x2 = self.bn1(x2) 126 | x2 = self.relu(x2) 127 | 128 | x3 = self.maxpool2(x2) 129 | x3 = self.layer2(x3) 130 | x3 = self.conv2(x3) 131 | x3 = self.bn2(x3) 132 | x3 = self.relu(x3) 133 | 134 | x4 = self.maxpool3(x3) 135 | x4 = self.layer3(x4) 136 | x4 = self.conv3(x4) 137 | x4 = self.bn3(x4) 138 | x4 = self.relu(x4) 139 | 140 | x5 = self.layer4(x4) 141 | x5 = self.conv4_1(x5) 142 | x5 = self.bn4_1(x5) 143 | x5 = self.relu(x5) 144 | x5 = self.conv4_2(x5) 145 | x5 = self.bn4_2(x5) 146 | x5 = self.relu(x5) 147 | 148 | return x5 149 | 150 | 151 | ## prediction 152 | 153 | class Attention(nn.Module): 154 | 155 | def __init__(self, input_size, hidden_size, num_classes): 156 | super(Attention, self).__init__() 157 | self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) 158 | self.hidden_size = hidden_size 159 | self.num_classes = num_classes 160 | self.generator = nn.Linear(hidden_size, num_classes) 161 | 162 | def _char_to_onehot(self, input_char, onehot_dim=38): 163 | input_char = input_char.unsqueeze(1) 164 | batch_size = input_char.size(0) 165 | # one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device) 166 | one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().cuda() 167 | one_hot = one_hot.scatter_(1, input_char, 1) 168 | return one_hot 169 | 170 | def forward(self, batch_H, text, is_train=True, batch_max_length=25): 171 | """ 172 | input: 173 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels] 174 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. 175 | output: probability distribution at each step [batch_size x num_steps x num_classes] 176 | """ 177 | batch_size = batch_H.size(0) 178 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. 179 | 180 | # output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device) 181 | # hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), 182 | # torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device)) 183 | output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).cuda() 184 | hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).cuda(), 185 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).cuda()) 186 | 187 | if is_train: 188 | for i in range(num_steps): 189 | # one-hot vectors for a i-th char. in a batch 190 | char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes) 191 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) 192 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 193 | output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) 194 | probs = self.generator(output_hiddens) 195 | 196 | else: 197 | # targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token 198 | # probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device) 199 | targets = torch.LongTensor(batch_size).fill_(0).cuda() # [GO] token 200 | probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).cuda() 201 | 202 | for i in range(num_steps): 203 | char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) 204 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 205 | probs_step = self.generator(hidden[0]) 206 | probs[:, i, :] = probs_step 207 | _, next_input = probs_step.max(1) 208 | targets = next_input 209 | 210 | return probs # batch_size x num_steps x num_classes 211 | 212 | class AttentionCell(nn.Module): 213 | 214 | def __init__(self, input_size, hidden_size, num_embeddings): 215 | super(AttentionCell, self).__init__() 216 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 217 | self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias 218 | self.score = nn.Linear(hidden_size, 1, bias=False) 219 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 220 | self.hidden_size = hidden_size 221 | 222 | def forward(self, prev_hidden, batch_H, char_onehots): 223 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 224 | batch_H_proj = self.i2h(batch_H) 225 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 226 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 227 | 228 | alpha = F.softmax(e, dim=1) 229 | context = torch.bmm(alpha.permute(0, 2, 1).contiguous(), batch_H).squeeze(1) # batch_size x num_channel 230 | concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) 231 | cur_hidden = self.rnn(concat_context, prev_hidden) 232 | return cur_hidden, alpha 233 | 234 | 235 | ## sequence modeling 236 | 237 | class BidirectionalLSTM(nn.Module): 238 | 239 | def __init__(self, input_size, hidden_size, output_size): 240 | super(BidirectionalLSTM, self).__init__() 241 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 242 | self.linear = nn.Linear(hidden_size * 2, output_size) 243 | # self.linear1 = nn.Linear(input_size, output_size) 244 | 245 | def forward(self, input): 246 | """ 247 | input : visual feature [batch_size x T x input_size] 248 | output : contextual feature [batch_size x T x output_size] 249 | """ 250 | self.rnn.flatten_parameters() 251 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 252 | output = self.linear(recurrent) # batch_size x T x output_size 253 | # output = self.linear1(input) 254 | return output 255 | 256 | 257 | ## transformation 258 | 259 | class TPS_SpatialTransformerNetwork(nn.Module): 260 | """ Rectification Network of RARE, namely TPS based STN """ 261 | 262 | def __init__(self, F, I_size, I_r_size, I_channel_num=1): 263 | """ Based on RARE TPS 264 | input: 265 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 266 | I_size : (height, width) of the input image I 267 | I_r_size : (height, width) of the rectified image I_r 268 | I_channel_num : the number of channels of the input image I 269 | output: 270 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 271 | """ 272 | super(TPS_SpatialTransformerNetwork, self).__init__() 273 | self.F = F 274 | self.I_size = I_size 275 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 276 | self.I_channel_num = I_channel_num 277 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 278 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 279 | 280 | def forward(self, batch_I): 281 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 282 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 283 | build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) 284 | 285 | if torch.__version__ > "1.2.0": 286 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) 287 | else: 288 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') 289 | 290 | return batch_I_r 291 | 292 | class LocalizationNetwork(nn.Module): 293 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ 294 | 295 | def __init__(self, F, I_channel_num): 296 | super(LocalizationNetwork, self).__init__() 297 | self.F = F 298 | self.I_channel_num = I_channel_num 299 | self.conv = nn.Sequential( 300 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, 301 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True), 302 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 303 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), 304 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 305 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), 306 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 307 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), 308 | nn.AdaptiveAvgPool2d(1) # batch_size x 512 309 | ) 310 | 311 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 312 | self.localization_fc2 = nn.Linear(256, self.F * 2) 313 | 314 | # Init fc2 in LocalizationNetwork 315 | self.localization_fc2.weight.data.fill_(0) 316 | """ see RARE paper Fig. 6 (a) """ 317 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 318 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 319 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 320 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 321 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 322 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 323 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) 324 | 325 | def forward(self, batch_I): 326 | """ 327 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 328 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 329 | """ 330 | batch_size = batch_I.size(0) 331 | features = self.conv(batch_I).view(batch_size, -1) 332 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) 333 | return batch_C_prime 334 | 335 | class GridGenerator(nn.Module): 336 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """ 337 | 338 | def __init__(self, F, I_r_size): 339 | """ Generate P_hat and inv_delta_C for later """ 340 | super(GridGenerator, self).__init__() 341 | self.eps = 1e-6 342 | self.I_r_height, self.I_r_width = I_r_size 343 | self.F = F 344 | self.C = self._build_C(self.F) # F x 2 345 | self.P = self._build_P(self.I_r_width, self.I_r_height) 346 | ## for multi-gpu, you need register buffer 347 | self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 348 | self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 349 | ## for fine-tuning with different image width, you may use below instead of self.register_buffer 350 | #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3 351 | #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3 352 | 353 | def _build_C(self, F): 354 | """ Return coordinates of fiducial points in I_r; C """ 355 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 356 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 357 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 358 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 359 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 360 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 361 | return C # F x 2 362 | 363 | def _build_inv_delta_C(self, F, C): 364 | """ Return inv_delta_C which is needed to calculate T """ 365 | hat_C = np.zeros((F, F), dtype=float) # F x F 366 | for i in range(0, F): 367 | for j in range(i, F): 368 | r = np.linalg.norm(C[i] - C[j]) 369 | hat_C[i, j] = r 370 | hat_C[j, i] = r 371 | np.fill_diagonal(hat_C, 1) 372 | hat_C = (hat_C ** 2) * np.log(hat_C) 373 | # print(C.shape, hat_C.shape) 374 | delta_C = np.concatenate( # F+3 x F+3 375 | [ 376 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 377 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 378 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 379 | ], 380 | axis=0 381 | ) 382 | inv_delta_C = np.linalg.inv(delta_C) 383 | return inv_delta_C # F+3 x F+3 384 | 385 | def _build_P(self, I_r_width, I_r_height): 386 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width 387 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height 388 | P = np.stack( # self.I_r_width x self.I_r_height x 2 389 | np.meshgrid(I_r_grid_x, I_r_grid_y), 390 | axis=2 391 | ) 392 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 393 | 394 | def _build_P_hat(self, F, C, P): 395 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 396 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 397 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 398 | P_diff = P_tile - C_tile # n x F x 2 399 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 400 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 401 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 402 | return P_hat # n x F+3 403 | 404 | def build_P_prime(self, batch_C_prime): 405 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """ 406 | batch_size = batch_C_prime.size(0) 407 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 408 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 409 | # batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( 410 | # batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2 411 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( 412 | batch_size, 3, 2).float().cuda()), dim=1) # batch_size x F+3 x 2 413 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 414 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 415 | return batch_P_prime # batch_size x n x 2 416 | 417 | --------------------------------------------------------------------------------