├── LICENSE
├── Misc
├── overview.png
└── show
│ ├── SDL2.dll
│ ├── SDL2_image.dll
│ ├── instructions.md
│ ├── large+.png
│ ├── large-.png
│ ├── large.exe
│ ├── libbrotlicommon.dll
│ ├── libbrotlidec.dll
│ ├── libbrotlienc.dll
│ ├── libgcc_s_seh-1.dll
│ ├── libjpeg-8.dll
│ ├── libjxl.dll
│ ├── liblzma-5.dll
│ ├── libpng16-16.dll
│ ├── libstdc++-6.dll
│ ├── libtiff-5.dll
│ ├── libwebp-7.dll
│ ├── libwinpthread-1.dll
│ ├── libzstd.dll
│ ├── single+.png
│ ├── single-.png
│ ├── single.exe
│ └── zlib1.dll
├── README.md
├── STR_modules
├── feature_extraction.py
├── model.py
├── prediction.py
├── sequence_modeling.py
└── transformation.py
├── STR_test.py
├── STR_train.py
├── baselines
├── Impact.ttf
├── __init__.py
├── fawa.py
├── test_black_models.py
├── transfer.py
├── transfer_attacker.py
└── wm_attacker.py
├── comtest
├── test_ali.py
├── test_azure.py
├── test_baidu.py
└── test_huawei.py
├── data
└── protego
│ └── up
│ └── 5.png
├── dataset.py
├── environment.yaml
├── gen_udp.py
├── main.py
├── models
├── GAN_models.py
├── enhancement_layers
│ ├── __init__.py
│ ├── combined.py
│ ├── identity.py
│ └── transform.py
└── enhancer.py
├── protego.py
├── test_udp.py
├── train_protego.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 YanruHe
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Misc/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/overview.png
--------------------------------------------------------------------------------
/Misc/show/SDL2.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/SDL2.dll
--------------------------------------------------------------------------------
/Misc/show/SDL2_image.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/SDL2_image.dll
--------------------------------------------------------------------------------
/Misc/show/instructions.md:
--------------------------------------------------------------------------------
1 | We provide two executable presentation scripts(single and large) that need to run on the windows system.
2 |
3 | For a single text image, please run the following command:
4 | ```
5 | .\single.exe .\single-.png .\single+.png
6 | ```
7 |
8 | For large-scale text image, please run the following command:
9 | ```
10 | .\large.exe .\large-.png .\large+.png
11 | ```
--------------------------------------------------------------------------------
/Misc/show/large+.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/large+.png
--------------------------------------------------------------------------------
/Misc/show/large-.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/large-.png
--------------------------------------------------------------------------------
/Misc/show/large.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/large.exe
--------------------------------------------------------------------------------
/Misc/show/libbrotlicommon.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libbrotlicommon.dll
--------------------------------------------------------------------------------
/Misc/show/libbrotlidec.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libbrotlidec.dll
--------------------------------------------------------------------------------
/Misc/show/libbrotlienc.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libbrotlienc.dll
--------------------------------------------------------------------------------
/Misc/show/libgcc_s_seh-1.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libgcc_s_seh-1.dll
--------------------------------------------------------------------------------
/Misc/show/libjpeg-8.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libjpeg-8.dll
--------------------------------------------------------------------------------
/Misc/show/libjxl.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libjxl.dll
--------------------------------------------------------------------------------
/Misc/show/liblzma-5.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/liblzma-5.dll
--------------------------------------------------------------------------------
/Misc/show/libpng16-16.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libpng16-16.dll
--------------------------------------------------------------------------------
/Misc/show/libstdc++-6.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libstdc++-6.dll
--------------------------------------------------------------------------------
/Misc/show/libtiff-5.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libtiff-5.dll
--------------------------------------------------------------------------------
/Misc/show/libwebp-7.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libwebp-7.dll
--------------------------------------------------------------------------------
/Misc/show/libwinpthread-1.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libwinpthread-1.dll
--------------------------------------------------------------------------------
/Misc/show/libzstd.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/libzstd.dll
--------------------------------------------------------------------------------
/Misc/show/single+.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/single+.png
--------------------------------------------------------------------------------
/Misc/show/single-.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/single-.png
--------------------------------------------------------------------------------
/Misc/show/single.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/single.exe
--------------------------------------------------------------------------------
/Misc/show/zlib1.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/Misc/show/zlib1.dll
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ProTegO (ACM MM 2023)
2 | This repository provides the official PyTorch implementation of the conference paper "ProTegO: Protect Text Content against OCR Extraction Attack".
3 |
4 | ## Introduction
5 |

