├── LICENSE ├── Misc ├── overview.png └── show │ ├── SDL2.dll │ ├── SDL2_image.dll │ ├── instructions.md │ ├── large+.png │ ├── large-.png │ ├── large.exe │ ├── libbrotlicommon.dll │ ├── libbrotlidec.dll │ ├── libbrotlienc.dll │ ├── libgcc_s_seh-1.dll │ ├── libjpeg-8.dll │ ├── libjxl.dll │ ├── liblzma-5.dll │ ├── libpng16-16.dll │ ├── libstdc++-6.dll │ ├── libtiff-5.dll │ ├── libwebp-7.dll │ ├── libwinpthread-1.dll │ ├── libzstd.dll │ ├── single+.png │ ├── single-.png │ ├── single.exe │ └── zlib1.dll ├── README.md ├── STR_modules ├── feature_extraction.py ├── model.py ├── prediction.py ├── sequence_modeling.py └── transformation.py ├── STR_test.py ├── STR_train.py ├── baselines ├── Impact.ttf ├── __init__.py ├── fawa.py ├── test_black_models.py ├── transfer.py ├── transfer_attacker.py └── wm_attacker.py ├── comtest ├── test_ali.py ├── test_azure.py ├── test_baidu.py └── test_huawei.py ├── data └── protego │ └── up │ └── 5.png ├── dataset.py ├── environment.yaml ├── gen_udp.py ├── main.py ├── models ├── GAN_models.py ├── enhancement_layers │ ├── __init__.py │ ├── combined.py │ ├── identity.py │ └── transform.py └── enhancer.py ├── protego.py ├── test_udp.py ├── train_protego.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 YanruHe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Misc/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/overview.png -------------------------------------------------------------------------------- /Misc/show/SDL2.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/SDL2.dll -------------------------------------------------------------------------------- /Misc/show/SDL2_image.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/SDL2_image.dll -------------------------------------------------------------------------------- /Misc/show/instructions.md: -------------------------------------------------------------------------------- 1 | We provide two executable presentation scripts(single and large) that need to run on the windows system. 2 | 3 | For a single text image, please run the following command: 4 | ``` 5 | .\single.exe .\single-.png .\single+.png 6 | ``` 7 | 8 | For large-scale text image, please run the following command: 9 | ``` 10 | .\large.exe .\large-.png .\large+.png 11 | ``` -------------------------------------------------------------------------------- /Misc/show/large+.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/large+.png -------------------------------------------------------------------------------- /Misc/show/large-.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/large-.png -------------------------------------------------------------------------------- /Misc/show/large.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/large.exe -------------------------------------------------------------------------------- /Misc/show/libbrotlicommon.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libbrotlicommon.dll -------------------------------------------------------------------------------- /Misc/show/libbrotlidec.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libbrotlidec.dll -------------------------------------------------------------------------------- /Misc/show/libbrotlienc.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libbrotlienc.dll -------------------------------------------------------------------------------- /Misc/show/libgcc_s_seh-1.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libgcc_s_seh-1.dll -------------------------------------------------------------------------------- /Misc/show/libjpeg-8.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libjpeg-8.dll -------------------------------------------------------------------------------- /Misc/show/libjxl.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libjxl.dll -------------------------------------------------------------------------------- /Misc/show/liblzma-5.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/liblzma-5.dll -------------------------------------------------------------------------------- /Misc/show/libpng16-16.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libpng16-16.dll -------------------------------------------------------------------------------- /Misc/show/libstdc++-6.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libstdc++-6.dll -------------------------------------------------------------------------------- /Misc/show/libtiff-5.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libtiff-5.dll -------------------------------------------------------------------------------- /Misc/show/libwebp-7.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libwebp-7.dll -------------------------------------------------------------------------------- /Misc/show/libwinpthread-1.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libwinpthread-1.dll -------------------------------------------------------------------------------- /Misc/show/libzstd.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libzstd.dll -------------------------------------------------------------------------------- /Misc/show/single+.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/single+.png -------------------------------------------------------------------------------- /Misc/show/single-.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/single-.png -------------------------------------------------------------------------------- /Misc/show/single.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/single.exe -------------------------------------------------------------------------------- /Misc/show/zlib1.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/zlib1.dll -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProTegO (ACM MM 2023) 2 | This repository provides the official PyTorch implementation of the conference paper "ProTegO: Protect Text Content against OCR Extraction Attack". 3 | 4 | ## Introduction 5 |

