├── bezier.py ├── tensorboard.py ├── README.md ├── model.py ├── train_bezier.py ├── vggnet.py ├── resnet.py ├── train.py └── synth.py /bezier.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | canvas_width = 256 5 | output_width = 64 6 | 7 | def normal(x): 8 | return (int)(x * (canvas_width - 1) + 0.5) 9 | 10 | def draw(f): 11 | x0, y0, z0, x1, y1, z1, x2, y2, z2 = f 12 | x0 = normal(x0) 13 | x1 = normal(x1) 14 | x2 = normal(x2) 15 | y0 = normal(y0) 16 | y1 = normal(y1) 17 | y2 = normal(y2) 18 | z0 = (int)(z0 * 32 + 2) 19 | z1 = (int)(z1 * 32 + 2) 20 | z2 = (int)(z2 * 32 + 2) 21 | canvas = np.zeros([canvas_width, canvas_width]).astype('float32') 22 | tmp = 1. / 100 23 | for i in range(100): 24 | t = i * tmp 25 | x = (int)((1-t) * (1-t) * x0 + 2 * t * (1-t) * x1 + t * t * x2) 26 | y = (int)((1-t) * (1-t) * y0 + 2 * t * (1-t) * y1 + t * t * y2) 27 | z = (int)((1-t) * (1-t) * z0 + 2 * t * (1-t) * z1 + t * t * z2) 28 | cv2.circle(canvas, (y, x), z, 1., -1) 29 | return 1 - cv2.resize(canvas, dsize=(output_width, output_width)) 30 | -------------------------------------------------------------------------------- /tensorboard.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import scipy.misc 3 | from io import BytesIO 4 | import tensorboardX as tb 5 | from tensorboardX.summary import Summary 6 | 7 | class TensorBoard(object): 8 | def __init__(self, model_dir): 9 | self.summary_writer = tb.FileWriter(model_dir) 10 | 11 | def add_image(self, tag, img, step): 12 | summary = Summary() 13 | bio = BytesIO() 14 | 15 | if type(img) == str: 16 | img = PIL.Image.open(img) 17 | elif type(img) == PIL.Image.Image: 18 | pass 19 | else: 20 | img = scipy.misc.toimage(img) 21 | 22 | img.save(bio, format="png") 23 | image_summary = Summary.Image(encoded_image_string=bio.getvalue()) 24 | summary.value.add(tag=tag, image=image_summary) 25 | self.summary_writer.add_summary(summary, global_step=step) 26 | 27 | def add_scalar(self, tag, value, step): 28 | summary= Summary(value=[Summary.Value(tag=tag, simple_value=value)]) 29 | self.summary_writer.add_summary(summary, global_step=step) 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stroke-based Character Reconstruction 2 | 3 | ---> https://arxiv.org/abs/1806.08990 4 | 5 | Please refer to our extended work,https://github.com/megvii-research/ICCV2019-LearningToPaint 6 | 7 | ## Abstract 8 | 9 | Character reconstruction for noisy character images or character images from real scene is still a challenging problem, due to the bewildering backgrounds, uneven illumination, low resolution and different distortions. We propose a stroke-based character reconstruction(SCR) method that use a weighted quadratic Bezier curve(WQBC) to represent strokes of a character. Only training on our synthetic data, our stroke extractor can achieve excellent reconstruction effect in real scenes. Meanwhile. It can also help achieve great ability in defending adversarial attacks of character recognizers. 10 | 11 | ## Installation 12 | Use [anaconda](https://conda.io/miniconda.html) to manage environment 13 | 14 | ``` 15 | $ conda create -n py36 python=3.6 16 | $ source activate py36 17 | ``` 18 | 19 | ### Dependencies 20 | * [PyTorch](http://pytorch.org/) 0.4 21 | * [tensorboardX](https://github.com/lanpa/tensorboard-pytorch/tree/master/tensorboardX) 22 | * [opencv-python](https://pypi.org/project/opencv-python/) 23 | 24 | ``` 25 | ## Reference 26 | 27 | [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar) (model) 28 | 29 | [FeatureSqueezing](https://github.com/uvasrg/FeatureSqueezing) (adversarial experiment) 30 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FCN(nn.Module): 6 | def __init__(self, width): 7 | super(FCN, self).__init__() 8 | self.fc0 = nn.Linear(9, 512) 9 | self.fc1 = nn.Linear(512, 1024) 10 | self.fc2 = nn.Linear(1024, 2048) 11 | self.fc3 = nn.Linear(2048, 4096) 12 | self.conv0 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 13 | self.conv1 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 14 | self.conv2 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 15 | self.conv3 = nn.Conv2d(16, 1, kernel_size=3, padding=1) 16 | self.relu = F.relu 17 | self.sigmoid = F.sigmoid 18 | self.upsample = nn.Upsample(scale_factor=2) 19 | self.init() 20 | 21 | def init(self): 22 | nn.init.kaiming_uniform_(self.fc0.weight) 23 | nn.init.kaiming_uniform_(self.fc1.weight) 24 | nn.init.kaiming_uniform_(self.fc2.weight) 25 | nn.init.kaiming_uniform_(self.fc3.weight) 26 | nn.init.kaiming_uniform_(self.conv0.weight) 27 | nn.init.kaiming_uniform_(self.conv1.weight) 28 | nn.init.kaiming_uniform_(self.conv2.weight) 29 | nn.init.kaiming_uniform_(self.conv3.weight) 30 | 31 | def forward(self, x): 32 | x = self.fc0(x) 33 | x = self.relu(x) 34 | x = self.fc1(x) 35 | x = self.relu(x) 36 | x = self.fc2(x) 37 | x = self.relu(x) 38 | x = self.fc3(x) 39 | x = self.relu(x) 40 | x = x.view(-1, 16, 16, 16) 41 | x = self.upsample(x) 42 | x = self.relu(self.conv0(x)) 43 | x = self.relu(self.conv1(x)) 44 | x = self.upsample(x) 45 | x = self.relu(self.conv2(x)) 46 | x = self.conv3(x) 47 | x = self.sigmoid(x) 48 | return x.reshape(-1, 64, 64) 49 | -------------------------------------------------------------------------------- /train_bezier.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from tensorboard import TensorBoard 8 | from model import FCN 9 | 10 | from bezier import * 11 | writer = TensorBoard('log/') 12 | import torch.optim as optim 13 | criterion = nn.MSELoss() 14 | Decoder = FCN(64) 15 | optimizer = optim.Adam(Decoder.parameters(), lr=3e-4) 16 | batch_size = 64 17 | 18 | use_cuda = True 19 | step = 0 20 | 21 | def save_model(): 22 | if use_cuda: 23 | Decoder.cpu() 24 | torch.save(Decoder.state_dict(),'./Decoder.pkl') 25 | if use_cuda: 26 | Decoder.cuda() 27 | 28 | def load_weights(): 29 | Decoder.load_state_dict(torch.load('./Decoder.pkl')) 30 | 31 | load_weights() 32 | while True: 33 | Decoder.train() 34 | train_batch = [] 35 | ground_truth = [] 36 | for i in range(batch_size): 37 | f = np.random.uniform(0, 1, 9) 38 | train_batch.append(f) 39 | ground_truth.append(draw(f)) 40 | 41 | train_batch = torch.tensor(train_batch).float() 42 | ground_truth = torch.tensor(ground_truth).float() 43 | if use_cuda: 44 | Decoder = Decoder.cuda() 45 | train_batch = train_batch.cuda() 46 | ground_truth = ground_truth.cuda() 47 | gen = Decoder(train_batch) 48 | optimizer.zero_grad() 49 | loss = criterion(gen, ground_truth) 50 | loss.backward() 51 | optimizer.step() 52 | print(step, loss.item()) 53 | writer.add_scalar('train/loss', loss.item(), step) 54 | if step % 100 == 0: 55 | Decoder.eval() 56 | gen = Decoder(train_batch) 57 | loss = criterion(gen, ground_truth) 58 | writer.add_scalar('validate/loss', loss.item(), step) 59 | for i in range(64): 60 | G = gen[i].cpu().data.numpy() 61 | GT = ground_truth[i].cpu().data.numpy() 62 | writer.add_image(str(step) + '/train/gen.png', G, step) 63 | writer.add_image(str(step) + '/train/ground_truth.png', GT, step) 64 | if step % 10000 == 0: 65 | save_model() 66 | step += 1 67 | -------------------------------------------------------------------------------- /vggnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | def conv_init(m): 7 | classname = m.__class__.__name__ 8 | if classname.find('Conv') != -1: 9 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 10 | init.constant(m.bias, 0) 11 | 12 | def cfg(depth): 13 | depth_lst = [11, 13, 16, 19] 14 | assert (depth in depth_lst), "Error : VGGnet depth should be either 11, 13, 16, 19" 15 | cf_dict = { 16 | '11': [ 17 | 64, 'mp', 18 | 128, 'mp', 19 | 256, 256, 'mp', 20 | 512, 512, 'mp', 21 | 512, 512, 'mp'], 22 | '13': [ 23 | 64, 64, 'mp', 24 | 128, 128, 'mp', 25 | 256, 256, 'mp', 26 | 512, 512, 'mp', 27 | 512, 512, 'mp' 28 | ], 29 | '16': [ 30 | 64, 64, 'mp', 31 | 128, 128, 'mp', 32 | 256, 256, 256, 'mp', 33 | 512, 512, 512, 'mp', 34 | 512, 512, 512, 'mp' 35 | ], 36 | '19': [ 37 | 64, 64, 'mp', 38 | 128, 128, 'mp', 39 | 256, 256, 256, 256, 'mp', 40 | 512, 512, 512, 512, 'mp', 41 | 512, 512, 512, 512, 'mp' 42 | ], 43 | } 44 | 45 | return cf_dict[str(depth)] 46 | 47 | def conv3x3(in_planes, out_planes, stride=1): 48 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 49 | 50 | class VGG(nn.Module): 51 | def __init__(self, depth, num_outputs): 52 | super(VGG, self).__init__() 53 | self.features = self._make_layers(cfg(depth)) 54 | self.fc0 = nn.Linear(2048, 256) 55 | self.fc1 = nn.Linear(256, num_outputs) 56 | self.fc2 = nn.Linear(num_outputs, 256) 57 | self.fc3 = nn.Linear(256, 10) 58 | 59 | def forward(self, x): 60 | x = self.features(x) 61 | x = x.view(x.size(0), -1) 62 | x = F.relu(self.fc0(x)) 63 | x = F.sigmoid(self.fc1(x)) 64 | y = F.relu(self.fc2(x)) 65 | y = self.fc3(y) 66 | return x, y 67 | 68 | def _make_layers(self, cfg): 69 | layers = [] 70 | in_planes = 3 71 | 72 | for x in cfg: 73 | if x == 'mp': 74 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 75 | else: 76 | layers += [conv3x3(in_planes, x), nn.BatchNorm2d(x), nn.ReLU(inplace=True)] 77 | in_planes = x 78 | 79 | # After cfg convolution 80 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 81 | return nn.Sequential(*layers) 82 | 83 | if __name__ == "__main__": 84 | net = VGG(16, 10) 85 | y = net(Variable(torch.randn(1,3,32,32))) 86 | print(y.size()) 87 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | import sys 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 10 | 11 | def conv_init(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('Conv') != -1: 14 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 15 | init.constant(m.bias, 0) 16 | 17 | def cfg(depth): 18 | depth_lst = [18, 34, 50, 101, 152] 19 | assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152" 20 | cf_dict = { 21 | '18': (BasicBlock, [2,2,2,2]), 22 | '34': (BasicBlock, [3,4,6,3]), 23 | '50': (Bottleneck, [3,4,6,3]), 24 | '101':(Bottleneck, [3,4,23,3]), 25 | '152':(Bottleneck, [3,8,36,3]), 26 | } 27 | 28 | return cf_dict[str(depth)] 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, in_planes, planes, stride=1): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(in_planes, planes, stride) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | 40 | self.shortcut = nn.Sequential() 41 | if stride != 1 or in_planes != self.expansion * planes: 42 | self.shortcut = nn.Sequential( 43 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True), 44 | nn.BatchNorm2d(self.expansion*planes) 45 | ) 46 | 47 | def forward(self, x): 48 | out = F.relu(self.bn1(self.conv1(x))) 49 | out = self.bn2(self.conv2(out)) 50 | out += self.shortcut(x) 51 | out = F.relu(out) 52 | 53 | return out 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, in_planes, planes, stride=1): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=True) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=True) 65 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 66 | 67 | self.shortcut = nn.Sequential() 68 | if stride != 1 or in_planes != self.expansion*planes: 69 | self.shortcut = nn.Sequential( 70 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True), 71 | nn.BatchNorm2d(self.expansion*planes) 72 | ) 73 | 74 | def forward(self, x): 75 | out = F.relu(self.bn1(self.conv1(x))) 76 | out = F.relu(self.bn2(self.conv2(out))) 77 | out = self.bn3(self.conv3(out)) 78 | out += self.shortcut(x) 79 | out = F.relu(out) 80 | 81 | return out 82 | 83 | class ResNet(nn.Module): 84 | def __init__(self, depth, num_outputs): 85 | super(ResNet, self).__init__() 86 | self.in_planes = 8 87 | 88 | block, num_blocks = cfg(depth) 89 | 90 | self.conv1 = conv3x3(3, 8, 2) 91 | self.bn1 = nn.BatchNorm2d(8) 92 | self.layer1 = self._make_layer(block, 8, num_blocks[0], stride=2) 93 | self.layer2 = self._make_layer(block, 16, num_blocks[1], stride=2) 94 | self.layer3 = self._make_layer(block, 32, num_blocks[2], stride=2) 95 | # self.linear = nn.Linear(64*block.expansion, num_classes) 96 | self.fc1 = (nn.Linear(32 * 16 * block.expansion, num_outputs)) 97 | self.fc2 = (nn.Linear(num_outputs, 512)) 98 | self.fc3 = (nn.Linear(512, 10)) 99 | 100 | def _make_layer(self, block, planes, num_blocks, stride): 101 | strides = [stride] + [1]*(num_blocks-1) 102 | layers = [] 103 | 104 | for stride in strides: 105 | layers.append(block(self.in_planes, planes, stride)) 106 | self.in_planes = planes * block.expansion 107 | 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | x = F.relu(self.bn1(self.conv1(x))) 112 | x = self.layer1(x) 113 | x = self.layer2(x) 114 | x = self.layer3(x) 115 | # x = F.avg_pool2d(x, 8) 116 | x = x.view(x.size(0), -1) 117 | x = F.sigmoid(self.fc1(x)) 118 | y = F.relu(self.fc2(x)) 119 | y = self.fc3(y) 120 | return x, y 121 | 122 | if __name__ == '__main__': 123 | net=ResNet(50, 10) 124 | y = net(Variable(torch.randn(1,3,32,32))) 125 | print(y.size()) 126 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | import random 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision.transforms as transforms 9 | from tensorboard import TensorBoard 10 | from model import FCN 11 | from synth import Generator 12 | 13 | from bezier import * 14 | from vggnet import * 15 | Encoder = VGG(16, 36) 16 | 17 | writer = TensorBoard('log/') 18 | import torch.optim as optim 19 | criterion = nn.MSELoss() 20 | criterion2 = nn.CrossEntropyLoss() 21 | 22 | Decoder = FCN(64) 23 | optimizerE = optim.Adam(Encoder.parameters(), lr=3e-4) 24 | optimizerD = optim.Adam(Decoder.parameters(), lr=3e-4) 25 | batch_size = 64 26 | data_size = 100000 27 | generated_size = 0 28 | val_data_size = 512 29 | first_generate = True 30 | 31 | use_cuda = True 32 | step = 0 33 | Train_batch = [None] * data_size 34 | Ground_truth = [None] * data_size 35 | Label_batch = [None] * data_size 36 | Val_train_batch = [None] * val_data_size 37 | Val_ground_truth = [None] * val_data_size 38 | Val_label_batch = [None] * val_data_size 39 | G = Generator() 40 | 41 | def hisEqulColor(img): 42 | img_yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV) 43 | img_yuv[:,:,0] = cv2.equalizeHist(img_yuv[:,:,0]) 44 | img_output = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2BGR) 45 | return img_output 46 | 47 | def save_model(): 48 | if use_cuda: 49 | Decoder.cpu() 50 | Encoder.cpu() 51 | torch.save(Decoder.state_dict(),'./Decoder.pkl') 52 | torch.save(Encoder.state_dict(),'./Encoder.pkl') 53 | if use_cuda: 54 | Decoder.cuda() 55 | Encoder.cuda() 56 | 57 | def load_weights(): 58 | Decoder.load_state_dict(torch.load('./Decoder.pkl')) 59 | # Encoder.load_state_dict(torch.load('./Encoder.pkl')) 60 | 61 | def decode(x, train_bezier=False): # b * 36 62 | x = x.reshape(-1, 9) 63 | if train_bezier: 64 | y = Decoder(x.detach()) 65 | else: 66 | y = None 67 | x = Decoder(x) 68 | x = x.reshape(-1, 4, 64, 64) 69 | return torch.min(x.permute(0, 2, 3, 1), dim=3)[0], y 70 | 71 | def sample(n, test=False): 72 | input_batch = [] 73 | ground_truth = [] 74 | label_batch = [] 75 | if not test: 76 | batch = random.sample(range(min(data_size, generated_size)), n) 77 | for i in range(n): 78 | if test: 79 | input_batch.append(Val_train_batch[i]) 80 | ground_truth.append(Val_ground_truth[i]) 81 | label_batch.append(Val_label_batch[i]) 82 | else: 83 | input_batch.append(Train_batch[batch[i]]) 84 | ground_truth.append(Ground_truth[batch[i]]) 85 | label_batch.append(Label_batch[batch[i]]) 86 | input_batch = torch.tensor(input_batch).float() 87 | ground_truth = torch.tensor(ground_truth).float() 88 | label_batch = torch.tensor(np.array(label_batch)) 89 | return input_batch, ground_truth, label_batch 90 | 91 | def generate_data(): 92 | print('Generating data') 93 | global Train_batch, Ground_truth 94 | global first_generate 95 | if first_generate == True: 96 | first_generate = False 97 | import scipy.io as sio 98 | mat = sio.loadmat('../data/svhn/train_32x32.mat') 99 | Data = mat['X'] 100 | Label = mat['y'] 101 | for i in range(len(Label)): 102 | if Label[i][0] % 10 == 0: 103 | Label[i][0] = 0 104 | for i in range(val_data_size): 105 | img = np.array(Data[..., i]) 106 | origin = noised = img 107 | origin = cv2.cvtColor(origin, cv2.COLOR_RGB2GRAY) 108 | origin = cv2.resize(origin, dsize=(64, 64), interpolation=cv2.INTER_CUBIC) 109 | noised = cv2.resize(noised, dsize=(64, 64), interpolation=cv2.INTER_CUBIC) 110 | noised = hisEqulColor(noised) / 255. 111 | Val_train_batch[i] = noised 112 | Val_ground_truth[i] = 1 - origin / 255. 113 | Val_label_batch[i] = Label[i][0] 114 | global generated_size 115 | for i in range(1000): 116 | id = generated_size % data_size 117 | img, origin, label = G.generate() 118 | origin = 255. - cv2.cvtColor(origin, cv2.COLOR_RGB2GRAY) 119 | Train_batch[id] = hisEqulColor(img) / 255. 120 | Ground_truth[id] = origin / 255. 121 | Label_batch[id] = label 122 | generated_size += 1 123 | 124 | if use_cuda: 125 | Decoder = Decoder.cuda() 126 | Encoder = Encoder.cuda() 127 | 128 | def train_bezier(x, img): 129 | Decoder.train() 130 | x = x.reshape(-1, 9) 131 | bezier = [] 132 | for i in range(x.shape[0]): 133 | bezier.append(draw(x[i])) 134 | bezier = torch.tensor(bezier).float() 135 | if use_cuda: 136 | bezier = bezier.cuda() 137 | optimizerD.zero_grad() 138 | loss = criterion(img, bezier) 139 | loss.backward() 140 | optimizerD.step() 141 | Decoder.eval() 142 | writer.add_scalar('train/bezier_loss', loss.item(), step) 143 | 144 | def train(): 145 | Encoder.train() 146 | train_batch, ground_truth, label_batch = sample(batch_size, test=False) 147 | train_batch = train_batch.permute(0, 3, 1, 2) 148 | if use_cuda: 149 | train_batch = train_batch.cuda() 150 | ground_truth = ground_truth.cuda() 151 | label_batch = label_batch.cuda() 152 | infered_stroke, infered_class = Encoder(train_batch) 153 | # if step % 5 == 0: 154 | # img, stroke_img = decode(infered_stroke, True) 155 | # train_bezier(infered_stroke, stroke_img) 156 | # else: 157 | img, _ = decode(infered_stroke, False) 158 | optimizerE.zero_grad() 159 | loss1 = criterion(img, ground_truth) 160 | loss2 = criterion2(infered_class, label_batch) 161 | acc = torch.sum(infered_class.max(1)[1] == label_batch.long()).item() / batch_size 162 | (loss1 + loss2 * 0.01).backward(retain_graph=True) 163 | optimizerE.step() 164 | print('train_loss: ', step, loss1.item(), loss2.item()) 165 | writer.add_scalar('train/img_loss', loss1.item(), step) 166 | writer.add_scalar('train/class_loss', loss2.item(), step) 167 | writer.add_scalar('train/acc', acc, step) 168 | if step % 50 == 0: 169 | for i in range(10): 170 | train_img = train_batch[i].cpu().data.numpy() 171 | gen_img = img[i].cpu().data.numpy() 172 | ground_truth_img = ground_truth[i].cpu().data.numpy() 173 | writer.add_image('train/' + str(i) + '/input.png', train_img, step) 174 | writer.add_image('train/' + str(i) +'/gen.png', gen_img, step) 175 | writer.add_image('train/' + str(i) +'/ground_truth.png', ground_truth_img, step) 176 | 177 | def test(): 178 | Encoder.eval() 179 | train_batch, ground_truth, label_batch = sample(512, test=True) 180 | train_batch = train_batch.permute(0, 3, 1, 2) 181 | if use_cuda: 182 | train_batch = train_batch.cuda() 183 | ground_truth = ground_truth.cuda() 184 | label_batch = label_batch.cuda() 185 | infered_stroke, infered_class = Encoder(train_batch) 186 | img, stroke_img = decode(infered_stroke) 187 | loss = criterion(img, ground_truth) 188 | print('validate_loss: ', step, loss.item()) 189 | acc = torch.sum(infered_class.max(1)[1] == label_batch.long()).item() / 512 190 | writer.add_scalar('validate/loss', loss.item(), step) 191 | writer.add_scalar('validate/acc', acc, step) 192 | for i in range(10): 193 | train_img = train_batch[i].cpu().data.numpy() 194 | gen_img = img[i].cpu().data.numpy() 195 | ground_truth_img = ground_truth[i].cpu().data.numpy() 196 | writer.add_image('validate/' + str(i) + '/input.png', train_img, step) 197 | writer.add_image('validate/' +str(i) +'/gen.png', gen_img, step) 198 | writer.add_image('validate/' +str(i) +'/ground_truth.png', ground_truth_img, step) 199 | 200 | load_weights() 201 | while True: 202 | if step % 100 == 0: 203 | generate_data() 204 | train() 205 | if step % 500 == 0: 206 | test() 207 | if step % 1000 == 0: 208 | save_model() 209 | step += 1 210 | -------------------------------------------------------------------------------- /synth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import imgaug as ia 5 | from imgaug import augmenters as iaa 6 | 7 | img_width = 64 8 | img_list = [] 9 | 10 | for i in range(5000): 11 | img = cv2.imread('../digits_resize/digit' + str(i) + '.png') 12 | img_list.append(img) 13 | 14 | def rand(): 15 | return np.random.uniform(0, 1) 16 | 17 | sometimes = lambda aug: iaa.Sometimes(0.5, aug) 18 | aug_with_origin = iaa.Sequential( 19 | [ 20 | iaa.OneOf([ 21 | sometimes(iaa.Affine( 22 | scale={"x": (0.6, 1), "y": (0.6, 1)}, 23 | rotate=(-15, 15), # rotate by -45 to +45 degrees 24 | shear=(-15, 15), # shear by -16 to +16 degrees 25 | order=[0, 1], # use nearest neighbour or bilinear interpolation (fast) 26 | cval=(0, 0), # if mode is constant, use a cval between 0 and 255 27 | mode='constant' 28 | )), 29 | sometimes(iaa.Pad( 30 | percent=(0, 0.4), 31 | pad_mode='constant', 32 | pad_cval=(0, 0) 33 | )), 34 | ]), 35 | iaa.OneOf([ 36 | sometimes(iaa.PerspectiveTransform(scale=(0.01, 0.1))), 37 | sometimes(iaa.PiecewiseAffine(scale=(0.01, 0.05))), # sometimes move parts of the image around 38 | ]) 39 | ] 40 | ) 41 | aug = iaa.Sequential( 42 | [ 43 | iaa.SomeOf((0, 6), 44 | [ 45 | (iaa.Superpixels(p_replace=(0, 1.0), n_segments=(100, 200))), 46 | # convert images into their superpixel representation 47 | iaa.OneOf([ 48 | iaa.GaussianBlur((0, 3.0)), # blur images with a sigma between 0 and 3.0 49 | iaa.AverageBlur(k=(2, 7)), # blur image using local means with kernel sizes between 2 and 7 50 | iaa.MedianBlur(k=(3, 9)), # blur image using local medians with kernel sizes between 2 and 7 51 | ]), 52 | iaa.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5)), 53 | iaa.Grayscale(alpha=(0.0, 1.0)), 54 | iaa.ContrastNormalization((0.5, 2.0), per_channel=0.5), 55 | iaa.Sharpen(alpha=(0, 1.0), lightness=(0.75, 1.5)), # sharpen images 56 | iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0)), # emboss images 57 | iaa.Invert(0.5, per_channel=True), # invert color channels 58 | iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05*255), per_channel=0.5), # add gaussian noise to images 59 | iaa.AddToHueAndSaturation((-20, 20)), # change hue and saturation 60 | iaa.ElasticTransformation(alpha=(0.5, 3.5), sigma=0.25), # move pixels locally around (with random strengths) 61 | ], 62 | random_order=True 63 | ), 64 | sometimes(iaa.OneOf([ 65 | iaa.Dropout((0.01, 0.1), per_channel=0.5), # randomly remove up to 10% of the pixels 66 | iaa.SaltAndPepper(p=(0.03, 0.3)), 67 | ])), 68 | ] 69 | ) 70 | 71 | 72 | def uneven_light(): 73 | def norm(Z): 74 | min_in_z=np.min(Z) 75 | max_in_z=np.max(Z) 76 | Z=(Z-min_in_z)/(max_in_z-min_in_z) 77 | return Z 78 | 79 | def rotate(x,y,cos): 80 | sin=np.sqrt(1-cos**2) 81 | 82 | x_=cos*x+sin*y 83 | y_=-sin*x+cos*y 84 | 85 | return x_,y_ 86 | 87 | def scale(x,y,s1,s2): 88 | x_,y_=x*s1,y*s2 89 | return x_,y_ 90 | 91 | def translate(x,y,x0,y0): 92 | return x+x0,y+y0 93 | 94 | w,h=img_width,img_width 95 | 96 | x=np.linspace(-1,1,w) 97 | y=np.linspace(-1,1,h) 98 | 99 | r=np.random.rand() 100 | s1,s2=np.random.rand()*5,np.random.rand()*5 101 | t1,t2=np.random.rand()*5,np.random.rand()*5 102 | 103 | X, Y = np.meshgrid(x, y) 104 | X1,Y1=rotate(X,Y,r) 105 | X1,Y1=scale(X1,Y1,s1,s2) 106 | X1,Y1=translate(X1,Y1,t1,t2) 107 | 108 | Z=(np.exp(Y1)+np.exp(-Y1))/2*np.sin(X1) 109 | Z=norm(Z) 110 | return Z 111 | 112 | class Generator: 113 | def __init__(self): 114 | pass 115 | 116 | def center_square_crop(self, img, wx, wy, cx=0, cy=0, origin=None): 117 | h = img.shape[0] 118 | w = img.shape[1] 119 | if cx: 120 | x = np.max([cx - wx // 2, 0]) 121 | y = np.max([cy - wy // 2, 0]) 122 | else: 123 | x = (h - wx) // 2 124 | y = (w - wy) // 2 125 | img = img[x : x + wx, y : y + wy] 126 | if type(img) == type(origin): 127 | origin = origin[x : x + wx, y : y + wy] 128 | img = cv2.resize(img, dsize=(h, w), interpolation=cv2.INTER_CUBIC) 129 | origin = cv2.resize(origin, dsize=(h, w), interpolation=cv2.INTER_CUBIC) 130 | return img, origin 131 | 132 | def random_rotate(self, img, ang, origin = None): 133 | ang = np.random.randint(-ang, ang) 134 | x = img.shape[0] 135 | y = img.shape[1] 136 | center = (y // 2, x // 2) 137 | M = cv2.getRotationMatrix2D(center, ang, 0.8) 138 | img = cv2.warpAffine(img, M, (y, x)) 139 | if type(img) == type(origin): 140 | origin = cv2.warpAffine(origin, M, (y, x)) 141 | return img, origin 142 | 143 | def random_color(self, h, w, same=False): 144 | R = np.ones((h,w)) * rand() 145 | G = np.ones((h,w)) * rand() 146 | B = np.ones((h,w)) * rand() 147 | L = uneven_light() 148 | R += L * np.random.uniform(-1, 1) 149 | G += L * np.random.uniform(-1, 1) 150 | B += L * np.random.uniform(-1, 1) 151 | if same: 152 | return np.stack((R, R, R), axis=2) 153 | else: 154 | return np.stack((R, G, B), axis=2) 155 | 156 | def generate_color(self): 157 | if rand() < 0.2: 158 | C1 = self.random_color(img_width, img_width, same=True) 159 | else: 160 | C1 = self.random_color(img_width, img_width) 161 | if rand() < 0.2: 162 | C2 = self.random_color(img_width, img_width, same=True) 163 | else: 164 | C2 = self.random_color(img_width, img_width) 165 | return C1, C2 166 | 167 | def add_color(self, img): 168 | C1, C2 = self.generate_color() 169 | while np.max(np.abs(C1 - C2) < 0.3): 170 | C1, C2 = self.generate_color() 171 | img = C1 * img + C2 * (1 - img) 172 | img = np.clip(img, 0, 1) 173 | return img 174 | 175 | def forgery_data(self, img1, img2, img3): 176 | if np.random.rand() < 0.2: 177 | img1 = np.zeros([64, 48, 3]) 178 | if np.random.rand() < 0.2: 179 | img3 = np.zeros([64, 48, 3]) 180 | img = np.hstack([np.array(img1), np.array(img2), np.array(img3)]).astype('uint8') 181 | origin = np.hstack([np.zeros(img1.shape), np.array(img2), np.zeros(img3.shape)]).astype('uint8') 182 | cy = img1.shape[1] + img2.shape[1] // 2 183 | cx = img2.shape[0] // 2 184 | # print(img.shape, cx, cy) 185 | h = np.random.randint(54, 64 + 1) 186 | w = np.random.randint(img2.shape[1], min(img2.shape[1] * 5, origin.shape[1] + 1)) 187 | img, origin = self.center_square_crop(img, 64, w, cx, cy, origin) 188 | img = cv2.resize(img, dsize=(img_width, img_width), interpolation=cv2.INTER_CUBIC) 189 | origin = cv2.resize(origin, dsize=(img_width, img_width), interpolation=cv2.INTER_CUBIC) 190 | # img, origin = self.random_rotate(img, 20, origin) 191 | # print(img.dtype, origin.dtype) 192 | img = (self.add_color(img / 255.) * 255).astype('uint8') 193 | tmp = np.concatenate([img, origin], 2) 194 | tmp = aug_with_origin.augment_image(tmp) 195 | img = tmp[:, :, 0 : 3] 196 | origin = tmp[:, :, 3 : 6] 197 | img = aug.augment_image(img) 198 | return img, origin 199 | 200 | def generate(self): 201 | label0 = np.random.randint(10) 202 | label1 = np.random.randint(10) 203 | label2 = np.random.randint(10) 204 | img0 = img_list[np.random.randint(500) * 10 + label0] 205 | img1 = img_list[np.random.randint(500) * 10 + label1] 206 | img2 = img_list[np.random.randint(500) * 10 + label2] 207 | label1 += 1 208 | if label1 == 10: 209 | label1 = 0 210 | img, origin = self.forgery_data(img0, img1, img2) 211 | return img, origin, label1 212 | 213 | if __name__ == '__main__': 214 | G = Generator() 215 | for i in range(100): 216 | img, origin, label = G.generate() 217 | img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC) 218 | origin = cv2.resize(origin, (224, 224), interpolation=cv2.INTER_CUBIC) 219 | cv2.imshow('img', img) 220 | cv2.imshow('ground_truth', origin) 221 | cv2.waitKey(0) 222 | --------------------------------------------------------------------------------