6 | The overview of ProTegO. It consists of four main parts: preprocessing of underpainting, adversarial underpainting
7 | generation, robustness enhancement, and visual compensation. The whole pipeline can be trained end-to-end.
8 |
9 | ## Environment Setup
10 | This code is tested with Python3.8, Pytorch = 1.10 and CUDA = 11.3, requiring the following dependencies:
11 |
12 | * opencv-python = 4.6.0.66
13 | * nltk = 3.8.1
14 | * trdg = 1.8.0
15 |
16 | To setup a conda environment, please use the following instructions:
17 | ```
18 | conda env create -f environment.yaml
19 | conda activate protego
20 | ```
21 |
22 |
23 | ## Preparation
24 |
25 | ### Dataset
26 | For training the 5 STR models, we use dataset [MJSynth (MJ)](https://www.robots.ox.ac.uk/~vgg/data/text/), but we do not use lmdb format, please refer to ``` dataset.py ``` for details.
27 |
28 | The overall directory structure should be:
29 | ```
30 | │data/
31 | ├── STR/
32 | │ │ │──train/
33 | │ │ │ ├── 1/
34 | │ │ │ │ ├── 1/
35 | | | | | ├──3_JERKIER_41423.png
36 | | | | | ├──.......
37 | │ │ │──test/
38 | │ │ │ ├── 2698/
39 | │ │ │ │ ├── 1/
40 | │ | | | ├──108_mahavira_46071.png
41 | │ | | | ├──.......
42 | │ │ │──val/
43 | │ │ │ ├── 2425/
44 | │ │ │ │ ├── 1/
45 | │ | | | ├──1_Hey_36013.png
46 | │ | | | ├──.......
47 | ```
48 |
49 | For our method ProTegO, we use [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator) to genrate text image samples. We also provide off-the-peg dataset [here](https://drive.google.com/drive/folders/1TdkwhSM-CMvxG0vcZk6ImAVR-_u9mZU4?usp=drive_link).
50 |
51 | The overall directory structure should be:
52 | ```
53 | │data/
54 | ├── protego/
55 | │ │ │──train/
56 | │ │ │ ├── arial
57 | │ │ │ │ ├── 0_01234_0.png
58 | │ │ │ │ ├── .......
59 | │ │ │ ├── georgia
60 | │ │ │ │ ├── 0_01234_0.png
61 | │ │ │ │ ├── .......
62 | │ │ │ ├── ......
63 | │ │ │──test/
64 | │ │ │ ├──0_unobtainable_0.png
65 | │ │ │ ├──.......
66 | │ │ │──up/
67 | │ │ │ ├──5.png
68 | │ │ │ ├──.......
69 | ```
70 |
71 | ### STR Models
72 | You can download the pretrained models used in our paper directly and put them at ```STR_modules/downloads_models/```. Their checkpoints can be downloaded [here](https://drive.google.com/drive/folders/1ciaBucPd1u0qDTjJe3DuxLvMIMf9R1zu?usp=drive_link).
73 |
74 |
75 | Also, you can train and test your own model by [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark), please run the following command:
76 |
77 | ```
78 | # take STARNet for example:
79 | python STR_train.py --exp_name STARNet --train_mode --train_data data/STR/train --valid_data data/STR/val \
80 | --batch_size 128 --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction CTC
81 |
82 | python STR_test.py --output /res-STR/test --eval_data data/STR/test --name STARNet --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction CTC \
83 | --saved_model ./STR_modules/saved_models/STARNet-TPS-ResNet-BiLSTM-CTC-sensitive/best_accuracy.pth
84 | ```
85 |
86 | In addition, we take [DBnet++](https://arxiv.org/pdf/2202.10304.pdf) as our guided network, and you can [download](https://drive.google.com/file/d/1gmVd5hForfSAqV97PvwcxCzw9H7jaJMU/view?usp=drive_link) and put it at ```models/```.
87 |
88 |
89 |
90 | ## Usage
91 | * First, given a preset underpainting, you need to follow the instructions in Section 3.3 of our paper, then adopt the [APAC](https://github.com/Myndex/apca-w3) algorithm to obtain the suitable underpainting. We also provide some pre-processed underpainting styles, and you can [download](https://drive.google.com/drive/folders/1oOclGkU-9yQBNVAM7wEmT_X6Vxu9Imlf?usp=drive_link) and put them at ```data/protego/up/```.
92 |
93 | * Next, train and test ProTegO in white-box setting, you can directly run the following commands:
94 | ```
95 | # if train with underpainting 5a (style 5 and white font):
96 | python main.py --train_path data/protego/train --test_path data/protego/test --use_eh --use_guide --dark
97 |
98 | # if train with underpainting 5b (style 5 and black font):
99 | python main.py --train_path data/protego/train --test_path data/protego/test --use_eh --use_guide
100 | ```
101 |
102 | * For black-box models evaluation, please run the following command:
103 | ```
104 | # take model CRNN as an example:
105 | python test_udp.py --output res-XX --str_model STR_modules/downloads_models CRNN-None-VGG-BiLSTM-CTC-sensitive.pth --STR_name CRNN --Transformation None --FeatureExtraction VGG --SequenceModeling BiLSTM --Prediction CTC
106 | ```
107 |
108 | * For commercial OCR services evaluation, we take [Baidu OCR API](https://cloud.baidu.com/product/ocr.html) as an example:
109 | ```
110 | python ./comtest/test_baidu.py
111 | ```
112 |
113 | Other commercial OCR API evaluation can be found in ```comtest/```, and more usage details are available on their official websites, such as [Microsoft Azure](https://learn.microsoft.com/zh-cn/azure/ai-services/computer-vision/overview-ocr), [Tencent Alibaba](https://ai.aliyun.com/ocr?spm=a2c4g.252763.0.0.32d53d80xtz0ZX), [Huawei](https://support.huaweicloud.com/ocr/index.html).
114 |
115 | For the built-in commercial text recognition tool provided by smartphones or application (like WeChat), you can use our provided script in ```Misc/show/``` to alternate display the two frames of the protected text images at the current fresh rate
116 | of the monitor. Then you can take random manual screenshots and test them with the OCR tool online.
117 |
118 | * For baseline methods evaluation, we provide [FAWA]() and two general transfer-based methods [SINIFGSM, VMIFGSM]() for test, please run the following command:
119 | ```
120 | # For FAWA:
121 | python baselines/fawa.py
122 |
123 | # For SINIFGSM or VMIFGSM:
124 | python baselines/transfer.py --name SINIFGSM
125 | python baselines/transfer.py --name VMIFGSM
126 |
127 | # For blackbox-models test, take model CRNN and method FAWA as an example:
128 | python baselines/test_black_models.py --attack_name fawa --adv_img xx/res-baselines/fawa-eps40/STARNet/wmadv \
129 | --str_model STR_modules/downloads_models/CRNN-None-VGG-BiLSTM-CTC-sensitive.pth \
130 | --Transformation None --FeatureExtraction VGG --SequenceModeling BiLSTM --Prediction CTC
131 | ```
132 |
133 |
134 | Since the codes of these two methods [What machines see is not what they get (CVPR'20)](https://openaccess.thecvf.com/content_CVPR_2020/papers/Xu_What_Machines_See_Is_Not_What_They_Get_Fooling_Scene_CVPR_2020_paper.pdf) and [The Best Protection is Attack (TIFS'23)](https://ieeexplore.ieee.org/abstract/document/10045728) are not open source, please contact their authors.
135 |
136 |
137 |
138 | ## Acknowledgement
139 | This repo is partially based on [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark), [FAWA](https://github.com/strongman1995/Fast-Adversarial-Watermark-Attack-on-OCR), [What machines see is not what they get](https://openaccess.thecvf.com/content_CVPR_2020/papers/Xu_What_Machines_See_Is_Not_What_They_Get_Fooling_Scene_CVPR_2020_paper.pdf), [The Best Protection is Attack](https://ieeexplore.ieee.org/abstract/document/10045728), [SINIFGSM](https://arxiv.org/pdf/1908.06281.pdf) and [VMIFGSM](https://openaccess.thecvf.com/content/CVPR2021/papers/Wang_Enhancing_the_Transferability_of_Adversarial_Attacks_Through_Variance_Tuning_CVPR_2021_paper.pdf). Thanks for their impressive works!
140 |
141 | ## Citation
142 | If you find this work useful for your research, please cite [our paper](https://dl.acm.org/doi/abs/10.1145/3581783.3612076):
143 | ```
144 | @inproceedings{he2023protego,
145 | title={ProTegO: Protect Text Content against OCR Extraction Attack},
146 | author={He, Yanru and Chen, Kejiang and Chen, Guoqiang and Ma, Zehua and Zhang, Kui and Zhang, Jie and Bian, Huanyu and Fang, Han and Zhang, Weiming and Yu, Nenghai},
147 | booktitle={Proceedings of the 31st ACM International Conference on Multimedia},
148 | pages={7424--7434},
149 | year={2023}
150 | }
151 | ```
152 |
153 | ## License
154 | The code is released under MIT License (see LICENSE file for details).
155 |
156 |
157 |
--------------------------------------------------------------------------------
/STR_modules/feature_extraction.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class VGG_FeatureExtractor(nn.Module):
6 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """
7 |
8 | def __init__(self, input_channel, output_channel=512):
9 | super(VGG_FeatureExtractor, self).__init__()
10 | self.output_channel = [int(output_channel / 8), int(output_channel / 4),
11 | int(output_channel / 2), output_channel] # [64, 128, 256, 512]
12 | self.ConvNet = nn.Sequential(
13 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
14 | nn.MaxPool2d(2, 2), # 64x16x50
15 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True),
16 | nn.MaxPool2d(2, 2), # 128x8x25
17 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25
18 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
19 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25
20 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False),
21 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25
22 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False),
23 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
24 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25
25 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24
26 |
27 | def forward(self, input):
28 | return self.ConvNet(input)
29 |
30 |
31 | class ResNet_FeatureExtractor(nn.Module):
32 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
33 |
34 | def __init__(self, input_channel, output_channel=512):
35 | super(ResNet_FeatureExtractor, self).__init__()
36 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3])
37 |
38 | def forward(self, input):
39 | return self.ConvNet(input)
40 |
41 | class BasicBlock(nn.Module):
42 | expansion = 1
43 |
44 | def __init__(self, inplanes, planes, stride=1, downsample=None):
45 | super(BasicBlock, self).__init__()
46 | self.conv1 = self._conv3x3(inplanes, planes)
47 | self.bn1 = nn.BatchNorm2d(planes)
48 | self.conv2 = self._conv3x3(planes, planes)
49 | self.bn2 = nn.BatchNorm2d(planes)
50 | self.relu = nn.ReLU(inplace=True)
51 | self.downsample = downsample
52 | self.stride = stride
53 |
54 | def _conv3x3(self, in_planes, out_planes, stride=1):
55 | "3x3 convolution with padding"
56 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
57 | padding=1, bias=False)
58 |
59 | def forward(self, x):
60 | residual = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 |
69 | if self.downsample is not None:
70 | residual = self.downsample(x)
71 | out += residual
72 | out = self.relu(out)
73 |
74 | return out
75 |
76 |
77 | class ResNet(nn.Module):
78 |
79 | def __init__(self, input_channel, output_channel, block, layers):
80 | super(ResNet, self).__init__()
81 |
82 | self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
83 |
84 | self.inplanes = int(output_channel / 8)
85 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
86 | kernel_size=3, stride=1, padding=1, bias=False)
87 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
88 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
89 | kernel_size=3, stride=1, padding=1, bias=False)
90 | self.bn0_2 = nn.BatchNorm2d(self.inplanes)
91 | self.relu = nn.ReLU(inplace=True)
92 |
93 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
94 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
95 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
96 | 0], kernel_size=3, stride=1, padding=1, bias=False)
97 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
98 |
99 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
100 | self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
101 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
102 | 1], kernel_size=3, stride=1, padding=1, bias=False)
103 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
104 |
105 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
106 | self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
107 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
108 | 2], kernel_size=3, stride=1, padding=1, bias=False)
109 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
110 |
111 | self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
112 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
113 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
114 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
115 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
116 | 3], kernel_size=2, stride=1, padding=0, bias=False)
117 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
118 |
119 | def _make_layer(self, block, planes, blocks, stride=1):
120 | downsample = None
121 | if stride != 1 or self.inplanes != planes * block.expansion:
122 | downsample = nn.Sequential(
123 | nn.Conv2d(self.inplanes, planes * block.expansion,
124 | kernel_size=1, stride=stride, bias=False),
125 | nn.BatchNorm2d(planes * block.expansion),
126 | )
127 |
128 | layers = []
129 | layers.append(block(self.inplanes, planes, stride, downsample))
130 | self.inplanes = planes * block.expansion
131 | for i in range(1, blocks):
132 | layers.append(block(self.inplanes, planes))
133 |
134 | return nn.Sequential(*layers)
135 |
136 | def forward(self, x):
137 | x = self.conv0_1(x)
138 | x = self.bn0_1(x)
139 | x = self.relu(x)
140 | x = self.conv0_2(x)
141 | x = self.bn0_2(x)
142 | x = self.relu(x)
143 |
144 | x = self.maxpool1(x)
145 | x = self.layer1(x)
146 | x = self.conv1(x)
147 | x = self.bn1(x)
148 | x = self.relu(x)
149 |
150 | x = self.maxpool2(x)
151 | x = self.layer2(x)
152 | x = self.conv2(x)
153 | x = self.bn2(x)
154 | x = self.relu(x)
155 |
156 | x = self.maxpool3(x)
157 | x = self.layer3(x)
158 | x = self.conv3(x)
159 | x = self.bn3(x)
160 | x = self.relu(x)
161 |
162 | x = self.layer4(x)
163 | x = self.conv4_1(x)
164 | x = self.bn4_1(x)
165 | x = self.relu(x)
166 | x = self.conv4_2(x)
167 | x = self.bn4_2(x)
168 | x = self.relu(x)
169 |
170 | return x
171 |
--------------------------------------------------------------------------------
/STR_modules/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from STR_modules.transformation import TPS_SpatialTransformerNetwork
4 | from STR_modules.feature_extraction import VGG_FeatureExtractor, ResNet_FeatureExtractor
5 | from STR_modules.sequence_modeling import BidirectionalLSTM
6 | from STR_modules.prediction import Attention
7 |
8 |
9 | class Model(nn.Module):
10 |
11 | def __init__(self, opt):
12 | super(Model, self).__init__()
13 | self.opt = opt
14 | self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
15 | 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}
16 |
17 | """ Transformation : output is rectified image [batch_size x I_channel_num x I_r_height x I_r_width] """
18 | if opt.Transformation == 'TPS':
19 | self.Transformation = TPS_SpatialTransformerNetwork(
20 | F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
21 | else:
22 | print('No Transformation module specified')
23 |
24 |
25 | """ FeatureExtraction """
26 | if opt.FeatureExtraction == 'VGG':
27 | self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) # 512x1x24
28 | elif opt.FeatureExtraction == 'ResNet':
29 | self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel)
30 | else:
31 | raise Exception('No FeatureExtraction module specified')
32 | self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512
33 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1
34 |
35 | """ Sequence modeling"""
36 | if opt.SequenceModeling == 'BiLSTM':
37 | self.SequenceModeling = nn.Sequential(
38 | BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
39 | BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) # hidden_size = 256, output is [batch_size x T x output_size]
40 | self.SequenceModeling_output = opt.hidden_size
41 | else:
42 | print('No SequenceModeling module specified')
43 | self.SequenceModeling_output = self.FeatureExtraction_output
44 |
45 | """ Prediction """
46 | if opt.Prediction == 'CTC':
47 | self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
48 | elif opt.Prediction == 'Attn':
49 | self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) # batch_size x num_steps x num_classes
50 | else:
51 | raise Exception('Prediction is neither CTC or Attn')
52 |
53 | def forward(self, input, text, is_train=True):
54 | """ Transformation stage """
55 | if not self.stages['Trans'] == "None":
56 | input = self.Transformation(input)
57 |
58 | """ Feature extraction stage """
59 | visual_feature = self.FeatureExtraction(input) # bx512x1x24
60 | visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]=bx24x512x1
61 | visual_feature = visual_feature.squeeze(3) # [b, w, c]=bx24x512
62 |
63 | """ Sequence modeling stage """
64 | if self.stages['Seq'] == 'BiLSTM':
65 | contextual_feature = self.SequenceModeling(visual_feature) # [b, w, hidden]=bx24x256
66 | else:
67 | contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM
68 |
69 | """ Prediction stage """
70 | if self.stages['Pred'] == 'CTC':
71 | prediction = self.Prediction(contextual_feature.contiguous()) #[b, w, class]=bx24x63
72 | else:
73 | prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length)
74 |
75 | return prediction
76 |
--------------------------------------------------------------------------------
/STR_modules/prediction.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6 |
7 | class Attention(nn.Module):
8 |
9 | def __init__(self, input_size, hidden_size, num_classes):
10 | super(Attention, self).__init__()
11 | self.attention_cell = AttentionCell(input_size, hidden_size, num_classes)
12 | self.hidden_size = hidden_size
13 | self.num_classes = num_classes
14 | self.generator = nn.Linear(hidden_size, num_classes)
15 |
16 | def _char_to_onehot(self, input_char, onehot_dim=38):
17 | input_char = input_char.unsqueeze(1)
18 | batch_size = input_char.size(0)
19 | one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device)
20 | one_hot = one_hot.scatter_(1, input_char, 1)
21 | return one_hot
22 |
23 | def forward(self, batch_H, text, is_train=True, batch_max_length=25):
24 | """
25 | input:
26 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels]
27 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO].
28 | output: probability distribution at each step [batch_size x num_steps x num_classes]
29 | """
30 | batch_size = batch_H.size(0)
31 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence.
32 |
33 | output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device)
34 | hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device),
35 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device))
36 |
37 | if is_train:
38 | for i in range(num_steps):
39 | # one-hot vectors for a i-th char. in a batch
40 | char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes)
41 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1})
42 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
43 | output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell)
44 | probs = self.generator(output_hiddens)
45 |
46 | else:
47 | targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token
48 | probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device)
49 |
50 | for i in range(num_steps):
51 | char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes)
52 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
53 | probs_step = self.generator(hidden[0])
54 | probs[:, i, :] = probs_step
55 | _, next_input = probs_step.max(1)
56 | targets = next_input
57 |
58 | return probs # batch_size x num_steps x num_classes
59 |
60 |
61 | class AttentionCell(nn.Module):
62 |
63 | def __init__(self, input_size, hidden_size, num_embeddings):
64 | super(AttentionCell, self).__init__()
65 | self.i2h = nn.Linear(input_size, hidden_size, bias=False)
66 | self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias
67 | self.score = nn.Linear(hidden_size, 1, bias=False)
68 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
69 | self.hidden_size = hidden_size
70 |
71 | def forward(self, prev_hidden, batch_H, char_onehots):
72 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
73 | batch_H_proj = self.i2h(batch_H)
74 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
75 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
76 |
77 | alpha = F.softmax(e, dim=1)
78 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
79 | concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding)
80 | cur_hidden = self.rnn(concat_context, prev_hidden)
81 | return cur_hidden, alpha
82 |
83 |
84 | # utils for label convert
85 | class CTCLabelConverter(object):
86 | """ Convert between text-label and text-index """
87 |
88 | def __init__(self, character):
89 | # character (str): set of the possible characters.
90 | dict_character = list(character)
91 |
92 | self.dict = {}
93 | for i, char in enumerate(dict_character):
94 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
95 | self.dict[char] = i + 1
96 |
97 | self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
98 |
99 | def encode(self, text, batch_max_length=25):
100 | """convert text-label into text-index.
101 | input:
102 | text: text labels of each image. [batch_size]
103 | batch_max_length: max length of text label in the batch. 25 by default
104 |
105 | output:
106 | text: text index for CTCLoss. [batch_size, batch_max_length]
107 | length: length of each text. [batch_size]
108 | """
109 | length = [len(s) for s in text]
110 |
111 | # The index used for padding (=0) would not affect the CTC loss calculation.
112 | batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0)
113 | for i, t in enumerate(text):
114 | text = list(t)
115 | text = [self.dict[char] for char in text]
116 | batch_text[i][:len(text)] = torch.LongTensor(text)
117 | return (batch_text.to(device), torch.IntTensor(length).to(device))
118 |
119 | def decode(self, text_index, length):
120 | """ convert text-index into text-label. """
121 | texts = []
122 | for index, l in enumerate(length):
123 | t = text_index[index, :]
124 |
125 | char_list = []
126 | for i in range(l):
127 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
128 | char_list.append(self.character[t[i]])
129 | text = ''.join(char_list)
130 |
131 | texts.append(text)
132 | return texts
133 |
134 | class AttnLabelConverter(object):
135 | """ Convert between text-label and text-index """
136 |
137 | def __init__(self, character):
138 | # character (str): set of the possible characters.
139 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
140 | list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]']
141 | list_character = list(character)
142 | self.character = list_token + list_character
143 |
144 | self.dict = {}
145 | for i, char in enumerate(self.character):
146 | # print(i, char)
147 | self.dict[char] = i
148 |
149 | def encode(self, text, batch_max_length=25):
150 | """ convert text-label into text-index.
151 | input:
152 | text: text labels of each image. [batch_size]
153 | batch_max_length: max length of text label in the batch. 25 by default
154 |
155 | output:
156 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token.
157 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token.
158 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size]
159 | """
160 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
161 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting
162 | batch_max_length += 1
163 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token.
164 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0)
165 | for i, t in enumerate(text):
166 | text = list(t)
167 | text.append('[s]')
168 | text = [self.dict[char] for char in text]
169 | batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
170 | return (batch_text.to(device), torch.IntTensor(length).to(device))
171 |
172 | def decode(self, text_index, length):
173 | """ convert text-index into text-label. """
174 | texts = []
175 | for index, l in enumerate(length):
176 | text = ''.join([self.character[i] for i in text_index[index, :]])
177 | texts.append(text)
178 | return texts
--------------------------------------------------------------------------------
/STR_modules/sequence_modeling.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class BidirectionalLSTM(nn.Module):
5 |
6 | def __init__(self, input_size, hidden_size, output_size):
7 | super(BidirectionalLSTM, self).__init__()
8 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
9 | self.linear = nn.Linear(hidden_size * 2, output_size)
10 |
11 | def forward(self, input):
12 | """
13 | input : visual feature [batch_size x T x input_size]
14 | output : contextual feature [batch_size x T x output_size]
15 | """
16 | self.rnn.flatten_parameters()
17 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
18 | output = self.linear(recurrent) # batch_size x T x output_size
19 | return output
20 |
--------------------------------------------------------------------------------
/STR_modules/transformation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8 |
9 | class TPS_SpatialTransformerNetwork(nn.Module):
10 | """ Rectification Network of RARE, namely TPS(thin-plate spline) based STN """
11 |
12 | def __init__(self, F, I_size, I_r_size, I_channel_num=1):
13 | """ Based on RARE TPS
14 | input:
15 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
16 | num_fiducial: number of fiducial points, 20 as default
17 | I_size : (height, width) of the input image I
18 | I_r_size : (height, width) of the rectified image I_r
19 | I_channel_num : the number of channels of the input image I
20 | output:
21 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
22 | """
23 | super(TPS_SpatialTransformerNetwork, self).__init__()
24 | self.F = F
25 | self.I_size = I_size
26 | self.I_r_size = I_r_size # = (I_r_height, I_r_width)
27 | self.I_channel_num = I_channel_num
28 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
29 | self.GridGenerator = GridGenerator(self.F, self.I_r_size)
30 |
31 | def forward(self, batch_I):
32 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
33 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
34 | build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
35 |
36 | if torch.__version__ > "1.2.0":
37 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
38 | else:
39 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
40 |
41 | return batch_I_r
42 |
43 |
44 | class LocalizationNetwork(nn.Module):
45 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """
46 |
47 | def __init__(self, F, I_channel_num):
48 | super(LocalizationNetwork, self).__init__()
49 | self.F = F
50 | self.I_channel_num = I_channel_num
51 | self.conv = nn.Sequential(
52 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1,
53 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
54 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
55 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
56 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
57 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
58 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
59 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True),
60 | nn.AdaptiveAvgPool2d(1) # batch_size x 512
61 | )
62 |
63 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
64 | self.localization_fc2 = nn.Linear(256, self.F * 2)
65 |
66 | # Init fc2 in LocalizationNetwork
67 | self.localization_fc2.weight.data.fill_(0)
68 | """ see RARE paper Fig. 6 (a) """
69 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
70 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
71 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
72 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
73 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
74 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
75 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1)
76 |
77 | def forward(self, batch_I):
78 | """
79 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
80 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
81 | """
82 | batch_size = batch_I.size(0)
83 | features = self.conv(batch_I).view(batch_size, -1)
84 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2)
85 | return batch_C_prime
86 |
87 |
88 | class GridGenerator(nn.Module):
89 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """
90 |
91 | def __init__(self, F, I_r_size):
92 | """ Generate P_hat and inv_delta_C for later """
93 | super(GridGenerator, self).__init__()
94 | self.eps = 1e-6
95 | self.I_r_height, self.I_r_width = I_r_size
96 | self.F = F
97 | self.C = self._build_C(self.F) # F x 2
98 | self.P = self._build_P(self.I_r_width, self.I_r_height)
99 | ## for multi-gpu, you need register buffer
100 | self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3
101 | self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3
102 | ## for fine-tuning with different image width, you may use below instead of self.register_buffer
103 | #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3
104 | #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3
105 |
106 | def _build_C(self, F):
107 | """ Return coordinates of fiducial points in I_r; C """
108 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
109 | ctrl_pts_y_top = -1 * np.ones(int(F / 2))
110 | ctrl_pts_y_bottom = np.ones(int(F / 2))
111 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
112 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
113 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
114 | return C # F x 2
115 |
116 | def _build_inv_delta_C(self, F, C):
117 | """ Return inv_delta_C which is needed to calculate T """
118 | hat_C = np.zeros((F, F), dtype=float) # F x F
119 | for i in range(0, F):
120 | for j in range(i, F):
121 | r = np.linalg.norm(C[i] - C[j])
122 | hat_C[i, j] = r
123 | hat_C[j, i] = r
124 | np.fill_diagonal(hat_C, 1)
125 | hat_C = (hat_C ** 2) * np.log(hat_C)
126 | # print(C.shape, hat_C.shape)
127 | delta_C = np.concatenate( # F+3 x F+3
128 | [
129 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
130 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
131 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3
132 | ],
133 | axis=0
134 | )
135 | inv_delta_C = np.linalg.inv(delta_C)
136 | return inv_delta_C # F+3 x F+3
137 |
138 | def _build_P(self, I_r_width, I_r_height):
139 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width
140 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height
141 | P = np.stack( # self.I_r_width x self.I_r_height x 2
142 | np.meshgrid(I_r_grid_x, I_r_grid_y),
143 | axis=2
144 | )
145 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
146 |
147 | def _build_P_hat(self, F, C, P):
148 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
149 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2
150 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
151 | P_diff = P_tile - C_tile # n x F x 2
152 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
153 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F
154 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
155 | return P_hat # n x F+3
156 |
157 | def build_P_prime(self, batch_C_prime):
158 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """
159 | batch_size = batch_C_prime.size(0)
160 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
161 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
162 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
163 | batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2
164 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2
165 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
166 | return batch_P_prime # batch_size x n x 2
167 |
--------------------------------------------------------------------------------
/STR_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import string
5 | import argparse
6 |
7 | import torch
8 | import torch.backends.cudnn as cudnn
9 | import torch.utils.data
10 | import torch.nn.functional as F
11 |
12 | from STR_modules.prediction import CTCLabelConverter, AttnLabelConverter
13 | from STR_modules.model import Model
14 | from dataset import strdataset, train_dataset_builder
15 | from utils import Averager, Logger
16 | from torchvision import utils as vutils
17 |
18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19 |
20 |
21 | def makedirs(path):
22 | if not os.path.exists(path):
23 | try:
24 | os.makedirs(path)
25 | except Exception as e:
26 | print('cannot create dirs: {}'.format(path))
27 | exit(0)
28 |
29 | def validation(model, criterion, evaluation_loader, converter, opt):
30 | """ validation or evaluation """
31 | n_correct = 0
32 |
33 | infer_time = 0
34 | valid_loss_avg = Averager()
35 |
36 | for i, data in enumerate(evaluation_loader):
37 | image_tensors, labels = data
38 | image = image_tensors.to(device)
39 |
40 | # For max length prediction
41 | length_for_pred = torch.IntTensor([opt.batch_max_length] * opt.batch_size).to(device)
42 | text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device)
43 |
44 | text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
45 |
46 | start_time = time.time()
47 | if 'CTC' in opt.Prediction:
48 | preds = model(image, text_for_pred)
49 | forward_time = time.time() - start_time
50 |
51 | # Calculate evaluation loss for CTC deocder.
52 | preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
53 | # permute 'preds' to use CTCloss format
54 | cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
55 |
56 | # Select max probabilty (greedy decoding) then decode index to character
57 | _, preds_index = preds.max(2)
58 | preds_str = converter.decode(preds_index.data, preds_size.data)
59 |
60 | else:
61 | preds = model(image, text_for_pred, is_train=False)
62 | forward_time = time.time() - start_time
63 |
64 | preds = preds[:, :text_for_loss.shape[1] - 1, :]
65 | target = text_for_loss[:, 1:] # without [GO] Symbol
66 | cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))
67 |
68 | # select max probabilty (greedy decoding) then decode index to character
69 | _, preds_index = preds.max(2)
70 | preds_str = converter.decode(preds_index, length_for_pred)
71 | labels = converter.decode(text_for_loss[:, 1:], length_for_loss)
72 |
73 | infer_time += forward_time
74 | valid_loss_avg.add(cost)
75 |
76 | # calculate accuracy & confidence score
77 | preds_prob = F.softmax(preds, dim=2)
78 | preds_max_prob, _ = preds_prob.max(dim=2)
79 | confidence_score_list = []
80 | for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
81 | if 'Attn' in opt.Prediction:
82 | gt = gt[:gt.find('[s]')]
83 | pred_EOS = pred.find('[s]')
84 | pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
85 | pred_max_prob = pred_max_prob[:pred_EOS]
86 |
87 | # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting.
88 | # if opt.sensitive:
89 | # pred = pred.lower()
90 | # gt = gt.lower()
91 | # alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
92 | # out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
93 | # pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
94 | # gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt)
95 | if pred == gt:
96 | n_correct += 1
97 | vutils.save_image(image, "{}/{}_{}_{}.png".format(opt.test_out, i, gt, i)) # 删选正确样本作为测试集
98 |
99 | # if not opt.train_mode:
100 | # print('GoundTruth: %-10s => Prediction: %-10s' % (gt, pred))
101 | if not opt.train_mode:
102 | print('Success:{},\t GoundTruth:{:20} => Prediction:{:20}'.format(pred == gt, gt, pred))
103 |
104 | # calculate confidence score (= multiply of pred_max_prob)
105 | try:
106 | confidence_score = pred_max_prob.cumprod(dim=0)[-1]
107 | except:
108 | confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s])
109 | confidence_score_list.append(confidence_score)
110 | # print(pred, gt, pred==gt, confidence_score)
111 |
112 | accuracy = n_correct / float(len(evaluation_loader)) * 100
113 |
114 | return valid_loss_avg.val(), accuracy, preds_str, confidence_score_list, labels, infer_time, len(evaluation_loader)
115 |
116 |
117 | def test(opt):
118 | """ save all the print content as log """
119 | opt.test_out = os.path.join(opt.output, opt.name)
120 | makedirs(opt.test_out)
121 | # log_file= os.path.join(opt.test_out, 'test.log')
122 | # sys.stdout = Logger(log_file)
123 |
124 | """ model configuration """
125 | if 'CTC' in opt.Prediction:
126 | converter = CTCLabelConverter(opt.character)
127 | else:
128 | converter = AttnLabelConverter(opt.character)
129 | opt.num_class = len(converter.character)
130 |
131 | if opt.rgb:
132 | opt.input_channel = 3
133 | model = Model(opt).to(device)
134 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
135 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
136 | opt.SequenceModeling, opt.Prediction)
137 | # model = torch.nn.DataParallel(model).to(device)
138 |
139 | # load model
140 | print('loading pretrained model from %s' % opt.saved_model)
141 | model.load_state_dict(torch.load(opt.saved_model, map_location=device),strict=False)
142 | # opt.exp_name = '_'.join(opt.saved_model.split('/')[1:])
143 | print(model)
144 |
145 | """ setup loss """
146 | if 'CTC' in opt.Prediction:
147 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
148 | else:
149 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
150 |
151 | """ evaluation """
152 | model.eval()
153 | with torch.no_grad():
154 | # eval_dataset = strdataset(opt.imgH, opt.imgW, opt.eval_data)
155 | eval_dataset = train_dataset_builder(opt.imgH, opt.imgW, opt.eval_data)
156 | evaluation_loader = torch.utils.data.DataLoader(
157 | eval_dataset, batch_size=opt.batch_size,
158 | shuffle=False, num_workers=int(opt.workers),
159 | # drop_last=True, pin_memory=True
160 | )
161 |
162 | _, accuracy_by_best_model, _, _, _, _, _ = validation(
163 | model, criterion, evaluation_loader, converter, opt)
164 | print('SR:', f'{accuracy_by_best_model:0.2f}', '%')
165 |
166 |
167 | if __name__ == '__main__':
168 | parser = argparse.ArgumentParser()
169 | parser.add_argument('--output', required=True, help='Test output path')
170 | parser.add_argument('--name', required=True, help='Test model name')
171 | parser.add_argument('--train_mode', action='store_true', help='defalut is Test mode')
172 | parser.add_argument('--eval_data', type=str, required=True, help='path to evaluation dataset')
173 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
174 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
175 | parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation")
176 | """ Data processing """
177 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
178 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
179 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
180 | parser.add_argument('--rgb', action='store_false', help='use rgb input')
181 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
182 | parser.add_argument('--sensitive', action='store_false', help='for sensitive character mode')
183 | """ Model Architecture """
184 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
185 | parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
186 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
187 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
188 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
189 | parser.add_argument('--input_channel', type=int, default=3, help='the number of input channel of Feature extractor')
190 | parser.add_argument('--output_channel', type=int, default=512,
191 | help='the number of output channel of Feature extractor')
192 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
193 |
194 | opt = parser.parse_args()
195 |
196 | """ vocab / character number configuration """
197 | if opt.sensitive:
198 | opt.character = string.printable[:62] # use 62 char (0~9, a~z, A~Z)
199 |
200 | cudnn.benchmark = True
201 | cudnn.deterministic = True
202 |
203 | test(opt)
--------------------------------------------------------------------------------
/STR_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import random
5 | import string
6 | import argparse
7 |
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | import torch.nn.init as init
11 | import torch.optim as optim
12 | import torch.utils.data
13 | import numpy as np
14 |
15 | from STR_modules.prediction import CTCLabelConverter, AttnLabelConverter
16 | from STR_modules.model import Model
17 | from STR_test import validation
18 | from dataset import strdataset
19 | from utils import Averager
20 |
21 |
22 |
23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24 |
25 | def train(opt):
26 |
27 | """ dataset preparation """
28 | train_dataset = strdataset(opt.imgH, opt.imgW, opt.train_data)
29 | train_loader = torch.utils.data.DataLoader(
30 | train_dataset, batch_size=opt.batch_size,
31 | shuffle=True, num_workers=int(opt.workers),
32 | drop_last=True, pin_memory=True)
33 |
34 |
35 | valid_dataset = strdataset(opt.imgH, opt.imgW, opt.valid_data)
36 | valid_loader = torch.utils.data.DataLoader(
37 | valid_dataset, batch_size=opt.batch_size,
38 | shuffle=False, num_workers=int(opt.workers),
39 | drop_last=True, pin_memory=True)
40 |
41 | """ model configuration """
42 | if 'CTC' in opt.Prediction:
43 | converter = CTCLabelConverter(opt.character)
44 | else:
45 | converter = AttnLabelConverter(opt.character)
46 | opt.num_class = len(converter.character)
47 |
48 | model = Model(opt).to(device)
49 |
50 | # weight initialization
51 | for name, param in model.named_parameters():
52 | if 'localization_fc2' in name:
53 | print(f'Skip {name} as it is already initialized')
54 | continue
55 | try:
56 | if 'bias' in name:
57 | init.constant_(param, 0.0)
58 | elif 'weight' in name:
59 | init.kaiming_normal_(param)
60 | except Exception as e: # for batchnorm.
61 | if 'weight' in name:
62 | param.data.fill_(1)
63 | continue
64 |
65 | # # data parallel for multi-GPU
66 | # model = torch.nn.DataParallel(model).to(device)
67 | model.train()
68 | if opt.saved_model != '':
69 | print(f'loading pretrained model from {opt.saved_model}')
70 | if opt.FT:
71 | model.load_state_dict(torch.load(opt.saved_model), strict=False)
72 | else:
73 | model.load_state_dict(torch.load(opt.saved_model))
74 | print("Model:")
75 | print(model)
76 |
77 | """ setup loss """
78 | if 'CTC' in opt.Prediction:
79 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
80 | else:
81 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
82 | # loss averager
83 | loss_avg = Averager()
84 |
85 | # filter that only require gradient decent
86 | filtered_parameters = []
87 | params_num = []
88 | for p in filter(lambda p: p.requires_grad, model.parameters()):
89 | filtered_parameters.append(p)
90 | params_num.append(np.prod(p.size()))
91 | print('Trainable params num : ', sum(params_num))
92 |
93 | # setup optimizer
94 | if opt.adam:
95 | optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
96 | else:
97 | optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
98 | print("Optimizer:")
99 | print(optimizer)
100 |
101 | """ final options """
102 | # print(opt)
103 | with open(f'./STR_modules/saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
104 | opt_log = '------------ Options -------------\n'
105 | args = vars(opt)
106 | for k, v in args.items():
107 | opt_log += f'{str(k)}: {str(v)}\n'
108 | opt_log += '---------------------------------------\n'
109 | print(opt_log)
110 | opt_file.write(opt_log)
111 |
112 | """ start training """
113 | start_iter = 0
114 | if opt.saved_model != '':
115 | try:
116 | start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
117 | print(f'continue to train, start_iter: {start_iter}')
118 | except:
119 | pass
120 |
121 | start_time = time.time()
122 | best_accuracy = -1
123 | iteration = start_iter
124 | while(True):
125 | # train part
126 | for i, data in enumerate(train_loader):
127 | image_tensors, labels = data
128 | image = image_tensors.to(device)
129 | text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
130 | batch_size = image.size(0)
131 |
132 | if 'CTC' in opt.Prediction:
133 | preds = model(image, text)
134 | preds_size = torch.IntTensor([preds.size(1)] * batch_size)
135 | preds = preds.log_softmax(2).permute(1, 0, 2)
136 | cost = criterion(preds, text, preds_size, length)
137 | else:
138 | preds = model(image, text[:, :-1]) # align with Attention.forward
139 | target = text[:, 1:] # without [GO] Symbol
140 | cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
141 |
142 | model.zero_grad()
143 | cost.backward()
144 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default)
145 | optimizer.step()
146 |
147 | loss_avg.add(cost)
148 |
149 | # validation part
150 | if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0'
151 | elapsed_time = time.time() - start_time
152 | # for log
153 | with open(f'./STR_modules/saved_models/{opt.exp_name}/log_train.txt', 'a') as log:
154 | model.eval()
155 | with torch.no_grad():
156 | valid_loss, current_accuracy, preds, confidence_score, labels, infer_time, length_of_data = validation(
157 | model, criterion, valid_loader, converter, opt)
158 | model.train()
159 |
160 | # training loss and validation loss
161 | loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
162 | loss_avg.reset()
163 |
164 | current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}'
165 |
166 | # keep best accuracy model (on valid dataset)
167 | if current_accuracy > best_accuracy:
168 | best_accuracy = current_accuracy
169 | torch.save(model.state_dict(), f'./STR_modules/saved_models/{opt.exp_name}/best_accuracy.pth')
170 |
171 | best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}'
172 |
173 | loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
174 | print(loss_model_log)
175 | log.write(loss_model_log + '\n')
176 |
177 | # show some predicted results
178 | dashed_line = '-' * 80
179 | head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
180 | predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
181 | for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
182 | if 'Attn' in opt.Prediction:
183 | gt = gt[:gt.find('[s]')]
184 | pred = pred[:pred.find('[s]')]
185 |
186 | predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
187 | predicted_result_log += f'{dashed_line}'
188 | print(predicted_result_log)
189 | log.write(predicted_result_log + '\n')
190 |
191 | # save model per 1e+4 iter.
192 | if (iteration + 1) % 1e+4 == 0:
193 | torch.save(
194 | model.state_dict(), f'./STR_modules/saved_models/{opt.exp_name}/iter_{iteration+1}.pth')
195 |
196 | if (iteration + 1) == opt.num_iter:
197 | print('end the training')
198 | sys.exit()
199 | iteration += 1
200 |
201 | if __name__ == '__main__':
202 | parser = argparse.ArgumentParser()
203 | parser.add_argument('--exp_name', required=True, help='Where to store logs and models')
204 | parser.add_argument('--train_mode', action='store_true', help='defalut is test mode')
205 | parser.add_argument('--train_data', type=str, required=True, help='path to training dataset')
206 | parser.add_argument('--valid_data', type=str, required=True, help='path to validation dataset')
207 | parser.add_argument('--manualSeed', type=int, default=2023, help='for random seed setting')
208 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers')
209 | parser.add_argument('--batch_size', type=int, default=16, help='input batch size')
210 | parser.add_argument('--num_iter', type=int, default=5e+4, help='number of iterations to train for') #train 3e+4, finetune 1e+4
211 | parser.add_argument('--valInterval', type=int, default=1000, help='Interval between each validation')
212 | parser.add_argument('--saved_model', type=str, default='', help="path to model to continue training")
213 | parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning')
214 | parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)')
215 | parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta')
216 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9')
217 | parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95')
218 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8')
219 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5')
220 | """ Data processing """
221 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
222 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
223 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
224 | # parser.add_argument('--rgb', action='store_false', help='default to use rgb input')
225 | parser.add_argument('--character', type=str,
226 | default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
227 | parser.add_argument('--sensitive', action='store_false', help='default for sensitive character mode')
228 | """ Model Architecture """
229 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
230 | parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
231 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
232 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
233 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
234 | parser.add_argument('--input_channel', type=int, default=3,
235 | help='the number of input channel of Feature extractor')
236 | parser.add_argument('--output_channel', type=int, default=512,
237 | help='the number of output channel of Feature extractor')
238 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
239 |
240 | opt = parser.parse_args()
241 |
242 |
243 | if not opt.exp_name:
244 | opt.exp_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}'
245 | # opt.exp_name += f'-Seed{opt.manualSeed}'
246 | if opt.sensitive:
247 | opt.exp_name += f'-sensitive'
248 | print(opt.exp_name)
249 | else:
250 | opt.exp_name = f'{opt.exp_name}-{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}'
251 | # opt.exp_name += f'-Seed{opt.manualSeed}'
252 | if opt.sensitive:
253 | opt.exp_name += f'-sensitive'
254 | print(opt.exp_name)
255 |
256 | if opt.FT:
257 | pass
258 | else:
259 | os.makedirs(f'./STR_modules/saved_models/{opt.exp_name}', exist_ok=True)
260 |
261 | """ vocab / character number configuration """
262 | if opt.sensitive:
263 | opt.character = string.printable[:62] # use 62 char (0~9, a~z, A~Z)
264 |
265 | """ Seed and GPU setting """
266 | print("Random Seed: ", opt.manualSeed)
267 | random.seed(opt.manualSeed)
268 | np.random.seed(opt.manualSeed)
269 | torch.manual_seed(opt.manualSeed)
270 | torch.cuda.manual_seed(opt.manualSeed)
271 |
272 | cudnn.benchmark = True
273 | cudnn.deterministic = True
274 |
275 | train(opt)
276 |
--------------------------------------------------------------------------------
/baselines/Impact.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/baselines/Impact.ttf
--------------------------------------------------------------------------------
/baselines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/baselines/__init__.py
--------------------------------------------------------------------------------
/baselines/fawa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | # sys.path.append('xx/ProTegO/')
4 | import string
5 | import shutil
6 | import argparse
7 | from PIL import Image
8 | import numpy as np
9 |
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 | import torch.utils.data
13 | from torchvision import utils as vutils
14 |
15 | from dataset import test_dataset_builder
16 | from utils import Logger, np2tensor, tensor2np, get_text_mask, cvt2Image, color_map
17 | from wm_attacker import WM_Attacker
18 |
19 |
20 | def makedirs(path):
21 | if not os.path.exists(path):
22 | try:
23 | os.makedirs(path)
24 | except Exception as e:
25 | print('cannot create dirs: {}'.format(path))
26 | exit(0)
27 |
28 | def fawa(opt):
29 | """prepare log with model_name """
30 | model_name = os.path.basename(opt.str_model.split('-')[0])
31 | print('-----Attack Settings< model_name:{} --iter_num:{} --eps:{} --decay:{} --alpha:{} >'
32 | .format(model_name, opt.iter_num, opt.eps, opt.decay, opt.alpha))
33 |
34 | save_root = os.path.join(opt.save_attacks, model_name)
35 | makedirs(save_root)
36 | del_list = os.listdir(save_root)
37 | for f in del_list:
38 | file_path = os.path.join(save_root, f)
39 | if os.path.isfile(file_path):
40 | os.remove(file_path)
41 | elif os.path.isdir(file_path):
42 | shutil.rmtree(file_path)
43 |
44 | wm_save_adv_path = os.path.join(save_root, 'wmadv')
45 | wm_save_per_path = os.path.join(save_root, 'wmper')
46 | makedirs(wm_save_adv_path)
47 | makedirs(wm_save_per_path)
48 |
49 | """ save all the print content as log """
50 | log_file= os.path.join(save_root, 'train.log')
51 | sys.stdout = Logger(log_file)
52 |
53 | dataset = test_dataset_builder(opt.imgH, opt.imgW, opt.root)
54 | dataloader = torch.utils.data.DataLoader(
55 | dataset, batch_size=opt.batch_size,
56 | shuffle=False, num_workers=4,
57 | drop_last=True,pin_memory=True)
58 |
59 | attacker = WM_Attacker(opt)
60 | time_all, suc, ED_sum = 0, 0, 0
61 | for i, data in enumerate(dataloader, start=0):
62 | label = data[1]
63 | img = data[5] # binary image
64 | img_index = data[2][0]
65 | img_path = data[4][0]
66 | img_tensor = img.to(opt.device)
67 |
68 | """basic attack"""
69 | attacker.init()
70 | adv_img, delta, epoch, preds, flag, time = attacker.basic_attack(img_tensor, label)
71 | print('img-{}_path:{} --iters:{} --Success:{} --prediction:{} --groundtruth:{} --time:{}'
72 | .format(i, img_path, epoch, flag, preds, label[0], time))
73 |
74 | """wm attack"""
75 | # find position to add watermark
76 | adv_np = tensor2np(adv_img.squeeze(0))
77 | img_np = tensor2np(img_tensor.squeeze(0))
78 | pos, frames = attacker.find_wm_pos(adv_np, img_np, True)
79 | # 按面积大小把pos从大到小排个序
80 | new_pos = []
81 | for _pos in pos:
82 | if len(_pos) > 1:
83 | new_pos.append(sorted(_pos, key=lambda x: (x[3]-x[1])*(x[2]-x[0]), reverse=True))
84 | else:
85 | new_pos.append(_pos)
86 | pos = new_pos
87 |
88 | # get watermark mask
89 | grayscale = 0
90 | color = (grayscale, grayscale, grayscale)
91 | wm_img = attacker.gen_wm(color)
92 | wm_arr = np.array(wm_img.convert('L'))
93 | bg_mask = ~(wm_arr != 255)
94 |
95 | # # get grayscale watermark
96 | # """
97 | # 灰度值在 76-226 有对应的彩色水印值,为了增加扰动后还在范围内,128-174,
98 | # *note by hyr:paper setting = 255*0.682 =174
99 | # """
100 | # grayscale = 174
101 | # color = (grayscale, grayscale, grayscale)
102 | # wm_img = np.array(Image.new(mode="RGB", size=wm_img.size, color=color))
103 | # wm_img[bg_mask] = 255
104 | # wm_img = Image.fromarray(wm_img)
105 |
106 | # # get color watermark
107 | grayscale = 174
108 | green_v = color_map(grayscale)
109 | color = (255, green_v, 0)
110 | wm_img = np.array(Image.new(mode="RGB", size=wm_img.size, color=color))
111 | wm_img[bg_mask] = 255
112 | wm_img = Image.fromarray(wm_img)
113 |
114 |
115 | text_img = cvt2Image(img_np)
116 | text_mask = get_text_mask(np.array(text_img)) # bool, 1 channel
117 | rgb_img = Image.new(mode="RGB", size=(text_img.size), color=(255, 255, 255)) # white bg
118 | p = -int(wm_img.size[0] * np.tan(10 * np.pi / 180))
119 | right_shift = 20
120 | xp = pos[0][0][0]+right_shift if len(pos) != 0 else right_shift
121 | rgb_img.paste(wm_img, box=(xp, p)) # first to add wm
122 | wm_mask = (np.array(rgb_img) != 255) # bool, 3 channel
123 | rgb_img.paste(text_img, mask=cvt2Image(text_mask)) # then add text
124 |
125 | wm0_img = np.array(rgb_img)
126 | wm_img = np2tensor(wm0_img)
127 | wm_mask = np2tensor(wm_mask)
128 | adv_text_mask = ~text_mask
129 | adv_text_mask = np2tensor(adv_text_mask)
130 |
131 | attacker.init()
132 | adv_img_wm, delta, epoch, preds, flag, ED, time = attacker.wm_attack(wm_img, label, wm_mask, adv_text_mask)
133 |
134 | print('wmimg-{}_path:{} --iters:{} --Success:{} --prediction:{} --groundtruth:{} --edit_distance:{} --time:{}'
135 | .format(i, img_path, epoch, flag, preds, label[0], ED, time))
136 |
137 | time_all +=time
138 | if flag:
139 | suc += 1
140 | ED_sum += ED
141 | vutils.save_image(adv_img_wm,"{}/{}_{}_adv.png".format(wm_save_adv_path, os.path.basename(img_index), label[0]))
142 | vutils.save_image(delta*100,"{}/{}_{}_delta.png".format(wm_save_per_path, os.path.basename(img_index), label[0]))
143 |
144 | print('FAWA_Total_attack_time:{} '.format(time_all))
145 | print('ASR:{:.2f}% '.format((suc/len(dataloader))*100))
146 | print('Average Edit_distance: {}'.format(ED_sum/suc))
147 |
148 |
149 |
150 | if __name__ == '__main__':
151 | parser = argparse.ArgumentParser()
152 | parser.add_argument('--root', default='/data/hyr/dataset/PTAU/new/test100-times/',help='path of original text images')
153 | parser.add_argument('--save_attacks', type=str, default='res-baselines/fawa', help='path of save adversarial images')
154 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
155 | parser.add_argument('--iter_num', type=int, default=2000, help='number of iterations')
156 | parser.add_argument('--eps', type=float, default=40/255, help='maximnum perturbation setting in paper')
157 | parser.add_argument('--decay', type=float, default=1.0, help='momentum factor')
158 | parser.add_argument('--alpha', type=float, default=0.05, help='the step size of the iteration')
159 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers')
160 |
161 | """ Data processing """
162 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
163 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
164 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
165 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
166 | parser.add_argument('--sensitive', action='store_false', help='for sensitive character mode')
167 | """ Model Architecture """
168 | parser.add_argument('--str_model', type=str, help="well-trained models for evaluation",
169 | default='STR_modules/downloads_models/STARNet-TPS-ResNet-BiLSTM-CTC-sensitive.pth')
170 | parser.add_argument('--Transformation', type=str, default='TPS', help='Transformation stage. None|TPS')
171 | parser.add_argument('--FeatureExtraction', type=str, default='ResNet', help='FeatureExtraction stage. VGG|RCNN|ResNet')
172 | parser.add_argument('--SequenceModeling', type=str, default='BiLSTM', help='SequenceModeling stage. None|BiLSTM')
173 | parser.add_argument('--Prediction', type=str, default='CTC', help='Prediction stage. CTC|Attn')
174 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
175 | parser.add_argument('--input_channel', type=int, default=3, help='the number of input channel of Feature extractor')
176 | parser.add_argument('--output_channel', type=int, default=512,
177 | help='the number of output channel of Feature extractor')
178 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
179 | opt = parser.parse_args()
180 |
181 | """ vocab / character number configuration """
182 | device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
183 | opt.device = torch.device(device_type)
184 |
185 | opt.save_attacks = opt.save_attacks + "-eps" + str(int(opt.eps*255))
186 |
187 | if opt.sensitive:
188 | opt.character = string.printable[:62]
189 |
190 | cudnn.benchmark = True
191 | cudnn.deterministic = True
192 |
193 | torch.cuda.synchronize()
194 | fawa(opt)
195 |
--------------------------------------------------------------------------------
/baselines/test_black_models.py:
--------------------------------------------------------------------------------
1 | """
2 | Test baseline methods on 4 Black-box models: CRNN, Rosetta, RARA, TRBA"""
3 |
4 | import os, sys, time, string, argparse, shutil
5 | sys.path.append('/data/hyr/ocr/ProTegO/release/')
6 | import torch
7 | from nltk.metrics import edit_distance
8 | from dataset import test_adv_dataset
9 | from STR_modules.model import Model
10 | from utils import Logger, CTCLabelConverter, AttnLabelConverter
11 |
12 | def makedirs(path):
13 | if not os.path.exists(path):
14 | try:
15 | os.makedirs(path)
16 | except Exception as e:
17 | print('cannot create dirs: {}'.format(path))
18 | exit(0)
19 |
20 | def process_line(line):
21 | adv_img_path, recog_result = line.split(':')
22 | label, adv_preds = recog_result.split('--->')
23 | adv_preds = adv_preds.strip('\n')
24 | return adv_preds, label, adv_img_path
25 |
26 | def test(opt):
27 |
28 | """ model configuration """
29 | if 'CTC' in opt.Prediction:
30 | converter = CTCLabelConverter(opt.character)
31 | else:
32 | converter = AttnLabelConverter(opt.character)
33 | opt.num_class = len(converter.character)
34 |
35 | model = Model(opt).to(opt.device)
36 | print('Loading a STR model from \"%s\" as the target model!' % opt.str_model)
37 | model.load_state_dict(torch.load(opt.str_model, map_location=opt.device),strict=False)
38 | model.eval()
39 |
40 | """ create new test model output path """
41 | makedirs(opt.output)
42 |
43 | # str_name = opt.str_model.split('/')[-2].split('-')[0]
44 | str_name = opt.str_model.split('/')[-1].split('-')[0]
45 | test_output_path = os.path.join(opt.output, opt.attack_name, str_name)
46 | attack_success_result = os.path.join(test_output_path , 'attack_success_result.txt')
47 | save_success_adv = os.path.join(test_output_path , 'attack-success-adv')
48 |
49 | makedirs(test_output_path)
50 | makedirs(save_success_adv)
51 |
52 | log_file= os.path.join(test_output_path, 'test.log')
53 | sys.stdout = Logger(log_file)
54 |
55 | test_dataset = test_adv_dataset(opt.imgH, opt.imgW, opt.adv_img)
56 | test_dataloader = torch.utils.data.DataLoader(
57 | test_dataset,
58 | batch_size=opt.batch_size,
59 | shuffle=False,
60 | num_workers=1)
61 |
62 | result = dict()
63 | for i, data in enumerate(test_dataloader):
64 | if opt.b:
65 | adv_img = data[0]
66 | else:
67 | adv_img = data[1]
68 | adv_img= adv_img.to(opt.device)
69 | label = data[2]
70 | adv_index = data[3][0]
71 | adv_path = data[5][0]
72 |
73 | length_for_pred = torch.IntTensor([opt.batch_max_length] * opt.batch_size).to(opt.device)
74 | text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(opt.device)
75 | if 'CTC' in opt.Prediction:
76 | preds = model(adv_img, text_for_pred).log_softmax(2)
77 | preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
78 | _, preds_index = preds.permute(1, 0, 2).max(2)
79 | preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
80 | preds_output = converter.decode(preds_index.data, preds_size)
81 | preds_output = preds_output[0]
82 | result[adv_index] = '{}:{}--->{}\n'.format(adv_path, label[0], preds_output)
83 | else: # Attention
84 | preds = model(adv_img, text_for_pred, is_train=False)
85 | _, preds_index = preds.max(2)
86 | preds_output = converter.decode(preds_index, length_for_pred)
87 | preds_output = preds_output[0]
88 | preds_output = preds_output[:preds_output.find('[s]')]
89 | result[adv_index] = '{}:{}--->{}\n'.format(adv_path, label[0], preds_output)
90 | result = sorted(result.items(), key=lambda x:x[0])
91 | with open(attack_success_result, 'w+') as f:
92 | for item in result:
93 | f.write(item[1])
94 |
95 | # calculate ASR
96 | with open(attack_success_result, 'r') as f:
97 | alladv = f.readlines()
98 | attack_success_num, ED_sum = 0, 0
99 | for line in alladv:
100 | adv_preds, label, adv_img_path = process_line(line)
101 |
102 | if adv_preds != label:
103 | ED_num = edit_distance(label, adv_preds)
104 | attack_success_num += 1
105 | shutil.copy(adv_img_path, save_success_adv)
106 | ED_sum += ED_num
107 | attack_success_rate = attack_success_num / len(test_dataset)
108 | print('ASR:{:.2f} %'.format(attack_success_rate * 100))
109 | if attack_success_num != 0:
110 | ED_num_avr = ED_sum / attack_success_num
111 | print('Average Edit_distance: {:.2f}'.format(ED_num_avr))
112 |
113 |
114 | if __name__ == '__main__':
115 | parser = argparse.ArgumentParser()
116 | parser.add_argument('--output', type=str, default='res-BlackModelTest/up5a', help='path to save attack results')
117 | parser.add_argument('--attack_name', required=True, help='baseline attack method name')
118 | """ Data processing """
119 | parser.add_argument('--adv_img', required=True, help='the path of adv_x which generated from STARNet model')
120 | parser.add_argument('--b', action='store_true', help='Use binarization processing to adv_img.')
121 | parser.add_argument('--batch_size', type=int, default=1)
122 | parser.add_argument('--img_channel', type=int, default=3, help='the number of input channel of image')
123 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
124 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
125 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
126 | parser.add_argument('--character', type=str,default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
127 | parser.add_argument('--sensitive', action='store_false', help='default for sensitive character mode')
128 | """ Model Architecture """
129 | parser.add_argument('--str_model', type=str, help='the model path of the target model',
130 | default='/STR_modules/downloads_models/CRNN-None-VGG-BiLSTM-CTC-sensitive.pth')
131 | parser.add_argument('--Transformation', type=str, default='None', help='Transformation stage. None|TPS')
132 | parser.add_argument('--FeatureExtraction', type=str, default='VGG', help='FeatureExtraction stage. VGG|RCNN|ResNet')
133 | parser.add_argument('--SequenceModeling', type=str, default='BiLSTM', help='SequenceModeling stage. None|BiLSTM')
134 | parser.add_argument('--Prediction', type=str, default='CTC', help='Prediction stage. CTC|Attn')
135 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
136 | parser.add_argument('--input_channel', type=int, default=3,
137 | help='the number of input channel of Feature extractor')
138 | parser.add_argument('--output_channel', type=int, default=512,
139 | help='the number of output channel of Feature extractor')
140 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
141 | opt = parser.parse_args()
142 | print(opt)
143 |
144 | opt.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
145 |
146 | """ vocab / character number configuration """
147 | if opt.sensitive:
148 | opt.character = string.printable[:62] # use 62 char (0~9, a~z, A~Z)
149 |
150 | torch.cuda.synchronize()
151 | time_st = time.time()
152 | test(opt)
153 | time_end = time.time()
154 | time_all = time_end - time_st
155 | print('Testing time:{:.2f}'.format(time_all))
--------------------------------------------------------------------------------
/baselines/transfer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | # sys.path.append('xx/ProTegO/')
4 | import string
5 | import shutil
6 | import argparse
7 |
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | import torch.utils.data
11 | from torchvision import utils as vutils
12 | from nltk.metrics import edit_distance
13 |
14 | from transfer_attacker import transfer_Attacker
15 | from dataset import test_dataset_builder
16 | from utils import Logger
17 |
18 |
19 | def makedirs(path):
20 | if not os.path.exists(path):
21 | try:
22 | os.makedirs(path)
23 | except Exception as e:
24 | print('cannot create dirs: {}'.format(path))
25 | exit(0)
26 |
27 | def transfer_attack(opt):
28 | """prepare log with model_name """
29 | model_name = os.path.basename(opt.str_model.split('-')[0])
30 | print('----------------------------------------------------------------------------------------')
31 | print('Start attacking: :{} :{} \t\n:{} :{:2f} :{:2f} :{:2f} :{} :{}'
32 | .format(model_name, opt.name, opt.iter_num, opt.eps, opt.alpha, opt.beta, opt.m, opt.N))
33 | print('----------------------------------------------------------------------------------------')
34 |
35 | save_root = os.path.join(opt.save_attacks, model_name)
36 | makedirs(save_root)
37 | del_list = os.listdir(save_root)
38 | for f in del_list:
39 | file_path = os.path.join(save_root, f)
40 | if os.path.isfile(file_path):
41 | os.remove(file_path)
42 | elif os.path.isdir(file_path):
43 | shutil.rmtree(file_path)
44 |
45 | save_adv_path = os.path.join(save_root, 'adv')
46 | save_per_path = os.path.join(save_root, 'per')
47 | makedirs(save_adv_path)
48 | makedirs(save_per_path)
49 |
50 | """ save all the print content as log """
51 | log_file= os.path.join(save_root, 'train.log')
52 | sys.stdout = Logger(log_file)
53 |
54 | dataset = test_dataset_builder(opt.imgH, opt.imgW, opt.root)
55 | dataloader = torch.utils.data.DataLoader(
56 | dataset, batch_size=opt.batch_size,
57 | shuffle=False, num_workers=4,
58 | drop_last=True,pin_memory=True)
59 |
60 | # up = up_dataset(opt.up_path)
61 | # up = up.repeat(opt.batch_size,1,1,1)
62 |
63 | attacker = transfer_Attacker(opt)
64 | time_all, suc, ED_sum = 0, 0, 0
65 | for i, data in enumerate(dataloader, start=0):
66 | # image = data[0]
67 | label = data[1]
68 | img_index = data[2][0]
69 | img_path = data[4][0]
70 | mask = data[5]
71 | mask_tensor = mask.to(opt.device)
72 | # img_tensor = torch.mul(mask, up).to(opt.device)
73 | if opt.name == 'SINIFGSM':
74 | adv_img, delta, preds, time, iters = attacker.SINIFGSM(mask_tensor, label)
75 | elif opt.name == 'VMIFGSM':
76 | adv_img, delta, preds, time, iters = attacker.VMIFGSM(mask_tensor, label)
77 |
78 | flag = (preds != label[0])
79 | ED = edit_distance(label[0], preds)
80 |
81 | print('img-{}_path:{} --Success:{} --prediction:{} --groundtruth:{} --edit_distance:{} --time:{:2f} --iter_num:{}'
82 | .format(i, img_path, flag, preds, label[0], ED, time, iters))
83 |
84 | time_all +=time
85 |
86 | vutils.save_image(adv_img,"{}/{}_{}_adv.png".format(save_adv_path, os.path.basename(img_index), label[0]))
87 | vutils.save_image(torch.abs(delta*100),"{}/{}_{}_delta.png".format(save_per_path, os.path.basename(img_index), label[0]))
88 |
89 | if flag:
90 | suc += 1
91 | ED_sum +=ED
92 |
93 |
94 | print('{} Total_attack_time:{} '.format(opt.name, time_all))
95 | print('ASR:{:.2f}% '.format((suc/len(dataloader))*100))
96 | print('Average Edit_distance: {}'.format(ED_sum/suc))
97 |
98 |
99 | if __name__ == '__main__':
100 | parser = argparse.ArgumentParser()
101 | parser.add_argument('--name', required=True, help='attack name [SINIFGSM, VMIFGSM]')
102 | parser.add_argument('--root', default='/data/hyr/dataset/PTAU/new/test100-times/',help='path of images')
103 | parser.add_argument('--save_attacks', type=str, default='res-baselines/', help='path of save adversarial text images')
104 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
105 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers')
106 | parser.add_argument('--eps', type=float, default=40/255, help='maximum perturbation')
107 | parser.add_argument('--iter_num', type=int, default=30, help='number of max iterations')
108 | parser.add_argument('--decay', type=float, default=1, help='momentum factor')
109 | parser.add_argument('--alpha', type=float, default=2/255, help='step size of each iteration')
110 | parser.add_argument('--beta', type=float, default=2/3, help='the upper bound of neighborhood.')
111 | parser.add_argument('--m', type=int, default=5, help='number of scale copies.')
112 | parser.add_argument('--N', type=int, default=5, help='the number of sampled examples in the neighborhood.')
113 |
114 | """ Data processing """
115 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
116 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
117 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
118 | # parser.add_argument('--rgb', action='store_true', help='use rgb input')
119 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz',
120 | help='character label')
121 | parser.add_argument('--sensitive', action='store_false', help='for sensitive character mode')
122 | """ Model Architecture """
123 | parser.add_argument('--str_model', type=str, help="well-trained models for evaluation",
124 | default='STR_modules/downloads_models/STARNet-TPS-ResNet-BiLSTM-CTC-sensitive.pth')
125 | parser.add_argument('--Transformation', type=str, default='TPS', help='Transformation stage. None|TPS')
126 | parser.add_argument('--FeatureExtraction', type=str, default='ResNet', help='FeatureExtraction stage. VGG|RCNN|ResNet')
127 | parser.add_argument('--SequenceModeling', type=str, default='BiLSTM', help='SequenceModeling stage. None|BiLSTM')
128 | parser.add_argument('--Prediction', type=str, default='CTC', help='Prediction stage. CTC|Attn')
129 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
130 | parser.add_argument('--input_channel', type=int, default=3, help='the number of input channel of Feature extractor')
131 | parser.add_argument('--output_channel', type=int, default=512,
132 | help='the number of output channel of Feature extractor')
133 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
134 | opt = parser.parse_args()
135 | print(opt)
136 |
137 | """ vocab / character number configuration """
138 | device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
139 | opt.device = torch.device(device_type)
140 |
141 | if opt.sensitive:
142 | opt.character = string.printable[:62]
143 |
144 | opt.save_attacks = opt.save_attacks + opt.name
145 | opt.save_attacks = opt.save_attacks + "-eps" + str(int(opt.eps*255))
146 |
147 | cudnn.benchmark = True
148 | cudnn.deterministic = True
149 |
150 | torch.cuda.synchronize()
151 | transfer_attack(opt)
--------------------------------------------------------------------------------
/baselines/transfer_attacker.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 |
5 | from utils import CTCLabelConverter, AttnLabelConverter
6 | from STR_modules.model import Model
7 |
8 | r"""
9 | Base class for transfer-based attack [SI-NI-FGSM, VMI-FGSM],
10 | modified from "https://github.com/Harry24k/adversarial-attacks-pytorch".
11 |
12 | SI-NI-FGSM in the paper 'NESTEROV ACCELERATED GRADIENT AND SCALEINVARIANCE FOR ADVERSARIAL ATTACKS'
13 | [https://arxiv.org/abs/1908.06281], Published as a conference paper at ICLR 2020
14 |
15 | VMI-FGSM in the paper 'Enhancing the Transferability of Adversarial Attacks through Variance Tuning
16 | [https://arxiv.org/abs/2103.15571], Published as a conference paper at CVPR 2021.
17 |
18 | Distance Measure : Linf
19 |
20 | Arguments:
21 | eps (float): maximum perturbation. (Default: 40/255)
22 | iter_num (int): number of iterations. (Default: 10)
23 | decay (float): momentum factor. (Default: 1.0)
24 | alpha (float): step size. (Default: 2/255)
25 | beta (float): the upper bound of neighborhood. (Default: 3/2)
26 | m (int): number of scale copies. (Default: 5)
27 | N (int): the number of sampled examples in the neighborhood. (Default: 5)
28 |
29 | """
30 |
31 | class transfer_Attacker(object):
32 |
33 | def __init__(self, c_para):
34 | self.device = c_para.device
35 | self.batch_size = c_para.batch_size
36 | self.eps = c_para.eps
37 | self.iter_num = c_para.iter_num
38 | self.decay = c_para.decay
39 | self.alpha = c_para.alpha
40 | self.beta = c_para.beta
41 | self.m = c_para.m
42 | self.N = c_para.N
43 |
44 | self.converter = self._load_converter(c_para)
45 | self.model = self._load_model(c_para)
46 | self.Transformation = c_para.Transformation
47 | self.FeatureExtraction = c_para.FeatureExtraction
48 | self.SequenceModeling = c_para.SequenceModeling
49 | self.Prediction = c_para.Prediction
50 | self.batch_max_length = c_para.batch_max_length
51 |
52 | self.criterion = self._load_base_loss(c_para)
53 | self.l2_loss = self._load_l2_loss()
54 |
55 | @staticmethod
56 | def _load_l2_loss():
57 | return torch.nn.MSELoss()
58 | @staticmethod
59 | def _load_base_loss(c_para):
60 | if c_para.Prediction == 'CTC':
61 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(c_para.device)
62 | else:
63 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(c_para.device)
64 | return criterion
65 | @staticmethod
66 | def _load_converter(c_para):
67 | if 'CTC' in c_para.Prediction:
68 | converter = CTCLabelConverter(c_para.character)
69 | else:
70 | converter = AttnLabelConverter(c_para.character)
71 | c_para.num_class = len(converter.character)
72 | return converter
73 | @staticmethod
74 | def _load_model(c_para):
75 | if not os.path.exists(c_para.str_model):
76 | raise FileNotFoundError("cannot find pth file in {}".format(c_para.str_model))
77 | # load model
78 | with torch.no_grad():
79 | model = Model(c_para).to(c_para.device)
80 | model.load_state_dict(torch.load(c_para.str_model, map_location=c_para.device))
81 | for name, para in model.named_parameters():
82 | para.requires_grad = False
83 | return model.eval()
84 |
85 | def model_pred(self, img, mode='CTC'):
86 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device)
87 | # self.model.eval()
88 | with torch.no_grad():
89 | if mode == 'CTC':
90 | pred = self.model(img, text_for_pred).log_softmax(2)
91 | size = torch.IntTensor([pred.size(1)] * self.batch_size).to(self.device)
92 | _, index = pred.permute(1, 0, 2).max(2)
93 | index = index.transpose(1, 0).contiguous().view(-1)
94 | pred_str = self.converter.decode(index.data, size.data)
95 | pred_str = pred_str[0]
96 | else: # ATTENTION
97 | pred = self.model(img, text_for_pred)
98 | size = torch.IntTensor([pred.size(1)] * self.batch_size)
99 | _, index = pred.max(2)
100 | pred_str = self.converter.decode(index, size)
101 | pred_s = pred_str[0]
102 | pred_str = pred_s[:pred_s.find('[s]')]
103 | # self.model.train()
104 | return pred_str
105 |
106 | def SINIFGSM(self, x, raw_text):
107 | torch.backends.cudnn.enabled=False
108 | x = x.clone().detach().to(self.device)
109 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device)
110 | text, length = self.converter.encode(raw_text, batch_max_length=self.batch_max_length)
111 |
112 | momentum = torch.zeros_like(x).detach().to(self.device)
113 |
114 | adv_x = x.clone().detach()
115 | time_each = 0
116 | t_st = time.time()
117 | pred_org = self.model_pred(x, self.Prediction)
118 | for iters in range(self.iter_num):
119 | adv_x.requires_grad = True
120 | nes_x = adv_x + self.decay*self.alpha*momentum
121 | # Calculate sum the gradients over the scale copies of the input image
122 | adv_grad = torch.zeros_like(x).detach().to(self.device)
123 | for i in torch.arange(self.m):
124 | nes_x = nes_x / torch.pow(2, i)
125 | preds = self.model(nes_x, text_for_pred).log_softmax(2)
126 | # Calculate loss
127 | if 'CTC' in self.Prediction:
128 | preds_size = torch.IntTensor([preds.size(1)] * self.batch_size).to(self.device)
129 | preds = preds.permute(1, 0, 2)
130 | cost = self.criterion(preds, text, preds_size, length)
131 | else: # ATTENTION
132 | target_text = text[:, 1:]
133 | cost = self.criterion(preds.view(-1, preds.shape[-1]), target_text.contiguous().view(-1))
134 | adv_grad += torch.autograd.grad(cost, adv_x,
135 | retain_graph=True, create_graph=False)[0]
136 | adv_grad = adv_grad / self.m
137 |
138 | # Update adversarial images
139 | grad = self.decay*momentum + adv_grad / torch.mean(torch.abs(adv_grad), dim=(1,2,3), keepdim=True)
140 | momentum = grad
141 | adv_x = adv_x.detach() + self.alpha*grad.sign()
142 | delta = torch.clamp(adv_x - x, min=-self.eps, max=self.eps)
143 | adv_x = torch.clamp(x + delta, min=0, max=1).detach()
144 |
145 | pred_adv = self.model_pred(adv_x, self.Prediction)
146 | if pred_adv != pred_org:
147 | break
148 | t_en = time.time()
149 | time_each = t_en - t_st
150 | return adv_x, delta, pred_adv, time_each, iters
151 |
152 | def VMIFGSM(self, x, raw_text):
153 | torch.backends.cudnn.enabled=False
154 | x = x.clone().detach().to(self.device)
155 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device)
156 | text, length = self.converter.encode(raw_text, batch_max_length=self.batch_max_length)
157 |
158 | momentum = torch.zeros_like(x).detach().to(self.device)
159 | v = torch.zeros_like(x).detach().to(self.device)
160 |
161 | adv_x = x.clone().detach()
162 | time_each = 0
163 | t_st = time.time()
164 | pred_org = self.model_pred(x, self.Prediction)
165 | for iters in range(self.iter_num):
166 | adv_x.requires_grad = True
167 | preds = self.model(adv_x, text_for_pred).log_softmax(2)
168 |
169 | # Calculate loss
170 | if 'CTC' in self.Prediction:
171 | preds_size = torch.IntTensor([preds.size(1)] * self.batch_size).to(self.device)
172 | preds = preds.permute(1, 0, 2)
173 | cost = self.criterion(preds, text, preds_size, length)
174 | else: # ATTENTION
175 | target_text = text[:, 1:]
176 | cost = self.criterion(preds.view(-1, preds.shape[-1]), target_text.contiguous().view(-1))
177 |
178 | # Update adversarial images
179 | adv_grad = torch.autograd.grad(cost, adv_x,
180 | retain_graph=False, create_graph=False)[0]
181 |
182 | grad = (adv_grad + v) / torch.mean(torch.abs(adv_grad + v), dim=(1,2,3), keepdim=True)
183 | grad = grad + momentum * self.decay
184 | momentum = grad
185 |
186 | # Calculate Gradient Variance
187 | GV_grad = torch.zeros_like(x).detach().to(self.device)
188 | for _ in range(self.N):
189 | neighbor_x = adv_x.detach() + \
190 | torch.randn_like(x).uniform_(-self.eps*self.beta, self.eps*self.beta)
191 | neighbor_x.requires_grad = True
192 | preds = self.model(neighbor_x, text_for_pred).log_softmax(2)
193 |
194 | # Calculate loss
195 | if 'CTC' in self.Prediction:
196 | preds_size = torch.IntTensor([preds.size(1)] * self.batch_size).to(self.device)
197 | preds = preds.permute(1, 0, 2)
198 | cost = self.criterion(preds, text, preds_size, length)
199 | else: # ATTENTION
200 | target_text = text[:, 1:]
201 | cost = self.criterion(preds.view(-1, preds.shape[-1]), target_text.contiguous().view(-1))
202 |
203 | GV_grad += torch.autograd.grad(cost, neighbor_x,
204 | retain_graph=False, create_graph=False)[0]
205 | # obtaining the gradient variance
206 | v = GV_grad / self.N - adv_grad
207 |
208 | adv_x = adv_x.detach() + self.alpha*grad.sign()
209 | delta = torch.clamp(adv_x - x, min=-self.eps, max=self.eps)
210 | adv_x = torch.clamp(x + delta, min=0, max=1).detach()
211 | pred_adv = self.model_pred(adv_x, self.Prediction)
212 | if pred_adv != pred_org:
213 | break
214 | t_en = time.time()
215 | time_each = t_en - t_st
216 |
217 |
218 | return adv_x, delta, pred_adv, time_each, iters
219 |
220 |
--------------------------------------------------------------------------------
/baselines/wm_attacker.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import cv2
5 | import torch
6 |
7 | from skimage import morphology
8 | from trdg.generators import GeneratorFromStrings
9 | from nltk.metrics import edit_distance
10 |
11 | from utils import CTCLabelConverter, AttnLabelConverter, RGB2Hex
12 | from STR_modules.model import Model
13 |
14 | r"""
15 | Base class for fawa.
16 |
17 | Distance Measure : Linf
18 |
19 | """
20 | class WM_Attacker(object):
21 | def __init__(self, c_para):
22 | r"""
23 | Arguments:
24 | c_para : all arguments from Parser which are prepared for
25 | """
26 | self.device = c_para.device
27 | self.batch_size = c_para.batch_size
28 | self.eps = c_para.eps
29 | self.iter_num = c_para.iter_num
30 | self.decay = c_para.decay
31 | self.alpha = c_para.alpha
32 |
33 | self.converter = self._load_converter(c_para)
34 | self.model = self._load_model(c_para)
35 | self.Transformation = c_para.Transformation
36 | self.FeatureExtraction = c_para.FeatureExtraction
37 | self.SequenceModeling = c_para.SequenceModeling
38 | self.Prediction = c_para.Prediction
39 | self.batch_max_length = c_para.batch_max_length
40 |
41 | self.criterion = self._load_base_loss(c_para)
42 | self.l2_loss = self._load_l2_loss()
43 |
44 | self.img_size = (c_para.batch_size, c_para.input_channel, c_para.imgH, c_para.imgW)
45 | self.best_img = 100 * torch.ones(self.img_size)
46 | self.best_iter = -1
47 | self.best_delta = 100 * torch.ones(self.img_size)
48 | self.preds = ''
49 | self.suc = False
50 |
51 | def init(self):
52 | self.best_img = 100* torch.ones(self.img_size)
53 | self.best_iter = -1
54 | self.best_delta = 100 * torch.ones(self.img_size)
55 | self.preds = ''
56 | self.suc = False
57 |
58 | @staticmethod
59 | def _load_l2_loss():
60 | return torch.nn.MSELoss()
61 |
62 | @staticmethod
63 | def _load_base_loss(c_para):
64 | if c_para.Prediction == 'CTC':
65 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(c_para.device)
66 | else:
67 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(c_para.device)
68 | return criterion
69 |
70 | @staticmethod
71 | def _load_converter(c_para):
72 | if 'CTC' in c_para.Prediction:
73 | converter = CTCLabelConverter(c_para.character)
74 | else:
75 | converter = AttnLabelConverter(c_para.character)
76 | c_para.num_class = len(converter.character)
77 | return converter
78 | @staticmethod
79 | def _load_model(c_para):
80 | if not os.path.exists(c_para.str_model):
81 | raise FileNotFoundError("cannot find pth file in {}".format(c_para.str_model))
82 | # load model
83 | with torch.no_grad():
84 | model = Model(c_para).to(c_para.device)
85 | model.load_state_dict(torch.load(c_para.str_model, map_location=c_para.device))
86 | for name, para in model.named_parameters():
87 | para.requires_grad = False
88 | return model.eval()
89 |
90 | def model_pred(self, img, mode='CTC'):
91 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device)
92 | # self.model.eval()
93 | with torch.no_grad():
94 | if mode == 'CTC':
95 | pred = self.model(img, text_for_pred).log_softmax(2)
96 | size = torch.IntTensor([pred.size(1)] * self.batch_size).to(self.device)
97 | _, index = pred.permute(1, 0, 2).max(2)
98 | index = index.transpose(1, 0).contiguous().view(-1)
99 | pred_str = self.converter.decode(index.data, size.data)
100 | pred_str = pred_str[0]
101 | else: # ATTENTION
102 | pred = self.model(img, text_for_pred)
103 | size = torch.IntTensor([pred.size(1)] * self.batch_size)
104 | _, index = pred.max(2)
105 | pred_str = self.converter.decode(index, size)
106 | pred_s = pred_str[0]
107 | pred_str = pred_s[:pred_s.find('[s]')]
108 | # self.model.train()
109 | return pred_str
110 |
111 | def basic_attack(self, x, raw_text):
112 | x = x.clone().detach().to(self.device)
113 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device)
114 | text, length = self.converter.encode(raw_text, batch_max_length=self.batch_max_length)
115 | pred_org = self.model_pred(x, self.Prediction) #TODO change to raw_text as label
116 |
117 | momentum = torch.zeros_like(x).detach().to(self.device)
118 |
119 | adv_x = x.clone().detach()
120 |
121 | num = 0
122 | time_each = 0
123 | t_st = time.time()
124 | for iter in range(self.iter_num):
125 | adv_x.requires_grad = True
126 |
127 | # erlier stop
128 | pred_adv = self.model_pred(adv_x, self.Prediction)
129 | if pred_adv != pred_org:
130 | num += 1
131 | # print('Best results!')
132 | self.best_iter = iter
133 | self.best_img = adv_x.detach().clone()
134 | self.best_delta = delta.detach().clone()
135 | self.preds = pred_adv
136 | self.suc = True
137 | if num == 1:
138 | break
139 |
140 | # Calculate loss
141 | torch.backends.cudnn.enabled=False
142 | preds = self.model(adv_x, text_for_pred).log_softmax(2)
143 | if 'CTC' in self.Prediction:
144 | preds_size = torch.IntTensor([preds.size(1)] * self.batch_size).to(self.device)
145 | preds = preds.permute(1, 0, 2)
146 | cost = self.criterion(preds, text, preds_size, length)
147 | else: # ATTENTION
148 | target_text = text[:, 1:]
149 | cost = self.criterion(preds.view(-1, preds.shape[-1]), target_text.contiguous().view(-1))
150 |
151 | # Update adversarial images
152 | grad = torch.autograd.grad(cost, adv_x,
153 | retain_graph=False, create_graph=False)[0]
154 | grad = grad / torch.mean(torch.abs(grad), dim=(1,2,3), keepdim=True)
155 | grad = grad + momentum*self.decay
156 | momentum = grad
157 |
158 | adv_x = adv_x.detach() + self.alpha*grad.sign()
159 | delta = torch.clamp(adv_x - x, min=-self.eps, max=self.eps)
160 | adv_x = torch.clamp(x + delta, min=0, max=1).detach()
161 |
162 | if iter == self.iter_num and self.best_iter == -1:
163 | print('[!] Attack failed: No optimal results were found in effective iter_num!')
164 | self.best_iter = -1
165 | self.best_img = adv_x.detach().clone()
166 | self.best_delta = delta.data.detach().clone()
167 | self.preds = pred_adv
168 | self.suc = False
169 |
170 | t_end = time.time()
171 | time_each = t_end - t_st
172 |
173 | return self.best_img, self.best_delta, self.best_iter, self.preds, self.suc, time_each
174 |
175 | def find_wm_pos(self, adv_img, input_img, ret_frame_img=False):
176 | pert = np.abs(adv_img - input_img)
177 | pert = (pert > 1e-2) * 255.0
178 | wm_pos_list = []
179 | frame_img_list = []
180 | for src in pert:
181 | kernel = np.ones((3, 3), np.uint8) # kernel_size 3*3
182 | dilate = cv2.dilate(src, kernel, iterations=2)
183 | erode = cv2.erode(dilate, kernel, iterations=2)
184 | remove = morphology.remove_small_objects(erode.astype('bool'), min_size=0)
185 | contours, _ = cv2.findContours((remove * 255).astype('uint8'), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
186 | wm_pos, frame_img = [], []
187 | for cont in contours:
188 | left_point = cont.min(axis=1).min(axis=0)
189 | right_point = cont.max(axis=1).max(axis=0)
190 | wm_pos.append(np.hstack((left_point, right_point)))
191 | if ret_frame_img:
192 | img = cv2.rectangle(
193 | (remove * 255).astype('uint8'), (left_point[0], left_point[1]),
194 | (right_point[0], right_point[1]), (255, 255, 255), 2)
195 | frame_img.append(img)
196 | wm_pos_list.append(wm_pos)
197 | frame_img_list.append(frame_img)
198 |
199 | if ret_frame_img:
200 | return (wm_pos_list, frame_img_list)
201 | else:
202 | return wm_pos_list
203 |
204 | def gen_wm(self, RGB):
205 | generator = GeneratorFromStrings(
206 | strings=['ecml'],
207 | count=1,
208 | fonts=['baselines/Impact.ttf'], # TODO change ['Impact.tff']]
209 | language='en',
210 | size=78, # default: 32
211 | skewing_angle=15,
212 | random_skew=False,
213 | blur=0,
214 | random_blur=False,
215 | background_type=1, # gaussian noise (0), plain white (1), quasicrystal (2) or picture (3)
216 | distorsion_type=0, # None(0), Sine wave(1),Cosine wave(2),Random(3)
217 | distorsion_orientation=0,
218 | is_handwritten=False,
219 | width=-1,
220 | alignment=1,
221 | text_color=RGB2Hex(RGB),
222 | orientation=0,
223 | space_width=1.0,
224 | character_spacing=0,
225 | margins=(0, 0, 0, 0),
226 | fit=True,
227 | )
228 | img_list = [img for img, _ in generator]
229 | return img_list[0]
230 |
231 | def wm_attack(self, wm_x, raw_text, wm_mask, adv_text_mask):
232 | wm_x = wm_x.clone().detach().to(self.device)
233 |
234 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device)
235 | text, length = self.converter.encode(raw_text, batch_max_length=self.batch_max_length)
236 | pred_org = self.model_pred(wm_x, self.Prediction)
237 |
238 | wm_mask = wm_mask.clone().detach().to(self.device)
239 | adv_text_mask = adv_text_mask.clone().detach().to(self.device)
240 |
241 | momentum = torch.zeros_like(wm_x).detach().to(self.device)
242 |
243 | adv_x = wm_x.clone().detach()
244 |
245 | num = 0
246 | ED_num= 0
247 | time_each = 0
248 | t_st = time.time()
249 | for iter in range(self.iter_num):
250 | adv_x.requires_grad = True
251 |
252 | # erlier stop
253 | pred_adv = self.model_pred(adv_x, self.Prediction)
254 | if pred_adv != pred_org:
255 | ED_num = edit_distance(pred_org, pred_adv)
256 | num += 1
257 | # print('Best results!')
258 | self.best_iter = iter
259 | self.best_img = adv_x.detach().clone()
260 | self.best_delta = delta.detach().clone()
261 | self.preds = pred_adv
262 | self.suc = True
263 | if num == 1:
264 | break
265 |
266 | # Calculate loss
267 | torch.backends.cudnn.enabled=False
268 | preds = self.model(adv_x, text_for_pred).log_softmax(2)
269 | if 'CTC' in self.Prediction:
270 | preds_size = torch.IntTensor([preds.size(1)] * self.batch_size).to(self.device)
271 | preds = preds.permute(1, 0, 2)
272 | cost = self.criterion(preds, text, preds_size, length)
273 | else: # ATTENTION
274 | target_text = text[:, 1:]
275 | cost = self.criterion(preds.view(-1, preds.shape[-1]), target_text.contiguous().view(-1))
276 |
277 | # Update adversarial images
278 | grad = torch.autograd.grad(cost, adv_x,
279 | retain_graph=False, create_graph=False)[0]
280 | grad = grad / torch.mean(torch.abs(grad), dim=(1,2,3), keepdim=True)
281 | grad = grad + momentum*self.decay
282 | momentum = grad
283 |
284 | adv_x = adv_x.detach() + self.alpha*grad.sign()
285 | delta = torch.clamp(adv_x - wm_x, min=-self.eps, max=self.eps)
286 | delta = torch.mul(delta, adv_text_mask)
287 | delta = torch.mul(delta, wm_mask)
288 | adv_x = torch.clamp(wm_x + delta, min=0, max=1).detach()
289 |
290 | if iter == self.iter_num and self.best_iter == -1:
291 | print('[!] Attack failed: No optimal results were found in effective iter_num!')
292 | ED_num = 0
293 | self.best_iter = -1
294 | self.best_img = adv_x.detach().clone()
295 | self.best_delta = delta.data.detach().clone()
296 | self.preds = pred_adv
297 | self.suc = False
298 |
299 | t_end = time.time()
300 | time_each = t_end - t_st
301 |
302 | return self.best_img, self.best_delta, self.best_iter, self.preds, self.suc, ED_num, time_each
303 |
304 |
305 |
--------------------------------------------------------------------------------
/comtest/test_ali.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os,sys
3 | import json
4 | from typing import List
5 |
6 | from alibabacloud_tea_openapi.client import Client as OpenApiClient
7 | from alibabacloud_tea_openapi import models as open_api_models
8 | from alibabacloud_tea_util import models as util_models
9 | from nltk.metrics import edit_distance
10 | from utils import Logger
11 |
12 | def get_file_content(filePath):
13 | with open(filePath, 'rb') as fp:
14 | return fp.read()
15 |
16 | class Sample:
17 | def __init__(self):
18 | pass
19 |
20 | @staticmethod
21 | def create_client(
22 | access_key_id: str,
23 | access_key_secret: str,
24 | ) -> OpenApiClient:
25 | """
26 | 使用AK&SK初始化账号Client
27 | @param access_key_id:
28 | @param access_key_secret:
29 | @return: Client
30 | @throws Exception
31 | """
32 | config = open_api_models.Config(
33 | # 必填,您的 AccessKey ID,
34 | access_key_id=access_key_id,
35 | # 必填,您的 AccessKey Secret,
36 | access_key_secret=access_key_secret
37 | )
38 | # 访问的域名
39 | config.endpoint = f'ocr-api.cn-hangzhou.aliyuncs.com'
40 | return OpenApiClient(config)
41 |
42 | @staticmethod
43 | def create_api_info() -> open_api_models.Params:
44 | """
45 | API 相关
46 | @param path: params
47 | @return: OpenApi.Params
48 | """
49 | params = open_api_models.Params(
50 | # 接口名称,
51 | action='RecognizeGeneral',
52 | # 接口版本,
53 | version='2021-07-07',
54 | # 接口协议,
55 | protocol='HTTPS',
56 | # 接口 HTTP 方法,
57 | method='POST',
58 | auth_type='AK',
59 | style='V3',
60 | # 接口 PATH,
61 | pathname=f'/',
62 | # 接口请求体内容格式,
63 | req_body_type='json',
64 | # 接口响应体内容格式,
65 | body_type='json'
66 | )
67 | return params
68 |
69 | @staticmethod
70 | def main(
71 | imagefiles: List[str],
72 | ) -> None:
73 | # 工程代码泄露可能会导致AccessKey泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考,建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html
74 | client = Sample.create_client('xx', 'xx')
75 | params = Sample.create_api_info()
76 | # runtime options
77 | runtime = util_models.RuntimeOptions()
78 |
79 | fcnt, ED_sum = 0, 0
80 | for imagefile in imagefiles:
81 | img = get_file_content(imagefile)
82 | request = open_api_models.OpenApiRequest(stream=img)
83 | # 复制代码运行请自行打印 API 的返回值
84 | # 返回值为 Map 类型,可从 Map 中获得三类数据:响应体 body、响应头 headers、HTTP 返回的状态码 statusCode
85 | result = client.call_api(params, request, runtime)
86 | data = json.loads(result['body']['Data'])
87 | content = data['content'].strip()
88 | # print(f"file: {file}")
89 | # print(content)
90 | imagename = os.path.basename(imagefile)
91 | true_label = imagename.split("_")[1]
92 | if true_label != content:
93 | fcnt += 1
94 | ED = edit_distance(true_label, content)
95 | else:
96 | ED = 0
97 | print("label:{} ---> result:{}".format(true_label, content))
98 | ED_sum += ED
99 |
100 | if fcnt != 0:
101 | score = {"DSR":fcnt/len(imagefiles), "ED":ED_sum/fcnt}
102 | else:
103 | score = {"DSR":0, "ED":0}
104 | return score
105 |
106 | if __name__ == '__main__':
107 | log_file= os.path.join('./res-comtest/ali', 'up5a.log')
108 | sys.stdout = Logger(log_file)
109 |
110 | # prepare your own test data
111 | file_path = "xx/adv+"
112 | # file_path = "xx/adv-"
113 |
114 | # get file paths
115 | imagefiles = [os.path.join(file_path,n) for n in os.listdir(file_path)]
116 | imagefiles = imagefiles[:100]
117 | score = Sample.main(imagefiles)
118 | print(score)
--------------------------------------------------------------------------------
/comtest/test_azure.py:
--------------------------------------------------------------------------------
1 | # need pip install azure-cognitiveservices-vision-computervision pillow
2 | # 5 call per minute, 5k call per month!!!
3 | from azure.cognitiveservices.vision.computervision import ComputerVisionClient
4 | from azure.cognitiveservices.vision.computervision.models import OperationStatusCodes
5 | from azure.cognitiveservices.vision.computervision.models import VisualFeatureTypes
6 | from msrest.authentication import CognitiveServicesCredentials
7 |
8 |
9 | import os,sys,time
10 | from nltk.metrics import edit_distance
11 | from utils import Logger
12 |
13 | '''
14 | Authenticate
15 | Authenticates your credentials and creates a client.
16 | '''
17 | subscription_key = "xx"
18 | endpoint = "xx"
19 | request_interval = 4
20 |
21 | def main(imagefiles):
22 | computervision_client = ComputerVisionClient(endpoint, CognitiveServicesCredentials(subscription_key))
23 |
24 | fcnt, ED_sum = 0, 0
25 | for imagefile in imagefiles:
26 | imagename = os.path.basename(imagefile)
27 | true_label = imagename.split("_")[1]
28 | # print(f"file: {file}")
29 | with open(imagefile, 'rb') as image_steam:
30 | read_response = computervision_client.recognize_printed_text_in_stream(image_steam, raw=True)
31 | time.sleep(request_interval)
32 | for region in read_response.output.regions:
33 | lines = region.lines
34 | for line in lines:
35 | line_text = " ".join([word.text for word in line.words])
36 | if true_label != line_text:
37 | fcnt += 1
38 | ED = edit_distance(true_label, line_text)
39 | else:
40 | ED = 0
41 | print("label:{} ---> result:{}".format(true_label, line_text))
42 | ED_sum += ED
43 | if fcnt != 0:
44 | score = {"DSR":fcnt/len(imagefiles), "ED":ED_sum/fcnt}
45 | else:
46 | score = {"DSR":0, "ED":0}
47 | return score
48 |
49 |
50 |
51 | if __name__ == "__main__":
52 | log_file= os.path.join('./res-comtest/azure', 'up5a.log')
53 | sys.stdout = Logger(log_file)
54 |
55 | # prepare your own test data
56 | file_path = "xx/adv+"
57 | # file_path = "xx/adv-"
58 |
59 | imagefiles = [os.path.join(file_path,n) for n in os.listdir(file_path)]
60 | score = main(imagefiles)
61 | print(score)
--------------------------------------------------------------------------------
/comtest/test_baidu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | import os,sys
3 | import time
4 | from aip import AipOcr
5 | from nltk.metrics import edit_distance
6 | from utils import Logger
7 |
8 | APP_ID = 'xx'
9 | API_KEY = 'xx'
10 | SECRET_KEY = 'xx'
11 | request_interval = 1
12 |
13 | def get_file_content(filePath):
14 | with open(filePath, 'rb') as fp:
15 | return fp.read()
16 |
17 | def get_score(jsons):
18 | fcnt, ED_sum = 0, 0
19 | for json in jsons:
20 | fcnt += 0 if json["label"] == json["words_result"][0]["words"] else 1
21 | ED_sum += edit_distance(json["label"], json["words_result"][0]["words"])
22 | if fcnt != 0:
23 | score = {"DSR":fcnt/len(jsons), "ED":ED_sum/fcnt}
24 | else:
25 | score = {"DSR":fcnt/len(jsons), "ED":0}
26 | return score
27 |
28 |
29 | if __name__ == "__main__":
30 | log_file= os.path.join('./res-comtest/baidu', 'up5a.log')
31 | sys.stdout = Logger(log_file)
32 | # create client
33 | client = AipOcr(APP_ID, API_KEY, SECRET_KEY)
34 |
35 | # prepare your own test data
36 | data_path = "xx/adv+"
37 | # data_path = "xx/adv-"
38 |
39 | imagefiles = [os.path.join(data_path,n) for n in os.listdir(data_path)]
40 | imagefiles = imagefiles[:100]
41 |
42 | # upload to baidu ocr
43 | result_jsons = []
44 | for imagefile in imagefiles:
45 | b_img = get_file_content(imagefile)
46 | answer = client.basicGeneral(b_img)
47 | imagename = os.path.basename(imagefile)
48 | answer["label"] = imagename.split("_")[1]
49 | try:
50 | print("label:{} ---> result:{}".format(
51 | answer["label"],answer["words_result"][0]["words"]
52 | ))
53 | except:
54 | answer["words_result"].append({'words':""})
55 | print(answer)
56 | result_jsons.append(answer)
57 | time.sleep(request_interval)
58 |
59 | # calc score
60 | score = get_score(result_jsons)
61 | print(score)
62 |
63 |
64 | '''
65 | basicGeneral: 通用文字识别(标准版)
66 | return json:
67 | {
68 | "words_result": [
69 | {
70 | "words": "firm"
71 | }
72 | ],
73 | "words_result_num": 1,
74 | "log_id": 1589115807642170340
75 | }
76 |
77 | not found words return json:
78 | return json:
79 | {
80 | "words_result": [],
81 | "words_result_num": 0,
82 | "log_id": 1589115807642170340
83 | }
84 | '''
--------------------------------------------------------------------------------
/comtest/test_huawei.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from huaweicloudsdkcore.auth.credentials import BasicCredentials
3 | from huaweicloudsdkocr.v1.region.ocr_region import OcrRegion
4 | from huaweicloudsdkcore.exceptions import exceptions
5 | from huaweicloudsdkocr.v1 import *
6 | import base64
7 | import os,sys
8 | from nltk.metrics import edit_distance
9 | from utils import Logger
10 |
11 | ak = "xx"
12 | sk = "xx"
13 |
14 | def get_file_content(filePath):
15 | with open(filePath, 'rb') as image_file:
16 | return base64.b64encode(image_file.read())
17 |
18 | def main(imagefiles):
19 | credentials = BasicCredentials(ak, sk)
20 |
21 | client = OcrClient.new_builder() \
22 | .with_credentials(credentials) \
23 | .with_region(OcrRegion.value_of("cn-east-3")) \
24 | .build()
25 |
26 | fcnt, ED_sum = 0, 0
27 | try:
28 | for imagefile in imagefiles:
29 | encoded_string = get_file_content(imagefile)
30 | # with open(file, "rb") as image_file:
31 | # encoded_string = base64.b64encode(image_file.read())
32 | request = RecognizeGeneralTextRequest()
33 | request.body = GeneralTextRequestBody(image=encoded_string)
34 | response = client.recognize_general_text(request)
35 | data = response.result.to_dict()
36 | words = [word['words'] for word in data['words_block_list']]
37 | content = "".join(words)
38 |
39 | imagename = os.path.basename(imagefile)
40 | true_label = imagename.split("_")[1]
41 | if true_label != content:
42 | fcnt += 1
43 | ED = edit_distance(true_label, content)
44 | else:
45 | ED = 0
46 | print("label:{} ---> result:{}".format(true_label, content))
47 | ED_sum += ED
48 |
49 | except exceptions.ClientRequestException as e:
50 | print(e.status_code)
51 | print(e.request_id)
52 | print(e.error_code)
53 | print(e.error_msg)
54 | if fcnt != 0:
55 | score = {"DSR":fcnt/len(imagefiles), "ED":ED_sum/fcnt}
56 | else:
57 | score = {"DSR":0, "ED":0}
58 | return score
59 |
60 | if __name__ == "__main__":
61 | log_file= os.path.join('./res-comtest/huawei', 'up5a.log')
62 | sys.stdout = Logger(log_file)
63 |
64 | # prepare your own test data
65 | file_path = "xx/adv+"
66 | # file_path = "xx/adv-"
67 |
68 | # get file paths
69 | imagefiles = [os.path.join(file_path,n) for n in os.listdir(file_path)]
70 | imagefiles = imagefiles[:1]
71 | score = main(imagefiles)
72 | print(score)
73 |
--------------------------------------------------------------------------------
/data/protego/up/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ruby-He/ProTegO/5fbaa18ef15654a473e899571f9400a83cb38277/data/protego/up/5.png
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import cv2
4 | import torch.utils.data
5 | from torch.utils.data import Dataset
6 |
7 |
8 | """ProtegO dataset"""
9 | def up_dataset(up_path):
10 | up = cv2.imread(up_path)
11 | up = cv2.resize(up, (100, 32))
12 | up = cv2.cvtColor(up, cv2.COLOR_BGR2RGB)
13 | up = torch.FloatTensor(up)
14 | up = up / 255 # normalization to [0,1]
15 | up = up.permute(2,0,1) # [C, H, W]
16 |
17 | return up
18 |
19 | class train_dataset_builder(Dataset):
20 | def __init__(self, height, width, img_path):
21 | '''
22 | height: input height to model
23 | width: input width to model
24 | total_img_path: path with all images
25 | seq_len: sequence length
26 | '''
27 | self.height = height
28 | self.width = width
29 | self.img_path = img_path
30 | self.dataset = []
31 |
32 | img = []
33 | for i,j,k in os.walk(self.img_path):
34 | for file in k:
35 | file_name = os.path.join(i ,file)
36 | img.append(file_name)
37 | self.total_img_name = img
38 |
39 | for img_name in self.total_img_name:
40 | _, label, _ = img_name.split('_')
41 | self.dataset.append([img_name, label])
42 |
43 |
44 | def __getitem__(self, index):
45 | img_name, label = self.dataset[index]
46 | IMG = cv2.imread(img_name)
47 | IMG = cv2.resize(IMG, (self.width, self.height)) # resize
48 |
49 | # binarization processing
50 | gray = cv2.cvtColor(IMG, cv2.COLOR_BGR2GRAY)
51 | _, binary = cv2.threshold(gray,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
52 | IMG = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB)
53 | IMG = torch.FloatTensor(IMG) # [H, W, C]
54 | IMG = IMG / 255 # normalization to [0,1]
55 | IMG = IMG.permute(2,0,1) # [C, H, W]
56 |
57 | return IMG, label
58 |
59 | def __len__(self):
60 | return len(self.dataset)
61 |
62 | class test_dataset_builder(Dataset):
63 | def __init__(self, height, width, img_path):
64 | self.height = height
65 | self.width = width
66 | self.img_path = img_path
67 | self.dataset = []
68 |
69 | img = []
70 | for i,j,k in os.walk(self.img_path):
71 | for file in k:
72 | file_name = os.path.join(i ,file)
73 | img.append(file_name)
74 | self.total_img_name = img
75 |
76 | for img_name in self.total_img_name:
77 | img_index, label, img_adv = img_name.split('_')
78 | img_adv = img_adv.split('.')
79 | index_or_advlogo = img_adv[0]
80 | self.dataset.append([img_name, label, img_index, index_or_advlogo])
81 | self.dataset = sorted(self.dataset)
82 |
83 | def __getitem__(self, index):
84 | img_name, label, img_index, index_or_advlogo = self.dataset[index]
85 | IMG = cv2.imread(img_name)
86 | ORG = cv2.resize(IMG, (self.width, self.height))
87 |
88 | IMG = cv2.cvtColor(ORG, cv2.COLOR_BGR2RGB)
89 | IMG = torch.FloatTensor(IMG) # convert to tensor [H, W, C]
90 | IMG = IMG / 255
91 | IMG = IMG.permute(2,0,1) # [C, H, W]
92 |
93 | gray = cv2.cvtColor(ORG, cv2.COLOR_BGR2GRAY)
94 | _, binary = cv2.threshold(gray,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
95 | mask = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB)
96 | mask = torch.FloatTensor(mask) # [H, W, C]
97 | mask = mask / 255 # normalization to [0,1]
98 | mask = mask.permute(2,0,1) # [C, H, W]
99 |
100 | return IMG, label, img_index, index_or_advlogo, img_name, mask
101 |
102 | def __len__(self):
103 | return len(self.dataset)
104 |
105 | class test_adv_dataset(Dataset):
106 | def __init__(self, height, width, img_path):
107 | self.height = height
108 | self.width = width
109 | self.img_path = img_path
110 | self.dataset = []
111 |
112 | img = []
113 | for i,j,k in os.walk(self.img_path):
114 | for file in k:
115 | file_name = os.path.join(i ,file)
116 | img.append(file_name)
117 | self.total_img_name = img
118 |
119 | for img_name in self.total_img_name:
120 | img_index, label, img_adv = img_name.split('_')
121 | img_adv = img_adv.split('.')
122 | index_or_advlogo = img_adv[0]
123 | self.dataset.append([img_name, label, img_index, index_or_advlogo])
124 | self.dataset = sorted(self.dataset)
125 |
126 | def __getitem__(self, index):
127 | img_name, label, img_index, index_or_advlogo = self.dataset[index]
128 | IMG = cv2.imread(img_name)
129 | IMG = cv2.resize(IMG, (self.width, self.height))
130 |
131 | # binarization processing
132 | gray = cv2.cvtColor(IMG, cv2.COLOR_BGR2GRAY)
133 | _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
134 | img_b = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB)
135 | img_b = torch.FloatTensor(img_b)
136 | img_b = img_b / 255 # normalization to [0,1]
137 | img_b = img_b.permute(2,0,1) # [C, H, W]
138 |
139 | img = cv2.cvtColor(IMG, cv2.COLOR_BGR2RGB)
140 | img = torch.FloatTensor(img)
141 | img = img /255 # normalization to [0,1]
142 | img = img.permute(2,0,1) # [C, H, W]
143 |
144 | return img_b, img, label, img_index, index_or_advlogo, img_name
145 |
146 | def __len__(self):
147 | return len(self.dataset)
148 |
149 |
150 | """STR models dataset"""
151 | class strdataset(Dataset):
152 | def __init__(self, height, width, total_img_path):
153 | '''
154 | height: input height to model
155 | width: input width to model
156 | total_img_path: path with all images
157 | seq_len: sequence length
158 | '''
159 | self.total_img_path = total_img_path
160 | self.height = height
161 | self.width = width
162 | img = []
163 | self.dataset = []
164 |
165 | for i,_,k in os.walk(total_img_path):
166 | for file in k:
167 | file_name = os.path.join(i ,file)
168 | img.append(file_name)
169 | self.total_img_name = img
170 |
171 | for img_name in self.total_img_name:
172 | _, label, _ = img_name.split('_')
173 | self.dataset.append([img_name, label])
174 |
175 | def __getitem__(self, index):
176 | img_name, label = self.dataset[index]
177 | IMG = cv2.imread(img_name)
178 | IMG = cv2.cvtColor(IMG, cv2.COLOR_BGR2RGB)
179 | IMG = cv2.resize(IMG, (self.width, self.height)) # resize
180 | IMG = (IMG - 127.5)/127.5 # normalization to [-1,1]
181 | IMG = torch.FloatTensor(IMG) # convert to tensor [H, W, C]
182 | IMG = IMG.permute(2,0,1) # [C, H, W]
183 |
184 | return IMG, label
185 |
186 | def __len__(self):
187 | return len(self.dataset)
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: protego
2 | channels:
3 | - pytorch
4 | - http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
5 | - conda-forge
6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
11 | dependencies:
12 | - _libgcc_mutex=0.1=main
13 | - _openmp_mutex=4.5=1_gnu
14 | - absl-py=1.0.0=pyhd8ed1ab_0
15 | - aiohttp=3.7.4.post0=py38h7f8727e_2
16 | - async-timeout=3.0.1=py_1000
17 | - attrs=21.4.0=pyhd8ed1ab_0
18 | - backcall=0.2.0=pyhd3eb1b0_0
19 | - blas=1.0=mkl
20 | - blinker=1.4=py_1
21 | - blosc=1.21.0=h8c45485_0
22 | - brotli=1.0.9=he6710b0_2
23 | - brotlipy=0.7.0=py38h497a2fe_1001
24 | - brunsli=0.1=h2531618_0
25 | - bzip2=1.0.8=h7b6447c_0
26 | - c-ares=1.17.1=h27cfd23_0
27 | - ca-certificates=2021.10.8=ha878542_0
28 | - cachetools=5.0.0=pyhd8ed1ab_0
29 | - certifi=2021.10.8=py38h578d9bd_2
30 | - cffi=1.15.0=py38hd667e15_1
31 | - cfitsio=3.470=hf0d0db6_6
32 | - chardet=4.0.0=py38h578d9bd_3
33 | - charls=2.2.0=h2531618_0
34 | - charset-normalizer=2.0.12=pyhd8ed1ab_0
35 | - click=8.1.3=py38h578d9bd_0
36 | - cloudpickle=2.0.0=pyhd3eb1b0_0
37 | - cryptography=35.0.0=py38ha5dfef3_0
38 | - cudatoolkit=11.3.1=h2bc3f7f_2
39 | - cycler=0.11.0=pyhd3eb1b0_0
40 | - cytoolz=0.11.0=py38h7b6447c_0
41 | - dask-core=2021.10.0=pyhd3eb1b0_0
42 | - decorator=5.1.0=pyhd3eb1b0_0
43 | - ffmpeg=4.3=hf484d3e_0
44 | - freetype=2.11.0=h70c0345_0
45 | - fsspec=2021.10.1=pyhd3eb1b0_0
46 | - giflib=5.2.1=h7b6447c_0
47 | - gmp=6.2.1=h2531618_2
48 | - gnutls=3.6.15=he1e5248_0
49 | - google-auth=2.6.6=pyh6c4a22f_0
50 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0
51 | - idna=3.3=pyhd8ed1ab_0
52 | - imagecodecs=2021.8.26=py38h4cda21f_0
53 | - imageio=2.9.0=pyhd3eb1b0_0
54 | - importlib-metadata=4.11.3=py38h578d9bd_1
55 | - intel-openmp=2021.4.0=h06a4308_3561
56 | - ipython=7.29.0=py38hb070fc8_0
57 | - jedi=0.18.0=py38h06a4308_1
58 | - jpeg=9d=h7f8727e_0
59 | - jxrlib=1.1=h7b6447c_2
60 | - krb5=1.19.2=hac12032_0
61 | - lame=3.100=h7b6447c_0
62 | - lcms2=2.12=h3be6417_0
63 | - ld_impl_linux-64=2.35.1=h7274673_9
64 | - lerc=3.0=h295c915_0
65 | - libaec=1.0.4=he6710b0_1
66 | - libcurl=7.80.0=h0b77cf5_0
67 | - libdeflate=1.8=h7f8727e_5
68 | - libedit=3.1.20210910=h7f8727e_0
69 | - libev=4.33=h7f8727e_1
70 | - libffi=3.3=he6710b0_2
71 | - libgcc-ng=9.3.0=h5101ec6_17
72 | - libgfortran-ng=7.5.0=ha8ba4b0_17
73 | - libgfortran4=7.5.0=ha8ba4b0_17
74 | - libgomp=9.3.0=h5101ec6_17
75 | - libiconv=1.15=h63c8f33_5
76 | - libidn2=2.3.2=h7f8727e_0
77 | - libnghttp2=1.46.0=hce63b2e_0
78 | - libpng=1.6.37=hbc83047_0
79 | - libprotobuf=3.15.8=h780b84a_0
80 | - libssh2=1.9.0=h1ba5d50_1
81 | - libstdcxx-ng=9.3.0=hd4cf53a_17
82 | - libtasn1=4.16.0=h27cfd23_0
83 | - libtiff=4.2.0=h85742a9_0
84 | - libunistring=0.9.10=h27cfd23_0
85 | - libuv=1.40.0=h7b6447c_0
86 | - libwebp=1.2.0=h89dd481_0
87 | - libwebp-base=1.2.0=h27cfd23_0
88 | - libzopfli=1.0.3=he6710b0_0
89 | - locket=0.2.1=py38h06a4308_1
90 | - lz4-c=1.9.3=h295c915_1
91 | - markdown=3.3.7=pyhd8ed1ab_0
92 | - matplotlib-base=3.5.0=py38h3ed280b_0
93 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2
94 | - mkl=2021.4.0=h06a4308_640
95 | - mkl-service=2.4.0=py38h7f8727e_0
96 | - mkl_fft=1.3.1=py38hd3c417c_0
97 | - mkl_random=1.2.2=py38h51133e4_0
98 | - multidict=5.2.0=py38h7f8727e_2
99 | - munkres=1.1.4=py_0
100 | - ncurses=6.3=heee7806_1
101 | - nettle=3.7.3=hbbd107a_1
102 | - networkx=2.6.3=pyhd3eb1b0_0
103 | - numpy=1.21.2=py38h20f2e39_0
104 | - numpy-base=1.21.2=py38h79a1101_0
105 | - oauthlib=3.2.0=pyhd8ed1ab_0
106 | - olefile=0.46=pyhd3eb1b0_0
107 | - openh264=2.1.0=hd408876_0
108 | - openjpeg=2.4.0=h3ad879b_0
109 | - openssl=1.1.1n=h7f8727e_0
110 | - packaging=21.3=pyhd3eb1b0_0
111 | - parso=0.8.2=pyhd3eb1b0_0
112 | - partd=1.2.0=pyhd3eb1b0_0
113 | - pexpect=4.8.0=pyhd3eb1b0_3
114 | - pickleshare=0.7.5=pyhd3eb1b0_1003
115 | - pip=21.2.4=py38h06a4308_0
116 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0
117 | - ptyprocess=0.7.0=pyhd3eb1b0_2
118 | - pyasn1=0.4.8=py_0
119 | - pycparser=2.21=pyhd8ed1ab_0
120 | - pygments=2.10.0=pyhd3eb1b0_0
121 | - pyjwt=2.3.0=pyhd8ed1ab_1
122 | - pyopenssl=22.0.0=pyhd8ed1ab_0
123 | - pysocks=1.7.1=py38h578d9bd_5
124 | - python=3.8.12=h12debd9_0
125 | - python-dateutil=2.8.2=pyhd3eb1b0_0
126 | - python_abi=3.8=2_cp38
127 | - pytorch=1.10.0=py3.8_cuda11.3_cudnn8.2.0_0
128 | - pytorch-mutex=1.0=cuda
129 | - pyu2f=0.1.5=pyhd8ed1ab_0
130 | - pywavelets=1.1.1=py38h7b6447c_2
131 | - pyyaml=6.0=py38h7f8727e_1
132 | - readline=8.1=h27cfd23_0
133 | - requests=2.27.1=pyhd8ed1ab_0
134 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0
135 | - rsa=4.8=pyhd8ed1ab_0
136 | - setuptools=58.0.4=py38h06a4308_0
137 | - six=1.16.0=pyhd3eb1b0_0
138 | - snappy=1.1.8=he6710b0_0
139 | - sqlite=3.36.0=hc218d9a_0
140 | - tensorboard=2.9.0=pyhd8ed1ab_0
141 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
142 | - tifffile=2021.7.2=pyhd3eb1b0_2
143 | - tk=8.6.11=h1ccaba5_0
144 | - toolz=0.11.2=pyhd3eb1b0_0
145 | - torchaudio=0.10.0=py38_cu113
146 | - torchvision=0.11.1=py38_cu113
147 | - traitlets=5.1.1=pyhd3eb1b0_0
148 | - urllib3=1.26.9=pyhd8ed1ab_0
149 | - wcwidth=0.2.5=pyhd3eb1b0_0
150 | - wheel=0.37.0=pyhd3eb1b0_1
151 | - xz=5.2.5=h7b6447c_0
152 | - yaml=0.2.5=h7b6447c_0
153 | - yarl=1.6.3=py38h27cfd23_0
154 | - zfp=0.5.5=h2531618_6
155 | - zipp=3.8.0=pyhd8ed1ab_0
156 | - zlib=1.2.11=h7b6447c_3
157 | - zstd=1.4.9=haebb681_0
158 | - pip:
159 | - addict==2.4.0
160 | - alibabacloud-credentials==0.3.2
161 | - alibabacloud-endpoint-util==0.0.3
162 | - alibabacloud-gateway-spi==0.0.1
163 | - alibabacloud-openapi-util==0.2.1
164 | - alibabacloud-tea==0.3.1
165 | - alibabacloud-tea-openapi==0.3.7
166 | - alibabacloud-tea-util==0.3.8
167 | - alibabacloud-tea-xml==0.0.2
168 | - anyconfig==0.9.10
169 | - anyio==3.6.1
170 | - arabic-reshaper==2.1.3
171 | - asynctest==0.13.0
172 | - azure-cognitiveservices-vision-computervision==0.9.0
173 | - azure-common==1.1.28
174 | - azure-core==1.26.4
175 | - baidu-aip==2.2.17.0
176 | - beautifulsoup4==4.11.1
177 | - codecov==2.1.12
178 | - colorama==0.4.5
179 | - commonmark==0.9.1
180 | - coverage==6.4.4
181 | - diffimg==0.2.3
182 | - editdistance==0.3.1
183 | - fire==0.5.0
184 | - flake8==5.0.4
185 | - flask==2.2.2
186 | - fonttools==4.28.1
187 | - future==0.18.2
188 | - gevent==21.12.0
189 | - gevent-websocket==0.10.1
190 | - greenlet==1.1.3
191 | - grpcio==1.46.0
192 | - h11==0.12.0
193 | - httpcore==0.15.0
194 | - httpx==0.23.0
195 | - huaweicloudsdkcore==3.1.37
196 | - huaweicloudsdkocr==3.1.37
197 | - imgaug==0.2.8
198 | - iniconfig==1.1.1
199 | - isodate==0.6.1
200 | - isort==5.10.1
201 | - itsdangerous==2.1.2
202 | - jarowinkler==1.2.3
203 | - jinja2==3.1.2
204 | - joblib==1.2.0
205 | - kiwisolver==1.3.2
206 | - kornia==0.6.7
207 | - kwarray==0.6.4
208 | - lanms-neo==1.0.2
209 | - lmdb==1.3.0
210 | - lpips==0.1.4
211 | - markupsafe==2.1.1
212 | - mccabe==0.7.0
213 | - mmcv-full==1.6.2
214 | - mmdet==2.25.2
215 | - mmocr==0.6.2
216 | - model-index==0.1.11
217 | - msrest==0.7.1
218 | - munch==2.5.0
219 | - natsort==8.1.0
220 | - nltk==3.8.1
221 | - opencv-python==4.6.0.66
222 | - openmim==0.3.2
223 | - ordered-set==4.1.0
224 | - pandas==1.4.2
225 | - pillow==9.2.0
226 | - pluggy==1.0.0
227 | - polygon3==3.0.9.1
228 | - protobuf==3.20.1
229 | - py==1.11.0
230 | - pyasn1-modules==0.2.8
231 | - pyclipper==1.1.0.post3
232 | - pycocotools==2.0.5
233 | - pycodestyle==2.9.1
234 | - pyflakes==2.5.0
235 | - pyparsing==3.0.6
236 | - pytest==7.1.3
237 | - pytest-cov==4.0.0
238 | - pytest-runner==6.0.0
239 | - python-bidi==0.4.2
240 | - pytz==2021.3
241 | - rapidfuzz==2.10.2
242 | - regex==2022.10.31
243 | - requests-toolbelt==0.10.1
244 | - rfc3986==1.5.0
245 | - rich==12.5.1
246 | - scikit-image==0.19.3
247 | - scipy==1.4.1
248 | - setuptools-scm==6.0.1
249 | - shapely==1.8.2
250 | - simplejson==3.19.1
251 | - sniffio==1.2.0
252 | - sortedcontainers==2.4.0
253 | - soupsieve==2.3.2.post1
254 | - tabulate==0.8.10
255 | - tensorboard-data-server==0.6.1
256 | - tensorboardx==2.5
257 | - termcolor==2.2.0
258 | - terminaltables==3.1.10
259 | - tomli==2.0.1
260 | - torch-tb-profiler==0.4.0
261 | - torchsummary==1.5.1
262 | - tqdm==4.62.3
263 | - trdg==1.8.0
264 | - typing-extensions==4.3.0
265 | - ubelt==1.2.2
266 | - websocket-client==1.3.3
267 | - websockets==10.3
268 | - werkzeug==2.2.2
269 | - wikipedia==1.4.0
270 | - xdoctest==1.1.0
271 | - yapf==0.32.0
272 | - zope-event==4.5.0
273 | - zope-interface==5.4.0
274 |
--------------------------------------------------------------------------------
/gen_udp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from torchvision import utils as vutils
7 |
8 | import models.GAN_models as G_models
9 | from dataset import test_dataset_builder, up_dataset
10 |
11 | def gen_underpaintings(opt, device, Generator_path, adv_output_path1, adv_output_path2, per_output_path, map_output_path):
12 | BOX_MIN = 0
13 | BOX_MAX = 255
14 |
15 | # load the well-trained generator
16 | pretrained_generator_path = os.path.join(Generator_path + '/netG_epoch_' + str(opt.epochs) + '.pth')
17 | pretrained_G = G_models.Generator(opt.img_channel).to(device)
18 | pretrained_G.load_state_dict(torch.load(pretrained_generator_path))
19 | pretrained_G.eval()
20 |
21 | test_dataset = test_dataset_builder(opt.imgH, opt.imgW, opt.test_path)
22 | test_dataloader = torch.utils.data.DataLoader(
23 | test_dataset, batch_size=1,
24 | shuffle=False, num_workers=4)
25 |
26 | up = up_dataset(opt.up_path).to(device)
27 | up = up.repeat(1,1,1,1)
28 | gui_net = torch.load(opt.dt_model).to(device)
29 | gui_net.eval()
30 |
31 | for i, data in enumerate(test_dataloader, 0):
32 | ori_labels = data[1][0]
33 | img_index = data[3][0]
34 | test_img = data[5]
35 | test_img = test_img.to(device)
36 | mask = test_img.detach().to(device)
37 | if opt.dark:
38 | test_img = (1-test_img) + torch.mul(mask, up)
39 | else:
40 | test_img = torch.mul(test_img, up)
41 |
42 | # gen adv_img
43 | test_map = G_models.guided_net(test_img, gui_net)
44 | vutils.save_image(test_map,"{}/{}_{}_map.png".format(map_output_path, img_index, ori_labels))
45 |
46 | perturbation = pretrained_G(up)
47 | perturbation = torch.clamp(perturbation, -opt.eps, opt.eps)
48 | vutils.save_image(perturbation, "{}/{}_{}_per.png".format(per_output_path, img_index, ori_labels))
49 |
50 | permap = G_models.guided_net(perturbation, gui_net)
51 | vutils.save_image(permap, "{}/{}_{}_permap.png".format(map_output_path, img_index, ori_labels))
52 | vutils.save_image(permap*100, "{}/{}_{}_permap100.png".format(map_output_path, img_index, ori_labels))
53 | perturbation = torch.mul(mask, perturbation)
54 |
55 | """convert float32 to uint8:
56 | Avoid the effect of float32 on
57 | generating fully complementary frames
58 | """
59 | perturbation_int = (perturbation*255).type(torch.int8)
60 | adv_img1_uint = (test_img*255).type(torch.uint8) - perturbation_int
61 | adv_img1_uint = torch.clamp(adv_img1_uint, BOX_MIN, BOX_MAX)
62 | adv_img2_uint = (test_img*255).type(torch.uint8) + perturbation_int
63 | adv_img2_uint = torch.clamp(adv_img2_uint, BOX_MIN, BOX_MAX)
64 | print((adv_img1_uint + perturbation_int).equal(adv_img2_uint - perturbation_int))
65 |
66 | adv1_map = G_models.guided_net(adv_img1_uint/255, gui_net)
67 | adv2_map = G_models.guided_net(adv_img2_uint/255, gui_net)
68 | vutils.save_image(adv1_map,"{}/{}_{}_map-.png".format(map_output_path, img_index, ori_labels))
69 | vutils.save_image(adv2_map,"{}/{}_{}_map+.png".format(map_output_path, img_index, ori_labels))
70 |
71 | adv_img1_uint = adv_img1_uint.squeeze(0).permute(1,2,0)
72 | adv_img1_uint = np.uint8(adv_img1_uint.cpu())
73 | adv_img1 = cv2.cvtColor(adv_img1_uint, cv2.COLOR_RGB2BGR)
74 | adv_img2_uint = adv_img2_uint.squeeze(0).permute(1,2,0)
75 | adv_img2_uint = np.uint8(adv_img2_uint.cpu())
76 | adv_img2 = cv2.cvtColor(adv_img2_uint, cv2.COLOR_RGB2BGR)
77 |
78 | cv2.imwrite("{}/{}_{}_adv-.png".format(adv_output_path1, img_index, ori_labels), adv_img1)
79 | cv2.imwrite("{}/{}_{}_adv+.png".format(adv_output_path2, img_index, ori_labels), adv_img2)
80 |
81 |
82 | if __name__ == '__main__':
83 | parser = argparse.ArgumentParser()
84 | parser.add_argument('--test_path', type= str, required=True, help='path of font test dataset')
85 | parser.add_argument('--up_path', type= str, default='data/protego/up/5.png', help='underpaintings path')
86 | parser.add_argument('--dark', action='store_true', help='use dark background and white text.')
87 | parser.add_argument('--dt_model', type=str, default='/models/dbnet++.pth',
88 | help='path of our guided network DBnet++')
89 | parser.add_argument('--batchsize', type= int, default=4, help='batchsize of training ProTegO')
90 | parser.add_argument('--epochs', type= int, default=60, help='epochs of training ProTegO')
91 | parser.add_argument('--eps', type=float, default=40/255, help='maximum perturbation')
92 | parser.add_argument('--use_eh', action='store_true', help='Use enhancement layers')
93 | parser.add_argument('--use_guide', action='store_true', help='use guided network')
94 | parser.add_argument('--img_channel', type=int, default=3,
95 | help='the number of input channel of text images')
96 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
97 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
98 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
99 | parser.add_argument('--character', type=str,default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
100 | parser.add_argument('--sensitive', action='store_false', help='default for sensitive character mode')
101 |
102 | parser.add_argument('--output', default='res-font', help='the path to save all ouput results')
103 | parser.add_argument('--test_output', default='test-out', help='the path to save output of test results')
104 | parser.add_argument('--adv_output', default='adv', help='the path to save adversarial text images')
105 | parser.add_argument('--per_output', default='perturbation', help='the path to save output of adversarial perturbation')
106 | parser.add_argument('--map_output', default='map', help='the path to save mapping results')
107 | parser.add_argument('--train_output', default='train-out', help='the path to save output of intermediate training results')
108 |
109 | parser.add_argument('--saveG', required=True, help='the path to save generator which is used for generated AEs')
110 |
111 | opt = parser.parse_args()
112 | print(opt)
113 |
114 | output_path = opt.output
115 | Generator_path = opt.saveG
116 |
117 | font_name = opt.test_path.split('/')[-1]
118 | test_output_path = os.path.join(output_path, font_name, opt.test_output)
119 | adv_output_path1 = os.path.join(test_output_path, opt.adv_output, 'adv-')
120 | adv_output_path2 = os.path.join(test_output_path, opt.adv_output, 'adv+')
121 | per_output_path = os.path.join(test_output_path, opt.per_output)
122 | map_output_path = os.path.join(test_output_path, opt.map_output)
123 |
124 |
125 | if not os.path.exists(test_output_path):
126 | os.makedirs(test_output_path)
127 | if not os.path.exists(adv_output_path1):
128 | os.makedirs(adv_output_path1)
129 | if not os.path.exists(adv_output_path2):
130 | os.makedirs(adv_output_path2)
131 | if not os.path.exists(per_output_path):
132 | os.makedirs(per_output_path)
133 | if not os.path.exists(map_output_path):
134 | os.makedirs(map_output_path)
135 |
136 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137 | gen_underpaintings(opt, device, Generator_path, adv_output_path1, adv_output_path2, per_output_path, map_output_path)
138 |
139 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import random
5 | import string
6 | import shutil
7 | import argparse
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | import torch.utils.data
11 | import numpy as np
12 | from utils import Logger
13 | from train_protego import run_train
14 | from gen_udp import gen_underpaintings
15 | from test_udp import test_udp
16 |
17 |
18 | """ Basic parameters settings """
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--manualSeed', type=int, default=3407,
21 | help='Refer to the settings in paper ')
22 | parser.add_argument('--train_path', type= str, required=True, help='path of training dataset')
23 | parser.add_argument('--test_path', type= str, required=True, help='path of test dataset')
24 | parser.add_argument('--up_path', type= str, default= "data/protego/up/5.png",
25 | help='path of the pre-processed underpaintings')
26 | parser.add_argument('--dt_model', type=str, default='models/dbnet++.pth', help='path of our guided network DBnet++')
27 | parser.add_argument('--batchsize', type= int, default=4, help='batchsize of training ProTegO')
28 | parser.add_argument('--epochs', type= int, default=60, help='epochs of training ProTegO')
29 | parser.add_argument('--eps', type=float, default=40/255, help='maximum perturbation')
30 | parser.add_argument('--lambda1', type= float, default=1e-3, help='the weight of hinge_loss')
31 | parser.add_argument('--lambda2', type= float, default=2, help='the weight of guide_loss')
32 | parser.add_argument('--lambda3', type= float, default=1, help='the weight of gan_loss')
33 | parser.add_argument('--lambda4', type= float, default=10, help='the weight of adv_loss')
34 | parser.add_argument('--dark', action='store_true', help='use dark background and white text')
35 | parser.add_argument('--b', action='store_true', help='robust test for both frames of adversarial text images')
36 | parser.add_argument('--use_eh', action='store_true', help='Use enhancement layers')
37 | parser.add_argument('--use_guide', action='store_true', help='use guided network')
38 |
39 | """ Model Architecture """
40 | parser.add_argument('--str_model', type=str, help="path of pretrainted STR models for evaluation",
41 | default='STR_modules/downloads_models/STARNet-TPS-ResNet-BiLSTM-CTC-sensitive.pth')
42 | parser.add_argument('--Transformation', type=str, default='TPS', help='Transformation stage. None|TPS')
43 | parser.add_argument('--FeatureExtraction', type=str, default='ResNet', help='FeatureExtraction stage. VGG|RCNN|ResNet')
44 | parser.add_argument('--SequenceModeling', type=str, default='BiLSTM', help='SequenceModeling stage. None|BiLSTM')
45 | parser.add_argument('--Prediction', type=str, default='CTC', help='Prediction stage. CTC|Attn')
46 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
47 | parser.add_argument('--input_channel', type=int, default=3,
48 | help='the number of input channel of Feature extractor')
49 | parser.add_argument('--output_channel', type=int, default=512,
50 | help='the number of output channel of Feature extractor')
51 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
52 |
53 | """ Data processing """
54 | parser.add_argument('--img_channel', type=int, default=3,
55 | help='the number of input channel of text images')
56 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
57 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
58 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
59 | parser.add_argument('--character', type=str,default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
60 | parser.add_argument('--sensitive', action='store_false', help='default for sensitive character mode')
61 |
62 | """ Output settings """
63 | parser.add_argument('--output', default=f'{time.strftime("res-%m%d-%H%M")}',
64 | help='the path to save all ouput results')
65 | parser.add_argument('--train_output', default='train-out', help='the path to save intermediate training results')
66 | parser.add_argument('--saveG', default='Generators', help='the path to save generators')
67 | parser.add_argument('--loss', default='losses', help='the path to save all training losses')
68 | parser.add_argument('--test_output', default='test-out', help='the path to save output of test results')
69 | parser.add_argument('--adv_output', default='adv', help='the path to save adversarial text images')
70 | parser.add_argument('--per_output', default='perturbation', help='the path to save adversarial perturbation')
71 | parser.add_argument('--map_output', default='map', help='the path to save mapping results')
72 |
73 | opt = parser.parse_args()
74 | # print(opt)
75 |
76 | """ Seed and GPU setting """
77 | print("Random Seed: ", opt.manualSeed)
78 | random.seed(opt.manualSeed)
79 | np.random.seed(opt.manualSeed)
80 | torch.manual_seed(opt.manualSeed)
81 | torch.cuda.manual_seed(opt.manualSeed)
82 | cudnn.benchmark = True
83 | cudnn.deterministic = True
84 |
85 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86 | """ vocab / character number configuration """
87 | if opt.sensitive:
88 | opt.character = string.printable[:62] # use 62 char (0~9, a~z, A~Z)
89 |
90 | """ output configuration """""
91 | output_path = opt.output
92 | if not os.path.exists(output_path):
93 | os.makedirs(output_path)
94 |
95 | # del all the output directories and files
96 | del_list = os.listdir(output_path)
97 | for f in del_list:
98 | file_path = os.path.join(output_path, f)
99 | if os.path.isfile(file_path):
100 | os.remove(file_path)
101 | elif os.path.isdir(file_path):
102 | shutil.rmtree(file_path)
103 |
104 | """ save all the print content as log """
105 | log_file= os.path.join(output_path, 'protego.log')
106 | sys.stdout = Logger(log_file)
107 |
108 | """ make all save directories """
109 | train_adv_path = os.path.join(output_path, opt.train_output, 'adv')
110 | train_per_path = os.path.join(output_path, opt.train_output, 'per')
111 | Generator_path = os.path.join(output_path,opt.saveG)
112 | loss_path = os.path.join(output_path, opt.loss)
113 | test_output_path = os.path.join(output_path, opt.test_output)
114 | adv_output_path1 = os.path.join(test_output_path, opt.adv_output, 'adv-')
115 | adv_output_path2 = os.path.join(test_output_path, opt.adv_output, 'adv+')
116 | per_output_path = os.path.join(test_output_path, opt.per_output)
117 | map_output_path = os.path.join(test_output_path, opt.map_output)
118 |
119 | if not os.path.exists(train_adv_path):
120 | os.makedirs(train_adv_path)
121 | if not os.path.exists(train_per_path):
122 | os.makedirs(train_per_path)
123 | if not os.path.exists(Generator_path):
124 | os.makedirs(Generator_path)
125 | if not os.path.exists(loss_path):
126 | os.makedirs(loss_path)
127 | if not os.path.exists(test_output_path):
128 | os.makedirs(test_output_path)
129 | if not os.path.exists(adv_output_path1):
130 | os.makedirs(adv_output_path1)
131 | if not os.path.exists(adv_output_path2):
132 | os.makedirs(adv_output_path2)
133 | if not os.path.exists(per_output_path):
134 | os.makedirs(per_output_path)
135 | if not os.path.exists(map_output_path):
136 | os.makedirs(map_output_path)
137 |
138 | torch.cuda.synchronize()
139 | time_start = time.time()
140 |
141 | print(opt)
142 |
143 | # train ProTegO
144 | run_train(opt, device, train_adv_path, train_per_path, Generator_path, loss_path)
145 |
146 | # Generate adversarial underpaintings fot text images
147 | gen_start = time.time()
148 | gen_underpaintings(opt, device, Generator_path, adv_output_path1, adv_output_path2, per_output_path, map_output_path)
149 | gen_end = time.time() - gen_start
150 | print('Generation time:' + str(gen_end))
151 |
152 | # Test ProTegO performance
153 | test_udp(opt, device, adv_output_path1, adv_output_path2, test_output_path)
154 | time_end = time.time()
155 | time_sum = time_end - time_start
156 | print('Total time:' + str(time_sum))
--------------------------------------------------------------------------------
/models/GAN_models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | def guided_net(x, net):
4 | net.eval()
5 | b_out = net.backbone(x) #resnet50
6 | n_out = net.neck(b_out)
7 | out = net.bbox_head.binarize(n_out)
8 | return out
9 |
10 | class Discriminator(nn.Module):
11 | def __init__(self, image_nc):
12 | super(Discriminator, self).__init__()
13 | # syn90k:3*32*100
14 | model = [
15 | nn.Conv2d(image_nc, 8, kernel_size=(4,20), stride=2, padding=0, bias=True),
16 | nn.LeakyReLU(0.2),
17 | # :8*15*41
18 | nn.Conv2d(8, 16, kernel_size=(4,20), stride=2, padding=0, bias=True),
19 | nn.BatchNorm2d(16),
20 | nn.LeakyReLU(0.2),
21 | # 16*6*11
22 | nn.Conv2d(16, 32, kernel_size=(5,10), stride=2, padding=0, bias=True),
23 | nn.BatchNorm2d(32),
24 | nn.LeakyReLU(0.2),
25 | nn.Conv2d(32, 1, 1),
26 | nn.Sigmoid()
27 | #1*1*1
28 | ]
29 | self.model = nn.Sequential(*model)
30 |
31 | def forward(self, x):
32 | output = self.model(x)
33 | return output.squeeze()
34 | # output = self.model(x.unsqueeze(1)) #x--->torch.Size([16, 32, 100]),需要增加channel维度, output--->torch.Size([16, 1, 1, 1])
35 | # return output.squeeze()
36 |
37 | class Generator(nn.Module):
38 | def __init__(self, image_nc):
39 | super(Generator, self).__init__()
40 |
41 | encoder_lis = [
42 | # syn90k:image_nc*32*100
43 | nn.Conv2d(image_nc, 8, kernel_size=3, stride=1, padding=0, bias=True),
44 | nn.InstanceNorm2d(8),
45 | nn.ReLU(),
46 | # 8*30*98
47 | nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=0, bias=True),
48 | nn.InstanceNorm2d(16),
49 | nn.ReLU(),
50 | # 16*14*48
51 | nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=0, bias=True),
52 | nn.InstanceNorm2d(32),
53 | nn.ReLU(),
54 | # 32*6*23
55 | ]
56 |
57 | bottle_neck_lis = [ResnetBlock(32),
58 | ResnetBlock(32),
59 | ResnetBlock(32),
60 | ResnetBlock(32),]
61 |
62 | decoder_lis = [
63 | # input 32*6*23
64 | nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=0, bias=False),
65 | nn.InstanceNorm2d(16),
66 | nn.ReLU(),
67 | # state size. 16*13*47
68 | nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=0, bias=False),
69 | nn.InstanceNorm2d(8),
70 | nn.ReLU(),
71 | # state size. 8*27*95
72 | nn.ConvTranspose2d(8, image_nc, kernel_size=6, stride=1, padding=0, bias=False),
73 | nn.Tanh()
74 | # state size. image_nc*32*100
75 | ]
76 |
77 | self.encoder = nn.Sequential(*encoder_lis)
78 | self.bottle_neck = nn.Sequential(*bottle_neck_lis)
79 | self.decoder = nn.Sequential(*decoder_lis)
80 |
81 | def forward(self, x):
82 | x = self.encoder(x)
83 | x = self.bottle_neck(x)
84 | x = self.decoder(x)
85 |
86 | return x
87 |
88 | class ResnetBlock(nn.Module):
89 | def __init__(self, dim, padding_type='zero', norm_layer=nn.BatchNorm2d, use_dropout=False, use_bias=False):
90 | super(ResnetBlock, self).__init__()
91 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
92 |
93 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
94 |
95 | conv_block = []
96 | p = 0
97 | if padding_type == 'reflect':
98 | conv_block += [nn.ReflectionPad2d(1)]
99 | elif padding_type == 'replicate':
100 | conv_block += [nn.ReplicationPad2d(1)]
101 | elif padding_type == 'zero':
102 | p = 1
103 | else:
104 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
105 |
106 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
107 | norm_layer(dim),
108 | nn.ReLU(True)]
109 | if use_dropout:
110 | conv_block += [nn.Dropout(0.5)]
111 |
112 | p = 0
113 | if padding_type == 'reflect':
114 | conv_block += [nn.ReflectionPad2d(1)]
115 | elif padding_type == 'replicate':
116 | conv_block += [nn.ReplicationPad2d(1)]
117 | elif padding_type == 'zero':
118 | p = 1
119 | else:
120 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
121 |
122 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
123 | norm_layer(dim)]
124 |
125 | return nn.Sequential(*conv_block)
126 |
127 | def forward(self, x):
128 | out = x + self.conv_block(x)
129 | return out
--------------------------------------------------------------------------------
/models/enhancement_layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .identity import Identity
--------------------------------------------------------------------------------
/models/enhancement_layers/combined.py:
--------------------------------------------------------------------------------
1 | from .identity import Identity
2 | from .transform import Translate, Resize, D_Binarization, R_Binarization
3 | import torch.nn as nn
4 | import random
5 |
6 | class Combined(nn.Module):
7 | def __init__(self, list=None):
8 | super(Combined, self).__init__()
9 | if list is None:
10 | list = [Identity()]
11 | self.list = list
12 |
13 | def forward(self, adv_image):
14 | id = random.randint(0, len(self.list) - 1)
15 | print(f"[+] Batch Combined {self.list[id]}")
16 | return self.list[id](adv_image)
17 |
18 |
--------------------------------------------------------------------------------
/models/enhancement_layers/identity.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class Identity(nn.Module):
4 | """
5 | Identity-mapping noise layer. Does not change the image
6 | """
7 | def __init__(self):
8 | super(Identity, self).__init__()
9 |
10 | def forward(self, adv_image):
11 | return adv_image
12 |
--------------------------------------------------------------------------------
/models/enhancement_layers/transform.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import random
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import kornia.geometry.transform as Ktrans
8 |
9 | def image(path):
10 | IMG = cv2.imread(path)
11 | IMG = cv2.resize(IMG, (100, 32)) # resize
12 |
13 | IMG = cv2.cvtColor(IMG, cv2.COLOR_BGR2RGB)
14 | IMG = torch.FloatTensor(IMG) # [H, W, C]
15 | IMG = IMG / 255 # normalization to [0,1]
16 | IMG = IMG.permute(2,0,1) # [C, H, W]
17 |
18 | return IMG.unsqueeze(0) # [B, C, H, W]
19 |
20 | # real binarization
21 | class R_Binarization(nn.Module):
22 | def __init__(self):
23 | super(R_Binarization, self).__init__()
24 |
25 | def RB(self, x):
26 | x_np = x.squeeze_(0).permute(1,2,0).cpu().numpy() #[h,w,c]
27 | maxValue = x_np.max()
28 | x_np = x_np * 255 / maxValue
29 | x_uint = np.uint8(x_np)
30 | gray = cv2.cvtColor(x_uint, cv2.COLOR_RGB2GRAY)
31 | _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
32 | x_b = cv2.cvtColor(binary, cv2.COLOR_GRAY2RGB)
33 | x_b = torch.FloatTensor(x_b)
34 | x_b = x_b / 255
35 | x_b = x_b.permute(2,0,1).unsqueeze_(0)
36 | return x_b.to(x.device)
37 |
38 | def forward(self, x):
39 | enhance_adv = self.RB(x)
40 | return enhance_adv
41 |
42 | # differentiable binarization
43 | class D_Binarization(nn.Module):
44 | def __init__(self):
45 | super(D_Binarization, self).__init__()
46 | self.k = 20
47 |
48 | def DB(self, x):
49 | y = 0.8 *torch.ones_like(x)
50 | return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
51 |
52 | def forward(self, x):
53 | enhance_adv = self.DB(x)
54 | return enhance_adv
55 |
56 | class Translate(nn.Module):
57 | """
58 | Translate the image.
59 | """
60 | def __init__(self):
61 | super(Translate, self).__init__()
62 | i = random.choice([0,1])
63 | if i == 0:
64 | self.translation = torch.tensor([[4., 4.]])
65 | else:
66 | self.translation = torch.tensor([[-4., -4.]])
67 |
68 | def forward(self, x):
69 | device = x.device
70 | enhance_adv = Ktrans.translate(x, self.translation.to(device), mode='bilinear', padding_mode='border', align_corners=True)
71 |
72 | return enhance_adv
73 |
74 | class Resize(nn.Module):
75 | """
76 | Resize the image. The target size is
77 | """
78 |
79 | def __init__(self):
80 | super(Resize, self).__init__()
81 | self.resize_rate = 1.60
82 | self.img_h = 32
83 | self.img_w = 100
84 |
85 | def input_diversity(self, x):
86 | img_resize_h = int(self.img_h * self.resize_rate)
87 | img_resize_w = int(self.img_w * self.resize_rate)
88 |
89 | rnd_h = torch.randint(low=self.img_h, high=img_resize_h, size=(1,), dtype=torch.int32)
90 | rnd_w = torch.randint(low=self.img_w , high=img_resize_w, size=(1,), dtype=torch.int32)
91 | rescaled = F.interpolate(x, size=[rnd_h, rnd_w], mode='bilinear', align_corners=False)
92 | h_rem = img_resize_h - rnd_h
93 | w_rem = img_resize_w - rnd_w
94 | pad_top = torch.randint(low=0, high=h_rem.item(), size=(1,), dtype=torch.int32)
95 | pad_bottom = h_rem - pad_top
96 | pad_left = torch.randint(low=0, high=w_rem.item(), size=(1,), dtype=torch.int32)
97 | pad_right = w_rem - pad_left
98 |
99 | padded = F.pad(rescaled, [pad_left.item(), pad_right.item(), pad_top.item(), pad_bottom.item()], value=0)
100 | padded = F.interpolate(padded, size=[self.img_h, self.img_w])
101 | return padded
102 |
103 | def forward(self, x):
104 | enhance_adv = self.input_diversity(x)
105 | return enhance_adv
--------------------------------------------------------------------------------
/models/enhancer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .enhancement_layers.combined import Combined
3 | from .enhancement_layers.identity import Identity
4 | from .enhancement_layers.transform import Resize, Translate, D_Binarization, R_Binarization
5 |
6 | class Enhancer(nn.Module):
7 | """
8 | This module allows to combine different enhancement layers into a sequential module.
9 | """
10 | def __init__(self, layers):
11 | super(Enhancer, self).__init__()
12 | for i in range(len(layers)):
13 | layers[i] = eval(layers[i])
14 | self.enhance = nn.Sequential(*layers)
15 |
16 | def forward(self, adv_image):
17 | enhance_adv = self.enhance(adv_image)
18 | return enhance_adv
19 |
--------------------------------------------------------------------------------
/protego.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 | import matplotlib.pyplot as plt
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torchvision import utils as vutils
9 |
10 | import models.GAN_models as G_models
11 | from models.enhancer import Enhancer
12 | from dataset import up_dataset
13 |
14 | """ custom weights initialization called on netG and netD """
15 | def weights_init(m):
16 | classname = m.__class__.__name__
17 | if classname.find('Conv') != -1:
18 | nn.init.normal_(m.weight.data, 0.0, 0.02)
19 | elif classname.find('BatchNorm') != -1:
20 | nn.init.normal_(m.weight.data, 1.0, 0.02)
21 | nn.init.constant_(m.bias.data, 0)
22 |
23 | def fix_bn(m):
24 | classname = m.__class__.__name__
25 | if classname.find('BatchNorm') != -1:
26 | m.eval()
27 |
28 | class framework():
29 | def __init__(self,
30 | device,
31 | model,
32 | dt_model,
33 | converter,
34 | criterion,
35 | batch_max_length,
36 | up_path,
37 | dark,
38 | batch_size,
39 | image_nc,
40 | height,
41 | width,
42 | eps,
43 | lambda1,
44 | lambda2,
45 | lambda3,
46 | lambda4,
47 | use_eh,
48 | use_guide):
49 | self.device = device
50 | self.model = model
51 | self.dt_model = dt_model
52 | self.converter = converter
53 | self.criterion = criterion
54 | self.batch_max_length = batch_max_length
55 | self.up_path = up_path
56 | self.dark = dark
57 | self.batch_size = batch_size
58 | self.image_nc = image_nc
59 | self.height = height
60 | self.width = width
61 | self.box_min = 0
62 | self.box_max = 1
63 | self.c = 0.1 # user-specified bound
64 | self.eps = eps
65 | self.lambda1 = lambda1
66 | self.lambda2 = lambda2
67 | self.lambda3 = lambda3
68 | self.lambda4 = lambda4
69 | self.use_eh = use_eh
70 | self.use_guide = use_guide
71 |
72 | self.model.apply(fix_bn).to(self.device)
73 | self.netG = G_models.Generator(self.image_nc).to(self.device)
74 | self.netD = G_models.Discriminator(self.image_nc).to(self.device)
75 | # initialize all weights
76 | self.netG.apply(weights_init)
77 | self.netD.apply(weights_init)
78 | # initialize optimizers
79 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=0.001)
80 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=0.001)
81 |
82 |
83 | def train_batch(self, images, labels, up, gui_net, epoch):
84 | # -----------optimize D-----------
85 | if self.dark:
86 | mask = images.detach()
87 | images = (1-images) + torch.mul(mask, up)
88 | else:
89 | mask = images.detach()
90 | images = torch.mul(images, up)
91 |
92 | perturbation = self.netG(up)
93 | perturbation = torch.clamp(perturbation, -self.eps, self.eps)
94 |
95 | if self.use_guide:
96 | map = G_models.guided_net(perturbation, gui_net)
97 | loss_guide = - torch.mean(torch.tanh(map*1000))
98 |
99 | else:
100 | map = torch.zeros_like(perturbation).to(self.device)
101 | loss_guide = torch.zeros(1).to(self.device)
102 |
103 | perturbation = torch.mul(mask, perturbation)
104 |
105 | if self.use_eh:
106 | mixed_layers = ["Combined([Identity(), Translate(), Resize(), D_Binarization()])"]
107 | print('epoch{}---enhancement_layers{}'.format(epoch, mixed_layers[0]))
108 | enhance = Enhancer(mixed_layers).to(self.device)
109 | adv_images1 = images - perturbation
110 | adv_images1 = enhance(adv_images1)
111 | adv_images1 = torch.clamp(adv_images1, self.box_min, self.box_max)
112 | adv_images2 = images + perturbation
113 | adv_images2 = enhance(adv_images2)
114 | adv_images2 = torch.clamp(adv_images2, self.box_min, self.box_max)
115 | else:
116 | adv_images1 = images - perturbation
117 | adv_images1 = torch.clamp(adv_images1, self.box_min, self.box_max)
118 | adv_images2 = images + perturbation
119 | adv_images2 = torch.clamp(adv_images2, self.box_min, self.box_max)
120 |
121 | self.optimizer_D.zero_grad()
122 | pred_real = self.netD(images)
123 | loss_D_real = F.mse_loss(pred_real, torch.ones_like(pred_real, device=self.device))
124 |
125 | pred_fake1 = self.netD(adv_images1.detach())
126 | loss_D_fake1 = F.mse_loss(pred_fake1, torch.zeros_like(pred_fake1, device=self.device))
127 | pred_fake2 = self.netD(adv_images2.detach())
128 | loss_D_fake2 = F.mse_loss(pred_fake2, torch.zeros_like(pred_fake2, device=self.device))
129 |
130 | loss_D_gan = loss_D_fake1 + loss_D_fake2 + loss_D_real
131 |
132 | loss_D_gan.backward()
133 | self.optimizer_D.step()
134 | # -----------optimize G-----------
135 | self.optimizer_G.zero_grad()
136 |
137 | # the hinge Loss part of L (calculate perturbation norm)
138 | loss_perturb = torch.mean(torch.norm(perturbation.view(perturbation.shape[0], -1), 2, dim=1))
139 | loss_hinge = torch.max(torch.zeros(1, device=self.device), loss_perturb - self.c)
140 |
141 | # the adv Loss part of L
142 | torch.backends.cudnn.enabled=False
143 | """actually, text_for_pred is the param for attention model"""
144 | text_for_pred = torch.LongTensor(self.batch_size, self.batch_max_length + 1).fill_(0).to(self.device)
145 | targets, target_len = self.converter.encode(labels, self.batch_max_length)
146 |
147 | preds1 = self.model(adv_images1, text_for_pred)
148 | preds_size1 = torch.IntTensor([preds1.size(1)] * self.batch_size)
149 | preds1 = preds1.log_softmax(2).permute(1, 0, 2)
150 | loss_adv1 = - self.criterion(preds1, targets, preds_size1, target_len)
151 | preds2 = self.model(adv_images2, text_for_pred)
152 | preds_size2 = torch.IntTensor([preds2.size(1)] * self.batch_size)
153 | preds2 = preds2.log_softmax(2).permute(1, 0, 2)
154 | loss_adv2 = - self.criterion(preds2, targets, preds_size2, target_len)
155 | loss_adv = loss_adv1 + loss_adv2
156 |
157 | # cal G's loss in GAN
158 | pred_fake1 = self.netD(adv_images1)
159 | loss_G_gan1 = F.mse_loss(pred_fake1, torch.ones_like(pred_fake1, device=self.device))
160 | pred_fake2 = self.netD(adv_images2)
161 | loss_G_gan2 = F.mse_loss(pred_fake2, torch.ones_like(pred_fake2, device=self.device))
162 | loss_G_gan = loss_G_gan1 + loss_G_gan2
163 | loss_G_gan.backward(retain_graph=True)
164 |
165 | loss_G = self.lambda1*loss_hinge + self.lambda2*loss_guide + self.lambda3*loss_G_gan + self.lambda4*loss_adv
166 | self.model.zero_grad()
167 | loss_G.backward()
168 | self.optimizer_G.step()
169 |
170 | return loss_G.item(), loss_D_gan.item(), loss_G_gan.item(), loss_hinge.item(), loss_adv.item(), loss_guide.item(), \
171 | map, perturbation, adv_images1, adv_images2
172 |
173 | def train(self, train_dataloader, epochs, train_adv_path, train_per_path, Generator_path, loss_path):
174 |
175 | loss_G, loss_D_gan, loss_G_gan, loss_hinge, loss_adv, loss_guide= [], [], [], [], [], []
176 |
177 | if self.use_eh and self.use_guide:
178 | print("==> Use enhancement and guidance module !")
179 | elif self.use_eh:
180 | print("==> ONLY Use enhancement layer ...")
181 | elif self.use_guide:
182 | print("==> ONLY Use guided network ...")
183 | else:
184 | print("Do not use any trick !")
185 |
186 | up = up_dataset(self.up_path)
187 | up = up.repeat(self.batch_size,1,1,1).to(self.device)
188 | # up = torch.clamp(up, self.eps, 1-self.eps)
189 | # vutils.save_image(up, "{}/up.png".format(loss_path)) # TODO remove when release
190 |
191 | print('Loading text detection model from \"%s\" as our guided net!' % self.dt_model)
192 | gui_net = torch.load(self.dt_model).to(self.device)
193 | gui_net.eval()
194 |
195 | for epoch in range(1, epochs+1):
196 | if epoch == 20:
197 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
198 | lr=0.0001)
199 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
200 | lr=0.0001)
201 | if epoch == 40:
202 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
203 | lr=0.00001)
204 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
205 | lr=0.00001)
206 | if epoch == 60:
207 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
208 | lr=0.000001)
209 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
210 | lr=0.000001)
211 |
212 | loss_G_sum, loss_D_gan_sum, loss_G_gan_sum, loss_hinge_sum, loss_adv_sum, loss_guide_sum = 0, 0, 0, 0, 0, 0
213 |
214 | for i_batch, data in enumerate(train_dataloader, start=0):
215 | images, labels = data
216 | images = images.to(self.device)
217 |
218 | loss_G_batch, loss_D_gan_batch, loss_G_gan_batch, \
219 | loss_hinge_batch, loss_adv_batch, loss_guide_batch,\
220 | map, perturbation, adv_images1, adv_images2 = self.train_batch(images, labels, up, gui_net, epoch)
221 |
222 | loss_G_sum += loss_G_batch
223 | loss_D_gan_sum += loss_D_gan_batch
224 | loss_G_gan_sum += loss_G_gan_batch
225 | loss_hinge_sum += loss_hinge_batch
226 | loss_adv_sum += loss_adv_batch
227 | loss_guide_sum += loss_guide_batch
228 |
229 | vutils.save_image(adv_images1, "{}/{}_{}_adv-.png".format(train_adv_path, epoch, i_batch))
230 | vutils.save_image(adv_images2, "{}/{}_{}_adv+.png".format(train_adv_path, epoch, i_batch))
231 | vutils.save_image(map*1000, "{}/{}_{}map.png".format(train_per_path, epoch, i_batch))
232 | vutils.save_image(perturbation, "{}/{}_{}per.png".format(train_per_path, epoch, i_batch))
233 |
234 | # print statistics
235 | batch_size = len(train_dataloader)
236 | print('epoch {}: \nloss G: {}, \n\tloss_G_gan: {}, \n\tloss_hinge: {}, \n\tloss_adv: {}, \n\tloss_guide: {}, \nloss D_gan: {}\n'.format(
237 | epoch,
238 | loss_G_sum/batch_size,
239 | loss_G_gan_sum/batch_size,
240 | loss_hinge_sum/batch_size,
241 | loss_adv_sum/batch_size,
242 | loss_guide_sum/batch_size,
243 | loss_D_gan_sum/batch_size,
244 | ))
245 |
246 | loss_G.append(loss_G_sum / batch_size)
247 | loss_D_gan.append( loss_D_gan_sum / batch_size)
248 | loss_G_gan.append(loss_G_gan_sum / batch_size)
249 | loss_hinge.append(loss_hinge_sum / batch_size)
250 | loss_adv.append(loss_adv_sum / batch_size)
251 | loss_guide.append(loss_guide_sum / batch_size)
252 |
253 | # save generator
254 | if epoch % 2== 0:
255 | netG_file_name = Generator_path + '/netG_epoch_' + str(epoch) + '.pth'
256 | torch.save(self.netG.state_dict(), netG_file_name)
257 |
258 | plt.figure()
259 | plt.plot(loss_G)
260 | plt.title("loss_G")
261 | plt.savefig(loss_path + '/loss_G.png')
262 |
263 | plt.figure()
264 | plt.plot(loss_D_gan)
265 | plt.title("loss_D_gan")
266 | plt.savefig(loss_path + '/loss_D_gan.png')
267 |
268 | plt.figure()
269 | plt.plot(loss_G_gan)
270 | plt.title("loss_G_gan")
271 | plt.savefig(loss_path + '/loss_G_gan.png')
272 |
273 | plt.figure()
274 | plt.plot(loss_hinge)
275 | plt.title("loss_hinge")
276 | plt.savefig(loss_path + '/loss_hinge.png')
277 |
278 | plt.figure()
279 | plt.plot(loss_adv)
280 | plt.title("loss_adv")
281 | plt.savefig(loss_path + '/loss_adv.png')
282 |
283 | plt.figure()
284 | plt.plot(loss_guide)
285 | plt.title("loss_guide")
286 | plt.savefig(loss_path + '/loss_guide.png')
287 |
288 | plt.close('all')
--------------------------------------------------------------------------------
/test_udp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import shutil
5 | import string
6 | import torch
7 | from utils import CTCLabelConverter, AttnLabelConverter, Logger
8 | from dataset import test_adv_dataset
9 | from STR_modules.model import Model
10 | from models.enhancer import Enhancer
11 | from nltk.metrics import edit_distance
12 |
13 |
14 |
15 | def fix_bn(m):
16 | classname = m.__class__.__name__
17 | if classname.find('BatchNorm') != -1:
18 | m.eval()
19 | def process_line(line):
20 | adv_img_path, recog_result = line.split(':')
21 | label, adv_preds = recog_result.split('--->')
22 | adv_preds = adv_preds.strip('\n')
23 | return adv_preds, label, adv_img_path
24 |
25 | def test_udp(opt, device, adv_output_path1, adv_output_path2, test_output_path):
26 | batch_size = 1
27 | save_success_adv = os.path.join(test_output_path , 'attack-success-adv')
28 | save_binary = os.path.join(test_output_path , 'binary_adv')
29 | attack_success_result1 = os.path.join(test_output_path , 'attack_success_result1.txt')
30 | attack_success_result2 = os.path.join(test_output_path , 'attack_success_result2.txt')
31 | if not os.path.exists(save_success_adv):
32 | os.makedirs(save_success_adv)
33 | if not os.path.exists(save_binary):
34 | os.makedirs(save_binary)
35 |
36 | """ model configuration """
37 | if 'CTC' in opt.Prediction:
38 | converter = CTCLabelConverter(opt.character)
39 | else:
40 | converter = AttnLabelConverter(opt.character)
41 | opt.num_class = len(converter.character)
42 |
43 | model = Model(opt).to(device)
44 | print('Loading a STR model from \"%s\" as the target model!' % opt.str_model)
45 | model.load_state_dict(torch.load(opt.str_model, map_location=device),strict=False)
46 | model.eval()
47 | model.apply(fix_bn)
48 |
49 | mixed_layers = ["Combined([Identity(), Translate(), D_Binarization()])"]
50 | robust_test = Enhancer(mixed_layers).to(device)
51 |
52 | test_dataset1= test_adv_dataset(opt.imgH, opt.imgW, adv_output_path1)
53 | test_dataloader1= torch.utils.data.DataLoader(
54 | test_dataset1,
55 | batch_size=batch_size,
56 | shuffle=False,
57 | num_workers=1)
58 |
59 | result1 = dict() # adv-
60 | for i, data in enumerate(test_dataloader1):
61 | adv_img1 = data[1]
62 | label1 = data[2]
63 | adv_index1 = data[3][0]
64 | adv_path1 = data[5][0]
65 | if opt.b:
66 | adv_img1 = robust_test(adv_img1)
67 | else:
68 | adv_img1
69 |
70 | adv_img1= adv_img1.to(device)
71 | length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
72 | text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
73 | if 'CTC' in opt.Prediction:
74 | preds1 = model(adv_img1, text_for_pred).log_softmax(2)
75 | preds_size1 = torch.IntTensor([preds1.size(1)] * batch_size)
76 | _, preds_index1 = preds1.permute(1, 0, 2).max(2)
77 | preds_index1 = preds_index1.transpose(1, 0).contiguous().view(-1)
78 | preds_output1 = converter.decode(preds_index1.data, preds_size1)
79 | preds_output1 = preds_output1[0]
80 | result1[adv_index1] = '{}:{}--->{}\n'.format(adv_path1, label1[0], preds_output1)
81 | else: # Attention
82 | preds1 = model(adv_img1, text_for_pred, is_train=False)
83 | _, preds_index1 = preds1.max(2)
84 | preds_output1 = converter.decode(preds_index1, length_for_pred)
85 | preds_output1 = preds_output1[0]
86 | preds_output1 = preds_output1[:preds_output1.find('[s]')]
87 | result1[adv_index1] = '{}:{}--->{}\n'.format(adv_path1, label1[0], preds_output1)
88 | result1 = sorted(result1.items(), key=lambda x:x[0])
89 | with open(attack_success_result1, 'w+') as f:
90 | for item in result1:
91 | f.write(item[1])
92 |
93 |
94 | test_dataset2= test_adv_dataset(opt.imgH, opt.imgW, adv_output_path2)
95 | test_dataloader2= torch.utils.data.DataLoader(
96 | test_dataset2,
97 | batch_size=batch_size,
98 | shuffle=False,
99 | num_workers=4)
100 | result2 = dict() # adv+
101 | for i, data in enumerate(test_dataloader2):
102 | adv_img2= data[1]
103 | label2 = data[2]
104 | adv_index2 = data[3][0]
105 | adv_path2 = data[5][0]
106 |
107 | if opt.b:
108 | adv_img2 = robust_test(adv_img2)
109 | else:
110 | adv_img2
111 |
112 | adv_img2= adv_img2.to(device)
113 | if 'CTC' in opt.Prediction:
114 | preds2 = model(adv_img2, text_for_pred).log_softmax(2)
115 | preds_size2 = torch.IntTensor([preds2.size(1)] * batch_size)
116 | _, preds_index2 = preds2.permute(1, 0, 2).max(2)
117 | preds_index2 = preds_index2.transpose(1, 0).contiguous().view(-1)
118 | preds_output2 = converter.decode(preds_index2.data, preds_size2)
119 | preds_output2 = preds_output2[0]
120 | result2[adv_index2] = '{}:{}--->{}\n'.format(adv_path2, label2[0], preds_output2)
121 | else:
122 | preds2 = model(adv_img2, text_for_pred, is_train=False)
123 | _, preds_index2 = preds2.max(2)
124 | preds_output2 = converter.decode(preds_index2, length_for_pred)
125 | preds_output2 = preds_output2[0]
126 | preds_output2 = preds_output2[:preds_output2.find('[s]')]
127 | result2[adv_index2] = '{}:{}--->{}\n'.format(adv_path2, label2[0], preds_output2)
128 | result2 = sorted(result2.items(), key=lambda x:x[0])
129 | with open(attack_success_result2, 'w+') as f:
130 | for item in result2:
131 | f.write(item[1])
132 |
133 | # calculate ASR
134 | with open(attack_success_result1, 'r') as f:
135 | alladv1 = f.readlines()
136 | with open(attack_success_result2, 'r') as f:
137 | alladv2 = f.readlines()
138 |
139 | attack_success_num,asc1,asc2, = 0, 0, 0
140 |
141 | ED_num1, ED_num2 = 0, 0
142 | for line1, line2 in zip(alladv1, alladv2):
143 | adv_preds1,label1,adv_img_path1 = process_line(line1)
144 | adv_preds2,label2,adv_img_path2 = process_line(line2)
145 | if adv_preds1 != label1:
146 | asc1 +=1
147 | if adv_preds2 != label2:
148 | asc2 +=1
149 | if adv_preds1 != label1 and adv_preds2 != label2:
150 | ED_num1 += edit_distance(label1, adv_preds1)
151 | ED_num2 += edit_distance(label2, adv_preds2)
152 | attack_success_num += 1
153 | shutil.copy(adv_img_path1, save_success_adv)
154 | shutil.copy(adv_img_path2, save_success_adv)
155 | print("***********Test Finished !***********")
156 | psr1 = asc1 / len(test_dataset1)
157 | psr2 = asc2 / len(test_dataset2)
158 | attack_success_rate = attack_success_num / len(test_dataset1)
159 | print('PSR1:{:.2%}'.format(psr1))
160 | print('PSR2:{:.2%}'.format(psr2))
161 | print('PSR:{:.2%}'.format(attack_success_rate))
162 | if attack_success_num != 0:
163 | ED_num1_avr = ED_num1 / attack_success_num
164 | ED_num2_avr = ED_num2 / attack_success_num
165 | ED_avr = (ED_num1_avr + ED_num2_avr) / 2
166 | print('Average Edit_distance-: {:.2f}'.format(ED_num1_avr))
167 | print('Average Edit_distance+: {:.2f}'.format(ED_num2_avr))
168 | print('Average Edit_distance-2: {:.2f}'.format(ED_avr))
169 |
170 |
171 |
172 | if __name__ == '__main__':
173 | """transfer attack for black-box models"""
174 | parser = argparse.ArgumentParser()
175 | parser.add_argument('--output', required=True, help='the path of the white-box model (STARNet) results')
176 | parser.add_argument('--STR_name', required=True, help='the path to save ouput results of different models')
177 | parser.add_argument('--saveG', default='Generators', help='the path to save generator which is used for generated AEs')
178 | parser.add_argument('--adv_output', default='adv', help='the path to save adversarial examples results')
179 | parser.add_argument('--per_output', default='perturbation', help='the path to save output of adversarial perturbation')
180 | parser.add_argument('--up_output', default='up', help='the path to save underpainting and mapping results')
181 | parser.add_argument('--b', action='store_true', help='Use binarization processing to test AEs.')
182 | """ Data processing """
183 | parser.add_argument('--img_channel', type=int, default=3, help='the number of input channel of image')
184 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
185 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
186 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
187 | parser.add_argument('--character', type=str,default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
188 | parser.add_argument('--sensitive', action='store_false', help='default for sensitive character mode')
189 | """ Model Architecture """
190 | parser.add_argument('--str_model', type=str, required=True, help='the model path of the target model')
191 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
192 | parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
193 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
194 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
195 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
196 | parser.add_argument('--input_channel', type=int, default=3,
197 | help='the number of input channel of Feature extractor')
198 | parser.add_argument('--output_channel', type=int, default=512,
199 | help='the number of output channel of Feature extractor')
200 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
201 | opt = parser.parse_args()
202 | print(opt)
203 |
204 | """ create new test model output path """
205 | if opt.b:
206 | test_output_path = os.path.join(opt.output, opt.STR_name, 'RB')
207 | else:
208 | test_output_path = os.path.join(opt.output, opt.STR_name)
209 | if not os.path.exists(test_output_path):
210 | os.makedirs(test_output_path)
211 |
212 | log_file= os.path.join(test_output_path, 'test.log')
213 | sys.stdout = Logger(log_file)
214 |
215 | """ already exist """
216 | Generator_path = os.path.join(opt.output, opt.saveG)
217 | adv_output_path1 = os.path.join(opt.output, 'test-out', opt.adv_output, 'adv-')
218 | adv_output_path2 = os.path.join(opt.output, 'test-out', opt.adv_output, 'adv+')
219 | per_output_path = os.path.join(opt.output, 'test-out', opt.per_output)
220 | up_output_path = os.path.join(opt.output, 'test-out', opt.up_output)
221 |
222 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
223 |
224 | """ vocab / character number configuration """
225 | if opt.sensitive:
226 | opt.character = string.printable[:62] # use 62 char (0~9, a~z, A~Z)
227 |
228 | test_udp(opt, device, adv_output_path1, adv_output_path2, test_output_path)
--------------------------------------------------------------------------------
/train_protego.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from protego import framework
4 | from STR_modules.model import Model
5 |
6 | from dataset import train_dataset_builder
7 | from utils import CTCLabelConverter, AttnLabelConverter
8 |
9 |
10 | def run_train(opt, device, train_adv_path, train_per_path, Generator_path, loss_path):
11 | """ data preparing """
12 | train_dataset = train_dataset_builder(opt.imgH, opt.imgW, opt.train_path)
13 | train_dataloader = torch.utils.data.DataLoader(
14 | train_dataset, batch_size=opt.batchsize,
15 | shuffle=True, num_workers=4,
16 | drop_last=True, pin_memory=True)
17 |
18 | """ model configuration """
19 | if 'CTC' in opt.Prediction:
20 | converter = CTCLabelConverter(opt.character)
21 | else:
22 | converter = AttnLabelConverter(opt.character)
23 | opt.num_class = len(converter.character)
24 |
25 | model = Model(opt).to(device)
26 | print('Loading STR pretrained model from %s' % opt.str_model)
27 | model.load_state_dict(torch.load(opt.str_model, map_location=device),strict=False)
28 | model.eval()
29 |
30 | """ setup loss """
31 | if 'CTC' in opt.Prediction:
32 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
33 | else:
34 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
35 |
36 | """ attack setting """
37 | ProTegO = framework(device, model, opt.dt_model, converter, criterion, opt.batch_max_length,
38 | opt.up_path, opt.dark, opt.batchsize, opt.img_channel, opt.imgH, opt.imgW,
39 | opt.eps, opt.lambda1, opt.lambda2, opt.lambda3, opt.lambda4,
40 | opt.use_eh, opt.use_guide)
41 |
42 | # train
43 | ProTegO.train(train_dataloader, opt.epochs, train_adv_path, train_per_path, Generator_path, loss_path)
44 |
45 |
46 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # encoding: utf-8
3 | import sys
4 | import torch
5 | import numpy as np
6 | from PIL import Image
7 | from torchvision import transforms
8 |
9 |
10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11 |
12 | class Logger(object):
13 | def __init__(self, filename = "train.log"):
14 | self.terminal =sys.stdout
15 | self.log = open(filename,"w")
16 |
17 | def write(self, message):
18 | self.terminal.write(message)
19 | self.log.write(message)
20 |
21 | def flush(self):
22 | pass
23 |
24 | class Averager(object):
25 | """Compute average for torch.Tensor, used for loss average."""
26 |
27 | def __init__(self):
28 | self.reset()
29 |
30 | def add(self, v):
31 | count = v.data.numel()
32 | v = v.data.sum()
33 | self.n_count += count
34 | self.sum += v
35 |
36 | def reset(self):
37 | self.n_count = 0
38 | self.sum = 0
39 |
40 | def val(self):
41 | res = 0
42 | if self.n_count != 0:
43 | res = self.sum / float(self.n_count)
44 | return res
45 |
46 | """ utils for fawa """
47 | def tensor2np(img):
48 | trans = transforms.Grayscale(1)
49 | img = trans(img).squeeze(0)
50 | return img.detach().cpu().numpy()
51 | def np2tensor(img: np.array):
52 | if len(img.shape) == 2:
53 | img_tensor = torch.from_numpy(img).float() # bool to float
54 | img_tensor = img_tensor.unsqueeze_(0).repeat(3, 1, 1)
55 | if len(img.shape) == 3:
56 | img_tensor = torch.from_numpy(img).float()
57 | img_tensor = img_tensor.permute(2,0,1)
58 | if torch.max(img_tensor) <= 1:
59 | img_tensor = img_tensor * 255
60 | img_tensor = img_tensor / 255.
61 | return img_tensor.unsqueeze_(0).to(device) # add batch dim
62 |
63 | def get_text_mask(img: np.array):
64 | if img.max() <= 1:
65 | return img < 1 / 1.25
66 | else:
67 | return img < 255 / 1.25
68 |
69 | def cvt2Image(array):
70 | if array.max() <= 0.5:
71 | return Image.fromarray(((array + 0.5) * 255).astype('uint8'))
72 | elif array.max() <= 1:
73 | return Image.fromarray((array * 255).astype('uint8'))
74 | elif array.max() <= 255:
75 | return Image.fromarray(array.astype('uint8'))
76 |
77 | def RGB2Hex(RGB): # RGB is a 3-tuple
78 | color = '#'
79 | for num in RGB:
80 | color += str(hex(num))[-2:].replace('x', '0').upper()
81 | return color
82 |
83 | def color_map(grayscale):
84 | gray_map = (grayscale - 255*0.299 - 0*0.114) / 0.587
85 | return int(gray_map)
86 |
87 |
88 | class CTCLabelConverter(object):
89 | """ Convert between text-label and text-index """
90 |
91 | def __init__(self, character):
92 | # character (str): set of the possible characters.
93 | dict_character = list(character)
94 |
95 | self.dict = {}
96 | for i, char in enumerate(dict_character):
97 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
98 | self.dict[char] = i + 1
99 |
100 | self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
101 |
102 | def encode(self, text, batch_max_length=25):
103 | """convert text-label into text-index.
104 | input:
105 | text: text labels of each image. Note:in our dataset,label is list, and len=batch_size
106 | batch_max_length: max length of text label in the batch. 25 by default
107 |
108 | output:
109 | text: text index for CTCLoss. [batch_size, batch_max_length]
110 | length: length of each text. [batch_size]
111 | """
112 | length = [len(s) for s in text]
113 |
114 | # The index used for padding (=0) would not affect the CTC loss calculation.
115 | batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0)
116 | for i, t in enumerate(text):
117 | text = list(t)
118 | text = [self.dict[char] for char in text] #index of char in text, shape=[len(text)]单词长度
119 | batch_text[i][:len(text)] = torch.LongTensor(text)
120 |
121 | return (batch_text.to(device), torch.IntTensor(length).to(device)) # [b, 25], list:b(16)
122 |
123 | def decode(self, text_index, length):
124 | """ convert text-index into text-label. """
125 | texts = []
126 | index = 0
127 | for l in length:
128 | t = text_index[index:index + l]
129 |
130 | char_list = []
131 | for i in range(l):
132 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
133 | char_list.append(self.character[t[i]])
134 | text = ''.join(char_list)
135 |
136 | texts.append(text)
137 | index += l
138 | return texts
139 |
140 | class AttnLabelConverter(object):
141 | """ Convert between text-label and text-index """
142 |
143 | def __init__(self, character):
144 | # character (str): set of the possible characters.
145 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
146 | list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]']
147 | list_character = list(character)
148 | self.character = list_token + list_character
149 |
150 | self.dict = {}
151 | for i, char in enumerate(self.character):
152 | # print(i, char)
153 | self.dict[char] = i
154 |
155 | def encode(self, text, batch_max_length=25):
156 | """ convert text-label into text-index.
157 | input:
158 | text: text labels of each image. [batch_size]
159 | batch_max_length: max length of text label in the batch. 25 by default
160 |
161 | output:
162 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token.
163 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token.
164 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size]
165 | """
166 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
167 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting
168 | batch_max_length += 1
169 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token.
170 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0)
171 | for i, t in enumerate(text):
172 | text = list(t)
173 | text.append('[s]')
174 | text = [self.dict[char] for char in text]
175 | batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
176 | return (batch_text.to(device), torch.IntTensor(length).to(device))
177 |
178 | def decode(self, text_index, length):
179 | """ convert text-index into text-label. """
180 | texts = []
181 | for index, l in enumerate(length):
182 | text = ''.join([self.character[i] for i in text_index[index, :]])
183 | texts.append(text)
184 | return texts
185 |
186 |
187 |
--------------------------------------------------------------------------------