6 | The overview of ProTegO. It consists of four main parts: preprocessing of underpainting, adversarial underpainting 7 | generation, robustness enhancement, and visual compensation. The whole pipeline can be trained end-to-end. 8 | 9 | ## Environment Setup 10 | This code is tested with Python3.8, Pytorch = 1.10 and CUDA = 11.3, requiring the following dependencies: 11 | 12 | * opencv-python = 4.6.0.66 13 | * nltk = 3.8.1 14 | * trdg = 1.8.0 15 | 16 | To setup a conda environment, please use the following instructions: 17 | ``` 18 | conda env create -f environment.yaml 19 | conda activate protego 20 | ``` 21 | 22 | 23 | ## Preparation 24 | 25 | ### Dataset 26 | For training the 5 STR models, we use dataset [MJSynth (MJ)](https://www.robots.ox.ac.uk/~vgg/data/text/), but we do not use lmdb format, please refer to ``` dataset.py ``` for details. 27 | 28 | The overall directory structure should be: 29 | ``` 30 | │data/ 31 | ├── STR/ 32 | │ │ │──train/ 33 | │ │ │   ├── 1/ 34 | │ │ │ │   ├── 1/ 35 | | | | | ├──3_JERKIER_41423.png 36 | | | | | ├──....... 37 | │ │ │──test/ 38 | │ │ │   ├── 2698/ 39 | │ │ │   │   ├── 1/ 40 | │ | | | ├──108_mahavira_46071.png 41 | │ | | | ├──....... 42 | │ │ │──val/ 43 | │ │ │   ├── 2425/ 44 | │ │ │   │   ├── 1/ 45 | │ | | | ├──1_Hey_36013.png 46 | │ | | | ├──....... 47 | ``` 48 | 49 | For our method ProTegO, we use [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator) to genrate text image samples. We also provide off-the-peg dataset [here](https://drive.google.com/drive/folders/1TdkwhSM-CMvxG0vcZk6ImAVR-_u9mZU4?usp=drive_link). 50 | 51 | The overall directory structure should be: 52 | ``` 53 | │data/ 54 | ├── protego/ 55 | │ │ │──train/ 56 | │ │ │   ├── arial 57 | │ │ │   │ ├── 0_01234_0.png 58 | │ │ │   │ ├── ....... 59 | │ │ │   ├── georgia 60 | │ │ │   │ ├── 0_01234_0.png 61 | │ │ │   │ ├── ....... 62 | │ │ │   ├── ...... 63 | │ │ │──test/ 64 | │ │ │   ├──0_unobtainable_0.png 65 | │ │ │   ├──....... 66 | │ │ │──up/ 67 | │ │ │   ├──5.png 68 | │ │ │   ├──....... 69 | ``` 70 | 71 | ### STR Models 72 | You can download the pretrained models used in our paper directly and put them at ```STR_modules/downloads_models/```. Their checkpoints can be downloaded [here](https://drive.google.com/drive/folders/1ciaBucPd1u0qDTjJe3DuxLvMIMf9R1zu?usp=drive_link). 73 | 74 | 75 | Also, you can train and test your own model by [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark), please run the following command: 76 | 77 | ``` 78 | # take STARNet for example: 79 | python STR_train.py --exp_name STARNet --train_mode --train_data data/STR/train --valid_data data/STR/val \ 80 | --batch_size 128 --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction CTC 81 | 82 | python STR_test.py --output /res-STR/test --eval_data data/STR/test --name STARNet --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction CTC \ 83 | --saved_model ./STR_modules/saved_models/STARNet-TPS-ResNet-BiLSTM-CTC-sensitive/best_accuracy.pth 84 | ``` 85 | 86 | In addition, we take [DBnet++](https://arxiv.org/pdf/2202.10304.pdf) as our guided network, and you can [download](https://drive.google.com/file/d/1gmVd5hForfSAqV97PvwcxCzw9H7jaJMU/view?usp=drive_link) and put it at ```models/```. 87 | 88 | 89 | 90 | ## Usage 91 | * First, given a preset underpainting, you need to follow the instructions in Section 3.3 of our paper, then adopt the [APAC](https://github.com/Myndex/apca-w3) algorithm to obtain the suitable underpainting. We also provide some pre-processed underpainting styles, and you can [download](https://drive.google.com/drive/folders/1oOclGkU-9yQBNVAM7wEmT_X6Vxu9Imlf?usp=drive_link) and put them at ```data/protego/up/```. 92 | 93 | * Next, train and test ProTegO in white-box setting, you can directly run the following commands: 94 | ``` 95 | # if train with underpainting 5a (style 5 and white font): 96 | python main.py --train_path data/protego/train --test_path data/protego/test --use_eh --use_guide --dark 97 | 98 | # if train with underpainting 5b (style 5 and black font): 99 | python main.py --train_path data/protego/train --test_path data/protego/test --use_eh --use_guide 100 | ``` 101 | 102 | * For black-box models evaluation, please run the following command: 103 | ``` 104 | # take model CRNN as an example: 105 | python test_udp.py --output res-XX --str_model STR_modules/downloads_models CRNN-None-VGG-BiLSTM-CTC-sensitive.pth --STR_name CRNN --Transformation None --FeatureExtraction VGG --SequenceModeling BiLSTM --Prediction CTC 106 | ``` 107 | 108 | * For commercial OCR services evaluation, we take [Baidu OCR API](https://cloud.baidu.com/product/ocr.html) as an example: 109 | ``` 110 | python ./comtest/test_baidu.py 111 | ``` 112 | 113 | Other commercial OCR API evaluation can be found in ```comtest/```, and more usage details are available on their official websites, such as [Microsoft Azure](https://learn.microsoft.com/zh-cn/azure/ai-services/computer-vision/overview-ocr), [Tencent Alibaba](https://ai.aliyun.com/ocr?spm=a2c4g.252763.0.0.32d53d80xtz0ZX), [Huawei](https://support.huaweicloud.com/ocr/index.html). 114 | 115 | For the built-in commercial text recognition tool provided by smartphones or application (like WeChat), you can use our provided script in ```Misc/show/``` to alternate display the two frames of the protected text images at the current fresh rate 116 | of the monitor. Then you can take random manual screenshots and test them with the OCR tool online. 117 | 118 | * For baseline methods evaluation, we provide [FAWA]() and two general transfer-based methods [SINIFGSM, VMIFGSM]() for test, please run the following command: 119 | ``` 120 | # For FAWA: 121 | python baselines/fawa.py 122 | 123 | # For SINIFGSM or VMIFGSM: 124 | python baselines/transfer.py --name SINIFGSM 125 | python baselines/transfer.py --name VMIFGSM 126 | 127 | # For blackbox-models test, take model CRNN and method FAWA as an example: 128 | python baselines/test_black_models.py --attack_name fawa --adv_img xx/res-baselines/fawa-eps40/STARNet/wmadv \ 129 | --str_model STR_modules/downloads_models/CRNN-None-VGG-BiLSTM-CTC-sensitive.pth \ 130 | --Transformation None --FeatureExtraction VGG --SequenceModeling BiLSTM --Prediction CTC 131 | ``` 132 | 133 | 134 | Since the codes of these two methods [What machines see is not what they get (CVPR'20)](https://openaccess.thecvf.com/content_CVPR_2020/papers/Xu_What_Machines_See_Is_Not_What_They_Get_Fooling_Scene_CVPR_2020_paper.pdf) and [The Best Protection is Attack (TIFS'23)](https://ieeexplore.ieee.org/abstract/document/10045728) are not open source, please contact their authors. 135 | 136 | 137 | 138 | ## Acknowledgement 139 | This repo is partially based on [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark), [FAWA](https://github.com/strongman1995/Fast-Adversarial-Watermark-Attack-on-OCR), [What machines see is not what they get](https://openaccess.thecvf.com/content_CVPR_2020/papers/Xu_What_Machines_See_Is_Not_What_They_Get_Fooling_Scene_CVPR_2020_paper.pdf), [The Best Protection is Attack](https://ieeexplore.ieee.org/abstract/document/10045728), [SINIFGSM](https://arxiv.org/pdf/1908.06281.pdf) and [VMIFGSM](https://openaccess.thecvf.com/content/CVPR2021/papers/Wang_Enhancing_the_Transferability_of_Adversarial_Attacks_Through_Variance_Tuning_CVPR_2021_paper.pdf). Thanks for their impressive works! 140 | 141 | ## Citation 142 | If you find this work useful for your research, please cite [our paper](https://dl.acm.org/doi/abs/10.1145/3581783.3612076): 143 | ``` 144 | @inproceedings{he2023protego, 145 | title={ProTegO: Protect Text Content against OCR Extraction Attack}, 146 | author={He, Yanru and Chen, Kejiang and Chen, Guoqiang and Ma, Zehua and Zhang, Kui and Zhang, Jie and Bian, Huanyu and Fang, Han and Zhang, Weiming and Yu, Nenghai}, 147 | booktitle={Proceedings of the 31st ACM International Conference on Multimedia}, 148 | pages={7424--7434}, 149 | year={2023} 150 | } 151 | ``` 152 | 153 | ## License 154 | The code is released under MIT License (see LICENSE file for details). 155 | 156 | 157 | -------------------------------------------------------------------------------- /STR_modules/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class VGG_FeatureExtractor(nn.Module): 6 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ 7 | 8 | def __init__(self, input_channel, output_channel=512): 9 | super(VGG_FeatureExtractor, self).__init__() 10 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 11 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 12 | self.ConvNet = nn.Sequential( 13 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 14 | nn.MaxPool2d(2, 2), # 64x16x50 15 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), 16 | nn.MaxPool2d(2, 2), # 128x8x25 17 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25 18 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), 19 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 20 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), 21 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 22 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), 23 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), 24 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 25 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 26 | 27 | def forward(self, input): 28 | return self.ConvNet(input) 29 | 30 | 31 | class ResNet_FeatureExtractor(nn.Module): 32 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ 33 | 34 | def __init__(self, input_channel, output_channel=512): 35 | super(ResNet_FeatureExtractor, self).__init__() 36 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) 37 | 38 | def forward(self, input): 39 | return self.ConvNet(input) 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, inplanes, planes, stride=1, downsample=None): 45 | super(BasicBlock, self).__init__() 46 | self.conv1 = self._conv3x3(inplanes, planes) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = self._conv3x3(planes, planes) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def _conv3x3(self, in_planes, out_planes, stride=1): 55 | "3x3 convolution with padding" 56 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 57 | padding=1, bias=False) 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | 69 | if self.downsample is not None: 70 | residual = self.downsample(x) 71 | out += residual 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | 77 | class ResNet(nn.Module): 78 | 79 | def __init__(self, input_channel, output_channel, block, layers): 80 | super(ResNet, self).__init__() 81 | 82 | self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] 83 | 84 | self.inplanes = int(output_channel / 8) 85 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), 86 | kernel_size=3, stride=1, padding=1, bias=False) 87 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 88 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, 89 | kernel_size=3, stride=1, padding=1, bias=False) 90 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 91 | self.relu = nn.ReLU(inplace=True) 92 | 93 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 94 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) 95 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ 96 | 0], kernel_size=3, stride=1, padding=1, bias=False) 97 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 98 | 99 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 100 | self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) 101 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ 102 | 1], kernel_size=3, stride=1, padding=1, bias=False) 103 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 104 | 105 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 106 | self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) 107 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ 108 | 2], kernel_size=3, stride=1, padding=1, bias=False) 109 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 110 | 111 | self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) 112 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 113 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) 114 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 115 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 116 | 3], kernel_size=2, stride=1, padding=0, bias=False) 117 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv0_1(x) 138 | x = self.bn0_1(x) 139 | x = self.relu(x) 140 | x = self.conv0_2(x) 141 | x = self.bn0_2(x) 142 | x = self.relu(x) 143 | 144 | x = self.maxpool1(x) 145 | x = self.layer1(x) 146 | x = self.conv1(x) 147 | x = self.bn1(x) 148 | x = self.relu(x) 149 | 150 | x = self.maxpool2(x) 151 | x = self.layer2(x) 152 | x = self.conv2(x) 153 | x = self.bn2(x) 154 | x = self.relu(x) 155 | 156 | x = self.maxpool3(x) 157 | x = self.layer3(x) 158 | x = self.conv3(x) 159 | x = self.bn3(x) 160 | x = self.relu(x) 161 | 162 | x = self.layer4(x) 163 | x = self.conv4_1(x) 164 | x = self.bn4_1(x) 165 | x = self.relu(x) 166 | x = self.conv4_2(x) 167 | x = self.bn4_2(x) 168 | x = self.relu(x) 169 | 170 | return x 171 | -------------------------------------------------------------------------------- /STR_modules/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from STR_modules.transformation import TPS_SpatialTransformerNetwork 4 | from STR_modules.feature_extraction import VGG_FeatureExtractor, ResNet_FeatureExtractor 5 | from STR_modules.sequence_modeling import BidirectionalLSTM 6 | from STR_modules.prediction import Attention 7 | 8 | 9 | class Model(nn.Module): 10 | 11 | def __init__(self, opt): 12 | super(Model, self).__init__() 13 | self.opt = opt 14 | self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 15 | 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} 16 | 17 | """ Transformation : output is rectified image [batch_size x I_channel_num x I_r_height x I_r_width] """ 18 | if opt.Transformation == 'TPS': 19 | self.Transformation = TPS_SpatialTransformerNetwork( 20 | F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) 21 | else: 22 | print('No Transformation module specified') 23 | 24 | 25 | """ FeatureExtraction """ 26 | if opt.FeatureExtraction == 'VGG': 27 | self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) # 512x1x24 28 | elif opt.FeatureExtraction == 'ResNet': 29 | self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) 30 | else: 31 | raise Exception('No FeatureExtraction module specified') 32 | self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 33 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 34 | 35 | """ Sequence modeling""" 36 | if opt.SequenceModeling == 'BiLSTM': 37 | self.SequenceModeling = nn.Sequential( 38 | BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), 39 | BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) # hidden_size = 256, output is [batch_size x T x output_size] 40 | self.SequenceModeling_output = opt.hidden_size 41 | else: 42 | print('No SequenceModeling module specified') 43 | self.SequenceModeling_output = self.FeatureExtraction_output 44 | 45 | """ Prediction """ 46 | if opt.Prediction == 'CTC': 47 | self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) 48 | elif opt.Prediction == 'Attn': 49 | self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) # batch_size x num_steps x num_classes 50 | else: 51 | raise Exception('Prediction is neither CTC or Attn') 52 | 53 | def forward(self, input, text, is_train=True): 54 | """ Transformation stage """ 55 | if not self.stages['Trans'] == "None": 56 | input = self.Transformation(input) 57 | 58 | """ Feature extraction stage """ 59 | visual_feature = self.FeatureExtraction(input) # bx512x1x24 60 | visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]=bx24x512x1 61 | visual_feature = visual_feature.squeeze(3) # [b, w, c]=bx24x512 62 | 63 | """ Sequence modeling stage """ 64 | if self.stages['Seq'] == 'BiLSTM': 65 | contextual_feature = self.SequenceModeling(visual_feature) # [b, w, hidden]=bx24x256 66 | else: 67 | contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM 68 | 69 | """ Prediction stage """ 70 | if self.stages['Pred'] == 'CTC': 71 | prediction = self.Prediction(contextual_feature.contiguous()) #[b, w, class]=bx24x63 72 | else: 73 | prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length) 74 | 75 | return prediction 76 | -------------------------------------------------------------------------------- /STR_modules/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | 7 | class Attention(nn.Module): 8 | 9 | def __init__(self, input_size, hidden_size, num_classes): 10 | super(Attention, self).__init__() 11 | self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) 12 | self.hidden_size = hidden_size 13 | self.num_classes = num_classes 14 | self.generator = nn.Linear(hidden_size, num_classes) 15 | 16 | def _char_to_onehot(self, input_char, onehot_dim=38): 17 | input_char = input_char.unsqueeze(1) 18 | batch_size = input_char.size(0) 19 | one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device) 20 | one_hot = one_hot.scatter_(1, input_char, 1) 21 | return one_hot 22 | 23 | def forward(self, batch_H, text, is_train=True, batch_max_length=25): 24 | """ 25 | input: 26 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels] 27 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. 28 | output: probability distribution at each step [batch_size x num_steps x num_classes] 29 | """ 30 | batch_size = batch_H.size(0) 31 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. 32 | 33 | output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device) 34 | hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), 35 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device)) 36 | 37 | if is_train: 38 | for i in range(num_steps): 39 | # one-hot vectors for a i-th char. in a batch 40 | char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes) 41 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) 42 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 43 | output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) 44 | probs = self.generator(output_hiddens) 45 | 46 | else: 47 | targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token 48 | probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device) 49 | 50 | for i in range(num_steps): 51 | char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) 52 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 53 | probs_step = self.generator(hidden[0]) 54 | probs[:, i, :] = probs_step 55 | _, next_input = probs_step.max(1) 56 | targets = next_input 57 | 58 | return probs # batch_size x num_steps x num_classes 59 | 60 | 61 | class AttentionCell(nn.Module): 62 | 63 | def __init__(self, input_size, hidden_size, num_embeddings): 64 | super(AttentionCell, self).__init__() 65 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 66 | self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias 67 | self.score = nn.Linear(hidden_size, 1, bias=False) 68 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 69 | self.hidden_size = hidden_size 70 | 71 | def forward(self, prev_hidden, batch_H, char_onehots): 72 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 73 | batch_H_proj = self.i2h(batch_H) 74 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 75 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 76 | 77 | alpha = F.softmax(e, dim=1) 78 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel 79 | concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) 80 | cur_hidden = self.rnn(concat_context, prev_hidden) 81 | return cur_hidden, alpha 82 | 83 | 84 | # utils for label convert 85 | class CTCLabelConverter(object): 86 | """ Convert between text-label and text-index """ 87 | 88 | def __init__(self, character): 89 | # character (str): set of the possible characters. 90 | dict_character = list(character) 91 | 92 | self.dict = {} 93 | for i, char in enumerate(dict_character): 94 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss 95 | self.dict[char] = i + 1 96 | 97 | self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) 98 | 99 | def encode(self, text, batch_max_length=25): 100 | """convert text-label into text-index. 101 | input: 102 | text: text labels of each image. [batch_size] 103 | batch_max_length: max length of text label in the batch. 25 by default 104 | 105 | output: 106 | text: text index for CTCLoss. [batch_size, batch_max_length] 107 | length: length of each text. [batch_size] 108 | """ 109 | length = [len(s) for s in text] 110 | 111 | # The index used for padding (=0) would not affect the CTC loss calculation. 112 | batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) 113 | for i, t in enumerate(text): 114 | text = list(t) 115 | text = [self.dict[char] for char in text] 116 | batch_text[i][:len(text)] = torch.LongTensor(text) 117 | return (batch_text.to(device), torch.IntTensor(length).to(device)) 118 | 119 | def decode(self, text_index, length): 120 | """ convert text-index into text-label. """ 121 | texts = [] 122 | for index, l in enumerate(length): 123 | t = text_index[index, :] 124 | 125 | char_list = [] 126 | for i in range(l): 127 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. 128 | char_list.append(self.character[t[i]]) 129 | text = ''.join(char_list) 130 | 131 | texts.append(text) 132 | return texts 133 | 134 | class AttnLabelConverter(object): 135 | """ Convert between text-label and text-index """ 136 | 137 | def __init__(self, character): 138 | # character (str): set of the possible characters. 139 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. 140 | list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] 141 | list_character = list(character) 142 | self.character = list_token + list_character 143 | 144 | self.dict = {} 145 | for i, char in enumerate(self.character): 146 | # print(i, char) 147 | self.dict[char] = i 148 | 149 | def encode(self, text, batch_max_length=25): 150 | """ convert text-label into text-index. 151 | input: 152 | text: text labels of each image. [batch_size] 153 | batch_max_length: max length of text label in the batch. 25 by default 154 | 155 | output: 156 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. 157 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. 158 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] 159 | """ 160 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. 161 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting 162 | batch_max_length += 1 163 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. 164 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) 165 | for i, t in enumerate(text): 166 | text = list(t) 167 | text.append('[s]') 168 | text = [self.dict[char] for char in text] 169 | batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token 170 | return (batch_text.to(device), torch.IntTensor(length).to(device)) 171 | 172 | def decode(self, text_index, length): 173 | """ convert text-index into text-label. """ 174 | texts = [] 175 | for index, l in enumerate(length): 176 | text = ''.join([self.character[i] for i in text_index[index, :]]) 177 | texts.append(text) 178 | return texts -------------------------------------------------------------------------------- /STR_modules/sequence_modeling.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | 6 | def __init__(self, input_size, hidden_size, output_size): 7 | super(BidirectionalLSTM, self).__init__() 8 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 9 | self.linear = nn.Linear(hidden_size * 2, output_size) 10 | 11 | def forward(self, input): 12 | """ 13 | input : visual feature [batch_size x T x input_size] 14 | output : contextual feature [batch_size x T x output_size] 15 | """ 16 | self.rnn.flatten_parameters() 17 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 18 | output = self.linear(recurrent) # batch_size x T x output_size 19 | return output 20 | -------------------------------------------------------------------------------- /STR_modules/transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | 9 | class TPS_SpatialTransformerNetwork(nn.Module): 10 | """ Rectification Network of RARE, namely TPS(thin-plate spline) based STN """ 11 | 12 | def __init__(self, F, I_size, I_r_size, I_channel_num=1): 13 | """ Based on RARE TPS 14 | input: 15 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 16 | num_fiducial: number of fiducial points, 20 as default 17 | I_size : (height, width) of the input image I 18 | I_r_size : (height, width) of the rectified image I_r 19 | I_channel_num : the number of channels of the input image I 20 | output: 21 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 22 | """ 23 | super(TPS_SpatialTransformerNetwork, self).__init__() 24 | self.F = F 25 | self.I_size = I_size 26 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 27 | self.I_channel_num = I_channel_num 28 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 29 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 30 | 31 | def forward(self, batch_I): 32 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 33 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 34 | build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) 35 | 36 | if torch.__version__ > "1.2.0": 37 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) 38 | else: 39 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') 40 | 41 | return batch_I_r 42 | 43 | 44 | class LocalizationNetwork(nn.Module): 45 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ 46 | 47 | def __init__(self, F, I_channel_num): 48 | super(LocalizationNetwork, self).__init__() 49 | self.F = F 50 | self.I_channel_num = I_channel_num 51 | self.conv = nn.Sequential( 52 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, 53 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True), 54 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 55 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), 56 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 57 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), 58 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 59 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), 60 | nn.AdaptiveAvgPool2d(1) # batch_size x 512 61 | ) 62 | 63 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 64 | self.localization_fc2 = nn.Linear(256, self.F * 2) 65 | 66 | # Init fc2 in LocalizationNetwork 67 | self.localization_fc2.weight.data.fill_(0) 68 | """ see RARE paper Fig. 6 (a) """ 69 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 70 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 71 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 72 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 73 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 74 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 75 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) 76 | 77 | def forward(self, batch_I): 78 | """ 79 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 80 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 81 | """ 82 | batch_size = batch_I.size(0) 83 | features = self.conv(batch_I).view(batch_size, -1) 84 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) 85 | return batch_C_prime 86 | 87 | 88 | class GridGenerator(nn.Module): 89 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """ 90 | 91 | def __init__(self, F, I_r_size): 92 | """ Generate P_hat and inv_delta_C for later """ 93 | super(GridGenerator, self).__init__() 94 | self.eps = 1e-6 95 | self.I_r_height, self.I_r_width = I_r_size 96 | self.F = F 97 | self.C = self._build_C(self.F) # F x 2 98 | self.P = self._build_P(self.I_r_width, self.I_r_height) 99 | ## for multi-gpu, you need register buffer 100 | self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 101 | self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 102 | ## for fine-tuning with different image width, you may use below instead of self.register_buffer 103 | #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3 104 | #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3 105 | 106 | def _build_C(self, F): 107 | """ Return coordinates of fiducial points in I_r; C """ 108 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 109 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 110 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 111 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 112 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 113 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 114 | return C # F x 2 115 | 116 | def _build_inv_delta_C(self, F, C): 117 | """ Return inv_delta_C which is needed to calculate T """ 118 | hat_C = np.zeros((F, F), dtype=float) # F x F 119 | for i in range(0, F): 120 | for j in range(i, F): 121 | r = np.linalg.norm(C[i] - C[j]) 122 | hat_C[i, j] = r 123 | hat_C[j, i] = r 124 | np.fill_diagonal(hat_C, 1) 125 | hat_C = (hat_C ** 2) * np.log(hat_C) 126 | # print(C.shape, hat_C.shape) 127 | delta_C = np.concatenate( # F+3 x F+3 128 | [ 129 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 130 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 131 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 132 | ], 133 | axis=0 134 | ) 135 | inv_delta_C = np.linalg.inv(delta_C) 136 | return inv_delta_C # F+3 x F+3 137 | 138 | def _build_P(self, I_r_width, I_r_height): 139 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width 140 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height 141 | P = np.stack( # self.I_r_width x self.I_r_height x 2 142 | np.meshgrid(I_r_grid_x, I_r_grid_y), 143 | axis=2 144 | ) 145 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 146 | 147 | def _build_P_hat(self, F, C, P): 148 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 149 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 150 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 151 | P_diff = P_tile - C_tile # n x F x 2 152 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 153 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 154 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 155 | return P_hat # n x F+3 156 | 157 | def build_P_prime(self, batch_C_prime): 158 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """ 159 | batch_size = batch_C_prime.size(0) 160 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 161 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 162 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( 163 | batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2 164 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 165 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 166 | return batch_P_prime # batch_size x n x 2 167 | -------------------------------------------------------------------------------- /STR_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import string 5 | import argparse 6 | 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.utils.data 10 | import torch.nn.functional as F 11 | 12 | from STR_modules.prediction import CTCLabelConverter, AttnLabelConverter 13 | from STR_modules.model import Model 14 | from dataset import strdataset, train_dataset_builder 15 | from utils import Averager, Logger 16 | from torchvision import utils as vutils 17 | 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | 20 | 21 | def makedirs(path): 22 | if not os.path.exists(path): 23 | try: 24 | os.makedirs(path) 25 | except Exception as e: 26 | print('cannot create dirs: {}'.format(path)) 27 | exit(0) 28 | 29 | def validation(model, criterion, evaluation_loader, converter, opt): 30 | """ validation or evaluation """ 31 | n_correct = 0 32 | 33 | infer_time = 0 34 | valid_loss_avg = Averager() 35 | 36 | for i, data in enumerate(evaluation_loader): 37 | image_tensors, labels = data 38 | image = image_tensors.to(device) 39 | 40 | # For max length prediction 41 | length_for_pred = torch.IntTensor([opt.batch_max_length] * opt.batch_size).to(device) 42 | text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device) 43 | 44 | text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length) 45 | 46 | start_time = time.time() 47 | if 'CTC' in opt.Prediction: 48 | preds = model(image, text_for_pred) 49 | forward_time = time.time() - start_time 50 | 51 | # Calculate evaluation loss for CTC deocder. 52 | preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) 53 | # permute 'preds' to use CTCloss format 54 | cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) 55 | 56 | # Select max probabilty (greedy decoding) then decode index to character 57 | _, preds_index = preds.max(2) 58 | preds_str = converter.decode(preds_index.data, preds_size.data) 59 | 60 | else: 61 | preds = model(image, text_for_pred, is_train=False) 62 | forward_time = time.time() - start_time 63 | 64 | preds = preds[:, :text_for_loss.shape[1] - 1, :] 65 | target = text_for_loss[:, 1:] # without [GO] Symbol 66 | cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1)) 67 | 68 | # select max probabilty (greedy decoding) then decode index to character 69 | _, preds_index = preds.max(2) 70 | preds_str = converter.decode(preds_index, length_for_pred) 71 | labels = converter.decode(text_for_loss[:, 1:], length_for_loss) 72 | 73 | infer_time += forward_time 74 | valid_loss_avg.add(cost) 75 | 76 | # calculate accuracy & confidence score 77 | preds_prob = F.softmax(preds, dim=2) 78 | preds_max_prob, _ = preds_prob.max(dim=2) 79 | confidence_score_list = [] 80 | for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob): 81 | if 'Attn' in opt.Prediction: 82 | gt = gt[:gt.find('[s]')] 83 | pred_EOS = pred.find('[s]') 84 | pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) 85 | pred_max_prob = pred_max_prob[:pred_EOS] 86 | 87 | # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. 88 | # if opt.sensitive: 89 | # pred = pred.lower() 90 | # gt = gt.lower() 91 | # alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz' 92 | # out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]' 93 | # pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred) 94 | # gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt) 95 | if pred == gt: 96 | n_correct += 1 97 | vutils.save_image(image, "{}/{}_{}_{}.png".format(opt.test_out, i, gt, i)) # 删选正确样本作为测试集 98 | 99 | # if not opt.train_mode: 100 | # print('GoundTruth: %-10s => Prediction: %-10s' % (gt, pred)) 101 | if not opt.train_mode: 102 | print('Success:{},\t GoundTruth:{:20} => Prediction:{:20}'.format(pred == gt, gt, pred)) 103 | 104 | # calculate confidence score (= multiply of pred_max_prob) 105 | try: 106 | confidence_score = pred_max_prob.cumprod(dim=0)[-1] 107 | except: 108 | confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s]) 109 | confidence_score_list.append(confidence_score) 110 | # print(pred, gt, pred==gt, confidence_score) 111 | 112 | accuracy = n_correct / float(len(evaluation_loader)) * 100 113 | 114 | return valid_loss_avg.val(), accuracy, preds_str, confidence_score_list, labels, infer_time, len(evaluation_loader) 115 | 116 | 117 | def test(opt): 118 | """ save all the print content as log """ 119 | opt.test_out = os.path.join(opt.output, opt.name) 120 | makedirs(opt.test_out) 121 | # log_file= os.path.join(opt.test_out, 'test.log') 122 | # sys.stdout = Logger(log_file) 123 | 124 | """ model configuration """ 125 | if 'CTC' in opt.Prediction: 126 | converter = CTCLabelConverter(opt.character) 127 | else: 128 | converter = AttnLabelConverter(opt.character) 129 | opt.num_class = len(converter.character) 130 | 131 | if opt.rgb: 132 | opt.input_channel = 3 133 | model = Model(opt).to(device) 134 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 135 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 136 | opt.SequenceModeling, opt.Prediction) 137 | # model = torch.nn.DataParallel(model).to(device) 138 | 139 | # load model 140 | print('loading pretrained model from %s' % opt.saved_model) 141 | model.load_state_dict(torch.load(opt.saved_model, map_location=device),strict=False) 142 | # opt.exp_name = '_'.join(opt.saved_model.split('/')[1:]) 143 | print(model) 144 | 145 | """ setup loss """ 146 | if 'CTC' in opt.Prediction: 147 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 148 | else: 149 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 150 | 151 | """ evaluation """ 152 | model.eval() 153 | with torch.no_grad(): 154 | # eval_dataset = strdataset(opt.imgH, opt.imgW, opt.eval_data) 155 | eval_dataset = train_dataset_builder(opt.imgH, opt.imgW, opt.eval_data) 156 | evaluation_loader = torch.utils.data.DataLoader( 157 | eval_dataset, batch_size=opt.batch_size, 158 | shuffle=False, num_workers=int(opt.workers), 159 | # drop_last=True, pin_memory=True 160 | ) 161 | 162 | _, accuracy_by_best_model, _, _, _, _, _ = validation( 163 | model, criterion, evaluation_loader, converter, opt) 164 | print('SR:', f'{accuracy_by_best_model:0.2f}', '%') 165 | 166 | 167 | if __name__ == '__main__': 168 | parser = argparse.ArgumentParser() 169 | parser.add_argument('--output', required=True, help='Test output path') 170 | parser.add_argument('--name', required=True, help='Test model name') 171 | parser.add_argument('--train_mode', action='store_true', help='defalut is Test mode') 172 | parser.add_argument('--eval_data', type=str, required=True, help='path to evaluation dataset') 173 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 174 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 175 | parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation") 176 | """ Data processing """ 177 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 178 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 179 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 180 | parser.add_argument('--rgb', action='store_false', help='use rgb input') 181 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 182 | parser.add_argument('--sensitive', action='store_false', help='for sensitive character mode') 183 | """ Model Architecture """ 184 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') 185 | parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet') 186 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') 187 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 188 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 189 | parser.add_argument('--input_channel', type=int, default=3, help='the number of input channel of Feature extractor') 190 | parser.add_argument('--output_channel', type=int, default=512, 191 | help='the number of output channel of Feature extractor') 192 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 193 | 194 | opt = parser.parse_args() 195 | 196 | """ vocab / character number configuration """ 197 | if opt.sensitive: 198 | opt.character = string.printable[:62] # use 62 char (0~9, a~z, A~Z) 199 | 200 | cudnn.benchmark = True 201 | cudnn.deterministic = True 202 | 203 | test(opt) -------------------------------------------------------------------------------- /STR_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import string 6 | import argparse 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn.init as init 11 | import torch.optim as optim 12 | import torch.utils.data 13 | import numpy as np 14 | 15 | from STR_modules.prediction import CTCLabelConverter, AttnLabelConverter 16 | from STR_modules.model import Model 17 | from STR_test import validation 18 | from dataset import strdataset 19 | from utils import Averager 20 | 21 | 22 | 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | 25 | def train(opt): 26 | 27 | """ dataset preparation """ 28 | train_dataset = strdataset(opt.imgH, opt.imgW, opt.train_data) 29 | train_loader = torch.utils.data.DataLoader( 30 | train_dataset, batch_size=opt.batch_size, 31 | shuffle=True, num_workers=int(opt.workers), 32 | drop_last=True, pin_memory=True) 33 | 34 | 35 | valid_dataset = strdataset(opt.imgH, opt.imgW, opt.valid_data) 36 | valid_loader = torch.utils.data.DataLoader( 37 | valid_dataset, batch_size=opt.batch_size, 38 | shuffle=False, num_workers=int(opt.workers), 39 | drop_last=True, pin_memory=True) 40 | 41 | """ model configuration """ 42 | if 'CTC' in opt.Prediction: 43 | converter = CTCLabelConverter(opt.character) 44 | else: 45 | converter = AttnLabelConverter(opt.character) 46 | opt.num_class = len(converter.character) 47 | 48 | model = Model(opt).to(device) 49 | 50 | # weight initialization 51 | for name, param in model.named_parameters(): 52 | if 'localization_fc2' in name: 53 | print(f'Skip {name} as it is already initialized') 54 | continue 55 | try: 56 | if 'bias' in name: 57 | init.constant_(param, 0.0) 58 | elif 'weight' in name: 59 | init.kaiming_normal_(param) 60 | except Exception as e: # for batchnorm. 61 | if 'weight' in name: 62 | param.data.fill_(1) 63 | continue 64 | 65 | # # data parallel for multi-GPU 66 | # model = torch.nn.DataParallel(model).to(device) 67 | model.train() 68 | if opt.saved_model != '': 69 | print(f'loading pretrained model from {opt.saved_model}') 70 | if opt.FT: 71 | model.load_state_dict(torch.load(opt.saved_model), strict=False) 72 | else: 73 | model.load_state_dict(torch.load(opt.saved_model)) 74 | print("Model:") 75 | print(model) 76 | 77 | """ setup loss """ 78 | if 'CTC' in opt.Prediction: 79 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 80 | else: 81 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 82 | # loss averager 83 | loss_avg = Averager() 84 | 85 | # filter that only require gradient decent 86 | filtered_parameters = [] 87 | params_num = [] 88 | for p in filter(lambda p: p.requires_grad, model.parameters()): 89 | filtered_parameters.append(p) 90 | params_num.append(np.prod(p.size())) 91 | print('Trainable params num : ', sum(params_num)) 92 | 93 | # setup optimizer 94 | if opt.adam: 95 | optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) 96 | else: 97 | optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) 98 | print("Optimizer:") 99 | print(optimizer) 100 | 101 | """ final options """ 102 | # print(opt) 103 | with open(f'./STR_modules/saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: 104 | opt_log = '------------ Options -------------\n' 105 | args = vars(opt) 106 | for k, v in args.items(): 107 | opt_log += f'{str(k)}: {str(v)}\n' 108 | opt_log += '---------------------------------------\n' 109 | print(opt_log) 110 | opt_file.write(opt_log) 111 | 112 | """ start training """ 113 | start_iter = 0 114 | if opt.saved_model != '': 115 | try: 116 | start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) 117 | print(f'continue to train, start_iter: {start_iter}') 118 | except: 119 | pass 120 | 121 | start_time = time.time() 122 | best_accuracy = -1 123 | iteration = start_iter 124 | while(True): 125 | # train part 126 | for i, data in enumerate(train_loader): 127 | image_tensors, labels = data 128 | image = image_tensors.to(device) 129 | text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) 130 | batch_size = image.size(0) 131 | 132 | if 'CTC' in opt.Prediction: 133 | preds = model(image, text) 134 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 135 | preds = preds.log_softmax(2).permute(1, 0, 2) 136 | cost = criterion(preds, text, preds_size, length) 137 | else: 138 | preds = model(image, text[:, :-1]) # align with Attention.forward 139 | target = text[:, 1:] # without [GO] Symbol 140 | cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 141 | 142 | model.zero_grad() 143 | cost.backward() 144 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) 145 | optimizer.step() 146 | 147 | loss_avg.add(cost) 148 | 149 | # validation part 150 | if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 151 | elapsed_time = time.time() - start_time 152 | # for log 153 | with open(f'./STR_modules/saved_models/{opt.exp_name}/log_train.txt', 'a') as log: 154 | model.eval() 155 | with torch.no_grad(): 156 | valid_loss, current_accuracy, preds, confidence_score, labels, infer_time, length_of_data = validation( 157 | model, criterion, valid_loader, converter, opt) 158 | model.train() 159 | 160 | # training loss and validation loss 161 | loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' 162 | loss_avg.reset() 163 | 164 | current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}' 165 | 166 | # keep best accuracy model (on valid dataset) 167 | if current_accuracy > best_accuracy: 168 | best_accuracy = current_accuracy 169 | torch.save(model.state_dict(), f'./STR_modules/saved_models/{opt.exp_name}/best_accuracy.pth') 170 | 171 | best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}' 172 | 173 | loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' 174 | print(loss_model_log) 175 | log.write(loss_model_log + '\n') 176 | 177 | # show some predicted results 178 | dashed_line = '-' * 80 179 | head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' 180 | predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' 181 | for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): 182 | if 'Attn' in opt.Prediction: 183 | gt = gt[:gt.find('[s]')] 184 | pred = pred[:pred.find('[s]')] 185 | 186 | predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' 187 | predicted_result_log += f'{dashed_line}' 188 | print(predicted_result_log) 189 | log.write(predicted_result_log + '\n') 190 | 191 | # save model per 1e+4 iter. 192 | if (iteration + 1) % 1e+4 == 0: 193 | torch.save( 194 | model.state_dict(), f'./STR_modules/saved_models/{opt.exp_name}/iter_{iteration+1}.pth') 195 | 196 | if (iteration + 1) == opt.num_iter: 197 | print('end the training') 198 | sys.exit() 199 | iteration += 1 200 | 201 | if __name__ == '__main__': 202 | parser = argparse.ArgumentParser() 203 | parser.add_argument('--exp_name', required=True, help='Where to store logs and models') 204 | parser.add_argument('--train_mode', action='store_true', help='defalut is test mode') 205 | parser.add_argument('--train_data', type=str, required=True, help='path to training dataset') 206 | parser.add_argument('--valid_data', type=str, required=True, help='path to validation dataset') 207 | parser.add_argument('--manualSeed', type=int, default=2023, help='for random seed setting') 208 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers') 209 | parser.add_argument('--batch_size', type=int, default=16, help='input batch size') 210 | parser.add_argument('--num_iter', type=int, default=5e+4, help='number of iterations to train for') #train 3e+4, finetune 1e+4 211 | parser.add_argument('--valInterval', type=int, default=1000, help='Interval between each validation') 212 | parser.add_argument('--saved_model', type=str, default='', help="path to model to continue training") 213 | parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning') 214 | parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)') 215 | parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta') 216 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') 217 | parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95') 218 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') 219 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5') 220 | """ Data processing """ 221 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 222 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 223 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 224 | # parser.add_argument('--rgb', action='store_false', help='default to use rgb input') 225 | parser.add_argument('--character', type=str, 226 | default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 227 | parser.add_argument('--sensitive', action='store_false', help='default for sensitive character mode') 228 | """ Model Architecture """ 229 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') 230 | parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet') 231 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') 232 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 233 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 234 | parser.add_argument('--input_channel', type=int, default=3, 235 | help='the number of input channel of Feature extractor') 236 | parser.add_argument('--output_channel', type=int, default=512, 237 | help='the number of output channel of Feature extractor') 238 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 239 | 240 | opt = parser.parse_args() 241 | 242 | 243 | if not opt.exp_name: 244 | opt.exp_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 245 | # opt.exp_name += f'-Seed{opt.manualSeed}' 246 | if opt.sensitive: 247 | opt.exp_name += f'-sensitive' 248 | print(opt.exp_name) 249 | else: 250 | opt.exp_name = f'{opt.exp_name}-{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 251 | # opt.exp_name += f'-Seed{opt.manualSeed}' 252 | if opt.sensitive: 253 | opt.exp_name += f'-sensitive' 254 | print(opt.exp_name) 255 | 256 | if opt.FT: 257 | pass 258 | else: 259 | os.makedirs(f'./STR_modules/saved_models/{opt.exp_name}', exist_ok=True) 260 | 261 | """ vocab / character number configuration """ 262 | if opt.sensitive: 263 | opt.character = string.printable[:62] # use 62 char (0~9, a~z, A~Z) 264 | 265 | """ Seed and GPU setting """ 266 | print("Random Seed: ", opt.manualSeed) 267 | random.seed(opt.manualSeed) 268 | np.random.seed(opt.manualSeed) 269 | torch.manual_seed(opt.manualSeed) 270 | torch.cuda.manual_seed(opt.manualSeed) 271 | 272 | cudnn.benchmark = True 273 | cudnn.deterministic = True 274 | 275 | train(opt) 276 | -------------------------------------------------------------------------------- /baselines/Impact.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/baselines/Impact.ttf -------------------------------------------------------------------------------- /baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/baselines/__init__.py -------------------------------------------------------------------------------- /baselines/fawa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | # sys.path.append('xx/ProTegO/') 4 | import string 5 | import shutil 6 | import argparse 7 | from PIL import Image 8 | import numpy as np 9 | 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.utils.data 13 | from torchvision import utils as vutils 14 | 15 | from dataset import test_dataset_builder 16 | from utils import Logger, np2tensor, tensor2np, get_text_mask, cvt2Image, color_map 17 | from wm_attacker import WM_Attacker 18 | 19 | 20 | def makedirs(path): 21 | if not os.path.exists(path): 22 | try: 23 | os.makedirs(path) 24 | except Exception as e: 25 | print('cannot create dirs: {}'.format(path)) 26 | exit(0) 27 | 28 | def fawa(opt): 29 | """prepare log with model_name """ 30 | model_name = os.path.basename(opt.str_model.split('-')[0]) 31 | print('-----Attack Settings< model_name:{} --iter_num:{} --eps:{} --decay:{} --alpha:{} >' 32 | .format(model_name, opt.iter_num, opt.eps, opt.decay, opt.alpha)) 33 | 34 | save_root = os.path.join(opt.save_attacks, model_name) 35 | makedirs(save_root) 36 | del_list = os.listdir(save_root) 37 | for f in del_list: 38 | file_path = os.path.join(save_root, f) 39 | if os.path.isfile(file_path): 40 | os.remove(file_path) 41 | elif os.path.isdir(file_path): 42 | shutil.rmtree(file_path) 43 | 44 | wm_save_adv_path = os.path.join(save_root, 'wmadv') 45 | wm_save_per_path = os.path.join(save_root, 'wmper') 46 | makedirs(wm_save_adv_path) 47 | makedirs(wm_save_per_path) 48 | 49 | """ save all the print content as log """ 50 | log_file= os.path.join(save_root, 'train.log') 51 | sys.stdout = Logger(log_file) 52 | 53 | dataset = test_dataset_builder(opt.imgH, opt.imgW, opt.root) 54 | dataloader = torch.utils.data.DataLoader( 55 | dataset, batch_size=opt.batch_size, 56 | shuffle=False, num_workers=4, 57 | drop_last=True,pin_memory=True) 58 | 59 | attacker = WM_Attacker(opt) 60 | time_all, suc, ED_sum = 0, 0, 0 61 | for i, data in enumerate(dataloader, start=0): 62 | label = data[1] 63 | img = data[5] # binary image 64 | img_index = data[2][0] 65 | img_path = data[4][0] 66 | img_tensor = img.to(opt.device) 67 | 68 | """basic attack""" 69 | attacker.init() 70 | adv_img, delta, epoch, preds, flag, time = attacker.basic_attack(img_tensor, label) 71 | print('img-{}_path:{} --iters:{} --Success:{} --prediction:{} --groundtruth:{} --time:{}' 72 | .format(i, img_path, epoch, flag, preds, label[0], time)) 73 | 74 | """wm attack""" 75 | # find position to add watermark 76 | adv_np = tensor2np(adv_img.squeeze(0)) 77 | img_np = tensor2np(img_tensor.squeeze(0)) 78 | pos, frames = attacker.find_wm_pos(adv_np, img_np, True) 79 | # 按面积大小把pos从大到小排个序 80 | new_pos = [] 81 | for _pos in pos: 82 | if len(_pos) > 1: 83 | new_pos.append(sorted(_pos, key=lambda x: (x[3]-x[1])*(x[2]-x[0]), reverse=True)) 84 | else: 85 | new_pos.append(_pos) 86 | pos = new_pos 87 | 88 | # get watermark mask 89 | grayscale = 0 90 | color = (grayscale, grayscale, grayscale) 91 | wm_img = attacker.gen_wm(color) 92 | wm_arr = np.array(wm_img.convert('L')) 93 | bg_mask = ~(wm_arr != 255) 94 | 95 | # # get grayscale watermark 96 | # """ 97 | # 灰度值在 76-226 有对应的彩色水印值,为了增加扰动后还在范围内,128-174, 98 | # *note by hyr:paper setting = 255*0.682 =174 99 | # """ 100 | # grayscale = 174 101 | # color = (grayscale, grayscale, grayscale) 102 | # wm_img = np.array(Image.new(mode="RGB", size=wm_img.size, color=color)) 103 | # wm_img[bg_mask] = 255 104 | # wm_img = Image.fromarray(wm_img) 105 | 106 | # # get color watermark 107 | grayscale = 174 108 | green_v = color_map(grayscale) 109 | color = (255, green_v, 0) 110 | wm_img = np.array(Image.new(mode="RGB", size=wm_img.size, color=color)) 111 | wm_img[bg_mask] = 255 112 | wm_img = Image.fromarray(wm_img) 113 | 114 | 115 | text_img = cvt2Image(img_np) 116 | text_mask = get_text_mask(np.array(text_img)) # bool, 1 channel 117 | rgb_img = Image.new(mode="RGB", size=(text_img.size), color=(255, 255, 255)) # white bg 118 | p = -int(wm_img.size[0] * np.tan(10 * np.pi / 180)) 119 | right_shift = 20 120 | xp = pos[0][0][0]+right_shift if len(pos) != 0 else right_shift 121 | rgb_img.paste(wm_img, box=(xp, p)) # first to add wm 122 | wm_mask = (np.array(rgb_img) != 255) # bool, 3 channel 123 | rgb_img.paste(text_img, mask=cvt2Image(text_mask)) # then add text 124 | 125 | wm0_img = np.array(rgb_img) 126 | wm_img = np2tensor(wm0_img) 127 | wm_mask = np2tensor(wm_mask) 128 | adv_text_mask = ~text_mask 129 | adv_text_mask = np2tensor(adv_text_mask) 130 | 131 | attacker.init() 132 | adv_img_wm, delta, epoch, preds, flag, ED, time = attacker.wm_attack(wm_img, label, wm_mask, adv_text_mask) 133 | 134 | print('wmimg-{}_path:{} --iters:{} --Success:{} --prediction:{} --groundtruth:{} --edit_distance:{} --time:{}' 135 | .format(i, img_path, epoch, flag, preds, label[0], ED, time)) 136 | 137 | time_all +=time 138 | if flag: 139 | suc += 1 140 | ED_sum += ED 141 | vutils.save_image(adv_img_wm,"{}/{}_{}_adv.png".format(wm_save_adv_path, os.path.basename(img_index), label[0])) 142 | vutils.save_image(delta*100,"{}/{}_{}_delta.png".format(wm_save_per_path, os.path.basename(img_index), label[0])) 143 | 144 | print('FAWA_Total_attack_time:{} '.format(time_all)) 145 | print('ASR:{:.2f}% '.format((suc/len(dataloader))*100)) 146 | print('Average Edit_distance: {}'.format(ED_sum/suc)) 147 | 148 | 149 | 150 | if __name__ == '__main__': 151 | parser = argparse.ArgumentParser() 152 | parser.add_argument('--root', default='/data/hyr/dataset/PTAU/new/test100-times/',help='path of original text images') 153 | parser.add_argument('--save_attacks', type=str, default='res-baselines/fawa', help='path of save adversarial images') 154 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 155 | parser.add_argument('--iter_num', type=int, default=2000, help='number of iterations') 156 | parser.add_argument('--eps', type=float, default=40/255, help='maximnum perturbation setting in paper') 157 | parser.add_argument('--decay', type=float, default=1.0, help='momentum factor') 158 | parser.add_argument('--alpha', type=float, default=0.05, help='the step size of the iteration') 159 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers') 160 | 161 | """ Data processing """ 162 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 163 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 164 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 165 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 166 | parser.add_argument('--sensitive', action='store_false', help='for sensitive character mode') 167 | """ Model Architecture """ 168 | parser.add_argument('--str_model', type=str, help="well-trained models for evaluation", 169 | default='STR_modules/downloads_models/STARNet-TPS-ResNet-BiLSTM-CTC-sensitive.pth') 170 | parser.add_argument('--Transformation', type=str, default='TPS', help='Transformation stage. None|TPS') 171 | parser.add_argument('--FeatureExtraction', type=str, default='ResNet', help='FeatureExtraction stage. VGG|RCNN|ResNet') 172 | parser.add_argument('--SequenceModeling', type=str, default='BiLSTM', help='SequenceModeling stage. None|BiLSTM') 173 | parser.add_argument('--Prediction', type=str, default='CTC', help='Prediction stage. CTC|Attn') 174 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 175 | parser.add_argument('--input_channel', type=int, default=3, help='the number of input channel of Feature extractor') 176 | parser.add_argument('--output_channel', type=int, default=512, 177 | help='the number of output channel of Feature extractor') 178 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 179 | opt = parser.parse_args() 180 | 181 | """ vocab / character number configuration """ 182 | device_type = 'cuda' if torch.cuda.is_available() else 'cpu' 183 | opt.device = torch.device(device_type) 184 | 185 | opt.save_attacks = opt.save_attacks + "-eps" + str(int(opt.eps*255)) 186 | 187 | if opt.sensitive: 188 | opt.character = string.printable[:62] 189 | 190 | cudnn.benchmark = True 191 | cudnn.deterministic = True 192 | 193 | torch.cuda.synchronize() 194 | fawa(opt) 195 | -------------------------------------------------------------------------------- /baselines/test_black_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test baseline methods on 4 Black-box models: CRNN, Rosetta, RARA, TRBA""" 3 | 4 | import os, sys, time, string, argparse, shutil 5 | sys.path.append('/data/hyr/ocr/ProTegO/release/') 6 | import torch 7 | from nltk.metrics import edit_distance 8 | from dataset import test_adv_dataset 9 | from STR_modules.model import Model 10 | from utils import Logger, CTCLabelConverter, AttnLabelConverter 11 | 12 | def makedirs(path): 13 | if not os.path.exists(path): 14 | try: 15 | os.makedirs(path) 16 | except Exception as e: 17 | print('cannot create dirs: {}'.format(path)) 18 | exit(0) 19 | 20 | def process_line(line): 21 | adv_img_path, recog_result = line.split(':') 22 | label, adv_preds = recog_result.split('--->') 23 | adv_preds = adv_preds.strip('\n') 24 | return adv_preds, label, adv_img_path 25 | 26 | def test(opt): 27 | 28 | """ model configuration """ 29 | if 'CTC' in opt.Prediction: 30 | converter = CTCLabelConverter(opt.character) 31 | else: 32 | converter = AttnLabelConverter(opt.character) 33 | opt.num_class = len(converter.character) 34 | 35 | model = Model(opt).to(opt.device) 36 | print('Loading a STR model from \"%s\" as the target model!' % opt.str_model) 37 | model.load_state_dict(torch.load(opt.str_model, map_location=opt.device),strict=False) 38 | model.eval() 39 | 40 | """ create new test model output path """ 41 | makedirs(opt.output) 42 | 43 | # str_name = opt.str_model.split('/')[-2].split('-')[0] 44 | str_name = opt.str_model.split('/')[-1].split('-')[0] 45 | test_output_path = os.path.join(opt.output, opt.attack_name, str_name) 46 | attack_success_result = os.path.join(test_output_path , 'attack_success_result.txt') 47 | save_success_adv = os.path.join(test_output_path , 'attack-success-adv') 48 | 49 | makedirs(test_output_path) 50 | makedirs(save_success_adv) 51 | 52 | log_file= os.path.join(test_output_path, 'test.log') 53 | sys.stdout = Logger(log_file) 54 | 55 | test_dataset = test_adv_dataset(opt.imgH, opt.imgW, opt.adv_img) 56 | test_dataloader = torch.utils.data.DataLoader( 57 | test_dataset, 58 | batch_size=opt.batch_size, 59 | shuffle=False, 60 | num_workers=1) 61 | 62 | result = dict() 63 | for i, data in enumerate(test_dataloader): 64 | if opt.b: 65 | adv_img = data[0] 66 | else: 67 | adv_img = data[1] 68 | adv_img= adv_img.to(opt.device) 69 | label = data[2] 70 | adv_index = data[3][0] 71 | adv_path = data[5][0] 72 | 73 | length_for_pred = torch.IntTensor([opt.batch_max_length] * opt.batch_size).to(opt.device) 74 | text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(opt.device) 75 | if 'CTC' in opt.Prediction: 76 | preds = model(adv_img, text_for_pred).log_softmax(2) 77 | preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size) 78 | _, preds_index = preds.permute(1, 0, 2).max(2) 79 | preds_index = preds_index.transpose(1, 0).contiguous().view(-1) 80 | preds_output = converter.decode(preds_index.data, preds_size) 81 | preds_output = preds_output[0] 82 | result[adv_index] = '{}:{}--->{}\n'.format(adv_path, label[0], preds_output) 83 | else: # Attention 84 | preds = model(adv_img, text_for_pred, is_train=False) 85 | _, preds_index = preds.max(2) 86 | preds_output = converter.decode(preds_index, length_for_pred) 87 | preds_output = preds_output[0] 88 | preds_output = preds_output[:preds_output.find('[s]')] 89 | result[adv_index] = '{}:{}--->{}\n'.format(adv_path, label[0], preds_output) 90 | result = sorted(result.items(), key=lambda x:x[0]) 91 | with open(attack_success_result, 'w+') as f: 92 | for item in result: 93 | f.write(item[1]) 94 | 95 | # calculate ASR 96 | with open(attack_success_result, 'r') as f: 97 | alladv = f.readlines() 98 | attack_success_num, ED_sum = 0, 0 99 | for line in alladv: 100 | adv_preds, label, adv_img_path = process_line(line) 101 | 102 | if adv_preds != label: 103 | ED_num = edit_distance(label, adv_preds) 104 | attack_success_num += 1 105 | shutil.copy(adv_img_path, save_success_adv) 106 | ED_sum += ED_num 107 | attack_success_rate = attack_success_num / len(test_dataset) 108 | print('ASR:{:.2f} %'.format(attack_success_rate * 100)) 109 | if attack_success_num != 0: 110 | ED_num_avr = ED_sum / attack_success_num 111 | print('Average Edit_distance: {:.2f}'.format(ED_num_avr)) 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('--output', type=str, default='res-BlackModelTest/up5a', help='path to save attack results') 117 | parser.add_argument('--attack_name', required=True, help='baseline attack method name') 118 | """ Data processing """ 119 | parser.add_argument('--adv_img', required=True, help='the path of adv_x which generated from STARNet model') 120 | parser.add_argument('--b', action='store_true', help='Use binarization processing to adv_img.') 121 | parser.add_argument('--batch_size', type=int, default=1) 122 | parser.add_argument('--img_channel', type=int, default=3, help='the number of input channel of image') 123 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 124 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 125 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 126 | parser.add_argument('--character', type=str,default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 127 | parser.add_argument('--sensitive', action='store_false', help='default for sensitive character mode') 128 | """ Model Architecture """ 129 | parser.add_argument('--str_model', type=str, help='the model path of the target model', 130 | default='/STR_modules/downloads_models/CRNN-None-VGG-BiLSTM-CTC-sensitive.pth') 131 | parser.add_argument('--Transformation', type=str, default='None', help='Transformation stage. None|TPS') 132 | parser.add_argument('--FeatureExtraction', type=str, default='VGG', help='FeatureExtraction stage. VGG|RCNN|ResNet') 133 | parser.add_argument('--SequenceModeling', type=str, default='BiLSTM', help='SequenceModeling stage. None|BiLSTM') 134 | parser.add_argument('--Prediction', type=str, default='CTC', help='Prediction stage. CTC|Attn') 135 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 136 | parser.add_argument('--input_channel', type=int, default=3, 137 | help='the number of input channel of Feature extractor') 138 | parser.add_argument('--output_channel', type=int, default=512, 139 | help='the number of output channel of Feature extractor') 140 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 141 | opt = parser.parse_args() 142 | print(opt) 143 | 144 | opt.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 145 | 146 | """ vocab / character number configuration """ 147 | if opt.sensitive: 148 | opt.character = string.printable[:62] # use 62 char (0~9, a~z, A~Z) 149 | 150 | torch.cuda.synchronize() 151 | time_st = time.time() 152 | test(opt) 153 | time_end = time.time() 154 | time_all = time_end - time_st 155 | print('Testing time:{:.2f}'.format(time_all)) -------------------------------------------------------------------------------- /baselines/transfer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | # sys.path.append('xx/ProTegO/') 4 | import string 5 | import shutil 6 | import argparse 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.utils.data 11 | from torchvision import utils as vutils 12 | from nltk.metrics import edit_distance 13 | 14 | from transfer_attacker import transfer_Attacker 15 | from dataset import test_dataset_builder 16 | from utils import Logger 17 | 18 | 19 | def makedirs(path): 20 | if not os.path.exists(path): 21 | try: 22 | os.makedirs(path) 23 | except Exception as e: 24 | print('cannot create dirs: {}'.format(path)) 25 | exit(0) 26 | 27 | def transfer_attack(opt): 28 | """prepare log with model_name """ 29 | model_name = os.path.basename(opt.str_model.split('-')[0]) 30 | print('----------------------------------------------------------------------------------------') 31 | print('Start attacking: :{} :{} \t\n:{} :{:2f} :{:2f} :{:2f} :{} :{}' 32 | .format(model_name, opt.name, opt.iter_num, opt.eps, opt.alpha, opt.beta, opt.m, opt.N)) 33 | print('----------------------------------------------------------------------------------------') 34 | 35 | save_root = os.path.join(opt.save_attacks, model_name) 36 | makedirs(save_root) 37 | del_list = os.listdir(save_root) 38 | for f in del_list: 39 | file_path = os.path.join(save_root, f) 40 | if os.path.isfile(file_path): 41 | os.remove(file_path) 42 | elif os.path.isdir(file_path): 43 | shutil.rmtree(file_path) 44 | 45 | save_adv_path = os.path.join(save_root, 'adv') 46 | save_per_path = os.path.join(save_root, 'per') 47 | makedirs(save_adv_path) 48 | makedirs(save_per_path) 49 | 50 | """ save all the print content as log """ 51 | log_file= os.path.join(save_root, 'train.log') 52 | sys.stdout = Logger(log_file) 53 | 54 | dataset = test_dataset_builder(opt.imgH, opt.imgW, opt.root) 55 | dataloader = torch.utils.data.DataLoader( 56 | dataset, batch_size=opt.batch_size, 57 | shuffle=False, num_workers=4, 58 | drop_last=True,pin_memory=True) 59 | 60 | # up = up_dataset(opt.up_path) 61 | # up = up.repeat(opt.batch_size,1,1,1) 62 | 63 | attacker = transfer_Attacker(opt) 64 | time_all, suc, ED_sum = 0, 0, 0 65 | for i, data in enumerate(dataloader, start=0): 66 | # image = data[0] 67 | label = data[1] 68 | img_index = data[2][0] 69 | img_path = data[4][0] 70 | mask = data[5] 71 | mask_tensor = mask.to(opt.device) 72 | # img_tensor = torch.mul(mask, up).to(opt.device) 73 | if opt.name == 'SINIFGSM': 74 | adv_img, delta, preds, time, iters = attacker.SINIFGSM(mask_tensor, label) 75 | elif opt.name == 'VMIFGSM': 76 | adv_img, delta, preds, time, iters = attacker.VMIFGSM(mask_tensor, label) 77 | 78 | flag = (preds != label[0]) 79 | ED = edit_distance(label[0], preds) 80 | 81 | print('img-{}_path:{} --Success:{} --prediction:{} --groundtruth:{} --edit_distance:{} --time:{:2f} --iter_num:{}' 82 | .format(i, img_path, flag, preds, label[0], ED, time, iters)) 83 | 84 | time_all +=time 85 | 86 | vutils.save_image(adv_img,"{}/{}_{}_adv.png".format(save_adv_path, os.path.basename(img_index), label[0])) 87 | vutils.save_image(torch.abs(delta*100),"{}/{}_{}_delta.png".format(save_per_path, os.path.basename(img_index), label[0])) 88 | 89 | if flag: 90 | suc += 1 91 | ED_sum +=ED 92 | 93 | 94 | print('{} Total_attack_time:{} '.format(opt.name, time_all)) 95 | print('ASR:{:.2f}% '.format((suc/len(dataloader))*100)) 96 | print('Average Edit_distance: {}'.format(ED_sum/suc)) 97 | 98 | 99 | if __name__ == '__main__': 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('--name', required=True, help='attack name [SINIFGSM, VMIFGSM]') 102 | parser.add_argument('--root', default='/data/hyr/dataset/PTAU/new/test100-times/',help='path of images') 103 | parser.add_argument('--save_attacks', type=str, default='res-baselines/', help='path of save adversarial text images') 104 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 105 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers') 106 | parser.add_argument('--eps', type=float, default=40/255, help='maximum perturbation') 107 | parser.add_argument('--iter_num', type=int, default=30, help='number of max iterations') 108 | parser.add_argument('--decay', type=float, default=1, help='momentum factor') 109 | parser.add_argument('--alpha', type=float, default=2/255, help='step size of each iteration') 110 | parser.add_argument('--beta', type=float, default=2/3, help='the upper bound of neighborhood.') 111 | parser.add_argument('--m', type=int, default=5, help='number of scale copies.') 112 | parser.add_argument('--N', type=int, default=5, help='the number of sampled examples in the neighborhood.') 113 | 114 | """ Data processing """ 115 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 116 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 117 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 118 | # parser.add_argument('--rgb', action='store_true', help='use rgb input') 119 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', 120 | help='character label') 121 | parser.add_argument('--sensitive', action='store_false', help='for sensitive character mode') 122 | """ Model Architecture """ 123 | parser.add_argument('--str_model', type=str, help="well-trained models for evaluation", 124 | default='STR_modules/downloads_models/STARNet-TPS-ResNet-BiLSTM-CTC-sensitive.pth') 125 | parser.add_argument('--Transformation', type=str, default='TPS', help='Transformation stage. None|TPS') 126 | parser.add_argument('--FeatureExtraction', type=str, default='ResNet', help='FeatureExtraction stage. VGG|RCNN|ResNet') 127 | parser.add_argument('--SequenceModeling', type=str, default='BiLSTM', help='SequenceModeling stage. None|BiLSTM') 128 | parser.add_argument('--Prediction', type=str, default='CTC', help='Prediction stage. CTC|Attn') 129 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 130 | parser.add_argument('--input_channel', type=int, default=3, help='the number of input channel of Feature extractor') 131 | parser.add_argument('--output_channel', type=int, default=512, 132 | help='the number of output channel of Feature extractor') 133 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 134 | opt = parser.parse_args() 135 | print(opt) 136 | 137 | """ vocab / character number configuration """ 138 | device_type = 'cuda' if torch.cuda.is_available() else 'cpu' 139 | opt.device = torch.device(device_type) 140 | 141 | if opt.sensitive: 142 | opt.character = string.printable[:62] 143 | 144 | opt.save_attacks = opt.save_attacks + opt.name 145 | opt.save_attacks = opt.save_attacks + "-eps" + str(int(opt.eps*255)) 146 | 147 | cudnn.benchmark = True 148 | cudnn.deterministic = True 149 | 150 | torch.cuda.synchronize() 151 | transfer_attack(opt) -------------------------------------------------------------------------------- /baselines/transfer_attacker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | 5 | from utils import CTCLabelConverter, AttnLabelConverter 6 | from STR_modules.model import Model 7 | 8 | r""" 9 | Base class for transfer-based attack [SI-NI-FGSM, VMI-FGSM], 10 | modified from "https://github.com/Harry24k/adversarial-attacks-pytorch". 11 | 12 | SI-NI-FGSM in the paper 'NESTEROV ACCELERATED GRADIENT AND SCALEINVARIANCE FOR ADVERSARIAL ATTACKS' 13 | [https://arxiv.org/abs/1908.06281], Published as a conference paper at ICLR 2020 14 | 15 | VMI-FGSM in the paper 'Enhancing the Transferability of Adversarial Attacks through Variance Tuning 16 | [https://arxiv.org/abs/2103.15571], Published as a conference paper at CVPR 2021. 17 | 18 | Distance Measure : Linf 19 | 20 | Arguments: 21 | eps (float): maximum perturbation. (Default: 40/255) 22 | iter_num (int): number of iterations. (Default: 10) 23 | decay (float): momentum factor. (Default: 1.0) 24 | alpha (float): step size. (Default: 2/255) 25 | beta (float): the upper bound of neighborhood. (Default: 3/2) 26 | m (int): number of scale copies. (Default: 5) 27 | N (int): the number of sampled examples in the neighborhood. (Default: 5) 28 | 29 | """ 30 | 31 | class transfer_Attacker(object): 32 | 33 | def __init__(self, c_para): 34 | self.device = c_para.device 35 | self.batch_size = c_para.batch_size 36 | self.eps = c_para.eps 37 | self.iter_num = c_para.iter_num 38 | self.decay = c_para.decay 39 | self.alpha = c_para.alpha 40 | self.beta = c_para.beta 41 | self.m = c_para.m 42 | self.N = c_para.N 43 | 44 | self.converter = self._load_converter(c_para) 45 | self.model = self._load_model(c_para) 46 | self.Transformation = c_para.Transformation 47 | self.FeatureExtraction = c_para.FeatureExtraction 48 | self.SequenceModeling = c_para.SequenceModeling 49 | self.Prediction = c_para.Prediction 50 | self.batch_max_length = c_para.batch_max_length 51 | 52 | self.criterion = self._load_base_loss(c_para) 53 | self.l2_loss = self._load_l2_loss() 54 | 55 | @staticmethod 56 | def _load_l2_loss(): 57 | return torch.nn.MSELoss() 58 | @staticmethod 59 | def _load_base_loss(c_para): 60 | if c_para.Prediction == 'CTC': 61 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(c_para.device) 62 | else: 63 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(c_para.device) 64 | return criterion 65 | @staticmethod 66 | def _load_converter(c_para): 67 | if 'CTC' in c_para.Prediction: 68 | converter = CTCLabelConverter(c_para.character) 69 | else: 70 | converter = AttnLabelConverter(c_para.character) 71 | c_para.num_class = len(converter.character) 72 | return converter 73 | @staticmethod 74 | def _load_model(c_para): 75 | if not os.path.exists(c_para.str_model): 76 | raise FileNotFoundError("cannot find pth file in {}".format(c_para.str_model)) 77 | # load model 78 | with torch.no_grad(): 79 | model = Model(c_para).to(c_para.device) 80 | model.load_state_dict(torch.load(c_para.str_model, map_location=c_para.device)) 81 | for name, para in model.named_parameters(): 82 | para.requires_grad = False 83 | return model.eval() 84 | 85 | def model_pred(self, img, mode='CTC'): 86 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device) 87 | # self.model.eval() 88 | with torch.no_grad(): 89 | if mode == 'CTC': 90 | pred = self.model(img, text_for_pred).log_softmax(2) 91 | size = torch.IntTensor([pred.size(1)] * self.batch_size).to(self.device) 92 | _, index = pred.permute(1, 0, 2).max(2) 93 | index = index.transpose(1, 0).contiguous().view(-1) 94 | pred_str = self.converter.decode(index.data, size.data) 95 | pred_str = pred_str[0] 96 | else: # ATTENTION 97 | pred = self.model(img, text_for_pred) 98 | size = torch.IntTensor([pred.size(1)] * self.batch_size) 99 | _, index = pred.max(2) 100 | pred_str = self.converter.decode(index, size) 101 | pred_s = pred_str[0] 102 | pred_str = pred_s[:pred_s.find('[s]')] 103 | # self.model.train() 104 | return pred_str 105 | 106 | def SINIFGSM(self, x, raw_text): 107 | torch.backends.cudnn.enabled=False 108 | x = x.clone().detach().to(self.device) 109 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device) 110 | text, length = self.converter.encode(raw_text, batch_max_length=self.batch_max_length) 111 | 112 | momentum = torch.zeros_like(x).detach().to(self.device) 113 | 114 | adv_x = x.clone().detach() 115 | time_each = 0 116 | t_st = time.time() 117 | pred_org = self.model_pred(x, self.Prediction) 118 | for iters in range(self.iter_num): 119 | adv_x.requires_grad = True 120 | nes_x = adv_x + self.decay*self.alpha*momentum 121 | # Calculate sum the gradients over the scale copies of the input image 122 | adv_grad = torch.zeros_like(x).detach().to(self.device) 123 | for i in torch.arange(self.m): 124 | nes_x = nes_x / torch.pow(2, i) 125 | preds = self.model(nes_x, text_for_pred).log_softmax(2) 126 | # Calculate loss 127 | if 'CTC' in self.Prediction: 128 | preds_size = torch.IntTensor([preds.size(1)] * self.batch_size).to(self.device) 129 | preds = preds.permute(1, 0, 2) 130 | cost = self.criterion(preds, text, preds_size, length) 131 | else: # ATTENTION 132 | target_text = text[:, 1:] 133 | cost = self.criterion(preds.view(-1, preds.shape[-1]), target_text.contiguous().view(-1)) 134 | adv_grad += torch.autograd.grad(cost, adv_x, 135 | retain_graph=True, create_graph=False)[0] 136 | adv_grad = adv_grad / self.m 137 | 138 | # Update adversarial images 139 | grad = self.decay*momentum + adv_grad / torch.mean(torch.abs(adv_grad), dim=(1,2,3), keepdim=True) 140 | momentum = grad 141 | adv_x = adv_x.detach() + self.alpha*grad.sign() 142 | delta = torch.clamp(adv_x - x, min=-self.eps, max=self.eps) 143 | adv_x = torch.clamp(x + delta, min=0, max=1).detach() 144 | 145 | pred_adv = self.model_pred(adv_x, self.Prediction) 146 | if pred_adv != pred_org: 147 | break 148 | t_en = time.time() 149 | time_each = t_en - t_st 150 | return adv_x, delta, pred_adv, time_each, iters 151 | 152 | def VMIFGSM(self, x, raw_text): 153 | torch.backends.cudnn.enabled=False 154 | x = x.clone().detach().to(self.device) 155 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device) 156 | text, length = self.converter.encode(raw_text, batch_max_length=self.batch_max_length) 157 | 158 | momentum = torch.zeros_like(x).detach().to(self.device) 159 | v = torch.zeros_like(x).detach().to(self.device) 160 | 161 | adv_x = x.clone().detach() 162 | time_each = 0 163 | t_st = time.time() 164 | pred_org = self.model_pred(x, self.Prediction) 165 | for iters in range(self.iter_num): 166 | adv_x.requires_grad = True 167 | preds = self.model(adv_x, text_for_pred).log_softmax(2) 168 | 169 | # Calculate loss 170 | if 'CTC' in self.Prediction: 171 | preds_size = torch.IntTensor([preds.size(1)] * self.batch_size).to(self.device) 172 | preds = preds.permute(1, 0, 2) 173 | cost = self.criterion(preds, text, preds_size, length) 174 | else: # ATTENTION 175 | target_text = text[:, 1:] 176 | cost = self.criterion(preds.view(-1, preds.shape[-1]), target_text.contiguous().view(-1)) 177 | 178 | # Update adversarial images 179 | adv_grad = torch.autograd.grad(cost, adv_x, 180 | retain_graph=False, create_graph=False)[0] 181 | 182 | grad = (adv_grad + v) / torch.mean(torch.abs(adv_grad + v), dim=(1,2,3), keepdim=True) 183 | grad = grad + momentum * self.decay 184 | momentum = grad 185 | 186 | # Calculate Gradient Variance 187 | GV_grad = torch.zeros_like(x).detach().to(self.device) 188 | for _ in range(self.N): 189 | neighbor_x = adv_x.detach() + \ 190 | torch.randn_like(x).uniform_(-self.eps*self.beta, self.eps*self.beta) 191 | neighbor_x.requires_grad = True 192 | preds = self.model(neighbor_x, text_for_pred).log_softmax(2) 193 | 194 | # Calculate loss 195 | if 'CTC' in self.Prediction: 196 | preds_size = torch.IntTensor([preds.size(1)] * self.batch_size).to(self.device) 197 | preds = preds.permute(1, 0, 2) 198 | cost = self.criterion(preds, text, preds_size, length) 199 | else: # ATTENTION 200 | target_text = text[:, 1:] 201 | cost = self.criterion(preds.view(-1, preds.shape[-1]), target_text.contiguous().view(-1)) 202 | 203 | GV_grad += torch.autograd.grad(cost, neighbor_x, 204 | retain_graph=False, create_graph=False)[0] 205 | # obtaining the gradient variance 206 | v = GV_grad / self.N - adv_grad 207 | 208 | adv_x = adv_x.detach() + self.alpha*grad.sign() 209 | delta = torch.clamp(adv_x - x, min=-self.eps, max=self.eps) 210 | adv_x = torch.clamp(x + delta, min=0, max=1).detach() 211 | pred_adv = self.model_pred(adv_x, self.Prediction) 212 | if pred_adv != pred_org: 213 | break 214 | t_en = time.time() 215 | time_each = t_en - t_st 216 | 217 | 218 | return adv_x, delta, pred_adv, time_each, iters 219 | 220 | -------------------------------------------------------------------------------- /baselines/wm_attacker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import cv2 5 | import torch 6 | 7 | from skimage import morphology 8 | from trdg.generators import GeneratorFromStrings 9 | from nltk.metrics import edit_distance 10 | 11 | from utils import CTCLabelConverter, AttnLabelConverter, RGB2Hex 12 | from STR_modules.model import Model 13 | 14 | r""" 15 | Base class for fawa. 16 | 17 | Distance Measure : Linf 18 | 19 | """ 20 | class WM_Attacker(object): 21 | def __init__(self, c_para): 22 | r""" 23 | Arguments: 24 | c_para : all arguments from Parser which are prepared for 25 | """ 26 | self.device = c_para.device 27 | self.batch_size = c_para.batch_size 28 | self.eps = c_para.eps 29 | self.iter_num = c_para.iter_num 30 | self.decay = c_para.decay 31 | self.alpha = c_para.alpha 32 | 33 | self.converter = self._load_converter(c_para) 34 | self.model = self._load_model(c_para) 35 | self.Transformation = c_para.Transformation 36 | self.FeatureExtraction = c_para.FeatureExtraction 37 | self.SequenceModeling = c_para.SequenceModeling 38 | self.Prediction = c_para.Prediction 39 | self.batch_max_length = c_para.batch_max_length 40 | 41 | self.criterion = self._load_base_loss(c_para) 42 | self.l2_loss = self._load_l2_loss() 43 | 44 | self.img_size = (c_para.batch_size, c_para.input_channel, c_para.imgH, c_para.imgW) 45 | self.best_img = 100 * torch.ones(self.img_size) 46 | self.best_iter = -1 47 | self.best_delta = 100 * torch.ones(self.img_size) 48 | self.preds = '' 49 | self.suc = False 50 | 51 | def init(self): 52 | self.best_img = 100* torch.ones(self.img_size) 53 | self.best_iter = -1 54 | self.best_delta = 100 * torch.ones(self.img_size) 55 | self.preds = '' 56 | self.suc = False 57 | 58 | @staticmethod 59 | def _load_l2_loss(): 60 | return torch.nn.MSELoss() 61 | 62 | @staticmethod 63 | def _load_base_loss(c_para): 64 | if c_para.Prediction == 'CTC': 65 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(c_para.device) 66 | else: 67 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(c_para.device) 68 | return criterion 69 | 70 | @staticmethod 71 | def _load_converter(c_para): 72 | if 'CTC' in c_para.Prediction: 73 | converter = CTCLabelConverter(c_para.character) 74 | else: 75 | converter = AttnLabelConverter(c_para.character) 76 | c_para.num_class = len(converter.character) 77 | return converter 78 | @staticmethod 79 | def _load_model(c_para): 80 | if not os.path.exists(c_para.str_model): 81 | raise FileNotFoundError("cannot find pth file in {}".format(c_para.str_model)) 82 | # load model 83 | with torch.no_grad(): 84 | model = Model(c_para).to(c_para.device) 85 | model.load_state_dict(torch.load(c_para.str_model, map_location=c_para.device)) 86 | for name, para in model.named_parameters(): 87 | para.requires_grad = False 88 | return model.eval() 89 | 90 | def model_pred(self, img, mode='CTC'): 91 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device) 92 | # self.model.eval() 93 | with torch.no_grad(): 94 | if mode == 'CTC': 95 | pred = self.model(img, text_for_pred).log_softmax(2) 96 | size = torch.IntTensor([pred.size(1)] * self.batch_size).to(self.device) 97 | _, index = pred.permute(1, 0, 2).max(2) 98 | index = index.transpose(1, 0).contiguous().view(-1) 99 | pred_str = self.converter.decode(index.data, size.data) 100 | pred_str = pred_str[0] 101 | else: # ATTENTION 102 | pred = self.model(img, text_for_pred) 103 | size = torch.IntTensor([pred.size(1)] * self.batch_size) 104 | _, index = pred.max(2) 105 | pred_str = self.converter.decode(index, size) 106 | pred_s = pred_str[0] 107 | pred_str = pred_s[:pred_s.find('[s]')] 108 | # self.model.train() 109 | return pred_str 110 | 111 | def basic_attack(self, x, raw_text): 112 | x = x.clone().detach().to(self.device) 113 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device) 114 | text, length = self.converter.encode(raw_text, batch_max_length=self.batch_max_length) 115 | pred_org = self.model_pred(x, self.Prediction) #TODO change to raw_text as label 116 | 117 | momentum = torch.zeros_like(x).detach().to(self.device) 118 | 119 | adv_x = x.clone().detach() 120 | 121 | num = 0 122 | time_each = 0 123 | t_st = time.time() 124 | for iter in range(self.iter_num): 125 | adv_x.requires_grad = True 126 | 127 | # erlier stop 128 | pred_adv = self.model_pred(adv_x, self.Prediction) 129 | if pred_adv != pred_org: 130 | num += 1 131 | # print('Best results!') 132 | self.best_iter = iter 133 | self.best_img = adv_x.detach().clone() 134 | self.best_delta = delta.detach().clone() 135 | self.preds = pred_adv 136 | self.suc = True 137 | if num == 1: 138 | break 139 | 140 | # Calculate loss 141 | torch.backends.cudnn.enabled=False 142 | preds = self.model(adv_x, text_for_pred).log_softmax(2) 143 | if 'CTC' in self.Prediction: 144 | preds_size = torch.IntTensor([preds.size(1)] * self.batch_size).to(self.device) 145 | preds = preds.permute(1, 0, 2) 146 | cost = self.criterion(preds, text, preds_size, length) 147 | else: # ATTENTION 148 | target_text = text[:, 1:] 149 | cost = self.criterion(preds.view(-1, preds.shape[-1]), target_text.contiguous().view(-1)) 150 | 151 | # Update adversarial images 152 | grad = torch.autograd.grad(cost, adv_x, 153 | retain_graph=False, create_graph=False)[0] 154 | grad = grad / torch.mean(torch.abs(grad), dim=(1,2,3), keepdim=True) 155 | grad = grad + momentum*self.decay 156 | momentum = grad 157 | 158 | adv_x = adv_x.detach() + self.alpha*grad.sign() 159 | delta = torch.clamp(adv_x - x, min=-self.eps, max=self.eps) 160 | adv_x = torch.clamp(x + delta, min=0, max=1).detach() 161 | 162 | if iter == self.iter_num and self.best_iter == -1: 163 | print('[!] Attack failed: No optimal results were found in effective iter_num!') 164 | self.best_iter = -1 165 | self.best_img = adv_x.detach().clone() 166 | self.best_delta = delta.data.detach().clone() 167 | self.preds = pred_adv 168 | self.suc = False 169 | 170 | t_end = time.time() 171 | time_each = t_end - t_st 172 | 173 | return self.best_img, self.best_delta, self.best_iter, self.preds, self.suc, time_each 174 | 175 | def find_wm_pos(self, adv_img, input_img, ret_frame_img=False): 176 | pert = np.abs(adv_img - input_img) 177 | pert = (pert > 1e-2) * 255.0 178 | wm_pos_list = [] 179 | frame_img_list = [] 180 | for src in pert: 181 | kernel = np.ones((3, 3), np.uint8) # kernel_size 3*3 182 | dilate = cv2.dilate(src, kernel, iterations=2) 183 | erode = cv2.erode(dilate, kernel, iterations=2) 184 | remove = morphology.remove_small_objects(erode.astype('bool'), min_size=0) 185 | contours, _ = cv2.findContours((remove * 255).astype('uint8'), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 186 | wm_pos, frame_img = [], [] 187 | for cont in contours: 188 | left_point = cont.min(axis=1).min(axis=0) 189 | right_point = cont.max(axis=1).max(axis=0) 190 | wm_pos.append(np.hstack((left_point, right_point))) 191 | if ret_frame_img: 192 | img = cv2.rectangle( 193 | (remove * 255).astype('uint8'), (left_point[0], left_point[1]), 194 | (right_point[0], right_point[1]), (255, 255, 255), 2) 195 | frame_img.append(img) 196 | wm_pos_list.append(wm_pos) 197 | frame_img_list.append(frame_img) 198 | 199 | if ret_frame_img: 200 | return (wm_pos_list, frame_img_list) 201 | else: 202 | return wm_pos_list 203 | 204 | def gen_wm(self, RGB): 205 | generator = GeneratorFromStrings( 206 | strings=['ecml'], 207 | count=1, 208 | fonts=['baselines/Impact.ttf'], # TODO change ['Impact.tff']] 209 | language='en', 210 | size=78, # default: 32 211 | skewing_angle=15, 212 | random_skew=False, 213 | blur=0, 214 | random_blur=False, 215 | background_type=1, # gaussian noise (0), plain white (1), quasicrystal (2) or picture (3) 216 | distorsion_type=0, # None(0), Sine wave(1),Cosine wave(2),Random(3) 217 | distorsion_orientation=0, 218 | is_handwritten=False, 219 | width=-1, 220 | alignment=1, 221 | text_color=RGB2Hex(RGB), 222 | orientation=0, 223 | space_width=1.0, 224 | character_spacing=0, 225 | margins=(0, 0, 0, 0), 226 | fit=True, 227 | ) 228 | img_list = [img for img, _ in generator] 229 | return img_list[0] 230 | 231 | def wm_attack(self, wm_x, raw_text, wm_mask, adv_text_mask): 232 | wm_x = wm_x.clone().detach().to(self.device) 233 | 234 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device) 235 | text, length = self.converter.encode(raw_text, batch_max_length=self.batch_max_length) 236 | pred_org = self.model_pred(wm_x, self.Prediction) 237 | 238 | wm_mask = wm_mask.clone().detach().to(self.device) 239 | adv_text_mask = adv_text_mask.clone().detach().to(self.device) 240 | 241 | momentum = torch.zeros_like(wm_x).detach().to(self.device) 242 | 243 | adv_x = wm_x.clone().detach() 244 | 245 | num = 0 246 | ED_num= 0 247 | time_each = 0 248 | t_st = time.time() 249 | for iter in range(self.iter_num): 250 | adv_x.requires_grad = True 251 | 252 | # erlier stop 253 | pred_adv = self.model_pred(adv_x, self.Prediction) 254 | if pred_adv != pred_org: 255 | ED_num = edit_distance(pred_org, pred_adv) 256 | num += 1 257 | # print('Best results!') 258 | self.best_iter = iter 259 | self.best_img = adv_x.detach().clone() 260 | self.best_delta = delta.detach().clone() 261 | self.preds = pred_adv 262 | self.suc = True 263 | if num == 1: 264 | break 265 | 266 | # Calculate loss 267 | torch.backends.cudnn.enabled=False 268 | preds = self.model(adv_x, text_for_pred).log_softmax(2) 269 | if 'CTC' in self.Prediction: 270 | preds_size = torch.IntTensor([preds.size(1)] * self.batch_size).to(self.device) 271 | preds = preds.permute(1, 0, 2) 272 | cost = self.criterion(preds, text, preds_size, length) 273 | else: # ATTENTION 274 | target_text = text[:, 1:] 275 | cost = self.criterion(preds.view(-1, preds.shape[-1]), target_text.contiguous().view(-1)) 276 | 277 | # Update adversarial images 278 | grad = torch.autograd.grad(cost, adv_x, 279 | retain_graph=False, create_graph=False)[0] 280 | grad = grad / torch.mean(torch.abs(grad), dim=(1,2,3), keepdim=True) 281 | grad = grad + momentum*self.decay 282 | momentum = grad 283 | 284 | adv_x = adv_x.detach() + self.alpha*grad.sign() 285 | delta = torch.clamp(adv_x - wm_x, min=-self.eps, max=self.eps) 286 | delta = torch.mul(delta, adv_text_mask) 287 | delta = torch.mul(delta, wm_mask) 288 | adv_x = torch.clamp(wm_x + delta, min=0, max=1).detach() 289 | 290 | if iter == self.iter_num and self.best_iter == -1: 291 | print('[!] Attack failed: No optimal results were found in effective iter_num!') 292 | ED_num = 0 293 | self.best_iter = -1 294 | self.best_img = adv_x.detach().clone() 295 | self.best_delta = delta.data.detach().clone() 296 | self.preds = pred_adv 297 | self.suc = False 298 | 299 | t_end = time.time() 300 | time_each = t_end - t_st 301 | 302 | return self.best_img, self.best_delta, self.best_iter, self.preds, self.suc, ED_num, time_each 303 | 304 | 305 | -------------------------------------------------------------------------------- /comtest/test_ali.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os,sys 3 | import json 4 | from typing import List 5 | 6 | from alibabacloud_tea_openapi.client import Client as OpenApiClient 7 | from alibabacloud_tea_openapi import models as open_api_models 8 | from alibabacloud_tea_util import models as util_models 9 | from nltk.metrics import edit_distance 10 | from utils import Logger 11 | 12 | def get_file_content(filePath): 13 | with open(filePath, 'rb') as fp: 14 | return fp.read() 15 | 16 | class Sample: 17 | def __init__(self): 18 | pass 19 | 20 | @staticmethod 21 | def create_client( 22 | access_key_id: str, 23 | access_key_secret: str, 24 | ) -> OpenApiClient: 25 | """ 26 | 使用AK&SK初始化账号Client 27 | @param access_key_id: 28 | @param access_key_secret: 29 | @return: Client 30 | @throws Exception 31 | """ 32 | config = open_api_models.Config( 33 | # 必填,您的 AccessKey ID, 34 | access_key_id=access_key_id, 35 | # 必填,您的 AccessKey Secret, 36 | access_key_secret=access_key_secret 37 | ) 38 | # 访问的域名 39 | config.endpoint = f'ocr-api.cn-hangzhou.aliyuncs.com' 40 | return OpenApiClient(config) 41 | 42 | @staticmethod 43 | def create_api_info() -> open_api_models.Params: 44 | """ 45 | API 相关 46 | @param path: params 47 | @return: OpenApi.Params 48 | """ 49 | params = open_api_models.Params( 50 | # 接口名称, 51 | action='RecognizeGeneral', 52 | # 接口版本, 53 | version='2021-07-07', 54 | # 接口协议, 55 | protocol='HTTPS', 56 | # 接口 HTTP 方法, 57 | method='POST', 58 | auth_type='AK', 59 | style='V3', 60 | # 接口 PATH, 61 | pathname=f'/', 62 | # 接口请求体内容格式, 63 | req_body_type='json', 64 | # 接口响应体内容格式, 65 | body_type='json' 66 | ) 67 | return params 68 | 69 | @staticmethod 70 | def main( 71 | imagefiles: List[str], 72 | ) -> None: 73 | # 工程代码泄露可能会导致AccessKey泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考,建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html 74 | client = Sample.create_client('xx', 'xx') 75 | params = Sample.create_api_info() 76 | # runtime options 77 | runtime = util_models.RuntimeOptions() 78 | 79 | fcnt, ED_sum = 0, 0 80 | for imagefile in imagefiles: 81 | img = get_file_content(imagefile) 82 | request = open_api_models.OpenApiRequest(stream=img) 83 | # 复制代码运行请自行打印 API 的返回值 84 | # 返回值为 Map 类型,可从 Map 中获得三类数据:响应体 body、响应头 headers、HTTP 返回的状态码 statusCode 85 | result = client.call_api(params, request, runtime) 86 | data = json.loads(result['body']['Data']) 87 | content = data['content'].strip() 88 | # print(f"file: {file}") 89 | # print(content) 90 | imagename = os.path.basename(imagefile) 91 | true_label = imagename.split("_")[1] 92 | if true_label != content: 93 | fcnt += 1 94 | ED = edit_distance(true_label, content) 95 | else: 96 | ED = 0 97 | print("label:{} ---> result:{}".format(true_label, content)) 98 | ED_sum += ED 99 | 100 | if fcnt != 0: 101 | score = {"DSR":fcnt/len(imagefiles), "ED":ED_sum/fcnt} 102 | else: 103 | score = {"DSR":0, "ED":0} 104 | return score 105 | 106 | if __name__ == '__main__': 107 | log_file= os.path.join('./res-comtest/ali', 'up5a.log') 108 | sys.stdout = Logger(log_file) 109 | 110 | # prepare your own test data 111 | file_path = "xx/adv+" 112 | # file_path = "xx/adv-" 113 | 114 | # get file paths 115 | imagefiles = [os.path.join(file_path,n) for n in os.listdir(file_path)] 116 | imagefiles = imagefiles[:100] 117 | score = Sample.main(imagefiles) 118 | print(score) -------------------------------------------------------------------------------- /comtest/test_azure.py: -------------------------------------------------------------------------------- 1 | # need pip install azure-cognitiveservices-vision-computervision pillow 2 | # 5 call per minute, 5k call per month!!! 3 | from azure.cognitiveservices.vision.computervision import ComputerVisionClient 4 | from azure.cognitiveservices.vision.computervision.models import OperationStatusCodes 5 | from azure.cognitiveservices.vision.computervision.models import VisualFeatureTypes 6 | from msrest.authentication import CognitiveServicesCredentials 7 | 8 | 9 | import os,sys,time 10 | from nltk.metrics import edit_distance 11 | from utils import Logger 12 | 13 | ''' 14 | Authenticate 15 | Authenticates your credentials and creates a client. 16 | ''' 17 | subscription_key = "xx" 18 | endpoint = "xx" 19 | request_interval = 4 20 | 21 | def main(imagefiles): 22 | computervision_client = ComputerVisionClient(endpoint, CognitiveServicesCredentials(subscription_key)) 23 | 24 | fcnt, ED_sum = 0, 0 25 | for imagefile in imagefiles: 26 | imagename = os.path.basename(imagefile) 27 | true_label = imagename.split("_")[1] 28 | # print(f"file: {file}") 29 | with open(imagefile, 'rb') as image_steam: 30 | read_response = computervision_client.recognize_printed_text_in_stream(image_steam, raw=True) 31 | time.sleep(request_interval) 32 | for region in read_response.output.regions: 33 | lines = region.lines 34 | for line in lines: 35 | line_text = " ".join([word.text for word in line.words]) 36 | if true_label != line_text: 37 | fcnt += 1 38 | ED = edit_distance(true_label, line_text) 39 | else: 40 | ED = 0 41 | print("label:{} ---> result:{}".format(true_label, line_text)) 42 | ED_sum += ED 43 | if fcnt != 0: 44 | score = {"DSR":fcnt/len(imagefiles), "ED":ED_sum/fcnt} 45 | else: 46 | score = {"DSR":0, "ED":0} 47 | return score 48 | 49 | 50 | 51 | if __name__ == "__main__": 52 | log_file= os.path.join('./res-comtest/azure', 'up5a.log') 53 | sys.stdout = Logger(log_file) 54 | 55 | # prepare your own test data 56 | file_path = "xx/adv+" 57 | # file_path = "xx/adv-" 58 | 59 | imagefiles = [os.path.join(file_path,n) for n in os.listdir(file_path)] 60 | score = main(imagefiles) 61 | print(score) -------------------------------------------------------------------------------- /comtest/test_baidu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import os,sys 3 | import time 4 | from aip import AipOcr 5 | from nltk.metrics import edit_distance 6 | from utils import Logger 7 | 8 | APP_ID = 'xx' 9 | API_KEY = 'xx' 10 | SECRET_KEY = 'xx' 11 | request_interval = 1 12 | 13 | def get_file_content(filePath): 14 | with open(filePath, 'rb') as fp: 15 | return fp.read() 16 | 17 | def get_score(jsons): 18 | fcnt, ED_sum = 0, 0 19 | for json in jsons: 20 | fcnt += 0 if json["label"] == json["words_result"][0]["words"] else 1 21 | ED_sum += edit_distance(json["label"], json["words_result"][0]["words"]) 22 | if fcnt != 0: 23 | score = {"DSR":fcnt/len(jsons), "ED":ED_sum/fcnt} 24 | else: 25 | score = {"DSR":fcnt/len(jsons), "ED":0} 26 | return score 27 | 28 | 29 | if __name__ == "__main__": 30 | log_file= os.path.join('./res-comtest/baidu', 'up5a.log') 31 | sys.stdout = Logger(log_file) 32 | # create client 33 | client = AipOcr(APP_ID, API_KEY, SECRET_KEY) 34 | 35 | # prepare your own test data 36 | data_path = "xx/adv+" 37 | # data_path = "xx/adv-" 38 | 39 | imagefiles = [os.path.join(data_path,n) for n in os.listdir(data_path)] 40 | imagefiles = imagefiles[:100] 41 | 42 | # upload to baidu ocr 43 | result_jsons = [] 44 | for imagefile in imagefiles: 45 | b_img = get_file_content(imagefile) 46 | answer = client.basicGeneral(b_img) 47 | imagename = os.path.basename(imagefile) 48 | answer["label"] = imagename.split("_")[1] 49 | try: 50 | print("label:{} ---> result:{}".format( 51 | answer["label"],answer["words_result"][0]["words"] 52 | )) 53 | except: 54 | answer["words_result"].append({'words':""}) 55 | print(answer) 56 | result_jsons.append(answer) 57 | time.sleep(request_interval) 58 | 59 | # calc score 60 | score = get_score(result_jsons) 61 | print(score) 62 | 63 | 64 | ''' 65 | basicGeneral: 通用文字识别(标准版) 66 | return json: 67 | { 68 | "words_result": [ 69 | { 70 | "words": "firm" 71 | } 72 | ], 73 | "words_result_num": 1, 74 | "log_id": 1589115807642170340 75 | } 76 | 77 | not found words return json: 78 | return json: 79 | { 80 | "words_result": [], 81 | "words_result_num": 0, 82 | "log_id": 1589115807642170340 83 | } 84 | ''' -------------------------------------------------------------------------------- /comtest/test_huawei.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from huaweicloudsdkcore.auth.credentials import BasicCredentials 3 | from huaweicloudsdkocr.v1.region.ocr_region import OcrRegion 4 | from huaweicloudsdkcore.exceptions import exceptions 5 | from huaweicloudsdkocr.v1 import * 6 | import base64 7 | import os,sys 8 | from nltk.metrics import edit_distance 9 | from utils import Logger 10 | 11 | ak = "xx" 12 | sk = "xx" 13 | 14 | def get_file_content(filePath): 15 | with open(filePath, 'rb') as image_file: 16 | return base64.b64encode(image_file.read()) 17 | 18 | def main(imagefiles): 19 | credentials = BasicCredentials(ak, sk) 20 | 21 | client = OcrClient.new_builder() \ 22 | .with_credentials(credentials) \ 23 | .with_region(OcrRegion.value_of("cn-east-3")) \ 24 | .build() 25 | 26 | fcnt, ED_sum = 0, 0 27 | try: 28 | for imagefile in imagefiles: 29 | encoded_string = get_file_content(imagefile) 30 | # with open(file, "rb") as image_file: 31 | # encoded_string = base64.b64encode(image_file.read()) 32 | request = RecognizeGeneralTextRequest() 33 | request.body = GeneralTextRequestBody(image=encoded_string) 34 | response = client.recognize_general_text(request) 35 | data = response.result.to_dict() 36 | words = [word['words'] for word in data['words_block_list']] 37 | content = "".join(words) 38 | 39 | imagename = os.path.basename(imagefile) 40 | true_label = imagename.split("_")[1] 41 | if true_label != content: 42 | fcnt += 1 43 | ED = edit_distance(true_label, content) 44 | else: 45 | ED = 0 46 | print("label:{} ---> result:{}".format(true_label, content)) 47 | ED_sum += ED 48 | 49 | except exceptions.ClientRequestException as e: 50 | print(e.status_code) 51 | print(e.request_id) 52 | print(e.error_code) 53 | print(e.error_msg) 54 | if fcnt != 0: 55 | score = {"DSR":fcnt/len(imagefiles), "ED":ED_sum/fcnt} 56 | else: 57 | score = {"DSR":0, "ED":0} 58 | return score 59 | 60 | if __name__ == "__main__": 61 | log_file= os.path.join('./res-comtest/huawei', 'up5a.log') 62 | sys.stdout = Logger(log_file) 63 | 64 | # prepare your own test data 65 | file_path = "xx/adv+" 66 | # file_path = "xx/adv-" 67 | 68 | # get file paths 69 | imagefiles = [os.path.join(file_path,n) for n in os.listdir(file_path)] 70 | imagefiles = imagefiles[:1] 71 | score = main(imagefiles) 72 | print(score) 73 | -------------------------------------------------------------------------------- /data/protego/up/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/data/protego/up/5.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import cv2 4 | import torch.utils.data 5 | from torch.utils.data import Dataset 6 | 7 | 8 | """ProtegO dataset""" 9 | def up_dataset(up_path): 10 | up = cv2.imread(up_path) 11 | up = cv2.resize(up, (100, 32)) 12 | up = cv2.cvtColor(up, cv2.COLOR_BGR2RGB) 13 | up = torch.FloatTensor(up) 14 | up = up / 255 # normalization to [0,1] 15 | up = up.permute(2,0,1) # [C, H, W] 16 | 17 | return up 18 | 19 | class train_dataset_builder(Dataset): 20 | def __init__(self, height, width, img_path): 21 | ''' 22 | height: input height to model 23 | width: input width to model 24 | total_img_path: path with all images 25 | seq_len: sequence length 26 | ''' 27 | self.height = height 28 | self.width = width 29 | self.img_path = img_path 30 | self.dataset = [] 31 | 32 | img = [] 33 | for i,j,k in os.walk(self.img_path): 34 | for file in k: 35 | file_name = os.path.join(i ,file) 36 | img.append(file_name) 37 | self.total_img_name = img 38 | 39 | for img_name in self.total_img_name: 40 | _, label, _ = img_name.split('_') 41 | self.dataset.append([img_name, label]) 42 | 43 | 44 | def __getitem__(self, index): 45 | img_name, label = self.dataset[index] 46 | IMG = cv2.imread(img_name) 47 | IMG = cv2.resize(IMG, (self.width, self.height)) # resize 48 | 49 | # binarization processing 50 | gray = cv2.cvtColor(IMG, cv2.COLOR_BGR2GRAY) 51 | _, binary = cv2.threshold(gray,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU) 52 | IMG = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB) 53 | IMG = torch.FloatTensor(IMG) # [H, W, C] 54 | IMG = IMG / 255 # normalization to [0,1] 55 | IMG = IMG.permute(2,0,1) # [C, H, W] 56 | 57 | return IMG, label 58 | 59 | def __len__(self): 60 | return len(self.dataset) 61 | 62 | class test_dataset_builder(Dataset): 63 | def __init__(self, height, width, img_path): 64 | self.height = height 65 | self.width = width 66 | self.img_path = img_path 67 | self.dataset = [] 68 | 69 | img = [] 70 | for i,j,k in os.walk(self.img_path): 71 | for file in k: 72 | file_name = os.path.join(i ,file) 73 | img.append(file_name) 74 | self.total_img_name = img 75 | 76 | for img_name in self.total_img_name: 77 | img_index, label, img_adv = img_name.split('_') 78 | img_adv = img_adv.split('.') 79 | index_or_advlogo = img_adv[0] 80 | self.dataset.append([img_name, label, img_index, index_or_advlogo]) 81 | self.dataset = sorted(self.dataset) 82 | 83 | def __getitem__(self, index): 84 | img_name, label, img_index, index_or_advlogo = self.dataset[index] 85 | IMG = cv2.imread(img_name) 86 | ORG = cv2.resize(IMG, (self.width, self.height)) 87 | 88 | IMG = cv2.cvtColor(ORG, cv2.COLOR_BGR2RGB) 89 | IMG = torch.FloatTensor(IMG) # convert to tensor [H, W, C] 90 | IMG = IMG / 255 91 | IMG = IMG.permute(2,0,1) # [C, H, W] 92 | 93 | gray = cv2.cvtColor(ORG, cv2.COLOR_BGR2GRAY) 94 | _, binary = cv2.threshold(gray,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU) 95 | mask = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB) 96 | mask = torch.FloatTensor(mask) # [H, W, C] 97 | mask = mask / 255 # normalization to [0,1] 98 | mask = mask.permute(2,0,1) # [C, H, W] 99 | 100 | return IMG, label, img_index, index_or_advlogo, img_name, mask 101 | 102 | def __len__(self): 103 | return len(self.dataset) 104 | 105 | class test_adv_dataset(Dataset): 106 | def __init__(self, height, width, img_path): 107 | self.height = height 108 | self.width = width 109 | self.img_path = img_path 110 | self.dataset = [] 111 | 112 | img = [] 113 | for i,j,k in os.walk(self.img_path): 114 | for file in k: 115 | file_name = os.path.join(i ,file) 116 | img.append(file_name) 117 | self.total_img_name = img 118 | 119 | for img_name in self.total_img_name: 120 | img_index, label, img_adv = img_name.split('_') 121 | img_adv = img_adv.split('.') 122 | index_or_advlogo = img_adv[0] 123 | self.dataset.append([img_name, label, img_index, index_or_advlogo]) 124 | self.dataset = sorted(self.dataset) 125 | 126 | def __getitem__(self, index): 127 | img_name, label, img_index, index_or_advlogo = self.dataset[index] 128 | IMG = cv2.imread(img_name) 129 | IMG = cv2.resize(IMG, (self.width, self.height)) 130 | 131 | # binarization processing 132 | gray = cv2.cvtColor(IMG, cv2.COLOR_BGR2GRAY) 133 | _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) 134 | img_b = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB) 135 | img_b = torch.FloatTensor(img_b) 136 | img_b = img_b / 255 # normalization to [0,1] 137 | img_b = img_b.permute(2,0,1) # [C, H, W] 138 | 139 | img = cv2.cvtColor(IMG, cv2.COLOR_BGR2RGB) 140 | img = torch.FloatTensor(img) 141 | img = img /255 # normalization to [0,1] 142 | img = img.permute(2,0,1) # [C, H, W] 143 | 144 | return img_b, img, label, img_index, index_or_advlogo, img_name 145 | 146 | def __len__(self): 147 | return len(self.dataset) 148 | 149 | 150 | """STR models dataset""" 151 | class strdataset(Dataset): 152 | def __init__(self, height, width, total_img_path): 153 | ''' 154 | height: input height to model 155 | width: input width to model 156 | total_img_path: path with all images 157 | seq_len: sequence length 158 | ''' 159 | self.total_img_path = total_img_path 160 | self.height = height 161 | self.width = width 162 | img = [] 163 | self.dataset = [] 164 | 165 | for i,_,k in os.walk(total_img_path): 166 | for file in k: 167 | file_name = os.path.join(i ,file) 168 | img.append(file_name) 169 | self.total_img_name = img 170 | 171 | for img_name in self.total_img_name: 172 | _, label, _ = img_name.split('_') 173 | self.dataset.append([img_name, label]) 174 | 175 | def __getitem__(self, index): 176 | img_name, label = self.dataset[index] 177 | IMG = cv2.imread(img_name) 178 | IMG = cv2.cvtColor(IMG, cv2.COLOR_BGR2RGB) 179 | IMG = cv2.resize(IMG, (self.width, self.height)) # resize 180 | IMG = (IMG - 127.5)/127.5 # normalization to [-1,1] 181 | IMG = torch.FloatTensor(IMG) # convert to tensor [H, W, C] 182 | IMG = IMG.permute(2,0,1) # [C, H, W] 183 | 184 | return IMG, label 185 | 186 | def __len__(self): 187 | return len(self.dataset) -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: protego 2 | channels: 3 | - pytorch 4 | - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 5 | - conda-forge 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 11 | dependencies: 12 | - _libgcc_mutex=0.1=main 13 | - _openmp_mutex=4.5=1_gnu 14 | - absl-py=1.0.0=pyhd8ed1ab_0 15 | - aiohttp=3.7.4.post0=py38h7f8727e_2 16 | - async-timeout=3.0.1=py_1000 17 | - attrs=21.4.0=pyhd8ed1ab_0 18 | - backcall=0.2.0=pyhd3eb1b0_0 19 | - blas=1.0=mkl 20 | - blinker=1.4=py_1 21 | - blosc=1.21.0=h8c45485_0 22 | - brotli=1.0.9=he6710b0_2 23 | - brotlipy=0.7.0=py38h497a2fe_1001 24 | - brunsli=0.1=h2531618_0 25 | - bzip2=1.0.8=h7b6447c_0 26 | - c-ares=1.17.1=h27cfd23_0 27 | - ca-certificates=2021.10.8=ha878542_0 28 | - cachetools=5.0.0=pyhd8ed1ab_0 29 | - certifi=2021.10.8=py38h578d9bd_2 30 | - cffi=1.15.0=py38hd667e15_1 31 | - cfitsio=3.470=hf0d0db6_6 32 | - chardet=4.0.0=py38h578d9bd_3 33 | - charls=2.2.0=h2531618_0 34 | - charset-normalizer=2.0.12=pyhd8ed1ab_0 35 | - click=8.1.3=py38h578d9bd_0 36 | - cloudpickle=2.0.0=pyhd3eb1b0_0 37 | - cryptography=35.0.0=py38ha5dfef3_0 38 | - cudatoolkit=11.3.1=h2bc3f7f_2 39 | - cycler=0.11.0=pyhd3eb1b0_0 40 | - cytoolz=0.11.0=py38h7b6447c_0 41 | - dask-core=2021.10.0=pyhd3eb1b0_0 42 | - decorator=5.1.0=pyhd3eb1b0_0 43 | - ffmpeg=4.3=hf484d3e_0 44 | - freetype=2.11.0=h70c0345_0 45 | - fsspec=2021.10.1=pyhd3eb1b0_0 46 | - giflib=5.2.1=h7b6447c_0 47 | - gmp=6.2.1=h2531618_2 48 | - gnutls=3.6.15=he1e5248_0 49 | - google-auth=2.6.6=pyh6c4a22f_0 50 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 51 | - idna=3.3=pyhd8ed1ab_0 52 | - imagecodecs=2021.8.26=py38h4cda21f_0 53 | - imageio=2.9.0=pyhd3eb1b0_0 54 | - importlib-metadata=4.11.3=py38h578d9bd_1 55 | - intel-openmp=2021.4.0=h06a4308_3561 56 | - ipython=7.29.0=py38hb070fc8_0 57 | - jedi=0.18.0=py38h06a4308_1 58 | - jpeg=9d=h7f8727e_0 59 | - jxrlib=1.1=h7b6447c_2 60 | - krb5=1.19.2=hac12032_0 61 | - lame=3.100=h7b6447c_0 62 | - lcms2=2.12=h3be6417_0 63 | - ld_impl_linux-64=2.35.1=h7274673_9 64 | - lerc=3.0=h295c915_0 65 | - libaec=1.0.4=he6710b0_1 66 | - libcurl=7.80.0=h0b77cf5_0 67 | - libdeflate=1.8=h7f8727e_5 68 | - libedit=3.1.20210910=h7f8727e_0 69 | - libev=4.33=h7f8727e_1 70 | - libffi=3.3=he6710b0_2 71 | - libgcc-ng=9.3.0=h5101ec6_17 72 | - libgfortran-ng=7.5.0=ha8ba4b0_17 73 | - libgfortran4=7.5.0=ha8ba4b0_17 74 | - libgomp=9.3.0=h5101ec6_17 75 | - libiconv=1.15=h63c8f33_5 76 | - libidn2=2.3.2=h7f8727e_0 77 | - libnghttp2=1.46.0=hce63b2e_0 78 | - libpng=1.6.37=hbc83047_0 79 | - libprotobuf=3.15.8=h780b84a_0 80 | - libssh2=1.9.0=h1ba5d50_1 81 | - libstdcxx-ng=9.3.0=hd4cf53a_17 82 | - libtasn1=4.16.0=h27cfd23_0 83 | - libtiff=4.2.0=h85742a9_0 84 | - libunistring=0.9.10=h27cfd23_0 85 | - libuv=1.40.0=h7b6447c_0 86 | - libwebp=1.2.0=h89dd481_0 87 | - libwebp-base=1.2.0=h27cfd23_0 88 | - libzopfli=1.0.3=he6710b0_0 89 | - locket=0.2.1=py38h06a4308_1 90 | - lz4-c=1.9.3=h295c915_1 91 | - markdown=3.3.7=pyhd8ed1ab_0 92 | - matplotlib-base=3.5.0=py38h3ed280b_0 93 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 94 | - mkl=2021.4.0=h06a4308_640 95 | - mkl-service=2.4.0=py38h7f8727e_0 96 | - mkl_fft=1.3.1=py38hd3c417c_0 97 | - mkl_random=1.2.2=py38h51133e4_0 98 | - multidict=5.2.0=py38h7f8727e_2 99 | - munkres=1.1.4=py_0 100 | - ncurses=6.3=heee7806_1 101 | - nettle=3.7.3=hbbd107a_1 102 | - networkx=2.6.3=pyhd3eb1b0_0 103 | - numpy=1.21.2=py38h20f2e39_0 104 | - numpy-base=1.21.2=py38h79a1101_0 105 | - oauthlib=3.2.0=pyhd8ed1ab_0 106 | - olefile=0.46=pyhd3eb1b0_0 107 | - openh264=2.1.0=hd408876_0 108 | - openjpeg=2.4.0=h3ad879b_0 109 | - openssl=1.1.1n=h7f8727e_0 110 | - packaging=21.3=pyhd3eb1b0_0 111 | - parso=0.8.2=pyhd3eb1b0_0 112 | - partd=1.2.0=pyhd3eb1b0_0 113 | - pexpect=4.8.0=pyhd3eb1b0_3 114 | - pickleshare=0.7.5=pyhd3eb1b0_1003 115 | - pip=21.2.4=py38h06a4308_0 116 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 117 | - ptyprocess=0.7.0=pyhd3eb1b0_2 118 | - pyasn1=0.4.8=py_0 119 | - pycparser=2.21=pyhd8ed1ab_0 120 | - pygments=2.10.0=pyhd3eb1b0_0 121 | - pyjwt=2.3.0=pyhd8ed1ab_1 122 | - pyopenssl=22.0.0=pyhd8ed1ab_0 123 | - pysocks=1.7.1=py38h578d9bd_5 124 | - python=3.8.12=h12debd9_0 125 | - python-dateutil=2.8.2=pyhd3eb1b0_0 126 | - python_abi=3.8=2_cp38 127 | - pytorch=1.10.0=py3.8_cuda11.3_cudnn8.2.0_0 128 | - pytorch-mutex=1.0=cuda 129 | - pyu2f=0.1.5=pyhd8ed1ab_0 130 | - pywavelets=1.1.1=py38h7b6447c_2 131 | - pyyaml=6.0=py38h7f8727e_1 132 | - readline=8.1=h27cfd23_0 133 | - requests=2.27.1=pyhd8ed1ab_0 134 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0 135 | - rsa=4.8=pyhd8ed1ab_0 136 | - setuptools=58.0.4=py38h06a4308_0 137 | - six=1.16.0=pyhd3eb1b0_0 138 | - snappy=1.1.8=he6710b0_0 139 | - sqlite=3.36.0=hc218d9a_0 140 | - tensorboard=2.9.0=pyhd8ed1ab_0 141 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0 142 | - tifffile=2021.7.2=pyhd3eb1b0_2 143 | - tk=8.6.11=h1ccaba5_0 144 | - toolz=0.11.2=pyhd3eb1b0_0 145 | - torchaudio=0.10.0=py38_cu113 146 | - torchvision=0.11.1=py38_cu113 147 | - traitlets=5.1.1=pyhd3eb1b0_0 148 | - urllib3=1.26.9=pyhd8ed1ab_0 149 | - wcwidth=0.2.5=pyhd3eb1b0_0 150 | - wheel=0.37.0=pyhd3eb1b0_1 151 | - xz=5.2.5=h7b6447c_0 152 | - yaml=0.2.5=h7b6447c_0 153 | - yarl=1.6.3=py38h27cfd23_0 154 | - zfp=0.5.5=h2531618_6 155 | - zipp=3.8.0=pyhd8ed1ab_0 156 | - zlib=1.2.11=h7b6447c_3 157 | - zstd=1.4.9=haebb681_0 158 | - pip: 159 | - addict==2.4.0 160 | - alibabacloud-credentials==0.3.2 161 | - alibabacloud-endpoint-util==0.0.3 162 | - alibabacloud-gateway-spi==0.0.1 163 | - alibabacloud-openapi-util==0.2.1 164 | - alibabacloud-tea==0.3.1 165 | - alibabacloud-tea-openapi==0.3.7 166 | - alibabacloud-tea-util==0.3.8 167 | - alibabacloud-tea-xml==0.0.2 168 | - anyconfig==0.9.10 169 | - anyio==3.6.1 170 | - arabic-reshaper==2.1.3 171 | - asynctest==0.13.0 172 | - azure-cognitiveservices-vision-computervision==0.9.0 173 | - azure-common==1.1.28 174 | - azure-core==1.26.4 175 | - baidu-aip==2.2.17.0 176 | - beautifulsoup4==4.11.1 177 | - codecov==2.1.12 178 | - colorama==0.4.5 179 | - commonmark==0.9.1 180 | - coverage==6.4.4 181 | - diffimg==0.2.3 182 | - editdistance==0.3.1 183 | - fire==0.5.0 184 | - flake8==5.0.4 185 | - flask==2.2.2 186 | - fonttools==4.28.1 187 | - future==0.18.2 188 | - gevent==21.12.0 189 | - gevent-websocket==0.10.1 190 | - greenlet==1.1.3 191 | - grpcio==1.46.0 192 | - h11==0.12.0 193 | - httpcore==0.15.0 194 | - httpx==0.23.0 195 | - huaweicloudsdkcore==3.1.37 196 | - huaweicloudsdkocr==3.1.37 197 | - imgaug==0.2.8 198 | - iniconfig==1.1.1 199 | - isodate==0.6.1 200 | - isort==5.10.1 201 | - itsdangerous==2.1.2 202 | - jarowinkler==1.2.3 203 | - jinja2==3.1.2 204 | - joblib==1.2.0 205 | - kiwisolver==1.3.2 206 | - kornia==0.6.7 207 | - kwarray==0.6.4 208 | - lanms-neo==1.0.2 209 | - lmdb==1.3.0 210 | - lpips==0.1.4 211 | - markupsafe==2.1.1 212 | - mccabe==0.7.0 213 | - mmcv-full==1.6.2 214 | - mmdet==2.25.2 215 | - mmocr==0.6.2 216 | - model-index==0.1.11 217 | - msrest==0.7.1 218 | - munch==2.5.0 219 | - natsort==8.1.0 220 | - nltk==3.8.1 221 | - opencv-python==4.6.0.66 222 | - openmim==0.3.2 223 | - ordered-set==4.1.0 224 | - pandas==1.4.2 225 | - pillow==9.2.0 226 | - pluggy==1.0.0 227 | - polygon3==3.0.9.1 228 | - protobuf==3.20.1 229 | - py==1.11.0 230 | - pyasn1-modules==0.2.8 231 | - pyclipper==1.1.0.post3 232 | - pycocotools==2.0.5 233 | - pycodestyle==2.9.1 234 | - pyflakes==2.5.0 235 | - pyparsing==3.0.6 236 | - pytest==7.1.3 237 | - pytest-cov==4.0.0 238 | - pytest-runner==6.0.0 239 | - python-bidi==0.4.2 240 | - pytz==2021.3 241 | - rapidfuzz==2.10.2 242 | - regex==2022.10.31 243 | - requests-toolbelt==0.10.1 244 | - rfc3986==1.5.0 245 | - rich==12.5.1 246 | - scikit-image==0.19.3 247 | - scipy==1.4.1 248 | - setuptools-scm==6.0.1 249 | - shapely==1.8.2 250 | - simplejson==3.19.1 251 | - sniffio==1.2.0 252 | - sortedcontainers==2.4.0 253 | - soupsieve==2.3.2.post1 254 | - tabulate==0.8.10 255 | - tensorboard-data-server==0.6.1 256 | - tensorboardx==2.5 257 | - termcolor==2.2.0 258 | - terminaltables==3.1.10 259 | - tomli==2.0.1 260 | - torch-tb-profiler==0.4.0 261 | - torchsummary==1.5.1 262 | - tqdm==4.62.3 263 | - trdg==1.8.0 264 | - typing-extensions==4.3.0 265 | - ubelt==1.2.2 266 | - websocket-client==1.3.3 267 | - websockets==10.3 268 | - werkzeug==2.2.2 269 | - wikipedia==1.4.0 270 | - xdoctest==1.1.0 271 | - yapf==0.32.0 272 | - zope-event==4.5.0 273 | - zope-interface==5.4.0 274 | -------------------------------------------------------------------------------- /gen_udp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from torchvision import utils as vutils 7 | 8 | import models.GAN_models as G_models 9 | from dataset import test_dataset_builder, up_dataset 10 | 11 | def gen_underpaintings(opt, device, Generator_path, adv_output_path1, adv_output_path2, per_output_path, map_output_path): 12 | BOX_MIN = 0 13 | BOX_MAX = 255 14 | 15 | # load the well-trained generator 16 | pretrained_generator_path = os.path.join(Generator_path + '/netG_epoch_' + str(opt.epochs) + '.pth') 17 | pretrained_G = G_models.Generator(opt.img_channel).to(device) 18 | pretrained_G.load_state_dict(torch.load(pretrained_generator_path)) 19 | pretrained_G.eval() 20 | 21 | test_dataset = test_dataset_builder(opt.imgH, opt.imgW, opt.test_path) 22 | test_dataloader = torch.utils.data.DataLoader( 23 | test_dataset, batch_size=1, 24 | shuffle=False, num_workers=4) 25 | 26 | up = up_dataset(opt.up_path).to(device) 27 | up = up.repeat(1,1,1,1) 28 | gui_net = torch.load(opt.dt_model).to(device) 29 | gui_net.eval() 30 | 31 | for i, data in enumerate(test_dataloader, 0): 32 | ori_labels = data[1][0] 33 | img_index = data[3][0] 34 | test_img = data[5] 35 | test_img = test_img.to(device) 36 | mask = test_img.detach().to(device) 37 | if opt.dark: 38 | test_img = (1-test_img) + torch.mul(mask, up) 39 | else: 40 | test_img = torch.mul(test_img, up) 41 | 42 | # gen adv_img 43 | test_map = G_models.guided_net(test_img, gui_net) 44 | vutils.save_image(test_map,"{}/{}_{}_map.png".format(map_output_path, img_index, ori_labels)) 45 | 46 | perturbation = pretrained_G(up) 47 | perturbation = torch.clamp(perturbation, -opt.eps, opt.eps) 48 | vutils.save_image(perturbation, "{}/{}_{}_per.png".format(per_output_path, img_index, ori_labels)) 49 | 50 | permap = G_models.guided_net(perturbation, gui_net) 51 | vutils.save_image(permap, "{}/{}_{}_permap.png".format(map_output_path, img_index, ori_labels)) 52 | vutils.save_image(permap*100, "{}/{}_{}_permap100.png".format(map_output_path, img_index, ori_labels)) 53 | perturbation = torch.mul(mask, perturbation) 54 | 55 | """convert float32 to uint8: 56 | Avoid the effect of float32 on 57 | generating fully complementary frames 58 | """ 59 | perturbation_int = (perturbation*255).type(torch.int8) 60 | adv_img1_uint = (test_img*255).type(torch.uint8) - perturbation_int 61 | adv_img1_uint = torch.clamp(adv_img1_uint, BOX_MIN, BOX_MAX) 62 | adv_img2_uint = (test_img*255).type(torch.uint8) + perturbation_int 63 | adv_img2_uint = torch.clamp(adv_img2_uint, BOX_MIN, BOX_MAX) 64 | print((adv_img1_uint + perturbation_int).equal(adv_img2_uint - perturbation_int)) 65 | 66 | adv1_map = G_models.guided_net(adv_img1_uint/255, gui_net) 67 | adv2_map = G_models.guided_net(adv_img2_uint/255, gui_net) 68 | vutils.save_image(adv1_map,"{}/{}_{}_map-.png".format(map_output_path, img_index, ori_labels)) 69 | vutils.save_image(adv2_map,"{}/{}_{}_map+.png".format(map_output_path, img_index, ori_labels)) 70 | 71 | adv_img1_uint = adv_img1_uint.squeeze(0).permute(1,2,0) 72 | adv_img1_uint = np.uint8(adv_img1_uint.cpu()) 73 | adv_img1 = cv2.cvtColor(adv_img1_uint, cv2.COLOR_RGB2BGR) 74 | adv_img2_uint = adv_img2_uint.squeeze(0).permute(1,2,0) 75 | adv_img2_uint = np.uint8(adv_img2_uint.cpu()) 76 | adv_img2 = cv2.cvtColor(adv_img2_uint, cv2.COLOR_RGB2BGR) 77 | 78 | cv2.imwrite("{}/{}_{}_adv-.png".format(adv_output_path1, img_index, ori_labels), adv_img1) 79 | cv2.imwrite("{}/{}_{}_adv+.png".format(adv_output_path2, img_index, ori_labels), adv_img2) 80 | 81 | 82 | if __name__ == '__main__': 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--test_path', type= str, required=True, help='path of font test dataset') 85 | parser.add_argument('--up_path', type= str, default='data/protego/up/5.png', help='underpaintings path') 86 | parser.add_argument('--dark', action='store_true', help='use dark background and white text.') 87 | parser.add_argument('--dt_model', type=str, default='/models/dbnet++.pth', 88 | help='path of our guided network DBnet++') 89 | parser.add_argument('--batchsize', type= int, default=4, help='batchsize of training ProTegO') 90 | parser.add_argument('--epochs', type= int, default=60, help='epochs of training ProTegO') 91 | parser.add_argument('--eps', type=float, default=40/255, help='maximum perturbation') 92 | parser.add_argument('--use_eh', action='store_true', help='Use enhancement layers') 93 | parser.add_argument('--use_guide', action='store_true', help='use guided network') 94 | parser.add_argument('--img_channel', type=int, default=3, 95 | help='the number of input channel of text images') 96 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 97 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 98 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 99 | parser.add_argument('--character', type=str,default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 100 | parser.add_argument('--sensitive', action='store_false', help='default for sensitive character mode') 101 | 102 | parser.add_argument('--output', default='res-font', help='the path to save all ouput results') 103 | parser.add_argument('--test_output', default='test-out', help='the path to save output of test results') 104 | parser.add_argument('--adv_output', default='adv', help='the path to save adversarial text images') 105 | parser.add_argument('--per_output', default='perturbation', help='the path to save output of adversarial perturbation') 106 | parser.add_argument('--map_output', default='map', help='the path to save mapping results') 107 | parser.add_argument('--train_output', default='train-out', help='the path to save output of intermediate training results') 108 | 109 | parser.add_argument('--saveG', required=True, help='the path to save generator which is used for generated AEs') 110 | 111 | opt = parser.parse_args() 112 | print(opt) 113 | 114 | output_path = opt.output 115 | Generator_path = opt.saveG 116 | 117 | font_name = opt.test_path.split('/')[-1] 118 | test_output_path = os.path.join(output_path, font_name, opt.test_output) 119 | adv_output_path1 = os.path.join(test_output_path, opt.adv_output, 'adv-') 120 | adv_output_path2 = os.path.join(test_output_path, opt.adv_output, 'adv+') 121 | per_output_path = os.path.join(test_output_path, opt.per_output) 122 | map_output_path = os.path.join(test_output_path, opt.map_output) 123 | 124 | 125 | if not os.path.exists(test_output_path): 126 | os.makedirs(test_output_path) 127 | if not os.path.exists(adv_output_path1): 128 | os.makedirs(adv_output_path1) 129 | if not os.path.exists(adv_output_path2): 130 | os.makedirs(adv_output_path2) 131 | if not os.path.exists(per_output_path): 132 | os.makedirs(per_output_path) 133 | if not os.path.exists(map_output_path): 134 | os.makedirs(map_output_path) 135 | 136 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 137 | gen_underpaintings(opt, device, Generator_path, adv_output_path1, adv_output_path2, per_output_path, map_output_path) 138 | 139 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import string 6 | import shutil 7 | import argparse 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.utils.data 11 | import numpy as np 12 | from utils import Logger 13 | from train_protego import run_train 14 | from gen_udp import gen_underpaintings 15 | from test_udp import test_udp 16 | 17 | 18 | """ Basic parameters settings """ 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--manualSeed', type=int, default=3407, 21 | help='Refer to the settings in paper ') 22 | parser.add_argument('--train_path', type= str, required=True, help='path of training dataset') 23 | parser.add_argument('--test_path', type= str, required=True, help='path of test dataset') 24 | parser.add_argument('--up_path', type= str, default= "data/protego/up/5.png", 25 | help='path of the pre-processed underpaintings') 26 | parser.add_argument('--dt_model', type=str, default='models/dbnet++.pth', help='path of our guided network DBnet++') 27 | parser.add_argument('--batchsize', type= int, default=4, help='batchsize of training ProTegO') 28 | parser.add_argument('--epochs', type= int, default=60, help='epochs of training ProTegO') 29 | parser.add_argument('--eps', type=float, default=40/255, help='maximum perturbation') 30 | parser.add_argument('--lambda1', type= float, default=1e-3, help='the weight of hinge_loss') 31 | parser.add_argument('--lambda2', type= float, default=2, help='the weight of guide_loss') 32 | parser.add_argument('--lambda3', type= float, default=1, help='the weight of gan_loss') 33 | parser.add_argument('--lambda4', type= float, default=10, help='the weight of adv_loss') 34 | parser.add_argument('--dark', action='store_true', help='use dark background and white text') 35 | parser.add_argument('--b', action='store_true', help='robust test for both frames of adversarial text images') 36 | parser.add_argument('--use_eh', action='store_true', help='Use enhancement layers') 37 | parser.add_argument('--use_guide', action='store_true', help='use guided network') 38 | 39 | """ Model Architecture """ 40 | parser.add_argument('--str_model', type=str, help="path of pretrainted STR models for evaluation", 41 | default='STR_modules/downloads_models/STARNet-TPS-ResNet-BiLSTM-CTC-sensitive.pth') 42 | parser.add_argument('--Transformation', type=str, default='TPS', help='Transformation stage. None|TPS') 43 | parser.add_argument('--FeatureExtraction', type=str, default='ResNet', help='FeatureExtraction stage. VGG|RCNN|ResNet') 44 | parser.add_argument('--SequenceModeling', type=str, default='BiLSTM', help='SequenceModeling stage. None|BiLSTM') 45 | parser.add_argument('--Prediction', type=str, default='CTC', help='Prediction stage. CTC|Attn') 46 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 47 | parser.add_argument('--input_channel', type=int, default=3, 48 | help='the number of input channel of Feature extractor') 49 | parser.add_argument('--output_channel', type=int, default=512, 50 | help='the number of output channel of Feature extractor') 51 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 52 | 53 | """ Data processing """ 54 | parser.add_argument('--img_channel', type=int, default=3, 55 | help='the number of input channel of text images') 56 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 57 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 58 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 59 | parser.add_argument('--character', type=str,default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 60 | parser.add_argument('--sensitive', action='store_false', help='default for sensitive character mode') 61 | 62 | """ Output settings """ 63 | parser.add_argument('--output', default=f'{time.strftime("res-%m%d-%H%M")}', 64 | help='the path to save all ouput results') 65 | parser.add_argument('--train_output', default='train-out', help='the path to save intermediate training results') 66 | parser.add_argument('--saveG', default='Generators', help='the path to save generators') 67 | parser.add_argument('--loss', default='losses', help='the path to save all training losses') 68 | parser.add_argument('--test_output', default='test-out', help='the path to save output of test results') 69 | parser.add_argument('--adv_output', default='adv', help='the path to save adversarial text images') 70 | parser.add_argument('--per_output', default='perturbation', help='the path to save adversarial perturbation') 71 | parser.add_argument('--map_output', default='map', help='the path to save mapping results') 72 | 73 | opt = parser.parse_args() 74 | # print(opt) 75 | 76 | """ Seed and GPU setting """ 77 | print("Random Seed: ", opt.manualSeed) 78 | random.seed(opt.manualSeed) 79 | np.random.seed(opt.manualSeed) 80 | torch.manual_seed(opt.manualSeed) 81 | torch.cuda.manual_seed(opt.manualSeed) 82 | cudnn.benchmark = True 83 | cudnn.deterministic = True 84 | 85 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 86 | """ vocab / character number configuration """ 87 | if opt.sensitive: 88 | opt.character = string.printable[:62] # use 62 char (0~9, a~z, A~Z) 89 | 90 | """ output configuration """"" 91 | output_path = opt.output 92 | if not os.path.exists(output_path): 93 | os.makedirs(output_path) 94 | 95 | # del all the output directories and files 96 | del_list = os.listdir(output_path) 97 | for f in del_list: 98 | file_path = os.path.join(output_path, f) 99 | if os.path.isfile(file_path): 100 | os.remove(file_path) 101 | elif os.path.isdir(file_path): 102 | shutil.rmtree(file_path) 103 | 104 | """ save all the print content as log """ 105 | log_file= os.path.join(output_path, 'protego.log') 106 | sys.stdout = Logger(log_file) 107 | 108 | """ make all save directories """ 109 | train_adv_path = os.path.join(output_path, opt.train_output, 'adv') 110 | train_per_path = os.path.join(output_path, opt.train_output, 'per') 111 | Generator_path = os.path.join(output_path,opt.saveG) 112 | loss_path = os.path.join(output_path, opt.loss) 113 | test_output_path = os.path.join(output_path, opt.test_output) 114 | adv_output_path1 = os.path.join(test_output_path, opt.adv_output, 'adv-') 115 | adv_output_path2 = os.path.join(test_output_path, opt.adv_output, 'adv+') 116 | per_output_path = os.path.join(test_output_path, opt.per_output) 117 | map_output_path = os.path.join(test_output_path, opt.map_output) 118 | 119 | if not os.path.exists(train_adv_path): 120 | os.makedirs(train_adv_path) 121 | if not os.path.exists(train_per_path): 122 | os.makedirs(train_per_path) 123 | if not os.path.exists(Generator_path): 124 | os.makedirs(Generator_path) 125 | if not os.path.exists(loss_path): 126 | os.makedirs(loss_path) 127 | if not os.path.exists(test_output_path): 128 | os.makedirs(test_output_path) 129 | if not os.path.exists(adv_output_path1): 130 | os.makedirs(adv_output_path1) 131 | if not os.path.exists(adv_output_path2): 132 | os.makedirs(adv_output_path2) 133 | if not os.path.exists(per_output_path): 134 | os.makedirs(per_output_path) 135 | if not os.path.exists(map_output_path): 136 | os.makedirs(map_output_path) 137 | 138 | torch.cuda.synchronize() 139 | time_start = time.time() 140 | 141 | print(opt) 142 | 143 | # train ProTegO 144 | run_train(opt, device, train_adv_path, train_per_path, Generator_path, loss_path) 145 | 146 | # Generate adversarial underpaintings fot text images 147 | gen_start = time.time() 148 | gen_underpaintings(opt, device, Generator_path, adv_output_path1, adv_output_path2, per_output_path, map_output_path) 149 | gen_end = time.time() - gen_start 150 | print('Generation time:' + str(gen_end)) 151 | 152 | # Test ProTegO performance 153 | test_udp(opt, device, adv_output_path1, adv_output_path2, test_output_path) 154 | time_end = time.time() 155 | time_sum = time_end - time_start 156 | print('Total time:' + str(time_sum)) -------------------------------------------------------------------------------- /models/GAN_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def guided_net(x, net): 4 | net.eval() 5 | b_out = net.backbone(x) #resnet50 6 | n_out = net.neck(b_out) 7 | out = net.bbox_head.binarize(n_out) 8 | return out 9 | 10 | class Discriminator(nn.Module): 11 | def __init__(self, image_nc): 12 | super(Discriminator, self).__init__() 13 | # syn90k:3*32*100 14 | model = [ 15 | nn.Conv2d(image_nc, 8, kernel_size=(4,20), stride=2, padding=0, bias=True), 16 | nn.LeakyReLU(0.2), 17 | # :8*15*41 18 | nn.Conv2d(8, 16, kernel_size=(4,20), stride=2, padding=0, bias=True), 19 | nn.BatchNorm2d(16), 20 | nn.LeakyReLU(0.2), 21 | # 16*6*11 22 | nn.Conv2d(16, 32, kernel_size=(5,10), stride=2, padding=0, bias=True), 23 | nn.BatchNorm2d(32), 24 | nn.LeakyReLU(0.2), 25 | nn.Conv2d(32, 1, 1), 26 | nn.Sigmoid() 27 | #1*1*1 28 | ] 29 | self.model = nn.Sequential(*model) 30 | 31 | def forward(self, x): 32 | output = self.model(x) 33 | return output.squeeze() 34 | # output = self.model(x.unsqueeze(1)) #x--->torch.Size([16, 32, 100]),需要增加channel维度, output--->torch.Size([16, 1, 1, 1]) 35 | # return output.squeeze() 36 | 37 | class Generator(nn.Module): 38 | def __init__(self, image_nc): 39 | super(Generator, self).__init__() 40 | 41 | encoder_lis = [ 42 | # syn90k:image_nc*32*100 43 | nn.Conv2d(image_nc, 8, kernel_size=3, stride=1, padding=0, bias=True), 44 | nn.InstanceNorm2d(8), 45 | nn.ReLU(), 46 | # 8*30*98 47 | nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=0, bias=True), 48 | nn.InstanceNorm2d(16), 49 | nn.ReLU(), 50 | # 16*14*48 51 | nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=0, bias=True), 52 | nn.InstanceNorm2d(32), 53 | nn.ReLU(), 54 | # 32*6*23 55 | ] 56 | 57 | bottle_neck_lis = [ResnetBlock(32), 58 | ResnetBlock(32), 59 | ResnetBlock(32), 60 | ResnetBlock(32),] 61 | 62 | decoder_lis = [ 63 | # input 32*6*23 64 | nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=0, bias=False), 65 | nn.InstanceNorm2d(16), 66 | nn.ReLU(), 67 | # state size. 16*13*47 68 | nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=0, bias=False), 69 | nn.InstanceNorm2d(8), 70 | nn.ReLU(), 71 | # state size. 8*27*95 72 | nn.ConvTranspose2d(8, image_nc, kernel_size=6, stride=1, padding=0, bias=False), 73 | nn.Tanh() 74 | # state size. image_nc*32*100 75 | ] 76 | 77 | self.encoder = nn.Sequential(*encoder_lis) 78 | self.bottle_neck = nn.Sequential(*bottle_neck_lis) 79 | self.decoder = nn.Sequential(*decoder_lis) 80 | 81 | def forward(self, x): 82 | x = self.encoder(x) 83 | x = self.bottle_neck(x) 84 | x = self.decoder(x) 85 | 86 | return x 87 | 88 | class ResnetBlock(nn.Module): 89 | def __init__(self, dim, padding_type='zero', norm_layer=nn.BatchNorm2d, use_dropout=False, use_bias=False): 90 | super(ResnetBlock, self).__init__() 91 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 92 | 93 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 94 | 95 | conv_block = [] 96 | p = 0 97 | if padding_type == 'reflect': 98 | conv_block += [nn.ReflectionPad2d(1)] 99 | elif padding_type == 'replicate': 100 | conv_block += [nn.ReplicationPad2d(1)] 101 | elif padding_type == 'zero': 102 | p = 1 103 | else: 104 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 105 | 106 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 107 | norm_layer(dim), 108 | nn.ReLU(True)] 109 | if use_dropout: 110 | conv_block += [nn.Dropout(0.5)] 111 | 112 | p = 0 113 | if padding_type == 'reflect': 114 | conv_block += [nn.ReflectionPad2d(1)] 115 | elif padding_type == 'replicate': 116 | conv_block += [nn.ReplicationPad2d(1)] 117 | elif padding_type == 'zero': 118 | p = 1 119 | else: 120 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 121 | 122 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 123 | norm_layer(dim)] 124 | 125 | return nn.Sequential(*conv_block) 126 | 127 | def forward(self, x): 128 | out = x + self.conv_block(x) 129 | return out -------------------------------------------------------------------------------- /models/enhancement_layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .identity import Identity -------------------------------------------------------------------------------- /models/enhancement_layers/combined.py: -------------------------------------------------------------------------------- 1 | from .identity import Identity 2 | from .transform import Translate, Resize, D_Binarization, R_Binarization 3 | import torch.nn as nn 4 | import random 5 | 6 | class Combined(nn.Module): 7 | def __init__(self, list=None): 8 | super(Combined, self).__init__() 9 | if list is None: 10 | list = [Identity()] 11 | self.list = list 12 | 13 | def forward(self, adv_image): 14 | id = random.randint(0, len(self.list) - 1) 15 | print(f"[+] Batch Combined {self.list[id]}") 16 | return self.list[id](adv_image) 17 | 18 | -------------------------------------------------------------------------------- /models/enhancement_layers/identity.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class Identity(nn.Module): 4 | """ 5 | Identity-mapping noise layer. Does not change the image 6 | """ 7 | def __init__(self): 8 | super(Identity, self).__init__() 9 | 10 | def forward(self, adv_image): 11 | return adv_image 12 | -------------------------------------------------------------------------------- /models/enhancement_layers/transform.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import kornia.geometry.transform as Ktrans 8 | 9 | def image(path): 10 | IMG = cv2.imread(path) 11 | IMG = cv2.resize(IMG, (100, 32)) # resize 12 | 13 | IMG = cv2.cvtColor(IMG, cv2.COLOR_BGR2RGB) 14 | IMG = torch.FloatTensor(IMG) # [H, W, C] 15 | IMG = IMG / 255 # normalization to [0,1] 16 | IMG = IMG.permute(2,0,1) # [C, H, W] 17 | 18 | return IMG.unsqueeze(0) # [B, C, H, W] 19 | 20 | # real binarization 21 | class R_Binarization(nn.Module): 22 | def __init__(self): 23 | super(R_Binarization, self).__init__() 24 | 25 | def RB(self, x): 26 | x_np = x.squeeze_(0).permute(1,2,0).cpu().numpy() #[h,w,c] 27 | maxValue = x_np.max() 28 | x_np = x_np * 255 / maxValue 29 | x_uint = np.uint8(x_np) 30 | gray = cv2.cvtColor(x_uint, cv2.COLOR_RGB2GRAY) 31 | _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) 32 | x_b = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB) 33 | x_b = torch.FloatTensor(x_b) 34 | x_b = x_b / 255 35 | x_b = x_b.permute(2,0,1).unsqueeze_(0) 36 | return x_b.to(x.device) 37 | 38 | def forward(self, x): 39 | enhance_adv = self.RB(x) 40 | return enhance_adv 41 | 42 | # differentiable binarization 43 | class D_Binarization(nn.Module): 44 | def __init__(self): 45 | super(D_Binarization, self).__init__() 46 | self.k = 20 47 | 48 | def DB(self, x): 49 | y = 0.8 *torch.ones_like(x) 50 | return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) 51 | 52 | def forward(self, x): 53 | enhance_adv = self.DB(x) 54 | return enhance_adv 55 | 56 | class Translate(nn.Module): 57 | """ 58 | Translate the image. 59 | """ 60 | def __init__(self): 61 | super(Translate, self).__init__() 62 | i = random.choice([0,1]) 63 | if i == 0: 64 | self.translation = torch.tensor([[4., 4.]]) 65 | else: 66 | self.translation = torch.tensor([[-4., -4.]]) 67 | 68 | def forward(self, x): 69 | device = x.device 70 | enhance_adv = Ktrans.translate(x, self.translation.to(device), mode='bilinear', padding_mode='border', align_corners=True) 71 | 72 | return enhance_adv 73 | 74 | class Resize(nn.Module): 75 | """ 76 | Resize the image. The target size is 77 | """ 78 | 79 | def __init__(self): 80 | super(Resize, self).__init__() 81 | self.resize_rate = 1.60 82 | self.img_h = 32 83 | self.img_w = 100 84 | 85 | def input_diversity(self, x): 86 | img_resize_h = int(self.img_h * self.resize_rate) 87 | img_resize_w = int(self.img_w * self.resize_rate) 88 | 89 | rnd_h = torch.randint(low=self.img_h, high=img_resize_h, size=(1,), dtype=torch.int32) 90 | rnd_w = torch.randint(low=self.img_w , high=img_resize_w, size=(1,), dtype=torch.int32) 91 | rescaled = F.interpolate(x, size=[rnd_h, rnd_w], mode='bilinear', align_corners=False) 92 | h_rem = img_resize_h - rnd_h 93 | w_rem = img_resize_w - rnd_w 94 | pad_top = torch.randint(low=0, high=h_rem.item(), size=(1,), dtype=torch.int32) 95 | pad_bottom = h_rem - pad_top 96 | pad_left = torch.randint(low=0, high=w_rem.item(), size=(1,), dtype=torch.int32) 97 | pad_right = w_rem - pad_left 98 | 99 | padded = F.pad(rescaled, [pad_left.item(), pad_right.item(), pad_top.item(), pad_bottom.item()], value=0) 100 | padded = F.interpolate(padded, size=[self.img_h, self.img_w]) 101 | return padded 102 | 103 | def forward(self, x): 104 | enhance_adv = self.input_diversity(x) 105 | return enhance_adv -------------------------------------------------------------------------------- /models/enhancer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .enhancement_layers.combined import Combined 3 | from .enhancement_layers.identity import Identity 4 | from .enhancement_layers.transform import Resize, Translate, D_Binarization, R_Binarization 5 | 6 | class Enhancer(nn.Module): 7 | """ 8 | This module allows to combine different enhancement layers into a sequential module. 9 | """ 10 | def __init__(self, layers): 11 | super(Enhancer, self).__init__() 12 | for i in range(len(layers)): 13 | layers[i] = eval(layers[i]) 14 | self.enhance = nn.Sequential(*layers) 15 | 16 | def forward(self, adv_image): 17 | enhance_adv = self.enhance(adv_image) 18 | return enhance_adv 19 | -------------------------------------------------------------------------------- /protego.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision import utils as vutils 9 | 10 | import models.GAN_models as G_models 11 | from models.enhancer import Enhancer 12 | from dataset import up_dataset 13 | 14 | """ custom weights initialization called on netG and netD """ 15 | def weights_init(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('Conv') != -1: 18 | nn.init.normal_(m.weight.data, 0.0, 0.02) 19 | elif classname.find('BatchNorm') != -1: 20 | nn.init.normal_(m.weight.data, 1.0, 0.02) 21 | nn.init.constant_(m.bias.data, 0) 22 | 23 | def fix_bn(m): 24 | classname = m.__class__.__name__ 25 | if classname.find('BatchNorm') != -1: 26 | m.eval() 27 | 28 | class framework(): 29 | def __init__(self, 30 | device, 31 | model, 32 | dt_model, 33 | converter, 34 | criterion, 35 | batch_max_length, 36 | up_path, 37 | dark, 38 | batch_size, 39 | image_nc, 40 | height, 41 | width, 42 | eps, 43 | lambda1, 44 | lambda2, 45 | lambda3, 46 | lambda4, 47 | use_eh, 48 | use_guide): 49 | self.device = device 50 | self.model = model 51 | self.dt_model = dt_model 52 | self.converter = converter 53 | self.criterion = criterion 54 | self.batch_max_length = batch_max_length 55 | self.up_path = up_path 56 | self.dark = dark 57 | self.batch_size = batch_size 58 | self.image_nc = image_nc 59 | self.height = height 60 | self.width = width 61 | self.box_min = 0 62 | self.box_max = 1 63 | self.c = 0.1 # user-specified bound 64 | self.eps = eps 65 | self.lambda1 = lambda1 66 | self.lambda2 = lambda2 67 | self.lambda3 = lambda3 68 | self.lambda4 = lambda4 69 | self.use_eh = use_eh 70 | self.use_guide = use_guide 71 | 72 | self.model.apply(fix_bn).to(self.device) 73 | self.netG = G_models.Generator(self.image_nc).to(self.device) 74 | self.netD = G_models.Discriminator(self.image_nc).to(self.device) 75 | # initialize all weights 76 | self.netG.apply(weights_init) 77 | self.netD.apply(weights_init) 78 | # initialize optimizers 79 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=0.001) 80 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=0.001) 81 | 82 | 83 | def train_batch(self, images, labels, up, gui_net, epoch): 84 | # -----------optimize D----------- 85 | if self.dark: 86 | mask = images.detach() 87 | images = (1-images) + torch.mul(mask, up) 88 | else: 89 | mask = images.detach() 90 | images = torch.mul(images, up) 91 | 92 | perturbation = self.netG(up) 93 | perturbation = torch.clamp(perturbation, -self.eps, self.eps) 94 | 95 | if self.use_guide: 96 | map = G_models.guided_net(perturbation, gui_net) 97 | loss_guide = - torch.mean(torch.tanh(map*1000)) 98 | 99 | else: 100 | map = torch.zeros_like(perturbation).to(self.device) 101 | loss_guide = torch.zeros(1).to(self.device) 102 | 103 | perturbation = torch.mul(mask, perturbation) 104 | 105 | if self.use_eh: 106 | mixed_layers = ["Combined([Identity(), Translate(), Resize(), D_Binarization()])"] 107 | print('epoch{}---enhancement_layers{}'.format(epoch, mixed_layers[0])) 108 | enhance = Enhancer(mixed_layers).to(self.device) 109 | adv_images1 = images - perturbation 110 | adv_images1 = enhance(adv_images1) 111 | adv_images1 = torch.clamp(adv_images1, self.box_min, self.box_max) 112 | adv_images2 = images + perturbation 113 | adv_images2 = enhance(adv_images2) 114 | adv_images2 = torch.clamp(adv_images2, self.box_min, self.box_max) 115 | else: 116 | adv_images1 = images - perturbation 117 | adv_images1 = torch.clamp(adv_images1, self.box_min, self.box_max) 118 | adv_images2 = images + perturbation 119 | adv_images2 = torch.clamp(adv_images2, self.box_min, self.box_max) 120 | 121 | self.optimizer_D.zero_grad() 122 | pred_real = self.netD(images) 123 | loss_D_real = F.mse_loss(pred_real, torch.ones_like(pred_real, device=self.device)) 124 | 125 | pred_fake1 = self.netD(adv_images1.detach()) 126 | loss_D_fake1 = F.mse_loss(pred_fake1, torch.zeros_like(pred_fake1, device=self.device)) 127 | pred_fake2 = self.netD(adv_images2.detach()) 128 | loss_D_fake2 = F.mse_loss(pred_fake2, torch.zeros_like(pred_fake2, device=self.device)) 129 | 130 | loss_D_gan = loss_D_fake1 + loss_D_fake2 + loss_D_real 131 | 132 | loss_D_gan.backward() 133 | self.optimizer_D.step() 134 | # -----------optimize G----------- 135 | self.optimizer_G.zero_grad() 136 | 137 | # the hinge Loss part of L (calculate perturbation norm) 138 | loss_perturb = torch.mean(torch.norm(perturbation.view(perturbation.shape[0], -1), 2, dim=1)) 139 | loss_hinge = torch.max(torch.zeros(1, device=self.device), loss_perturb - self.c) 140 | 141 | # the adv Loss part of L 142 | torch.backends.cudnn.enabled=False 143 | """actually, text_for_pred is the param for attention model""" 144 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device) 145 | targets, target_len = self.converter.encode(labels, self.batch_max_length) 146 | 147 | preds1 = self.model(adv_images1, text_for_pred) 148 | preds_size1 = torch.IntTensor([preds1.size(1)] * self.batch_size) 149 | preds1 = preds1.log_softmax(2).permute(1, 0, 2) 150 | loss_adv1 = - self.criterion(preds1, targets, preds_size1, target_len) 151 | preds2 = self.model(adv_images2, text_for_pred) 152 | preds_size2 = torch.IntTensor([preds2.size(1)] * self.batch_size) 153 | preds2 = preds2.log_softmax(2).permute(1, 0, 2) 154 | loss_adv2 = - self.criterion(preds2, targets, preds_size2, target_len) 155 | loss_adv = loss_adv1 + loss_adv2 156 | 157 | # cal G's loss in GAN 158 | pred_fake1 = self.netD(adv_images1) 159 | loss_G_gan1 = F.mse_loss(pred_fake1, torch.ones_like(pred_fake1, device=self.device)) 160 | pred_fake2 = self.netD(adv_images2) 161 | loss_G_gan2 = F.mse_loss(pred_fake2, torch.ones_like(pred_fake2, device=self.device)) 162 | loss_G_gan = loss_G_gan1 + loss_G_gan2 163 | loss_G_gan.backward(retain_graph=True) 164 | 165 | loss_G = self.lambda1*loss_hinge + self.lambda2*loss_guide + self.lambda3*loss_G_gan + self.lambda4*loss_adv 166 | self.model.zero_grad() 167 | loss_G.backward() 168 | self.optimizer_G.step() 169 | 170 | return loss_G.item(), loss_D_gan.item(), loss_G_gan.item(), loss_hinge.item(), loss_adv.item(), loss_guide.item(), \ 171 | map, perturbation, adv_images1, adv_images2 172 | 173 | def train(self, train_dataloader, epochs, train_adv_path, train_per_path, Generator_path, loss_path): 174 | 175 | loss_G, loss_D_gan, loss_G_gan, loss_hinge, loss_adv, loss_guide= [], [], [], [], [], [] 176 | 177 | if self.use_eh and self.use_guide: 178 | print("==> Use enhancement and guidance module !") 179 | elif self.use_eh: 180 | print("==> ONLY Use enhancement layer ...") 181 | elif self.use_guide: 182 | print("==> ONLY Use guided network ...") 183 | else: 184 | print("Do not use any trick !") 185 | 186 | up = up_dataset(self.up_path) 187 | up = up.repeat(self.batch_size,1,1,1).to(self.device) 188 | # up = torch.clamp(up, self.eps, 1-self.eps) 189 | # vutils.save_image(up, "{}/up.png".format(loss_path)) # TODO remove when release 190 | 191 | print('Loading text detection model from \"%s\" as our guided net!' % self.dt_model) 192 | gui_net = torch.load(self.dt_model).to(self.device) 193 | gui_net.eval() 194 | 195 | for epoch in range(1, epochs+1): 196 | if epoch == 20: 197 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 198 | lr=0.0001) 199 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 200 | lr=0.0001) 201 | if epoch == 40: 202 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 203 | lr=0.00001) 204 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 205 | lr=0.00001) 206 | if epoch == 60: 207 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 208 | lr=0.000001) 209 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 210 | lr=0.000001) 211 | 212 | loss_G_sum, loss_D_gan_sum, loss_G_gan_sum, loss_hinge_sum, loss_adv_sum, loss_guide_sum = 0, 0, 0, 0, 0, 0 213 | 214 | for i_batch, data in enumerate(train_dataloader, start=0): 215 | images, labels = data 216 | images = images.to(self.device) 217 | 218 | loss_G_batch, loss_D_gan_batch, loss_G_gan_batch, \ 219 | loss_hinge_batch, loss_adv_batch, loss_guide_batch,\ 220 | map, perturbation, adv_images1, adv_images2 = self.train_batch(images, labels, up, gui_net, epoch) 221 | 222 | loss_G_sum += loss_G_batch 223 | loss_D_gan_sum += loss_D_gan_batch 224 | loss_G_gan_sum += loss_G_gan_batch 225 | loss_hinge_sum += loss_hinge_batch 226 | loss_adv_sum += loss_adv_batch 227 | loss_guide_sum += loss_guide_batch 228 | 229 | vutils.save_image(adv_images1, "{}/{}_{}_adv-.png".format(train_adv_path, epoch, i_batch)) 230 | vutils.save_image(adv_images2, "{}/{}_{}_adv+.png".format(train_adv_path, epoch, i_batch)) 231 | vutils.save_image(map*1000, "{}/{}_{}map.png".format(train_per_path, epoch, i_batch)) 232 | vutils.save_image(perturbation, "{}/{}_{}per.png".format(train_per_path, epoch, i_batch)) 233 | 234 | # print statistics 235 | batch_size = len(train_dataloader) 236 | print('epoch {}: \nloss G: {}, \n\tloss_G_gan: {}, \n\tloss_hinge: {}, \n\tloss_adv: {}, \n\tloss_guide: {}, \nloss D_gan: {}\n'.format( 237 | epoch, 238 | loss_G_sum/batch_size, 239 | loss_G_gan_sum/batch_size, 240 | loss_hinge_sum/batch_size, 241 | loss_adv_sum/batch_size, 242 | loss_guide_sum/batch_size, 243 | loss_D_gan_sum/batch_size, 244 | )) 245 | 246 | loss_G.append(loss_G_sum / batch_size) 247 | loss_D_gan.append( loss_D_gan_sum / batch_size) 248 | loss_G_gan.append(loss_G_gan_sum / batch_size) 249 | loss_hinge.append(loss_hinge_sum / batch_size) 250 | loss_adv.append(loss_adv_sum / batch_size) 251 | loss_guide.append(loss_guide_sum / batch_size) 252 | 253 | # save generator 254 | if epoch % 2== 0: 255 | netG_file_name = Generator_path + '/netG_epoch_' + str(epoch) + '.pth' 256 | torch.save(self.netG.state_dict(), netG_file_name) 257 | 258 | plt.figure() 259 | plt.plot(loss_G) 260 | plt.title("loss_G") 261 | plt.savefig(loss_path + '/loss_G.png') 262 | 263 | plt.figure() 264 | plt.plot(loss_D_gan) 265 | plt.title("loss_D_gan") 266 | plt.savefig(loss_path + '/loss_D_gan.png') 267 | 268 | plt.figure() 269 | plt.plot(loss_G_gan) 270 | plt.title("loss_G_gan") 271 | plt.savefig(loss_path + '/loss_G_gan.png') 272 | 273 | plt.figure() 274 | plt.plot(loss_hinge) 275 | plt.title("loss_hinge") 276 | plt.savefig(loss_path + '/loss_hinge.png') 277 | 278 | plt.figure() 279 | plt.plot(loss_adv) 280 | plt.title("loss_adv") 281 | plt.savefig(loss_path + '/loss_adv.png') 282 | 283 | plt.figure() 284 | plt.plot(loss_guide) 285 | plt.title("loss_guide") 286 | plt.savefig(loss_path + '/loss_guide.png') 287 | 288 | plt.close('all') -------------------------------------------------------------------------------- /test_udp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import shutil 5 | import string 6 | import torch 7 | from utils import CTCLabelConverter, AttnLabelConverter, Logger 8 | from dataset import test_adv_dataset 9 | from STR_modules.model import Model 10 | from models.enhancer import Enhancer 11 | from nltk.metrics import edit_distance 12 | 13 | 14 | 15 | def fix_bn(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('BatchNorm') != -1: 18 | m.eval() 19 | def process_line(line): 20 | adv_img_path, recog_result = line.split(':') 21 | label, adv_preds = recog_result.split('--->') 22 | adv_preds = adv_preds.strip('\n') 23 | return adv_preds, label, adv_img_path 24 | 25 | def test_udp(opt, device, adv_output_path1, adv_output_path2, test_output_path): 26 | batch_size = 1 27 | save_success_adv = os.path.join(test_output_path , 'attack-success-adv') 28 | save_binary = os.path.join(test_output_path , 'binary_adv') 29 | attack_success_result1 = os.path.join(test_output_path , 'attack_success_result1.txt') 30 | attack_success_result2 = os.path.join(test_output_path , 'attack_success_result2.txt') 31 | if not os.path.exists(save_success_adv): 32 | os.makedirs(save_success_adv) 33 | if not os.path.exists(save_binary): 34 | os.makedirs(save_binary) 35 | 36 | """ model configuration """ 37 | if 'CTC' in opt.Prediction: 38 | converter = CTCLabelConverter(opt.character) 39 | else: 40 | converter = AttnLabelConverter(opt.character) 41 | opt.num_class = len(converter.character) 42 | 43 | model = Model(opt).to(device) 44 | print('Loading a STR model from \"%s\" as the target model!' % opt.str_model) 45 | model.load_state_dict(torch.load(opt.str_model, map_location=device),strict=False) 46 | model.eval() 47 | model.apply(fix_bn) 48 | 49 | mixed_layers = ["Combined([Identity(), Translate(), D_Binarization()])"] 50 | robust_test = Enhancer(mixed_layers).to(device) 51 | 52 | test_dataset1= test_adv_dataset(opt.imgH, opt.imgW, adv_output_path1) 53 | test_dataloader1= torch.utils.data.DataLoader( 54 | test_dataset1, 55 | batch_size=batch_size, 56 | shuffle=False, 57 | num_workers=1) 58 | 59 | result1 = dict() # adv- 60 | for i, data in enumerate(test_dataloader1): 61 | adv_img1 = data[1] 62 | label1 = data[2] 63 | adv_index1 = data[3][0] 64 | adv_path1 = data[5][0] 65 | if opt.b: 66 | adv_img1 = robust_test(adv_img1) 67 | else: 68 | adv_img1 69 | 70 | adv_img1= adv_img1.to(device) 71 | length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) 72 | text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) 73 | if 'CTC' in opt.Prediction: 74 | preds1 = model(adv_img1, text_for_pred).log_softmax(2) 75 | preds_size1 = torch.IntTensor([preds1.size(1)] * batch_size) 76 | _, preds_index1 = preds1.permute(1, 0, 2).max(2) 77 | preds_index1 = preds_index1.transpose(1, 0).contiguous().view(-1) 78 | preds_output1 = converter.decode(preds_index1.data, preds_size1) 79 | preds_output1 = preds_output1[0] 80 | result1[adv_index1] = '{}:{}--->{}\n'.format(adv_path1, label1[0], preds_output1) 81 | else: # Attention 82 | preds1 = model(adv_img1, text_for_pred, is_train=False) 83 | _, preds_index1 = preds1.max(2) 84 | preds_output1 = converter.decode(preds_index1, length_for_pred) 85 | preds_output1 = preds_output1[0] 86 | preds_output1 = preds_output1[:preds_output1.find('[s]')] 87 | result1[adv_index1] = '{}:{}--->{}\n'.format(adv_path1, label1[0], preds_output1) 88 | result1 = sorted(result1.items(), key=lambda x:x[0]) 89 | with open(attack_success_result1, 'w+') as f: 90 | for item in result1: 91 | f.write(item[1]) 92 | 93 | 94 | test_dataset2= test_adv_dataset(opt.imgH, opt.imgW, adv_output_path2) 95 | test_dataloader2= torch.utils.data.DataLoader( 96 | test_dataset2, 97 | batch_size=batch_size, 98 | shuffle=False, 99 | num_workers=4) 100 | result2 = dict() # adv+ 101 | for i, data in enumerate(test_dataloader2): 102 | adv_img2= data[1] 103 | label2 = data[2] 104 | adv_index2 = data[3][0] 105 | adv_path2 = data[5][0] 106 | 107 | if opt.b: 108 | adv_img2 = robust_test(adv_img2) 109 | else: 110 | adv_img2 111 | 112 | adv_img2= adv_img2.to(device) 113 | if 'CTC' in opt.Prediction: 114 | preds2 = model(adv_img2, text_for_pred).log_softmax(2) 115 | preds_size2 = torch.IntTensor([preds2.size(1)] * batch_size) 116 | _, preds_index2 = preds2.permute(1, 0, 2).max(2) 117 | preds_index2 = preds_index2.transpose(1, 0).contiguous().view(-1) 118 | preds_output2 = converter.decode(preds_index2.data, preds_size2) 119 | preds_output2 = preds_output2[0] 120 | result2[adv_index2] = '{}:{}--->{}\n'.format(adv_path2, label2[0], preds_output2) 121 | else: 122 | preds2 = model(adv_img2, text_for_pred, is_train=False) 123 | _, preds_index2 = preds2.max(2) 124 | preds_output2 = converter.decode(preds_index2, length_for_pred) 125 | preds_output2 = preds_output2[0] 126 | preds_output2 = preds_output2[:preds_output2.find('[s]')] 127 | result2[adv_index2] = '{}:{}--->{}\n'.format(adv_path2, label2[0], preds_output2) 128 | result2 = sorted(result2.items(), key=lambda x:x[0]) 129 | with open(attack_success_result2, 'w+') as f: 130 | for item in result2: 131 | f.write(item[1]) 132 | 133 | # calculate ASR 134 | with open(attack_success_result1, 'r') as f: 135 | alladv1 = f.readlines() 136 | with open(attack_success_result2, 'r') as f: 137 | alladv2 = f.readlines() 138 | 139 | attack_success_num,asc1,asc2, = 0, 0, 0 140 | 141 | ED_num1, ED_num2 = 0, 0 142 | for line1, line2 in zip(alladv1, alladv2): 143 | adv_preds1,label1,adv_img_path1 = process_line(line1) 144 | adv_preds2,label2,adv_img_path2 = process_line(line2) 145 | if adv_preds1 != label1: 146 | asc1 +=1 147 | if adv_preds2 != label2: 148 | asc2 +=1 149 | if adv_preds1 != label1 and adv_preds2 != label2: 150 | ED_num1 += edit_distance(label1, adv_preds1) 151 | ED_num2 += edit_distance(label2, adv_preds2) 152 | attack_success_num += 1 153 | shutil.copy(adv_img_path1, save_success_adv) 154 | shutil.copy(adv_img_path2, save_success_adv) 155 | print("***********Test Finished !***********") 156 | psr1 = asc1 / len(test_dataset1) 157 | psr2 = asc2 / len(test_dataset2) 158 | attack_success_rate = attack_success_num / len(test_dataset1) 159 | print('PSR1:{:.2%}'.format(psr1)) 160 | print('PSR2:{:.2%}'.format(psr2)) 161 | print('PSR:{:.2%}'.format(attack_success_rate)) 162 | if attack_success_num != 0: 163 | ED_num1_avr = ED_num1 / attack_success_num 164 | ED_num2_avr = ED_num2 / attack_success_num 165 | ED_avr = (ED_num1_avr + ED_num2_avr) / 2 166 | print('Average Edit_distance-: {:.2f}'.format(ED_num1_avr)) 167 | print('Average Edit_distance+: {:.2f}'.format(ED_num2_avr)) 168 | print('Average Edit_distance-2: {:.2f}'.format(ED_avr)) 169 | 170 | 171 | 172 | if __name__ == '__main__': 173 | """transfer attack for black-box models""" 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument('--output', required=True, help='the path of the white-box model (STARNet) results') 176 | parser.add_argument('--STR_name', required=True, help='the path to save ouput results of different models') 177 | parser.add_argument('--saveG', default='Generators', help='the path to save generator which is used for generated AEs') 178 | parser.add_argument('--adv_output', default='adv', help='the path to save adversarial examples results') 179 | parser.add_argument('--per_output', default='perturbation', help='the path to save output of adversarial perturbation') 180 | parser.add_argument('--up_output', default='up', help='the path to save underpainting and mapping results') 181 | parser.add_argument('--b', action='store_true', help='Use binarization processing to test AEs.') 182 | """ Data processing """ 183 | parser.add_argument('--img_channel', type=int, default=3, help='the number of input channel of image') 184 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 185 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 186 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 187 | parser.add_argument('--character', type=str,default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 188 | parser.add_argument('--sensitive', action='store_false', help='default for sensitive character mode') 189 | """ Model Architecture """ 190 | parser.add_argument('--str_model', type=str, required=True, help='the model path of the target model') 191 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') 192 | parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet') 193 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') 194 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 195 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 196 | parser.add_argument('--input_channel', type=int, default=3, 197 | help='the number of input channel of Feature extractor') 198 | parser.add_argument('--output_channel', type=int, default=512, 199 | help='the number of output channel of Feature extractor') 200 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 201 | opt = parser.parse_args() 202 | print(opt) 203 | 204 | """ create new test model output path """ 205 | if opt.b: 206 | test_output_path = os.path.join(opt.output, opt.STR_name, 'RB') 207 | else: 208 | test_output_path = os.path.join(opt.output, opt.STR_name) 209 | if not os.path.exists(test_output_path): 210 | os.makedirs(test_output_path) 211 | 212 | log_file= os.path.join(test_output_path, 'test.log') 213 | sys.stdout = Logger(log_file) 214 | 215 | """ already exist """ 216 | Generator_path = os.path.join(opt.output, opt.saveG) 217 | adv_output_path1 = os.path.join(opt.output, 'test-out', opt.adv_output, 'adv-') 218 | adv_output_path2 = os.path.join(opt.output, 'test-out', opt.adv_output, 'adv+') 219 | per_output_path = os.path.join(opt.output, 'test-out', opt.per_output) 220 | up_output_path = os.path.join(opt.output, 'test-out', opt.up_output) 221 | 222 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 223 | 224 | """ vocab / character number configuration """ 225 | if opt.sensitive: 226 | opt.character = string.printable[:62] # use 62 char (0~9, a~z, A~Z) 227 | 228 | test_udp(opt, device, adv_output_path1, adv_output_path2, test_output_path) -------------------------------------------------------------------------------- /train_protego.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from protego import framework 4 | from STR_modules.model import Model 5 | 6 | from dataset import train_dataset_builder 7 | from utils import CTCLabelConverter, AttnLabelConverter 8 | 9 | 10 | def run_train(opt, device, train_adv_path, train_per_path, Generator_path, loss_path): 11 | """ data preparing """ 12 | train_dataset = train_dataset_builder(opt.imgH, opt.imgW, opt.train_path) 13 | train_dataloader = torch.utils.data.DataLoader( 14 | train_dataset, batch_size=opt.batchsize, 15 | shuffle=True, num_workers=4, 16 | drop_last=True, pin_memory=True) 17 | 18 | """ model configuration """ 19 | if 'CTC' in opt.Prediction: 20 | converter = CTCLabelConverter(opt.character) 21 | else: 22 | converter = AttnLabelConverter(opt.character) 23 | opt.num_class = len(converter.character) 24 | 25 | model = Model(opt).to(device) 26 | print('Loading STR pretrained model from %s' % opt.str_model) 27 | model.load_state_dict(torch.load(opt.str_model, map_location=device),strict=False) 28 | model.eval() 29 | 30 | """ setup loss """ 31 | if 'CTC' in opt.Prediction: 32 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 33 | else: 34 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 35 | 36 | """ attack setting """ 37 | ProTegO = framework(device, model, opt.dt_model, converter, criterion, opt.batch_max_length, 38 | opt.up_path, opt.dark, opt.batchsize, opt.img_channel, opt.imgH, opt.imgW, 39 | opt.eps, opt.lambda1, opt.lambda2, opt.lambda3, opt.lambda4, 40 | opt.use_eh, opt.use_guide) 41 | 42 | # train 43 | ProTegO.train(train_dataloader, opt.epochs, train_adv_path, train_per_path, Generator_path, loss_path) 44 | 45 | 46 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | import sys 4 | import torch 5 | import numpy as np 6 | from PIL import Image 7 | from torchvision import transforms 8 | 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | class Logger(object): 13 | def __init__(self, filename = "train.log"): 14 | self.terminal =sys.stdout 15 | self.log = open(filename,"w") 16 | 17 | def write(self, message): 18 | self.terminal.write(message) 19 | self.log.write(message) 20 | 21 | def flush(self): 22 | pass 23 | 24 | class Averager(object): 25 | """Compute average for torch.Tensor, used for loss average.""" 26 | 27 | def __init__(self): 28 | self.reset() 29 | 30 | def add(self, v): 31 | count = v.data.numel() 32 | v = v.data.sum() 33 | self.n_count += count 34 | self.sum += v 35 | 36 | def reset(self): 37 | self.n_count = 0 38 | self.sum = 0 39 | 40 | def val(self): 41 | res = 0 42 | if self.n_count != 0: 43 | res = self.sum / float(self.n_count) 44 | return res 45 | 46 | """ utils for fawa """ 47 | def tensor2np(img): 48 | trans = transforms.Grayscale(1) 49 | img = trans(img).squeeze(0) 50 | return img.detach().cpu().numpy() 51 | def np2tensor(img: np.array): 52 | if len(img.shape) == 2: 53 | img_tensor = torch.from_numpy(img).float() # bool to float 54 | img_tensor = img_tensor.unsqueeze_(0).repeat(3, 1, 1) 55 | if len(img.shape) == 3: 56 | img_tensor = torch.from_numpy(img).float() 57 | img_tensor = img_tensor.permute(2,0,1) 58 | if torch.max(img_tensor) <= 1: 59 | img_tensor = img_tensor * 255 60 | img_tensor = img_tensor / 255. 61 | return img_tensor.unsqueeze_(0).to(device) # add batch dim 62 | 63 | def get_text_mask(img: np.array): 64 | if img.max() <= 1: 65 | return img < 1 / 1.25 66 | else: 67 | return img < 255 / 1.25 68 | 69 | def cvt2Image(array): 70 | if array.max() <= 0.5: 71 | return Image.fromarray(((array + 0.5) * 255).astype('uint8')) 72 | elif array.max() <= 1: 73 | return Image.fromarray((array * 255).astype('uint8')) 74 | elif array.max() <= 255: 75 | return Image.fromarray(array.astype('uint8')) 76 | 77 | def RGB2Hex(RGB): # RGB is a 3-tuple 78 | color = '#' 79 | for num in RGB: 80 | color += str(hex(num))[-2:].replace('x', '0').upper() 81 | return color 82 | 83 | def color_map(grayscale): 84 | gray_map = (grayscale - 255*0.299 - 0*0.114) / 0.587 85 | return int(gray_map) 86 | 87 | 88 | class CTCLabelConverter(object): 89 | """ Convert between text-label and text-index """ 90 | 91 | def __init__(self, character): 92 | # character (str): set of the possible characters. 93 | dict_character = list(character) 94 | 95 | self.dict = {} 96 | for i, char in enumerate(dict_character): 97 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss 98 | self.dict[char] = i + 1 99 | 100 | self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) 101 | 102 | def encode(self, text, batch_max_length=25): 103 | """convert text-label into text-index. 104 | input: 105 | text: text labels of each image. Note:in our dataset,label is list, and len=batch_size 106 | batch_max_length: max length of text label in the batch. 25 by default 107 | 108 | output: 109 | text: text index for CTCLoss. [batch_size, batch_max_length] 110 | length: length of each text. [batch_size] 111 | """ 112 | length = [len(s) for s in text] 113 | 114 | # The index used for padding (=0) would not affect the CTC loss calculation. 115 | batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) 116 | for i, t in enumerate(text): 117 | text = list(t) 118 | text = [self.dict[char] for char in text] #index of char in text, shape=[len(text)]单词长度 119 | batch_text[i][:len(text)] = torch.LongTensor(text) 120 | 121 | return (batch_text.to(device), torch.IntTensor(length).to(device)) # [b, 25], list:b(16) 122 | 123 | def decode(self, text_index, length): 124 | """ convert text-index into text-label. """ 125 | texts = [] 126 | index = 0 127 | for l in length: 128 | t = text_index[index:index + l] 129 | 130 | char_list = [] 131 | for i in range(l): 132 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. 133 | char_list.append(self.character[t[i]]) 134 | text = ''.join(char_list) 135 | 136 | texts.append(text) 137 | index += l 138 | return texts 139 | 140 | class AttnLabelConverter(object): 141 | """ Convert between text-label and text-index """ 142 | 143 | def __init__(self, character): 144 | # character (str): set of the possible characters. 145 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. 146 | list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] 147 | list_character = list(character) 148 | self.character = list_token + list_character 149 | 150 | self.dict = {} 151 | for i, char in enumerate(self.character): 152 | # print(i, char) 153 | self.dict[char] = i 154 | 155 | def encode(self, text, batch_max_length=25): 156 | """ convert text-label into text-index. 157 | input: 158 | text: text labels of each image. [batch_size] 159 | batch_max_length: max length of text label in the batch. 25 by default 160 | 161 | output: 162 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. 163 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. 164 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] 165 | """ 166 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. 167 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting 168 | batch_max_length += 1 169 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. 170 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) 171 | for i, t in enumerate(text): 172 | text = list(t) 173 | text.append('[s]') 174 | text = [self.dict[char] for char in text] 175 | batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token 176 | return (batch_text.to(device), torch.IntTensor(length).to(device)) 177 | 178 | def decode(self, text_index, length): 179 | """ convert text-index into text-label. """ 180 | texts = [] 181 | for index, l in enumerate(length): 182 | text = ''.join([self.character[i] for i in text_index[index, :]]) 183 | texts.append(text) 184 | return texts 185 | 186 | 187 | --------------------------------------------------------------------------------