├── .DS_Store ├── gallery ├── nst.jpg ├── 8bitart.jpg ├── gif_teaser_1.gif ├── diamond_markerpen.jpg └── apple_oilpaintbrush.jpg ├── test_images ├── yui.jpg ├── joker.jpg ├── kasumi.png └── satomi.png ├── style_images ├── fire.jpg ├── mosaic.jpg ├── scream.jpg └── picasso.jpg ├── brushes ├── brush_fromweb2_large_horizontal.png ├── brush_fromweb2_large_vertical.png ├── brush_fromweb2_small_horizontal.png └── brush_fromweb2_small_vertical.png ├── Requirements.txt ├── README.md ├── morphology.py ├── train_imitator.py ├── pytorch_batch_sinkhorn.py ├── demo.py ├── demo_nst.py ├── demo_prog.py ├── demo_8bitart.py ├── loss.py ├── LICENSE ├── imitator.py ├── image_to_paint.ipynb ├── utils.py ├── renderer.py ├── painter.py └── networks.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/.DS_Store -------------------------------------------------------------------------------- /gallery/nst.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/gallery/nst.jpg -------------------------------------------------------------------------------- /gallery/8bitart.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/gallery/8bitart.jpg -------------------------------------------------------------------------------- /test_images/yui.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/test_images/yui.jpg -------------------------------------------------------------------------------- /style_images/fire.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/style_images/fire.jpg -------------------------------------------------------------------------------- /style_images/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/style_images/mosaic.jpg -------------------------------------------------------------------------------- /style_images/scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/style_images/scream.jpg -------------------------------------------------------------------------------- /test_images/joker.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/test_images/joker.jpg -------------------------------------------------------------------------------- /test_images/kasumi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/test_images/kasumi.png -------------------------------------------------------------------------------- /test_images/satomi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/test_images/satomi.png -------------------------------------------------------------------------------- /gallery/gif_teaser_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/gallery/gif_teaser_1.gif -------------------------------------------------------------------------------- /style_images/picasso.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/style_images/picasso.jpg -------------------------------------------------------------------------------- /gallery/diamond_markerpen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/gallery/diamond_markerpen.jpg -------------------------------------------------------------------------------- /gallery/apple_oilpaintbrush.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/gallery/apple_oilpaintbrush.jpg -------------------------------------------------------------------------------- /brushes/brush_fromweb2_large_horizontal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/brushes/brush_fromweb2_large_horizontal.png -------------------------------------------------------------------------------- /brushes/brush_fromweb2_large_vertical.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/brushes/brush_fromweb2_large_vertical.png -------------------------------------------------------------------------------- /brushes/brush_fromweb2_small_horizontal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/brushes/brush_fromweb2_small_horizontal.png -------------------------------------------------------------------------------- /brushes/brush_fromweb2_small_vertical.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/brushes/brush_fromweb2_small_vertical.png -------------------------------------------------------------------------------- /Requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | scikit-image 3 | scikit-learn 4 | scipy 5 | numpy 6 | torch 7 | torchvision 8 | opencv-python 9 | opencv-contrib-python -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |  image_to_paint.ipynbファイルを開くと、Jupyternotebook形式のファイルが表示されます。上部にある「OpeninColab」ボタンをクリックすると、Google Colab上で実行できます。使用環境は、Google Colabが動作すれば何でも構いません。 2 | 3 | # image_to_paint.ipynb 4 |  Stylized Neural Painting とは、従来からあるピクセル単位で画像変換する方法とは異なり、ベクトル化されたデータを使って順次レンダリングするという、まるで筆で絵具を重ね塗りするような方法で油絵へ変換します。詳細は、[cedro-blog](http://cedro3.com/ai/image-to-paint/)を参照下さい。 5 | -------------------------------------------------------------------------------- /morphology.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pdb 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Erosion2d(nn.Module): 9 | 10 | def __init__(self, m=1): 11 | 12 | super(Erosion2d, self).__init__() 13 | self.m = m 14 | self.pad = [m, m, m, m] 15 | self.unfold = nn.Unfold(2*m+1, padding=0, stride=1) 16 | 17 | def forward(self, x): 18 | batch_size, c, h, w = x.shape 19 | x_pad = F.pad(x, pad=self.pad, mode='constant', value=1e9) 20 | for i in range(c): 21 | channel = self.unfold(x_pad[:, [i], :, :]) 22 | channel = torch.min(channel, dim=1, keepdim=True)[0] 23 | channel = channel.view([batch_size, 1, h, w]) 24 | x[:, [i], :, :] = channel 25 | 26 | return x 27 | 28 | 29 | 30 | class Dilation2d(nn.Module): 31 | 32 | def __init__(self, m=1): 33 | 34 | super(Dilation2d, self).__init__() 35 | self.m = m 36 | self.pad = [m, m, m, m] 37 | self.unfold = nn.Unfold(2*m+1, padding=0, stride=1) 38 | 39 | def forward(self, x): 40 | batch_size, c, h, w = x.shape 41 | x_pad = F.pad(x, pad=self.pad, mode='constant', value=-1e9) 42 | for i in range(c): 43 | channel = self.unfold(x_pad[:, [i], :, :]) 44 | channel = torch.max(channel, dim=1, keepdim=True)[0] 45 | channel = channel.view([batch_size, 1, h, w]) 46 | x[:, [i], :, :] = channel 47 | 48 | return x 49 | -------------------------------------------------------------------------------- /train_imitator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import utils 5 | from imitator import* 6 | 7 | # settings 8 | parser = argparse.ArgumentParser(description='ZZX TRAIN IMITATOR') 9 | parser.add_argument('--renderer', type=str, default='oilpaintbrush', metavar='str', 10 | help='renderer: [watercolor, markerpen, oilpaintbrush, rectangle' 11 | 'bezier, circle, square, rectangle] (default ...)') 12 | parser.add_argument('--batch_size', type=int, default=64, metavar='N', 13 | help='input batch size for training (default: 4)') 14 | parser.add_argument('--print_models', action='store_true', default=False, 15 | help='visualize and print networks') 16 | parser.add_argument('--net_G', type=str, default='zou-fusion-net', metavar='str', 17 | help='net_G: plain-dcgan or plain-unet or huang-net or zou-fusion-net') 18 | parser.add_argument('--checkpoint_dir', type=str, default=r'./checkpoints_G', metavar='str', 19 | help='dir to save checkpoints (default: ...)') 20 | parser.add_argument('--vis_dir', type=str, default=r'./val_out_G', metavar='str', 21 | help='dir to save results during training (default: ./val_out_G)') 22 | parser.add_argument('--lr', type=float, default=2e-4, 23 | help='learning rate (default: 0.0002)') 24 | parser.add_argument('--max_num_epochs', type=int, default=400, metavar='N', 25 | help='max number of training epochs (default 400)') 26 | args = parser.parse_args() 27 | 28 | 29 | if __name__ == '__main__': 30 | 31 | dataloaders = utils.get_renderer_loaders(args) 32 | imt = Imitator(args=args, dataloaders=dataloaders) 33 | imt.train_models() 34 | 35 | # # How to check if the data is loading correctly? 36 | # dataloaders = utils.get_renderer_loaders(args) 37 | # for i in range(100): 38 | # data = next(iter(dataloaders['train'])) 39 | # vis_A = data['A'] 40 | # vis_B = utils.make_numpy_grid(data['B']) 41 | # print(data['A'].cpu().numpy().shape[1]) 42 | # print(data['B'].shape) 43 | # plt.imshow(vis_B) 44 | # plt.show() 45 | 46 | 47 | -------------------------------------------------------------------------------- /pytorch_batch_sinkhorn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | sinkhorn_pointcloud.py 4 | 5 | Discrete OT : Sinkhorn algorithm for point cloud marginals. 6 | 7 | """ 8 | 9 | import torch 10 | from torch.autograd import Variable 11 | 12 | # Decide which device we want to run on 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | def sinkhorn_normalized(x, y, epsilon, niter, mass_x=None, mass_y=None): 17 | 18 | Wxy = sinkhorn_loss(x, y, epsilon, niter, mass_x, mass_y) 19 | Wxx = sinkhorn_loss(x, x, epsilon, niter, mass_x, mass_x) 20 | Wyy = sinkhorn_loss(y, y, epsilon, niter, mass_y, mass_y) 21 | return 2 * Wxy - Wxx - Wyy 22 | 23 | def sinkhorn_loss(x, y, epsilon, niter, mass_x=None, mass_y=None): 24 | """ 25 | Given two emprical measures with n points each with locations x and y 26 | outputs an approximation of the OT cost with regularization parameter epsilon 27 | niter is the max. number of steps in sinkhorn loop 28 | """ 29 | 30 | # The Sinkhorn algorithm takes as input three variables : 31 | C = cost_matrix(x, y) # Wasserstein cost function 32 | 33 | nx = x.shape[1] 34 | ny = y.shape[1] 35 | batch_size = x.shape[0] 36 | 37 | if mass_x is None: 38 | # assign marginal to fixed with equal weights 39 | mu = 1. / nx * torch.ones([batch_size, nx]).to(device) 40 | else: # normalize 41 | mass_x.data = torch.clamp(mass_x.data, min=0, max=1e9) 42 | mass_x = mass_x + 1e-9 43 | mu = (mass_x / mass_x.sum(dim=-1, keepdim=True)).to(device) 44 | 45 | if mass_y is None: 46 | # assign marginal to fixed with equal weights 47 | nu = 1. / ny * torch.ones([batch_size, ny]).to(device) 48 | else: # normalize 49 | mass_y.data = torch.clamp(mass_y.data, min=0, max=1e9) 50 | mass_y = mass_y + 1e-9 51 | nu = (mass_y / mass_y.sum(dim=-1, keepdim=True)).to(device) 52 | 53 | def M(u, v): 54 | "Modified cost for logarithmic updates" 55 | "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$" 56 | return (-C + u.unsqueeze(2) + v.unsqueeze(1)) / epsilon 57 | 58 | def lse(A): 59 | "log-sum-exp" 60 | return torch.log(torch.exp(A).sum(2, keepdim=True) + 1e-6) # add 10^-6 to prevent NaN 61 | 62 | # Actual Sinkhorn loop ...................................................................... 63 | u, v, err = 0. * mu, 0. * nu, 0. 64 | 65 | for i in range(niter): 66 | u = epsilon * (torch.log(mu) - lse(M(u, v)).squeeze()) + u 67 | v = epsilon * (torch.log(nu) - lse(M(u, v).transpose(dim0=1, dim1=2)).squeeze()) + v 68 | 69 | U, V = u, v 70 | pi = torch.exp(M(U, V)) # Transport plan pi = diag(a)*K*diag(b) 71 | cost = torch.sum(pi * C, dim=[1, 2]) # Sinkhorn cost 72 | 73 | return torch.mean(cost) 74 | 75 | 76 | def cost_matrix(x, y, p=2): 77 | "Returns the matrix of $|x_i-y_j|^p$." 78 | x_col = x.unsqueeze(2) 79 | y_lin = y.unsqueeze(1) 80 | c = torch.sum((torch.abs(x_col - y_lin)) ** p, -1) 81 | return c 82 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | torch.cuda.current_device() 5 | import torch.optim as optim 6 | 7 | from painter import * 8 | 9 | # settings 10 | parser = argparse.ArgumentParser(description='STYLIZED NEURAL PAINTING') 11 | parser.add_argument('--img_path', type=str, default='./test_images/sunflowers.jpg', metavar='str', 12 | help='path to test image (default: ./test_images/sunflowers.jpg)') 13 | parser.add_argument('--renderer', type=str, default='oilpaintbrush', metavar='str', 14 | help='renderer: [watercolor, markerpen, oilpaintbrush, rectangle (default oilpaintbrush)') 15 | parser.add_argument('--canvas_color', type=str, default='black', metavar='str', 16 | help='canvas_color: [black, white] (default black)') 17 | parser.add_argument('--canvas_size', type=int, default=512, metavar='str', 18 | help='size of the canvas for stroke rendering') 19 | parser.add_argument('--max_m_strokes', type=int, default=500, metavar='str', 20 | help='max number of strokes (default 500)') 21 | parser.add_argument('--m_grid', type=int, default=5, metavar='N', 22 | help='divide an image to m_grid x m_grid patches (default 5)') 23 | parser.add_argument('--beta_L1', type=float, default=1.0, 24 | help='weight for L1 loss (default: 1.0)') 25 | parser.add_argument('--with_ot_loss', action='store_true', default=False, 26 | help='imporve the convergence by using optimal transportation loss') 27 | parser.add_argument('--beta_ot', type=float, default=0.1, 28 | help='weight for optimal transportation loss (default: 0.1)') 29 | parser.add_argument('--net_G', type=str, default='zou-fusion-net', metavar='str', 30 | help='net_G: plain-dcgan, plain-unet, huang-net, or zou-fusion-net (default: zou-fusion-net)') 31 | parser.add_argument('--renderer_checkpoint_dir', type=str, default=r'./checkpoints_G_oilpaintbrush', metavar='str', 32 | help='dir to load neu-renderer (default: ./checkpoints_G_oilpaintbrush)') 33 | parser.add_argument('--lr', type=float, default=0.005, 34 | help='learning rate for stroke searching (default: 0.005)') 35 | parser.add_argument('--output_dir', type=str, default=r'./output', metavar='str', 36 | help='dir to save painting results (default: ./output)') 37 | args = parser.parse_args() 38 | 39 | 40 | # Decide which device we want to run on 41 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 42 | 43 | 44 | def optimize_x(pt): 45 | 46 | pt._load_checkpoint() 47 | pt.net_G.eval() 48 | 49 | pt.initialize_params() 50 | pt.x_ctt.requires_grad = True 51 | pt.x_color.requires_grad = True 52 | pt.x_alpha.requires_grad = True 53 | utils.set_requires_grad(pt.net_G, False) 54 | 55 | pt.optimizer_x = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr) 56 | 57 | print('begin to draw...') 58 | pt.step_id = 0 59 | for pt.anchor_id in range(0, pt.m_strokes_per_block): 60 | pt.stroke_sampler(pt.anchor_id) 61 | iters_per_stroke = 20 62 | if pt.anchor_id == pt.m_strokes_per_block - 1: 63 | iters_per_stroke = 40 64 | for i in range(iters_per_stroke): 65 | 66 | pt.optimizer_x.zero_grad() 67 | 68 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1) 69 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1) 70 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1) 71 | 72 | if args.canvas_color == 'white': 73 | pt.G_pred_canvas = torch.ones([args.m_grid ** 2, 3, 128, 128]).to(device) 74 | else: 75 | pt.G_pred_canvas = torch.zeros(args.m_grid ** 2, 3, 128, 128).to(device) 76 | 77 | pt._forward_pass() 78 | pt._drawing_step_states() 79 | pt._backward_x() 80 | pt.optimizer_x.step() 81 | 82 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1) 83 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1) 84 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1) 85 | 86 | pt.step_id += 1 87 | 88 | v = pt.x.detach().cpu().numpy() 89 | pt._save_stroke_params(v) 90 | v_n = pt._normalize_strokes(pt.x) 91 | pt.final_rendered_images = pt._render_on_grids(v_n) 92 | pt._save_rendered_images() 93 | 94 | 95 | 96 | if __name__ == '__main__': 97 | 98 | pt = Painter(args=args) 99 | optimize_x(pt) 100 | 101 | -------------------------------------------------------------------------------- /demo_nst.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | torch.cuda.current_device() 5 | import torch.optim as optim 6 | 7 | from painter import * 8 | 9 | # settings 10 | parser = argparse.ArgumentParser(description='STYLIZED NEURAL PAINTING') 11 | parser.add_argument('--renderer', type=str, default='oilpaintbrush', metavar='str', 12 | help='renderer: [watercolor, markerpen, oilpaintbrush, rectangle (default oilpaintbrush)') 13 | parser.add_argument('--vector_file', type=str, default='./output/sunflowers_strokes.npz', metavar='str', 14 | help='path to pre-generated stroke vector file (default: ...)') 15 | parser.add_argument('--style_img_path', type=str, default='./style_images/fire.jpg', metavar='str', 16 | help='path to style image (default: ...)') 17 | parser.add_argument('--content_img_path', type=str, default='./test_images/sunflowers.jpg', metavar='str', 18 | help='path to content image (default: ...)') 19 | parser.add_argument('--transfer_mode', type=int, default=1, metavar='N', 20 | help='style transfer mode, 0: transfer color only, 1: transfer both color and texture, ' 21 | 'defalt: 1') 22 | parser.add_argument('--canvas_color', type=str, default='black', metavar='str', 23 | help='canvas_color: [black, white] (default black)') 24 | parser.add_argument('--canvas_size', type=int, default=512, metavar='str', 25 | help='size of the canvas for stroke rendering') 26 | parser.add_argument('--beta_L1', type=float, default=1.0, 27 | help='weight for L1 loss (default: 1.0)') 28 | parser.add_argument('--beta_sty', type=float, default=0.5, 29 | help='weight for vgg style loss (default: 0.5)') 30 | parser.add_argument('--net_G', type=str, default='zou-fusion-net', metavar='str', 31 | help='net_G: plain-dcgan, plain-unet, huang-net, or zou-fusion-net (default: zou-fusion-net)') 32 | parser.add_argument('--renderer_checkpoint_dir', type=str, default=r'./checkpoints_G_oilpaintbrush', metavar='str', 33 | help='dir to load neu-renderer (default: ./checkpoints_G_oilpaintbrush)') 34 | parser.add_argument('--lr', type=float, default=0.005, 35 | help='learning rate for stroke searching (default: 0.005)') 36 | parser.add_argument('--output_dir', type=str, default=r'./output', metavar='str', 37 | help='dir to save style transfer results (default: ./output)') 38 | args = parser.parse_args() 39 | 40 | 41 | # Decide which device we want to run on 42 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 43 | 44 | 45 | def optimize_x(pt): 46 | 47 | pt._load_checkpoint() 48 | pt.net_G.eval() 49 | 50 | if args.transfer_mode == 0: # transfer color only 51 | pt.x_ctt.requires_grad = False 52 | pt.x_color.requires_grad = True 53 | pt.x_alpha.requires_grad = False 54 | else: # transfer both color and texture 55 | pt.x_ctt.requires_grad = True 56 | pt.x_color.requires_grad = True 57 | pt.x_alpha.requires_grad = True 58 | 59 | pt.optimizer_x_sty = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr) 60 | 61 | iters_per_stroke = 100 62 | for i in range(iters_per_stroke): 63 | pt.optimizer_x_sty.zero_grad() 64 | 65 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1) 66 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1) 67 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1) 68 | 69 | if args.canvas_color == 'white': 70 | pt.G_pred_canvas = torch.ones([pt.m_grid*pt.m_grid, 3, 128, 128]).to(device) 71 | else: 72 | pt.G_pred_canvas = torch.zeros(pt.m_grid*pt.m_grid, 3, 128, 128).to(device) 73 | 74 | pt._forward_pass() 75 | pt._style_transfer_step_states() 76 | pt._backward_x_sty() 77 | pt.optimizer_x_sty.step() 78 | 79 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1) 80 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1) 81 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1) 82 | 83 | pt.step_id += 1 84 | 85 | print('saving style transfer result...') 86 | v_n = pt._normalize_strokes(pt.x) 87 | pt.final_rendered_images = pt._render_on_grids(v_n) 88 | 89 | file_dir = os.path.join( 90 | args.output_dir, args.content_img_path.split('/')[-1][:-4]) 91 | plt.imsave(file_dir + '_style_img_' + 92 | args.style_img_path.split('/')[-1][:-4] + '.png', pt.style_img_) 93 | plt.imsave(file_dir + '_style_transfer_' + 94 | args.style_img_path.split('/')[-1][:-4] + '.png', pt.final_rendered_images[-1]) 95 | 96 | 97 | if __name__ == '__main__': 98 | 99 | pt = NeuralStyleTransfer(args=args) 100 | optimize_x(pt) 101 | 102 | -------------------------------------------------------------------------------- /demo_prog.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | torch.cuda.current_device() 5 | import torch.optim as optim 6 | 7 | from painter import * 8 | 9 | # settings 10 | parser = argparse.ArgumentParser(description='STYLIZED NEURAL PAINTING') 11 | parser.add_argument('--img_path', type=str, default='./test_images/sunflowers.jpg', metavar='str', 12 | help='path to test image (default: ./test_images/sunflowers.jpg)') 13 | parser.add_argument('--renderer', type=str, default='oilpaintbrush', metavar='str', 14 | help='renderer: [watercolor, markerpen, oilpaintbrush, rectangle (default oilpaintbrush)') 15 | parser.add_argument('--canvas_color', type=str, default='black', metavar='str', 16 | help='canvas_color: [black, white] (default black)') 17 | parser.add_argument('--canvas_size', type=int, default=512, metavar='str', 18 | help='size of the canvas for stroke rendering') 19 | parser.add_argument('--max_m_strokes', type=int, default=500, metavar='str', 20 | help='max number of strokes (default 500)') 21 | parser.add_argument('--max_divide', type=int, default=5, metavar='N', 22 | help='divide an image up-to max_divide x max_divide patches (default 5)') 23 | parser.add_argument('--beta_L1', type=float, default=1.0, 24 | help='weight for L1 loss (default: 1.0)') 25 | parser.add_argument('--with_ot_loss', action='store_true', default=False, 26 | help='imporve the convergence by using optimal transportation loss') 27 | parser.add_argument('--beta_ot', type=float, default=0.1, 28 | help='weight for optimal transportation loss (default: 0.1)') 29 | parser.add_argument('--net_G', type=str, default='zou-fusion-net', metavar='str', 30 | help='net_G: plain-dcgan, plain-unet, huang-net, or zou-fusion-net (default: zou-fusion-net)') 31 | parser.add_argument('--renderer_checkpoint_dir', type=str, default=r'./checkpoints_G_oilpaintbrush', metavar='str', 32 | help='dir to load neu-renderer (default: ./checkpoints_G_oilpaintbrush)') 33 | parser.add_argument('--lr', type=float, default=0.005, 34 | help='learning rate for stroke searching (default: 0.005)') 35 | parser.add_argument('--output_dir', type=str, default=r'./output', metavar='str', 36 | help='dir to save painting results (default: ./output)') 37 | args = parser.parse_args() 38 | 39 | 40 | # Decide which device we want to run on 41 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 42 | 43 | 44 | def optimize_x(pt): 45 | 46 | pt._load_checkpoint() 47 | pt.net_G.eval() 48 | 49 | print('begin drawing...') 50 | 51 | PARAMS = np.zeros([1, 0, pt.rderr.d], np.float32) 52 | 53 | if pt.rderr.canvas_color == 'white': 54 | CANVAS_tmp = torch.ones([1, 3, 128, 128]).to(device) 55 | else: 56 | CANVAS_tmp = torch.zeros([1, 3, 128, 128]).to(device) 57 | 58 | for pt.m_grid in range(1, pt.max_divide + 1): 59 | 60 | pt.img_batch = utils.img2patches(pt.img_, pt.m_grid).to(device) 61 | pt.G_final_pred_canvas = CANVAS_tmp 62 | 63 | pt.initialize_params() 64 | pt.x_ctt.requires_grad = True 65 | pt.x_color.requires_grad = True 66 | pt.x_alpha.requires_grad = True 67 | utils.set_requires_grad(pt.net_G, False) 68 | 69 | pt.optimizer_x = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr, centered=True) 70 | 71 | pt.step_id = 0 72 | for pt.anchor_id in range(0, pt.m_strokes_per_block): 73 | pt.stroke_sampler(pt.anchor_id) 74 | iters_per_stroke = 40 75 | for i in range(iters_per_stroke): 76 | pt.G_pred_canvas = CANVAS_tmp 77 | 78 | # update x 79 | pt.optimizer_x.zero_grad() 80 | 81 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1) 82 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1) 83 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1) 84 | 85 | pt._forward_pass() 86 | pt._drawing_step_states() 87 | pt._backward_x() 88 | 89 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1) 90 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1) 91 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1) 92 | 93 | pt.optimizer_x.step() 94 | pt.step_id += 1 95 | 96 | v = pt._normalize_strokes(pt.x) 97 | PARAMS = np.concatenate([PARAMS, np.reshape(v, [1, -1, pt.rderr.d])], axis=1) 98 | CANVAS_tmp = pt._render(PARAMS)[-1] 99 | CANVAS_tmp = utils.img2patches(CANVAS_tmp, pt.m_grid + 1, to_tensor=True).to(device) 100 | 101 | pt._save_stroke_params(PARAMS) 102 | pt.final_rendered_images = pt._render(PARAMS) 103 | pt._save_rendered_images() 104 | 105 | 106 | 107 | if __name__ == '__main__': 108 | 109 | pt = ProgressivePainter(args=args) 110 | optimize_x(pt) 111 | 112 | -------------------------------------------------------------------------------- /demo_8bitart.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | torch.cuda.current_device() 5 | import torch.optim as optim 6 | 7 | from painter import * 8 | 9 | # settings 10 | parser = argparse.ArgumentParser(description='STYLIZED NEURAL PAINTING') 11 | parser.add_argument('--img_path', type=str, default='./test_images/sunflowers.jpg', metavar='str', 12 | help='path to test image (default: ./test_images/sunflowers.jpg)') 13 | parser.add_argument('--renderer', type=str, default='rectangle', metavar='str', 14 | help='renderer: [watercolor, markerpen, oilpaintbrush, rectangle (default oilpaintbrush)') 15 | parser.add_argument('--canvas_color', type=str, default='black', metavar='str', 16 | help='canvas_color: [black, white] (default black)') 17 | parser.add_argument('--canvas_size', type=int, default=512, metavar='str', 18 | help='size of the canvas for stroke rendering') 19 | parser.add_argument('--max_m_strokes', type=int, default=500, metavar='str', 20 | help='max number of strokes (default 500)') 21 | parser.add_argument('--max_divide', type=int, default=5, metavar='N', 22 | help='divide an image up-to max_divide x max_divide patches (default 5)') 23 | parser.add_argument('--beta_L1', type=float, default=1.0, 24 | help='weight for L1 loss (default: 1.0)') 25 | parser.add_argument('--with_ot_loss', action='store_true', default=False, 26 | help='imporve the convergence by using optimal transportation loss') 27 | parser.add_argument('--beta_ot', type=float, default=0.1, 28 | help='weight for optimal transportation loss (default: 0.1)') 29 | parser.add_argument('--net_G', type=str, default='zou-fusion-net', metavar='str', 30 | help='net_G: plain-dcgan, plain-unet, huang-net, or zou-fusion-net (default: zou-fusion-net)') 31 | parser.add_argument('--renderer_checkpoint_dir', type=str, default=r'./checkpoints_G_rectangle', metavar='str', 32 | help='dir to load neu-renderer (default: ./checkpoints_G_rectangle)') 33 | parser.add_argument('--lr', type=float, default=0.005, 34 | help='learning rate for stroke searching (default: 0.005)') 35 | parser.add_argument('--output_dir', type=str, default=r'./output', metavar='str', 36 | help='dir to save painting results (default: ./output)') 37 | args = parser.parse_args() 38 | 39 | 40 | # Decide which device we want to run on 41 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 42 | 43 | 44 | def optimize_x(pt): 45 | 46 | pt._load_checkpoint() 47 | pt.net_G.eval() 48 | 49 | print('begin drawing...') 50 | 51 | PARAMS = np.zeros([1, 0, pt.rderr.d], np.float32) 52 | 53 | if pt.rderr.canvas_color == 'white': 54 | CANVAS_tmp = torch.ones([1, 3, 128, 128]).to(device) 55 | else: 56 | CANVAS_tmp = torch.zeros([1, 3, 128, 128]).to(device) 57 | 58 | for pt.m_grid in range(1, pt.max_divide + 1): 59 | 60 | pt.img_batch = utils.img2patches(pt.img_, pt.m_grid).to(device) 61 | pt.G_final_pred_canvas = CANVAS_tmp 62 | 63 | pt.initialize_params() 64 | pt.x_ctt.requires_grad = True 65 | pt.x_color.requires_grad = True 66 | pt.x_alpha.requires_grad = True 67 | utils.set_requires_grad(pt.net_G, False) 68 | 69 | pt.optimizer_x = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr, centered=True) 70 | 71 | pt.step_id = 0 72 | for pt.anchor_id in range(0, pt.m_strokes_per_block): 73 | pt.stroke_sampler(pt.anchor_id) 74 | iters_per_stroke = 20 75 | for i in range(iters_per_stroke): 76 | pt.G_pred_canvas = CANVAS_tmp 77 | 78 | # update x 79 | pt.optimizer_x.zero_grad() 80 | 81 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0, 1) 82 | pt.x_ctt.data[:, :, -1] = torch.clamp(pt.x_ctt.data[:, :, -1], 0, 0) 83 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1) 84 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 1, 1) 85 | 86 | pt._forward_pass() 87 | pt._backward_x() 88 | 89 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0, 1) 90 | pt.x_ctt.data[:, :, -1] = torch.clamp(pt.x_ctt.data[:, :, -1], 0, 0) 91 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1) 92 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 1, 1) 93 | 94 | pt._drawing_step_states() 95 | 96 | pt.optimizer_x.step() 97 | pt.step_id += 1 98 | 99 | v = pt._normalize_strokes(pt.x) 100 | PARAMS = np.concatenate([PARAMS, np.reshape(v, [1, -1, pt.rderr.d])], axis=1) 101 | CANVAS_tmp = pt._render(PARAMS)[-1] 102 | CANVAS_tmp = utils.img2patches(CANVAS_tmp, pt.m_grid + 1, to_tensor=True).to(device) 103 | 104 | pt._save_stroke_params(PARAMS) 105 | pt.final_rendered_images = pt._render(PARAMS) 106 | pt._save_rendered_images() 107 | 108 | 109 | 110 | if __name__ == '__main__': 111 | 112 | pt = ProgressivePainter(args=args) 113 | optimize_x(pt) 114 | 115 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.cuda.current_device() 3 | import torch.nn as nn 4 | import torchvision 5 | import utils 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import random 9 | import pytorch_batch_sinkhorn as spc 10 | 11 | # Decide which device we want to run on 12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | class PixelLoss(nn.Module): 16 | 17 | def __init__(self, p=1): 18 | super(PixelLoss, self).__init__() 19 | self.p = p 20 | 21 | def forward(self, canvas, gt, ignore_color=False): 22 | if ignore_color: 23 | canvas = torch.mean(canvas, dim=1) 24 | gt = torch.mean(gt, dim=1) 25 | loss = torch.mean(torch.abs(canvas-gt)**self.p) 26 | return loss 27 | 28 | 29 | class VGGPerceptualLoss(torch.nn.Module): 30 | def __init__(self, resize=True): 31 | super(VGGPerceptualLoss, self).__init__() 32 | vgg = torchvision.models.vgg16(pretrained=True).to(device) 33 | blocks = [] 34 | blocks.append(vgg.features[:4].eval()) 35 | blocks.append(vgg.features[4:9].eval()) 36 | blocks.append(vgg.features[9:16].eval()) 37 | blocks.append(vgg.features[16:23].eval()) 38 | for bl in blocks: 39 | for p in bl: 40 | p.requires_grad = False 41 | self.blocks = torch.nn.ModuleList(blocks) 42 | self.transform = torch.nn.functional.interpolate 43 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1) 44 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1) 45 | self.resize = resize 46 | 47 | def forward(self, input, target, ignore_color=False): 48 | self.mean = self.mean.type_as(input) 49 | self.std = self.std.type_as(input) 50 | if ignore_color: 51 | input = torch.mean(input, dim=1, keepdim=True) 52 | target = torch.mean(target, dim=1, keepdim=True) 53 | if input.shape[1] != 3: 54 | input = input.repeat(1, 3, 1, 1) 55 | target = target.repeat(1, 3, 1, 1) 56 | input = (input-self.mean) / self.std 57 | target = (target-self.mean) / self.std 58 | if self.resize: 59 | input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) 60 | target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) 61 | loss = 0.0 62 | x = input 63 | y = target 64 | for block in self.blocks: 65 | x = block(x) 66 | y = block(y) 67 | loss += torch.nn.functional.l1_loss(x, y) 68 | return loss 69 | 70 | 71 | 72 | class VGGStyleLoss(torch.nn.Module): 73 | def __init__(self, transfer_mode, resize=True): 74 | super(VGGStyleLoss, self).__init__() 75 | vgg = torchvision.models.vgg16(pretrained=True).to(device) 76 | for i, layer in enumerate(vgg.features): 77 | if isinstance(layer, torch.nn.MaxPool2d): 78 | vgg.features[i] = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0) 79 | 80 | blocks = [] 81 | if transfer_mode == 0: # transfer color only 82 | blocks.append(vgg.features[:4].eval()) 83 | blocks.append(vgg.features[4:9].eval()) 84 | else: # transfer both color and texture 85 | blocks.append(vgg.features[:4].eval()) 86 | blocks.append(vgg.features[4:9].eval()) 87 | blocks.append(vgg.features[9:16].eval()) 88 | blocks.append(vgg.features[16:23].eval()) 89 | 90 | for bl in blocks: 91 | for p in bl: 92 | p.requires_grad = False 93 | self.blocks = torch.nn.ModuleList(blocks) 94 | 95 | self.transform = torch.nn.functional.interpolate 96 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) 97 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) 98 | self.resize = resize 99 | 100 | def gram_matrix(self, y): 101 | (b, ch, h, w) = y.size() 102 | features = y.view(b, ch, w * h) 103 | features_t = features.transpose(1, 2) 104 | gram = features.bmm(features_t) / (ch * w * h) 105 | return gram 106 | 107 | def forward(self, input, target): 108 | if input.shape[1] != 3: 109 | input = input.repeat(1, 3, 1, 1) 110 | target = target.repeat(1, 3, 1, 1) 111 | input = (input - self.mean) / self.std 112 | target = (target - self.mean) / self.std 113 | if self.resize: 114 | input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) 115 | target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) 116 | loss = 0.0 117 | x = input 118 | y = target 119 | for block in self.blocks: 120 | x = block(x) 121 | y = block(y) 122 | gm_x = self.gram_matrix(x) 123 | gm_y = self.gram_matrix(y) 124 | loss += torch.sum((gm_x-gm_y)**2) 125 | return loss 126 | 127 | 128 | 129 | class SinkhornLoss(nn.Module): 130 | 131 | def __init__(self, epsilon=0.01, niter=5, normalize=False): 132 | super(SinkhornLoss, self).__init__() 133 | self.epsilon = epsilon 134 | self.niter = niter 135 | self.normalize = normalize 136 | 137 | def _mesh_grids(self, batch_size, h, w): 138 | 139 | a = torch.linspace(0.0, h - 1.0, h).to(device) 140 | b = torch.linspace(0.0, w - 1.0, w).to(device) 141 | y_grid = a.view(-1, 1).repeat(batch_size, 1, w) / h 142 | x_grid = b.view(1, -1).repeat(batch_size, h, 1) / w 143 | grids = torch.cat([y_grid.view(batch_size, -1, 1), x_grid.view(batch_size, -1, 1)], dim=-1) 144 | return grids 145 | 146 | def forward(self, canvas, gt): 147 | 148 | batch_size, c, h, w = gt.shape 149 | if h > 24: 150 | canvas = nn.functional.interpolate(canvas, [24, 24], mode='area') 151 | gt = nn.functional.interpolate(gt, [24, 24], mode='area') 152 | batch_size, c, h, w = gt.shape 153 | 154 | canvas_grids = self._mesh_grids(batch_size, h, w) 155 | gt_grids = torch.clone(canvas_grids) 156 | 157 | # randomly select a color channel, to speedup and consume memory 158 | i = random.randint(0, 2) 159 | 160 | img_1 = canvas[:, [i], :, :] 161 | img_2 = gt[:, [i], :, :] 162 | 163 | mass_x = img_1.reshape(batch_size, -1) 164 | mass_y = img_2.reshape(batch_size, -1) 165 | if self.normalize: 166 | loss = spc.sinkhorn_normalized( 167 | canvas_grids, gt_grids, epsilon=self.epsilon, niter=self.niter, 168 | mass_x=mass_x, mass_y=mass_y) 169 | else: 170 | loss = spc.sinkhorn_loss( 171 | canvas_grids, gt_grids, epsilon=self.epsilon, niter=self.niter, 172 | mass_x=mass_x, mass_y=mass_y) 173 | 174 | 175 | return loss 176 | 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /imitator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | 5 | import utils 6 | import loss 7 | from networks import * 8 | 9 | import torch 10 | torch.cuda.current_device() 11 | import torch.optim as optim 12 | from torch.optim import lr_scheduler 13 | import torch.nn as nn 14 | 15 | import renderer 16 | 17 | 18 | # Decide which device we want to run on 19 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 20 | 21 | 22 | class Imitator(): 23 | 24 | def __init__(self, args, dataloaders): 25 | 26 | self.dataloaders = dataloaders 27 | 28 | self.rderr = renderer.Renderer(renderer=args.renderer) 29 | 30 | # define G 31 | self.net_G = define_G(rdrr=self.rderr, netG=args.net_G).to(device) 32 | 33 | # Learning rate 34 | self.lr = args.lr 35 | 36 | # define optimizers 37 | self.optimizer_G = optim.Adam( 38 | self.net_G.parameters(), lr=self.lr, betas=(0.9, 0.999)) 39 | 40 | # define lr schedulers 41 | self.exp_lr_scheduler_G = lr_scheduler.StepLR( 42 | self.optimizer_G, step_size=100, gamma=0.1) 43 | 44 | # define some other vars to record the training states 45 | self.running_acc = [] 46 | self.epoch_acc = 0 47 | self.best_val_acc = 0.0 48 | self.best_epoch_id = 0 49 | self.epoch_to_start = 0 50 | self.max_num_epochs = args.max_num_epochs 51 | self.G_pred_foreground = None 52 | self.G_pred_alpha = None 53 | self.batch = None 54 | self.G_loss = None 55 | self.is_training = False 56 | self.batch_id = 0 57 | self.epoch_id = 0 58 | self.checkpoint_dir = args.checkpoint_dir 59 | self.vis_dir = args.vis_dir 60 | 61 | # define the loss functions 62 | self._pxl_loss = loss.PixelLoss(p=2) 63 | 64 | self.VAL_ACC = np.array([], np.float32) 65 | if os.path.exists(os.path.join(self.checkpoint_dir, 'val_acc.npy')): 66 | self.VAL_ACC = np.load(os.path.join(self.checkpoint_dir, 'val_acc.npy')) 67 | 68 | # check and create model dir 69 | if os.path.exists(self.checkpoint_dir) is False: 70 | os.mkdir(self.checkpoint_dir) 71 | if os.path.exists(self.vis_dir) is False: 72 | os.mkdir(self.vis_dir) 73 | 74 | # visualize model 75 | if args.print_models: 76 | self._visualize_models() 77 | 78 | 79 | def _visualize_models(self): 80 | 81 | from torchviz import make_dot 82 | 83 | # visualize models with the package torchviz 84 | data = next(iter(self.dataloaders['train'])) 85 | y = self.net_G(data['A'].to(device)) 86 | mygraph = make_dot(y.mean(), params=dict(self.net_G.named_parameters())) 87 | mygraph.render('G') 88 | 89 | 90 | def _load_checkpoint(self): 91 | 92 | if os.path.exists(os.path.join(self.checkpoint_dir, 'last_ckpt.pt')): 93 | print('loading last checkpoint...') 94 | # load the entire checkpoint 95 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, 'last_ckpt.pt')) 96 | 97 | # update net_G states 98 | self.net_G.load_state_dict(checkpoint['model_G_state_dict']) 99 | self.optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict']) 100 | self.exp_lr_scheduler_G.load_state_dict( 101 | checkpoint['exp_lr_scheduler_G_state_dict']) 102 | self.net_G.to(device) 103 | 104 | # update some other states 105 | self.epoch_to_start = checkpoint['epoch_id'] + 1 106 | self.best_val_acc = checkpoint['best_val_acc'] 107 | self.best_epoch_id = checkpoint['best_epoch_id'] 108 | 109 | print('Epoch_to_start = %d, Historical_best_acc = %.4f (at epoch %d)' % 110 | (self.epoch_to_start, self.best_val_acc, self.best_epoch_id)) 111 | print() 112 | 113 | else: 114 | print('training from scratch...') 115 | 116 | 117 | def _save_checkpoint(self, ckpt_name): 118 | torch.save({ 119 | 'epoch_id': self.epoch_id, 120 | 'best_val_acc': self.best_val_acc, 121 | 'best_epoch_id': self.best_epoch_id, 122 | 'model_G_state_dict': self.net_G.state_dict(), 123 | 'optimizer_G_state_dict': self.optimizer_G.state_dict(), 124 | 'exp_lr_scheduler_G_state_dict': self.exp_lr_scheduler_G.state_dict() 125 | }, os.path.join(self.checkpoint_dir, ckpt_name)) 126 | 127 | 128 | def _update_lr_schedulers(self): 129 | self.exp_lr_scheduler_G.step() 130 | 131 | 132 | def _compute_acc(self): 133 | 134 | target_foreground = self.batch['B'].to(device).detach() 135 | target_alpha_map = self.batch['ALPHA'].to(device).detach() 136 | foreground = self.G_pred_foreground.detach() 137 | alpha_map = self.G_pred_alpha.detach() 138 | psnr1 = utils.cpt_batch_psnr(foreground, target_foreground, PIXEL_MAX=1.0) 139 | psnr2 = utils.cpt_batch_psnr(alpha_map, target_alpha_map, PIXEL_MAX=1.0) 140 | return (psnr1 + psnr2)/2.0 141 | 142 | 143 | def _collect_running_batch_states(self): 144 | self.running_acc.append(self._compute_acc().item()) 145 | 146 | m = len(self.dataloaders['train']) 147 | if self.is_training is False: 148 | m = len(self.dataloaders['val']) 149 | 150 | if np.mod(self.batch_id, 100) == 1: 151 | print('Is_training: %s. [%d,%d][%d,%d], G_loss: %.5f, running_acc: %.5f' 152 | % (self.is_training, self.epoch_id, self.max_num_epochs-1, self.batch_id, m, 153 | self.G_loss.item(), np.mean(self.running_acc))) 154 | 155 | if np.mod(self.batch_id, 1000) == 1: 156 | vis_pred_foreground = utils.make_numpy_grid(self.G_pred_foreground) 157 | vis_gt_foreground = utils.make_numpy_grid(self.batch['B']) 158 | vis_pred_alpha = utils.make_numpy_grid(self.G_pred_alpha) 159 | vis_gt_alpha = utils.make_numpy_grid(self.batch['ALPHA']) 160 | vis = np.concatenate([vis_pred_foreground, vis_gt_foreground, 161 | vis_pred_alpha, vis_gt_alpha], axis=0) 162 | vis = np.clip(vis, a_min=0.0, a_max=1.0) 163 | file_name = os.path.join( 164 | self.vis_dir, 'istrain_'+str(self.is_training)+'_'+ 165 | str(self.epoch_id)+'_'+str(self.batch_id)+'.jpg') 166 | plt.imsave(file_name, vis) 167 | 168 | 169 | 170 | def _collect_epoch_states(self): 171 | 172 | self.epoch_acc = np.mean(self.running_acc) 173 | print('Is_training: %s. Epoch %d / %d, epoch_acc= %.5f' % 174 | (self.is_training, self.epoch_id, self.max_num_epochs-1, self.epoch_acc)) 175 | print() 176 | 177 | 178 | def _update_checkpoints(self): 179 | 180 | # save current model 181 | self._save_checkpoint(ckpt_name='last_ckpt.pt') 182 | print('Lastest model updated. Epoch_acc=%.4f, Historical_best_acc=%.4f (at epoch %d)' 183 | % (self.epoch_acc, self.best_val_acc, self.best_epoch_id)) 184 | print() 185 | 186 | self.VAL_ACC = np.append(self.VAL_ACC, [self.epoch_acc]) 187 | np.save(os.path.join(self.checkpoint_dir, 'val_acc.npy'), self.VAL_ACC) 188 | 189 | # update the best model (based on eval acc) 190 | if self.epoch_acc > self.best_val_acc: 191 | self.best_val_acc = self.epoch_acc 192 | self.best_epoch_id = self.epoch_id 193 | self._save_checkpoint(ckpt_name='best_ckpt.pt') 194 | print('*' * 10 + 'Best model updated!') 195 | print() 196 | 197 | 198 | def _clear_cache(self): 199 | self.running_acc = [] 200 | 201 | 202 | def _forward_pass(self, batch): 203 | self.batch = batch 204 | z_in = batch['A'].to(device) 205 | self.G_pred_foreground, self.G_pred_alpha = self.net_G(z_in) 206 | 207 | 208 | def _backward_G(self): 209 | 210 | gt_foreground = self.batch['B'].to(device) 211 | gt_alpha = self.batch['ALPHA'].to(device) 212 | 213 | pixel_loss1 = self._pxl_loss(self.G_pred_foreground, gt_foreground) 214 | pixel_loss2 = self._pxl_loss(self.G_pred_alpha, gt_alpha) 215 | self.G_loss = 100 * (pixel_loss1 + pixel_loss2) / 2.0 216 | self.G_loss.backward() 217 | 218 | 219 | def train_models(self): 220 | 221 | self._load_checkpoint() 222 | 223 | # loop over the dataset multiple times 224 | for self.epoch_id in range(self.epoch_to_start, self.max_num_epochs): 225 | 226 | ################## train ################# 227 | ########################################## 228 | self._clear_cache() 229 | self.is_training = True 230 | self.net_G.train() # Set model to training mode 231 | # Iterate over data. 232 | for self.batch_id, batch in enumerate(self.dataloaders['train'], 0): 233 | self._forward_pass(batch) 234 | # update G 235 | self.optimizer_G.zero_grad() 236 | self._backward_G() 237 | self.optimizer_G.step() 238 | self._collect_running_batch_states() 239 | self._collect_epoch_states() 240 | self._update_lr_schedulers() 241 | 242 | ########### Update_Checkpoints ########### 243 | ########################################## 244 | self._update_checkpoints() 245 | 246 | -------------------------------------------------------------------------------- /image_to_paint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "image-to-paint", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "machine_shape": "hm", 11 | "include_colab_link": true 12 | }, 13 | "kernelspec": { 14 | "display_name": "Python 3", 15 | "name": "python3" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "LFLeFyHW7AfS" 33 | }, 34 | "source": [ 35 | "# Githubからコードをコピー\n", 36 | "\n" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "metadata": { 42 | "id": "h72s1Mk26j4V" 43 | }, 44 | "source": [ 45 | "# githubのコードをコピー\n", 46 | "!git clone https://github.com/cedro3/stylized-neural-painting.git\n", 47 | "%cd stylized-neural-painting" 48 | ], 49 | "execution_count": null, 50 | "outputs": [] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "source": [ 55 | "# 学習済みモデルのダウンロード" 56 | ], 57 | "metadata": { 58 | "id": "FDboP9pR76ma" 59 | } 60 | }, 61 | { 62 | "cell_type": "code", 63 | "source": [ 64 | "# 学習済みパラメータのダウンロード\n", 65 | "! pip install --upgrade gdown\n", 66 | "import gdown\n", 67 | "gdown.download('https://drive.google.com/uc?id=1sqWhgBKqaBJggl2A8sD1bLSq2_B1ScMG', './checkpoints_G_oilpaintbrush.zip', quiet=False)\n", 68 | "! unzip checkpoints_G_oilpaintbrush.zip" 69 | ], 70 | "metadata": { 71 | "id": "3EWAl4aq73JT" 72 | }, 73 | "execution_count": null, 74 | "outputs": [] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": { 79 | "id": "RwEdnkcI5fFE" 80 | }, 81 | "source": [ 82 | "# コード本体" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "metadata": { 88 | "id": "eD8de5Il7JdO" 89 | }, 90 | "source": [ 91 | "import argparse\n", 92 | "import torch\n", 93 | "torch.cuda.current_device()\n", 94 | "import torch.optim as optim\n", 95 | "from painter import *\n", 96 | "\n", 97 | "# Decide which device we want to run on\n", 98 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 99 | "\n", 100 | "# settings\n", 101 | "parser = argparse.ArgumentParser(description='STYLIZED NEURAL PAINTING')\n", 102 | "args = parser.parse_args(args=[])\n", 103 | "args.img_path = './test_images/kasumi.png' # path to input photo\n", 104 | "args.renderer = 'oilpaintbrush' # [watercolor, markerpen, oilpaintbrush, rectangle]\n", 105 | "args.canvas_color = 'black' # [black, white]\n", 106 | "args.canvas_size = 512 # size of the canvas for stroke rendering'\n", 107 | "args.max_m_strokes = 500 # max number of strokes\n", 108 | "args.max_divide = 5 # divide an image up-to max_divide x max_divide patches\n", 109 | "args.beta_L1 = 1.0 # weight for L1 loss\n", 110 | "args.with_ot_loss = False # set True for imporving the convergence by using optimal transportation loss, but will slow-down the speed\n", 111 | "args.beta_ot = 0.1 # weight for optimal transportation loss\n", 112 | "args.net_G = 'zou-fusion-net' # renderer architecture\n", 113 | "args.renderer_checkpoint_dir = './checkpoints_G_oilpaintbrush' # dir to load the pretrained neu-renderer\n", 114 | "args.lr = 0.005 # learning rate for stroke searching\n", 115 | "args.output_dir = './output' # dir to save painting results\n", 116 | "\n", 117 | "\n", 118 | "def _drawing_step_states(pt):\n", 119 | " acc = pt._compute_acc().item()\n", 120 | " print('iteration step %d, G_loss: %.5f, step_acc: %.5f, grid_scale: %d / %d, strokes: %d / %d'\n", 121 | " % (pt.step_id, pt.G_loss.item(), acc,\n", 122 | " pt.m_grid, pt.max_divide,\n", 123 | " pt.anchor_id, pt.m_strokes_per_block))\n", 124 | " vis2 = utils.patches2img(pt.G_final_pred_canvas, pt.m_grid).clip(min=0, max=1)\n", 125 | "\n", 126 | "\n", 127 | "def optimize_x(pt):\n", 128 | " pt._load_checkpoint()\n", 129 | " pt.net_G.eval()\n", 130 | " print('begin drawing...')\n", 131 | "\n", 132 | " PARAMS = np.zeros([1, 0, pt.rderr.d], np.float32)\n", 133 | "\n", 134 | " if pt.rderr.canvas_color == 'white':\n", 135 | " CANVAS_tmp = torch.ones([1, 3, 128, 128]).to(device)\n", 136 | " else:\n", 137 | " CANVAS_tmp = torch.zeros([1, 3, 128, 128]).to(device)\n", 138 | "\n", 139 | " for pt.m_grid in range(1, pt.max_divide + 1):\n", 140 | "\n", 141 | " pt.img_batch = utils.img2patches(pt.img_, pt.m_grid).to(device)\n", 142 | " pt.G_final_pred_canvas = CANVAS_tmp\n", 143 | "\n", 144 | " pt.initialize_params()\n", 145 | " pt.x_ctt.requires_grad = True\n", 146 | " pt.x_color.requires_grad = True\n", 147 | " pt.x_alpha.requires_grad = True\n", 148 | " utils.set_requires_grad(pt.net_G, False)\n", 149 | "\n", 150 | " pt.optimizer_x = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr, centered=True)\n", 151 | "\n", 152 | " pt.step_id = 0\n", 153 | " for pt.anchor_id in range(0, pt.m_strokes_per_block):\n", 154 | " pt.stroke_sampler(pt.anchor_id)\n", 155 | " iters_per_stroke = 80\n", 156 | " for i in range(iters_per_stroke):\n", 157 | " pt.G_pred_canvas = CANVAS_tmp\n", 158 | "\n", 159 | " # update x\n", 160 | " pt.optimizer_x.zero_grad()\n", 161 | "\n", 162 | " pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)\n", 163 | " pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)\n", 164 | " pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)\n", 165 | "\n", 166 | " pt._forward_pass()\n", 167 | " _drawing_step_states(pt)\n", 168 | " pt._backward_x()\n", 169 | "\n", 170 | " pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)\n", 171 | " pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)\n", 172 | " pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)\n", 173 | "\n", 174 | " pt.optimizer_x.step()\n", 175 | " pt.step_id += 1\n", 176 | "\n", 177 | " v = pt._normalize_strokes(pt.x)\n", 178 | " PARAMS = np.concatenate([PARAMS, np.reshape(v, [1, -1, pt.rderr.d])], axis=1)\n", 179 | " CANVAS_tmp = pt._render(PARAMS)[-1]\n", 180 | " CANVAS_tmp = utils.img2patches(CANVAS_tmp, pt.m_grid + 1, to_tensor=True).to(device)\n", 181 | "\n", 182 | " pt._save_stroke_params(PARAMS)\n", 183 | " pt.final_rendered_images = pt._render(PARAMS)\n", 184 | " pt._save_rendered_images()" 185 | ], 186 | "execution_count": null, 187 | "outputs": [] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": { 192 | "id": "b_QxLKdc-7nr" 193 | }, 194 | "source": [ 195 | "# レンダリング" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "metadata": { 201 | "id": "7yGUv65G-UCm" 202 | }, 203 | "source": [ 204 | " pt = ProgressivePainter(args=args)\n", 205 | " optimize_x(pt)" 206 | ], 207 | "execution_count": null, 208 | "outputs": [] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": { 213 | "id": "-p4XHLUh_ppN" 214 | }, 215 | "source": [ 216 | "# 画像表示" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "metadata": { 222 | "id": "lmYpxdcW_Azv" 223 | }, 224 | "source": [ 225 | "# show picture\n", 226 | "fig = plt.figure(figsize=(8,4))\n", 227 | "plt.subplot(1,2,1)\n", 228 | "plt.imshow(pt.img_), plt.title('input')\n", 229 | "plt.subplot(1,2,2)\n", 230 | "plt.imshow(pt.final_rendered_images[-1]), plt.title('generated')\n", 231 | "plt.show()" 232 | ], 233 | "execution_count": null, 234 | "outputs": [] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "metadata": { 239 | "id": "5LAIizo7_rxz" 240 | }, 241 | "source": [ 242 | "# make animation\n", 243 | "import matplotlib.animation as animation\n", 244 | "from IPython.display import HTML\n", 245 | "\n", 246 | "fig = plt.figure(figsize=(8,8))\n", 247 | "plt.axis('off')\n", 248 | "ims = [[plt.imshow(img, animated=True)] for img in pt.final_rendered_images[::10]]\n", 249 | "ani = animation.ArtistAnimation(fig, ims, interval=100)\n", 250 | "\n", 251 | "HTML(ani.to_jshtml())" 252 | ], 253 | "execution_count": null, 254 | "outputs": [] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "metadata": { 259 | "id": "AlDC5MmeYuYL" 260 | }, 261 | "source": [ 262 | "# save animation\n", 263 | "ani.save('anime.mp4', writer='ffmpeg')" 264 | ], 265 | "execution_count": null, 266 | "outputs": [] 267 | } 268 | ] 269 | } -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import cv2 4 | #from skimage.measure import compare_ssim as sk_cpt_ssim 5 | from skimage.metrics import structural_similarity as sk_cpt_ssim 6 | 7 | 8 | import os 9 | import glob 10 | import random 11 | 12 | import torch 13 | torch.cuda.current_device() 14 | import torchvision.transforms.functional as TF 15 | from torch.utils.data import Dataset, DataLoader, Subset 16 | from torchvision import transforms, utils 17 | import renderer 18 | 19 | 20 | 21 | M_RENDERING_SAMPLES_PER_EPOCH = 50000 22 | 23 | class PairedDataAugmentation: 24 | 25 | def __init__( 26 | self, 27 | img_size, 28 | with_random_hflip=False, 29 | with_random_vflip=False, 30 | with_random_rot90=False, 31 | with_random_rot180=False, 32 | with_random_rot270=False, 33 | with_random_crop=False, 34 | with_random_patch=False 35 | ): 36 | self.img_size = img_size 37 | self.with_random_hflip = with_random_hflip 38 | self.with_random_vflip = with_random_vflip 39 | self.with_random_rot90 = with_random_rot90 40 | self.with_random_rot180 = with_random_rot180 41 | self.with_random_rot270 = with_random_rot270 42 | self.with_random_crop = with_random_crop 43 | self.with_random_patch = with_random_patch 44 | 45 | def transform(self, img1, img2): 46 | 47 | # resize image and covert to tensor 48 | img1 = TF.to_pil_image(img1) 49 | img1 = TF.resize(img1, [self.img_size, self.img_size], interpolation=3) 50 | img2 = TF.to_pil_image(img2) 51 | img2 = TF.resize(img2, [self.img_size, self.img_size], interpolation=3) 52 | 53 | if self.with_random_hflip and random.random() > 0.5: 54 | img1 = TF.hflip(img1) 55 | img2 = TF.hflip(img2) 56 | 57 | if self.with_random_vflip and random.random() > 0.5: 58 | img1 = TF.vflip(img1) 59 | img2 = TF.vflip(img2) 60 | 61 | if self.with_random_rot90 and random.random() > 0.5: 62 | img1 = TF.rotate(img1, 90) 63 | img2 = TF.rotate(img2, 90) 64 | 65 | if self.with_random_rot180 and random.random() > 0.5: 66 | img1 = TF.rotate(img1, 180) 67 | img2 = TF.rotate(img2, 180) 68 | 69 | if self.with_random_rot270 and random.random() > 0.5: 70 | img1 = TF.rotate(img1, 270) 71 | img2 = TF.rotate(img2, 270) 72 | 73 | if self.with_random_crop and random.random() > 0.5: 74 | i, j, h, w = transforms.RandomResizedCrop(size=self.img_size). \ 75 | get_params(img=img1, scale=(0.5, 1.0), ratio=(0.9, 1.1)) 76 | img1 = TF.resized_crop( 77 | img1, i, j, h, w, size=(self.img_size, self.img_size)) 78 | img2 = TF.resized_crop( 79 | img2, i, j, h, w, size=(self.img_size, self.img_size)) 80 | 81 | if self.with_random_patch: 82 | i, j, h, w = transforms.RandomResizedCrop(size=self.img_size). \ 83 | get_params(img=img1, scale=(1/16.0, 1/9.0), ratio=(0.9, 1.1)) 84 | img1 = TF.resized_crop( 85 | img1, i, j, h, w, size=(self.img_size, self.img_size)) 86 | img2 = TF.resized_crop( 87 | img2, i, j, h, w, size=(self.img_size, self.img_size)) 88 | 89 | # to tensor 90 | img1 = TF.to_tensor(img1) 91 | img2 = TF.to_tensor(img2) 92 | 93 | return img1, img2 94 | 95 | 96 | 97 | class StrokeDataset(Dataset): 98 | 99 | def __init__(self, renderer_type, is_train=True): 100 | self.rderr = renderer.Renderer(renderer=renderer_type, CANVAS_WIDTH=128, train=True) 101 | self.is_train = is_train 102 | 103 | def __len__(self): 104 | if self.is_train: 105 | return M_RENDERING_SAMPLES_PER_EPOCH 106 | else: 107 | return int(M_RENDERING_SAMPLES_PER_EPOCH / 20) 108 | 109 | def __getitem__(self, idx): 110 | 111 | self.rderr.foreground = None 112 | self.rderr.stroke_alpha_map = None 113 | 114 | self.rderr.random_stroke_params() 115 | self.rderr.draw_stroke() 116 | 117 | # to tensor 118 | params = torch.tensor(np.array(self.rderr.stroke_params, dtype=np.float32)) 119 | params = torch.reshape(params, [-1, 1, 1]) 120 | foreground = TF.to_tensor(np.array(self.rderr.foreground, dtype=np.float32)) 121 | stroke_alpha_map = TF.to_tensor(np.array(self.rderr.stroke_alpha_map, dtype=np.float32)) 122 | 123 | data = {'A': params, 'B': foreground, 'ALPHA': stroke_alpha_map} 124 | 125 | return data 126 | 127 | 128 | 129 | 130 | def get_renderer_loaders(args): 131 | 132 | training_set = StrokeDataset(renderer_type=args.renderer, is_train=True) 133 | val_set = StrokeDataset(renderer_type=args.renderer, is_train=False) 134 | 135 | datasets = {'train': training_set, 'val': val_set} 136 | dataloaders = {x: DataLoader(datasets[x], batch_size=args.batch_size, 137 | shuffle=True, num_workers=4) 138 | for x in ['train', 'val']} 139 | 140 | return dataloaders 141 | 142 | 143 | 144 | def set_requires_grad(nets, requires_grad=False): 145 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 146 | Parameters: 147 | nets (network list) -- a list of networks 148 | requires_grad (bool) -- whether the networks require gradients or not 149 | """ 150 | if not isinstance(nets, list): 151 | nets = [nets] 152 | for net in nets: 153 | if net is not None: 154 | for param in net.parameters(): 155 | param.requires_grad = requires_grad 156 | 157 | 158 | 159 | def make_numpy_grid(tensor_data): 160 | tensor_data = tensor_data.detach() 161 | vis = utils.make_grid(tensor_data) 162 | vis = np.array(vis.cpu()).transpose((1,2,0)) 163 | if vis.shape[2] == 1: 164 | vis = np.stack([vis, vis, vis], axis=-1) 165 | return vis.clip(min=0, max=1) 166 | 167 | 168 | 169 | def tensor2img(tensor_data): 170 | if tensor_data.shape[0] > 1: 171 | raise NotImplementedError('batch size > 1, please use make_numpy_grid') 172 | tensor_data = tensor_data.detach()[0, :] 173 | img = np.array(tensor_data.cpu()).transpose((1, 2, 0)) 174 | if img.shape[2] == 1: 175 | img = np.stack([img, img, img], axis=-1) 176 | return img.clip(min=0, max=1) 177 | 178 | 179 | 180 | def cpt_ssim(img, img_gt, normalize=False): 181 | 182 | if normalize: 183 | img = (img - img.min()) / (img.max() - img.min() + 1e-9) 184 | img_gt = (img_gt - img_gt.min()) / (img_gt.max() - img_gt.min() + 1e-9) 185 | 186 | SSIM = sk_cpt_ssim(img, img_gt, data_range=img_gt.max() - img_gt.min()) 187 | 188 | return SSIM 189 | 190 | 191 | def cpt_psnr(img, img_gt, PIXEL_MAX=1.0, normalize=False): 192 | 193 | if normalize: 194 | img = (img - img.min()) / (img.max() - img.min() + 1e-9) 195 | img_gt = (img_gt - img_gt.min()) / (img_gt.max() - img_gt.min() + 1e-9) 196 | 197 | mse = np.mean((img - img_gt) ** 2) 198 | psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse)) 199 | 200 | return psnr 201 | 202 | 203 | def cpt_cos_similarity(img, img_gt, normalize=False): 204 | 205 | if normalize: 206 | img = (img - img.min()) / (img.max() - img.min() + 1e-9) 207 | img_gt = (img_gt - img_gt.min()) / (img_gt.max() - img_gt.min() + 1e-9) 208 | 209 | cos_dist = np.sum(img*img_gt) / np.sqrt(np.sum(img**2)*np.sum(img_gt**2) + 1e-9) 210 | 211 | return cos_dist 212 | 213 | 214 | def cpt_batch_psnr(img, img_gt, PIXEL_MAX): 215 | mse = torch.mean((img - img_gt) ** 2) 216 | psnr = 20 * torch.log10(PIXEL_MAX / torch.sqrt(mse)) 217 | return psnr 218 | 219 | 220 | def rotate_pt(pt, rotate_center, theta, return_int=True): 221 | 222 | # theta in [0, pi] 223 | x, y = pt[0], pt[1] 224 | xc, yc = rotate_center[0], rotate_center[1] 225 | 226 | x_ = (x-xc) * np.cos(theta) + (y-yc) * np.sin(theta) + xc 227 | y_ = -1 * (x-xc) * np.sin(theta) + (y-yc) * np.cos(theta) + yc 228 | 229 | if return_int: 230 | x_, y_ = int(x_), int(y_) 231 | 232 | pt_ = (x_, y_) 233 | 234 | return pt_ 235 | 236 | 237 | def img2patches(img, m_grid, to_tensor=True): 238 | # input img: h, w, 3 (np.float32) 239 | # output patches: N, 3, 128, 128 (tensor) 240 | 241 | img = cv2.resize(img, (m_grid * 128, m_grid * 128)) 242 | img_batch = np.zeros([m_grid ** 2, 3, 128, 128]) 243 | for y_id in range(m_grid): 244 | for x_id in range(m_grid): 245 | patch = img[y_id * 128:y_id * 128 + 128, 246 | x_id * 128:x_id * 128 + 128, :].transpose([2, 0, 1]) 247 | img_batch[y_id * m_grid + x_id, :, :, :] = patch 248 | 249 | if to_tensor: 250 | img_batch = torch.tensor(img_batch) 251 | 252 | return img_batch 253 | 254 | 255 | 256 | def patches2img(img_batch, m_grid, to_numpy=True): 257 | # input patches: m_grid**2, 3, 128, 128 (tensor) 258 | # output img: 128*m_grid, 128*m_grid, 3 (np.float32) 259 | 260 | img = torch.zeros([128*m_grid, 128*m_grid, 3]) 261 | 262 | for y_id in range(m_grid): 263 | for x_id in range(m_grid): 264 | patch = img_batch[y_id * m_grid + x_id, :, :, :] 265 | img[y_id * 128:y_id * 128 + 128, x_id * 128:x_id * 128 + 128, :] \ 266 | = patch.permute([1, 2, 0]) 267 | if to_numpy: 268 | img = img.detach().numpy() 269 | else: 270 | img = img.permute([2,0,1]).unsqueeze(0) 271 | 272 | return img 273 | 274 | 275 | 276 | 277 | def create_transformed_brush(brush, canvas_w, canvas_h, 278 | x0, y0, w, h, theta, R0, G0, B0, R2, G2, B2): 279 | 280 | brush_alpha = np.stack([brush, brush, brush], axis=-1) 281 | brush_alpha = (brush_alpha > 0).astype(np.float32) 282 | brush_alpha = (brush_alpha*255).astype(np.uint8) 283 | colormap = np.zeros([brush.shape[0], brush.shape[1], 3], np.float32) 284 | for ii in range(brush.shape[0]): 285 | t = ii / brush.shape[0] 286 | this_color = [(1 - t) * R0 + t * R2, 287 | (1 - t) * G0 + t * G2, 288 | (1 - t) * B0 + t * B2] 289 | colormap[ii, :, :] = np.expand_dims(this_color, axis=0) 290 | 291 | brush = np.expand_dims(brush, axis=-1).astype(np.float32) / 255. 292 | brush = (brush * colormap * 255).astype(np.uint8) 293 | # plt.imshow(brush), plt.show() 294 | 295 | M1 = build_transformation_matrix([-brush.shape[1]/2, -brush.shape[0]/2, 0]) 296 | M2 = build_scale_matrix(sx=w/brush.shape[1], sy=h/brush.shape[0]) 297 | M3 = build_transformation_matrix([0,0,theta]) 298 | M4 = build_transformation_matrix([x0, y0, 0]) 299 | 300 | M = update_transformation_matrix(M1, M2) 301 | M = update_transformation_matrix(M, M3) 302 | M = update_transformation_matrix(M, M4) 303 | 304 | brush = cv2.warpAffine( 305 | brush, M, (canvas_w, canvas_h), 306 | borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_AREA) 307 | brush_alpha = cv2.warpAffine( 308 | brush_alpha, M, (canvas_w, canvas_h), 309 | borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_AREA) 310 | 311 | return brush, brush_alpha 312 | 313 | 314 | def build_scale_matrix(sx, sy): 315 | transform_matrix = np.zeros((2, 3)) 316 | transform_matrix[0, 0] = sx 317 | transform_matrix[1, 1] = sy 318 | return transform_matrix 319 | 320 | 321 | def update_transformation_matrix(M, m): 322 | 323 | # extend M and m to 3x3 by adding an [0,0,1] to their 3rd row 324 | M_ = np.concatenate([M, np.zeros([1,3])], axis=0) 325 | M_[-1, -1] = 1 326 | m_ = np.concatenate([m, np.zeros([1,3])], axis=0) 327 | m_[-1, -1] = 1 328 | 329 | M_new = np.matmul(m_, M_) 330 | return M_new[0:2, :] 331 | # 332 | 333 | def build_transformation_matrix(transform): 334 | """Convert transform list to transformation matrix 335 | 336 | :param transform: transform list as [dx, dy, da] 337 | :return: transform matrix as 2d (2, 3) numpy array 338 | """ 339 | transform_matrix = np.zeros((2, 3)) 340 | 341 | transform_matrix[0, 0] = np.cos(transform[2]) 342 | transform_matrix[0, 1] = -np.sin(transform[2]) 343 | transform_matrix[1, 0] = np.sin(transform[2]) 344 | transform_matrix[1, 1] = np.cos(transform[2]) 345 | transform_matrix[0, 2] = transform[0] 346 | transform_matrix[1, 2] = transform[1] 347 | 348 | return transform_matrix 349 | 350 | 351 | 352 | -------------------------------------------------------------------------------- /renderer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import random 4 | import utils 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def _random_floats(low, high, size): 10 | return [random.uniform(low, high) for _ in range(size)] 11 | 12 | 13 | def _normalize(x, width): 14 | return (int)(x * (width - 1) + 0.5) 15 | 16 | 17 | class Renderer(): 18 | 19 | def __init__(self, renderer='bezier', CANVAS_WIDTH=128, train=False, canvas_color='black'): 20 | 21 | self.CANVAS_WIDTH = CANVAS_WIDTH 22 | self.renderer = renderer 23 | self.stroke_params = None 24 | self.canvas_color = canvas_color 25 | 26 | self.canvas = None 27 | self.create_empty_canvas() 28 | 29 | self.train = train 30 | 31 | if self.renderer in ['markerpen']: 32 | self.d = 12 # x0, y0, x1, y1, x2, y2, radius0, radius2, R, G, B, A 33 | self.d_shape = 8 34 | self.d_color = 3 35 | self.d_alpha = 1 36 | elif self.renderer in ['watercolor']: 37 | self.d = 15 # x0, y0, x1, y1, x2, y2, radius0, radius2, R0, G0, B0, R2, G2, B2, A 38 | self.d_shape = 8 39 | self.d_color = 6 40 | self.d_alpha = 1 41 | elif self.renderer in ['oilpaintbrush']: 42 | self.d = 12 # xc, yc, w, h, theta, R0, G0, B0, R2, G2, B2, A 43 | self.d_shape = 5 44 | self.d_color = 6 45 | self.d_alpha = 1 46 | self.brush_small_vertical = cv2.imread( 47 | r'./brushes/brush_fromweb2_small_vertical.png', cv2.IMREAD_GRAYSCALE) 48 | self.brush_small_horizontal = cv2.imread( 49 | r'./brushes/brush_fromweb2_small_horizontal.png', cv2.IMREAD_GRAYSCALE) 50 | self.brush_large_vertical = cv2.imread( 51 | r'./brushes/brush_fromweb2_large_vertical.png', cv2.IMREAD_GRAYSCALE) 52 | self.brush_large_horizontal = cv2.imread( 53 | r'./brushes/brush_fromweb2_large_horizontal.png', cv2.IMREAD_GRAYSCALE) 54 | elif self.renderer in ['rectangle']: 55 | self.d = 9 # xc, yc, w, h, theta, R, G, B, A 56 | self.d_shape = 5 57 | self.d_color = 3 58 | self.d_alpha = 1 59 | else: 60 | raise NotImplementedError( 61 | 'Wrong renderer name %s (choose one from [watercolor, markerpen, oilpaintbrush, rectangle] ...)' 62 | % self.renderer) 63 | 64 | def create_empty_canvas(self): 65 | if self.canvas_color == 'white': 66 | self.canvas = np.ones( 67 | [self.CANVAS_WIDTH, self.CANVAS_WIDTH, 3]).astype('float32') 68 | else: 69 | self.canvas = np.zeros( 70 | [self.CANVAS_WIDTH, self.CANVAS_WIDTH, 3]).astype('float32') 71 | 72 | 73 | def random_stroke_params(self): 74 | self.stroke_params = np.array(_random_floats(0, 1, self.d), dtype=np.float32) 75 | 76 | def random_stroke_params_sampler(self, err_map, img): 77 | 78 | map_h, map_w, c = img.shape 79 | 80 | err_map = cv2.resize(err_map, (self.CANVAS_WIDTH, self.CANVAS_WIDTH)) 81 | err_map[err_map < 0] = 0 82 | if np.all((err_map == 0)): 83 | err_map = np.ones_like(err_map) 84 | err_map = err_map / (np.sum(err_map) + 1e-99) 85 | 86 | index = np.random.choice(range(err_map.size), size=1, p=err_map.ravel())[0] 87 | 88 | cy = (index // self.CANVAS_WIDTH) / self.CANVAS_WIDTH 89 | cx = (index % self.CANVAS_WIDTH) / self.CANVAS_WIDTH 90 | 91 | if self.renderer in ['markerpen']: 92 | # x0, y0, x1, y1, x2, y2, radius0, radius2, R, G, B, A 93 | x0, y0, x1, y1, x2, y2 = cx, cy, cx, cy, cx, cy 94 | x = [x0, y0, x1, y1, x2, y2] 95 | r = _random_floats(0.1, 0.5, 2) 96 | color = img[int(cy*map_h), int(cx*map_w), :].tolist() 97 | alpha = _random_floats(0.8, 0.98, 1) 98 | self.stroke_params = np.array(x + r + color + alpha, dtype=np.float32) 99 | elif self.renderer in ['watercolor']: 100 | # x0, y0, x1, y1, x2, y2, radius0, radius2, R0, G0, B0, R2, G2, B2, A 101 | x0, y0, x1, y1, x2, y2 = cx, cy, cx, cy, cx, cy 102 | x = [x0, y0, x1, y1, x2, y2] 103 | r = _random_floats(0.1, 0.5, 2) 104 | color = img[int(cy*map_h), int(cx*map_w), :].tolist() 105 | color = color + color 106 | alpha = _random_floats(0.98, 1.0, 1) 107 | self.stroke_params = np.array(x + r + color + alpha, dtype=np.float32) 108 | elif self.renderer in ['oilpaintbrush']: 109 | # xc, yc, w, h, theta, R0, G0, B0, R2, G2, B2, A 110 | x = [cx, cy] 111 | wh = _random_floats(0.1, 0.5, 2) 112 | theta = _random_floats(0, 1, 1) 113 | color = img[int(cy*map_h), int(cx*map_w), :].tolist() 114 | color = color + color 115 | alpha = _random_floats(0.98, 1.0, 1) 116 | self.stroke_params = np.array(x + wh + theta + color + alpha, dtype=np.float32) 117 | elif self.renderer in ['rectangle']: 118 | # xc, yc, w, h, theta, R, G, B, A 119 | x = [cx, cy] 120 | wh = _random_floats(0.1, 0.5, 2) 121 | theta = [0] 122 | color = img[int(cy*map_h), int(cx*map_w), :].tolist() 123 | alpha = _random_floats(0.8, 0.98, 1) 124 | self.stroke_params = np.array(x + wh + theta + color + alpha, dtype=np.float32) 125 | 126 | 127 | def check_stroke(self): 128 | r_ = 1.0 129 | if self.renderer in ['markerpen', 'watercolor']: 130 | r_ = max(self.stroke_params[6], self.stroke_params[7]) 131 | elif self.renderer in ['oilpaintbrush']: 132 | r_ = max(self.stroke_params[2], self.stroke_params[3]) 133 | elif self.renderer in ['rectangle']: 134 | r_ = max(self.stroke_params[2], self.stroke_params[3]) 135 | if r_ > 3/128.: 136 | return True 137 | else: 138 | return False 139 | 140 | 141 | def draw_stroke(self): 142 | 143 | if self.renderer == 'watercolor': 144 | return self._draw_watercolor() 145 | elif self.renderer == 'markerpen': 146 | return self._draw_markerpen() 147 | elif self.renderer == 'oilpaintbrush': 148 | return self._draw_oilpaintbrush() 149 | elif self.renderer == 'rectangle': 150 | return self._draw_rectangle() 151 | 152 | 153 | def _draw_watercolor(self): 154 | 155 | # x0, y0, x1, y1, x2, y2, radius0, radius2, R0, G0, B0, R2, G2, B2, A 156 | x0, y0, x1, y1, x2, y2, radius0, radius2 = self.stroke_params[0:8] 157 | R0, G0, B0, R2, G2, B2, ALPHA = self.stroke_params[8:] 158 | x1 = x0 + (x2 - x0) * x1 159 | y1 = y0 + (y2 - y0) * y1 160 | x0 = _normalize(x0, self.CANVAS_WIDTH) 161 | x1 = _normalize(x1, self.CANVAS_WIDTH) 162 | x2 = _normalize(x2, self.CANVAS_WIDTH) 163 | y0 = _normalize(y0, self.CANVAS_WIDTH) 164 | y1 = _normalize(y1, self.CANVAS_WIDTH) 165 | y2 = _normalize(y2, self.CANVAS_WIDTH) 166 | radius0 = (int)(1 + radius0 * self.CANVAS_WIDTH // 4) 167 | radius2 = (int)(1 + radius2 * self.CANVAS_WIDTH // 4) 168 | 169 | stroke_alpha_value = self.stroke_params[-1] 170 | 171 | self.foreground = np.zeros_like( 172 | self.canvas, dtype=np.uint8) # uint8 for antialiasing 173 | self.stroke_alpha_map = np.zeros_like( 174 | self.canvas, dtype=np.uint8) # uint8 for antialiasing 175 | 176 | alpha = (stroke_alpha_value * 255, 177 | stroke_alpha_value * 255, 178 | stroke_alpha_value * 255) 179 | tmp = 1. / 100 180 | for i in range(100): 181 | t = i * tmp 182 | x = (int)((1 - t) * (1 - t) * x0 + 2 * t * (1 - t) * x1 + t * t * x2) 183 | y = (int)((1 - t) * (1 - t) * y0 + 2 * t * (1 - t) * y1 + t * t * y2) 184 | radius = (int)((1 - t) * radius0 + t * radius2) 185 | color = ((1-t)*R0*255 + t*R2*255, 186 | (1-t)*G0*255 + t*G2*255, 187 | (1-t)*B0*255 + t*B2*255) 188 | cv2.circle(self.foreground, (x, y), radius, color, -1, lineType=cv2.LINE_AA) 189 | cv2.circle(self.stroke_alpha_map, (x, y), radius, alpha, -1, lineType=cv2.LINE_AA) 190 | 191 | if not self.train: 192 | self.foreground = cv2.dilate(self.foreground, np.ones([2, 2])) 193 | self.stroke_alpha_map = cv2.erode(self.stroke_alpha_map, np.ones([2, 2])) 194 | 195 | self.foreground = np.array(self.foreground, dtype=np.float32)/255. 196 | self.stroke_alpha_map = np.array(self.stroke_alpha_map, dtype=np.float32)/255. 197 | self.canvas = self._update_canvas() 198 | 199 | 200 | def _draw_rectangle(self): 201 | 202 | # xc, yc, w, h, theta, R, G, B, A 203 | x0, y0, w, h, theta = self.stroke_params[0:5] 204 | R0, G0, B0, ALPHA = self.stroke_params[5:] 205 | x0 = _normalize(x0, self.CANVAS_WIDTH) 206 | y0 = _normalize(y0, self.CANVAS_WIDTH) 207 | w = (int)(1 + w * self.CANVAS_WIDTH // 4) 208 | h = (int)(1 + h * self.CANVAS_WIDTH // 4) 209 | theta = np.pi*theta 210 | stroke_alpha_value = self.stroke_params[-1] 211 | 212 | self.foreground = np.zeros_like( 213 | self.canvas, dtype=np.uint8) # uint8 for antialiasing 214 | self.stroke_alpha_map = np.zeros_like( 215 | self.canvas, dtype=np.uint8) # uint8 for antialiasing 216 | 217 | color = (R0 * 255, G0 * 255, B0 * 255) 218 | alpha = (stroke_alpha_value * 255, 219 | stroke_alpha_value * 255, 220 | stroke_alpha_value * 255) 221 | ptc = (x0, y0) 222 | pt0 = utils.rotate_pt(pt=(x0 - w, y0 - h), rotate_center=ptc, theta=theta) 223 | pt1 = utils.rotate_pt(pt=(x0 + w, y0 - h), rotate_center=ptc, theta=theta) 224 | pt2 = utils.rotate_pt(pt=(x0 + w, y0 + h), rotate_center=ptc, theta=theta) 225 | pt3 = utils.rotate_pt(pt=(x0 - w, y0 + h), rotate_center=ptc, theta=theta) 226 | 227 | ppt = np.array([pt0, pt1, pt2, pt3], np.int32) 228 | ppt = ppt.reshape((-1, 1, 2)) 229 | cv2.fillPoly(self.foreground, [ppt], color, lineType=cv2.LINE_AA) 230 | cv2.fillPoly(self.stroke_alpha_map, [ppt], alpha, lineType=cv2.LINE_AA) 231 | 232 | if not self.train: 233 | self.foreground = cv2.dilate(self.foreground, np.ones([2, 2])) 234 | self.stroke_alpha_map = cv2.erode(self.stroke_alpha_map, np.ones([2, 2])) 235 | 236 | self.foreground = np.array(self.foreground, dtype=np.float32)/255. 237 | self.stroke_alpha_map = np.array(self.stroke_alpha_map, dtype=np.float32)/255. 238 | self.canvas = self._update_canvas() 239 | 240 | 241 | def _draw_markerpen(self): 242 | 243 | # x0, y0, x1, y1, x2, y2, radius0, radius2, R, G, B, A 244 | x0, y0, x1, y1, x2, y2, radius, _ = self.stroke_params[0:8] 245 | R0, G0, B0, ALPHA = self.stroke_params[8:] 246 | x1 = x0 + (x2 - x0) * x1 247 | y1 = y0 + (y2 - y0) * y1 248 | x0 = _normalize(x0, self.CANVAS_WIDTH) 249 | x1 = _normalize(x1, self.CANVAS_WIDTH) 250 | x2 = _normalize(x2, self.CANVAS_WIDTH) 251 | y0 = _normalize(y0, self.CANVAS_WIDTH) 252 | y1 = _normalize(y1, self.CANVAS_WIDTH) 253 | y2 = _normalize(y2, self.CANVAS_WIDTH) 254 | radius = (int)(1 + radius * self.CANVAS_WIDTH // 4) 255 | 256 | stroke_alpha_value = self.stroke_params[-1] 257 | 258 | self.foreground = np.zeros_like( 259 | self.canvas, dtype=np.uint8) # uint8 for antialiasing 260 | self.stroke_alpha_map = np.zeros_like( 261 | self.canvas, dtype=np.uint8) # uint8 for antialiasing 262 | 263 | if abs(x0-x2) + abs(y0-y2) < 4: # too small, dont draw 264 | self.foreground = np.array(self.foreground, dtype=np.float32) / 255. 265 | self.stroke_alpha_map = np.array(self.stroke_alpha_map, dtype=np.float32) / 255. 266 | self.canvas = self._update_canvas() 267 | return 268 | 269 | color = (R0 * 255, G0 * 255, B0 * 255) 270 | alpha = (stroke_alpha_value * 255, 271 | stroke_alpha_value * 255, 272 | stroke_alpha_value * 255) 273 | tmp = 1. / 100 274 | for i in range(100): 275 | t = i * tmp 276 | x = (1 - t) * (1 - t) * x0 + 2 * t * (1 - t) * x1 + t * t * x2 277 | y = (1 - t) * (1 - t) * y0 + 2 * t * (1 - t) * y1 + t * t * y2 278 | 279 | ptc = (x, y) 280 | dx = 2 * (t - 1) * x0 + 2 * (1 - 2 * t) * x1 + 2 * t * x2 281 | dy = 2 * (t - 1) * y0 + 2 * (1 - 2 * t) * y1 + 2 * t * y2 282 | 283 | theta = np.arctan2(dx, dy) - np.pi/2 284 | pt0 = utils.rotate_pt(pt=(x - radius, y - radius), rotate_center=ptc, theta=theta) 285 | pt1 = utils.rotate_pt(pt=(x + radius, y - radius), rotate_center=ptc, theta=theta) 286 | pt2 = utils.rotate_pt(pt=(x + radius, y + radius), rotate_center=ptc, theta=theta) 287 | pt3 = utils.rotate_pt(pt=(x - radius, y + radius), rotate_center=ptc, theta=theta) 288 | ppt = np.array([pt0, pt1, pt2, pt3], np.int32) 289 | ppt = ppt.reshape((-1, 1, 2)) 290 | cv2.fillPoly(self.foreground, [ppt], color, lineType=cv2.LINE_AA) 291 | cv2.fillPoly(self.stroke_alpha_map, [ppt], alpha, lineType=cv2.LINE_AA) 292 | 293 | if not self.train: 294 | self.foreground = cv2.dilate(self.foreground, np.ones([2, 2])) 295 | self.stroke_alpha_map = cv2.erode(self.stroke_alpha_map, np.ones([2, 2])) 296 | 297 | self.foreground = np.array(self.foreground, dtype=np.float32)/255. 298 | self.stroke_alpha_map = np.array(self.stroke_alpha_map, dtype=np.float32)/255. 299 | self.canvas = self._update_canvas() 300 | 301 | 302 | 303 | def _draw_oilpaintbrush(self): 304 | 305 | # xc, yc, w, h, theta, R0, G0, B0, R2, G2, B2, A 306 | x0, y0, w, h, theta = self.stroke_params[0:5] 307 | R0, G0, B0, R2, G2, B2, ALPHA = self.stroke_params[5:] 308 | x0 = _normalize(x0, self.CANVAS_WIDTH) 309 | y0 = _normalize(y0, self.CANVAS_WIDTH) 310 | w = (int)(1 + w * self.CANVAS_WIDTH) 311 | h = (int)(1 + h * self.CANVAS_WIDTH) 312 | theta = np.pi*theta 313 | 314 | if w * h / (self.CANVAS_WIDTH**2) > 0.1: 315 | if h > w: 316 | brush = self.brush_large_vertical 317 | else: 318 | brush = self.brush_large_horizontal 319 | else: 320 | if h > w: 321 | brush = self.brush_small_vertical 322 | else: 323 | brush = self.brush_small_horizontal 324 | self.foreground, self.stroke_alpha_map = utils.create_transformed_brush( 325 | brush, self.CANVAS_WIDTH, self.CANVAS_WIDTH, 326 | x0, y0, w, h, theta, R0, G0, B0, R2, G2, B2) 327 | 328 | if not self.train: 329 | self.foreground = cv2.dilate(self.foreground, np.ones([2, 2])) 330 | self.stroke_alpha_map = cv2.erode(self.stroke_alpha_map, np.ones([2, 2])) 331 | 332 | self.foreground = np.array(self.foreground, dtype=np.float32)/255. 333 | self.stroke_alpha_map = np.array(self.stroke_alpha_map, dtype=np.float32)/255. 334 | self.canvas = self._update_canvas() 335 | 336 | 337 | def _update_canvas(self): 338 | return self.foreground * self.stroke_alpha_map + \ 339 | self.canvas * (1 - self.stroke_alpha_map) 340 | -------------------------------------------------------------------------------- /painter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | 5 | import utils 6 | import loss 7 | from networks import * 8 | import morphology 9 | 10 | import renderer 11 | 12 | import torch 13 | torch.cuda.current_device() 14 | 15 | # Decide which device we want to run on 16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | 18 | 19 | class PainterBase(): 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | self.rderr = renderer.Renderer(renderer=args.renderer, 24 | CANVAS_WIDTH=args.canvas_size, canvas_color=args.canvas_color) 25 | 26 | # define G 27 | self.net_G = define_G(rdrr=self.rderr, netG=args.net_G).to(device) 28 | 29 | # define some other vars to record the training states 30 | self.x_ctt = None 31 | self.x_color = None 32 | self.x_alpha = None 33 | 34 | self.G_pred_foreground = None 35 | self.G_pred_alpha = None 36 | self.G_final_pred_canvas = torch.zeros([1, 3, 128, 128]).to(device) 37 | 38 | self.G_loss = torch.tensor(0.0) 39 | self.step_id = 0 40 | self.anchor_id = 0 41 | self.renderer_checkpoint_dir = args.renderer_checkpoint_dir 42 | self.output_dir = args.output_dir 43 | self.lr = args.lr 44 | 45 | # define the loss functions 46 | self._pxl_loss = loss.PixelLoss(p=1) 47 | self._sinkhorn_loss = loss.SinkhornLoss(epsilon=0.01, niter=5, normalize=False) 48 | 49 | # some other vars to be initialized in child classes 50 | self.img_path = None 51 | self.img_batch = None 52 | self.img_ = None 53 | self.final_rendered_images = None 54 | self.m_grid = None 55 | self.m_strokes_per_block = None 56 | 57 | if os.path.exists(self.output_dir) is False: 58 | os.mkdir(self.output_dir) 59 | 60 | def _load_checkpoint(self): 61 | 62 | # load renderer G 63 | if os.path.exists((os.path.join( 64 | self.renderer_checkpoint_dir, 'last_ckpt.pt'))): 65 | print('loading renderer from pre-trained checkpoint...') 66 | # load the entire checkpoint 67 | checkpoint = torch.load(os.path.join( 68 | self.renderer_checkpoint_dir, 'last_ckpt.pt')) 69 | # update net_G states 70 | self.net_G.load_state_dict(checkpoint['model_G_state_dict']) 71 | self.net_G.to(device) 72 | self.net_G.eval() 73 | else: 74 | print('pre-trained renderer does not exist...') 75 | exit() 76 | 77 | 78 | def _compute_acc(self): 79 | 80 | target = self.img_batch.detach() 81 | canvas = self.G_pred_canvas.detach() 82 | psnr = utils.cpt_batch_psnr(canvas, target, PIXEL_MAX=1.0) 83 | 84 | return psnr 85 | 86 | def _save_stroke_params(self, v): 87 | 88 | d_shape = self.rderr.d_shape 89 | d_color = self.rderr.d_color 90 | d_alpha = self.rderr.d_alpha 91 | 92 | x_ctt = v[:, :, 0:d_shape] 93 | x_color = v[:, :, d_shape:d_shape+d_color] 94 | x_alpha = v[:, :, d_shape+d_color:d_shape+d_color+d_alpha] 95 | print('saving stroke parameters...') 96 | file_name = os.path.join( 97 | self.output_dir, self.img_path.split('/')[-1][:-4]) 98 | np.savez(file_name + '_strokes.npz', x_ctt=x_ctt, 99 | x_color=x_color, x_alpha=x_alpha) 100 | 101 | 102 | def _save_rendered_images(self): 103 | print('saving rendered images...') 104 | file_name = os.path.join( 105 | self.output_dir, self.img_path.split('/')[-1][:-4]) 106 | plt.imsave(file_name+'_input.png', self.img_) 107 | for i in range(len(self.final_rendered_images)): 108 | plt.imsave(file_name + '_rendered_stroke_' + str((i+1)).zfill(4) + 109 | '.png', self.final_rendered_images[i]) 110 | 111 | 112 | def _normalize_strokes(self, v): 113 | 114 | v = np.array(v.detach().cpu()) 115 | 116 | if self.rderr.renderer in ['watercolor', 'markerpen']: 117 | # x0, y0, x1, y1, x2, y2, radius0, radius2, ... 118 | xs = np.array([0, 4]) 119 | ys = np.array([1, 5]) 120 | rs = np.array([6, 7]) 121 | elif self.rderr.renderer in ['oilpaintbrush', 'rectangle']: 122 | # xc, yc, w, h, theta ... 123 | xs = np.array([0]) 124 | ys = np.array([1]) 125 | rs = np.array([2, 3]) 126 | else: 127 | raise NotImplementedError('renderer [%s] is not implemented' % self.rderr.renderer) 128 | 129 | for y_id in range(self.m_grid): 130 | for x_id in range(self.m_grid): 131 | y_bias = y_id / self.m_grid 132 | x_bias = x_id / self.m_grid 133 | v[y_id * self.m_grid + x_id, :, ys] = \ 134 | y_bias + v[y_id * self.m_grid + x_id, :, ys] / self.m_grid 135 | v[y_id * self.m_grid + x_id, :, xs] = \ 136 | x_bias + v[y_id * self.m_grid + x_id, :, xs] / self.m_grid 137 | v[y_id * self.m_grid + x_id, :, rs] /= self.m_grid 138 | 139 | return v 140 | 141 | 142 | def initialize_params(self): 143 | 144 | self.x_ctt = np.random.rand( 145 | self.m_grid*self.m_grid, self.m_strokes_per_block, 146 | self.rderr.d_shape).astype(np.float32) 147 | self.x_ctt = torch.tensor(self.x_ctt).to(device) 148 | 149 | self.x_color = np.random.rand( 150 | self.m_grid*self.m_grid, self.m_strokes_per_block, 151 | self.rderr.d_color).astype(np.float32) 152 | self.x_color = torch.tensor(self.x_color).to(device) 153 | 154 | self.x_alpha = np.random.rand( 155 | self.m_grid*self.m_grid, self.m_strokes_per_block, 156 | self.rderr.d_alpha).astype(np.float32) 157 | self.x_alpha = torch.tensor(self.x_alpha).to(device) 158 | 159 | 160 | def stroke_sampler(self, anchor_id): 161 | 162 | if anchor_id == self.m_strokes_per_block: 163 | return 164 | 165 | err_maps = torch.sum( 166 | torch.abs(self.img_batch - self.G_final_pred_canvas), 167 | dim=1, keepdim=True).detach() 168 | 169 | for i in range(self.m_grid*self.m_grid): 170 | this_err_map = err_maps[i,0,:,:].cpu().numpy() 171 | this_err_map = cv2.blur(this_err_map, (20, 20)) 172 | this_err_map = this_err_map ** 4 173 | this_img = self.img_batch[i, :, :, :].detach().permute([1, 2, 0]).cpu().numpy() 174 | 175 | self.rderr.random_stroke_params_sampler( 176 | err_map=this_err_map, img=this_img) 177 | 178 | self.x_ctt.data[i, anchor_id, :] = torch.tensor( 179 | self.rderr.stroke_params[0:self.rderr.d_shape]) 180 | self.x_color.data[i, anchor_id, :] = torch.tensor( 181 | self.rderr.stroke_params[self.rderr.d_shape:self.rderr.d_shape+self.rderr.d_color]) 182 | self.x_alpha.data[i, anchor_id, :] = torch.tensor(self.rderr.stroke_params[-1]) 183 | 184 | 185 | def _backward_x(self): 186 | 187 | self.G_loss = 0 188 | self.G_loss += self.args.beta_L1 * self._pxl_loss( 189 | canvas=self.G_final_pred_canvas, gt=self.img_batch) 190 | if self.args.with_ot_loss: 191 | self.G_loss += self.args.beta_ot * self._sinkhorn_loss( 192 | self.G_final_pred_canvas, self.img_batch) 193 | self.G_loss.backward() 194 | 195 | 196 | def _forward_pass(self): 197 | 198 | self.x = torch.cat([self.x_ctt, self.x_color, self.x_alpha], dim=-1) 199 | 200 | v = torch.reshape(self.x[:, 0:self.anchor_id+1, :], 201 | [self.m_grid*self.m_grid*(self.anchor_id+1), -1, 1, 1]) 202 | self.G_pred_foregrounds, self.G_pred_alphas = self.net_G(v) 203 | 204 | self.G_pred_foregrounds = morphology.Dilation2d(m=1)(self.G_pred_foregrounds) 205 | self.G_pred_alphas = morphology.Erosion2d(m=1)(self.G_pred_alphas) 206 | 207 | self.G_pred_foregrounds = torch.reshape( 208 | self.G_pred_foregrounds, [self.m_grid*self.m_grid, self.anchor_id+1, 3, 128, 128]) 209 | self.G_pred_alphas = torch.reshape( 210 | self.G_pred_alphas, [self.m_grid*self.m_grid, self.anchor_id+1, 3, 128, 128]) 211 | 212 | for i in range(self.anchor_id+1): 213 | G_pred_foreground = self.G_pred_foregrounds[:, i] 214 | G_pred_alpha = self.G_pred_alphas[:, i] 215 | self.G_pred_canvas = G_pred_foreground * G_pred_alpha \ 216 | + self.G_pred_canvas * (1 - G_pred_alpha) 217 | 218 | self.G_final_pred_canvas = self.G_pred_canvas 219 | 220 | 221 | 222 | 223 | class Painter(PainterBase): 224 | 225 | def __init__(self, args): 226 | super(Painter, self).__init__(args=args) 227 | 228 | self.m_grid = args.m_grid 229 | 230 | self.max_m_strokes = args.max_m_strokes 231 | 232 | self.img_path = args.img_path 233 | self.img_ = cv2.imread(args.img_path, cv2.IMREAD_COLOR) 234 | self.img_ = cv2.cvtColor(self.img_, cv2.COLOR_BGR2RGB).astype(np.float32) / 255. 235 | self.img_ = cv2.resize(self.img_, (128 * args.m_grid, 128 * args.m_grid)) 236 | 237 | self.m_strokes_per_block = int(args.max_m_strokes / (args.m_grid * args.m_grid)) 238 | 239 | self.img_batch = utils.img2patches(self.img_, args.m_grid).to(device) 240 | 241 | self.final_rendered_images = None 242 | 243 | 244 | def _drawing_step_states(self): 245 | acc = self._compute_acc().item() 246 | print('iteration step %d, G_loss: %.5f, step_psnr: %.5f, strokes: %d / %d' 247 | % (self.step_id, self.G_loss.item(), acc, 248 | (self.anchor_id+1)*self.m_grid*self.m_grid, 249 | self.max_m_strokes)) 250 | vis2 = utils.patches2img(self.G_final_pred_canvas, self.m_grid).clip(min=0, max=1) 251 | cv2.imshow('G_pred', vis2[:,:,::-1]) 252 | cv2.imshow('input', self.img_[:, :, ::-1]) 253 | cv2.waitKey(1) 254 | 255 | def _render_on_grids(self, v): 256 | 257 | rendered_imgs = [] 258 | 259 | self.rderr.create_empty_canvas() 260 | 261 | grid_idx = list(range(self.m_grid ** 2)) 262 | random.shuffle(grid_idx) 263 | for j in range(v.shape[1]): # for each group of stroke 264 | for i in range(len(grid_idx)): # for each random patch 265 | self.rderr.stroke_params = v[grid_idx[i], j, :] 266 | if self.rderr.check_stroke(): 267 | self.rderr.draw_stroke() 268 | rendered_imgs.append(self.rderr.canvas) 269 | 270 | return rendered_imgs 271 | 272 | 273 | 274 | 275 | class ProgressivePainter(PainterBase): 276 | 277 | def __init__(self, args): 278 | super(ProgressivePainter, self).__init__(args=args) 279 | 280 | self.max_divide = args.max_divide 281 | 282 | self.max_m_strokes = args.max_m_strokes 283 | 284 | self.m_strokes_per_block = self.stroke_parser() 285 | 286 | self.m_grid = 1 287 | 288 | self.img_path = args.img_path 289 | self.img_ = cv2.imread(args.img_path, cv2.IMREAD_COLOR) 290 | self.img_ = cv2.cvtColor(self.img_, cv2.COLOR_BGR2RGB).astype(np.float32) / 255. 291 | self.img_ = cv2.resize(self.img_, (128 * args.max_divide, 128 * args.max_divide)) 292 | 293 | 294 | 295 | def _render(self, v): 296 | 297 | v = v[0,:,:] 298 | 299 | rendered_imgs = [] 300 | 301 | self.rderr.create_empty_canvas() 302 | 303 | for i in range(v.shape[0]): # for each stroke 304 | self.rderr.stroke_params = v[i, :] 305 | if self.rderr.check_stroke(): 306 | self.rderr.draw_stroke() 307 | rendered_imgs.append(self.rderr.canvas) 308 | 309 | return rendered_imgs 310 | 311 | 312 | def stroke_parser(self): 313 | 314 | total_blocks = 0 315 | for i in range(0, self.max_divide + 1): 316 | total_blocks += i ** 2 317 | 318 | return int(self.max_m_strokes / total_blocks) 319 | 320 | 321 | def _drawing_step_states(self): 322 | acc = self._compute_acc().item() 323 | print('iteration step %d, G_loss: %.5f, step_acc: %.5f, grid_scale: %d / %d, strokes: %d / %d' 324 | % (self.step_id, self.G_loss.item(), acc, 325 | self.m_grid, self.max_divide, 326 | self.anchor_id, self.m_strokes_per_block)) 327 | vis2 = utils.patches2img(self.G_final_pred_canvas, self.m_grid).clip(min=0, max=1) 328 | cv2.imshow('G_pred', vis2[:,:,::-1]) 329 | cv2.imshow('input', self.img_[:, :, ::-1]) 330 | cv2.waitKey(1) 331 | 332 | 333 | 334 | class NeuralStyleTransfer(PainterBase): 335 | 336 | def __init__(self, args): 337 | super(NeuralStyleTransfer, self).__init__(args=args) 338 | 339 | self._style_loss = loss.VGGStyleLoss(transfer_mode=args.transfer_mode, resize=True) 340 | 341 | npzfile = np.load(args.vector_file) 342 | 343 | self.x_ctt = torch.tensor(npzfile['x_ctt']).to(device) 344 | self.x_color = torch.tensor(npzfile['x_color']).to(device) 345 | self.x_alpha = torch.tensor(npzfile['x_alpha']).to(device) 346 | self.m_grid = int(np.sqrt(self.x_ctt.shape[0])) 347 | 348 | self.anchor_id = self.x_ctt.shape[1] - 1 349 | 350 | img_ = cv2.imread(args.content_img_path, cv2.IMREAD_COLOR) 351 | img_ = cv2.cvtColor(img_, cv2.COLOR_BGR2RGB).astype(np.float32) / 255. 352 | self.img_ = cv2.resize(img_, (128*self.m_grid, 128*self.m_grid)) 353 | self.img_batch = utils.img2patches(self.img_, self.m_grid).to(device) 354 | 355 | style_img = cv2.imread(args.style_img_path, cv2.IMREAD_COLOR) 356 | self.style_img_ = cv2.cvtColor(style_img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255. 357 | self.style_img = cv2.blur(cv2.resize(self.style_img_, (128, 128)), (2, 2)) 358 | self.style_img = torch.tensor(self.style_img).permute([2, 0, 1]).unsqueeze(0).to(device) 359 | 360 | 361 | def _style_transfer_step_states(self): 362 | acc = self._compute_acc().item() 363 | print('running style transfer... iteration step %d, G_loss: %.5f, step_psnr: %.5f' 364 | % (self.step_id, self.G_loss.item(), acc)) 365 | vis2 = utils.patches2img(self.G_final_pred_canvas, self.m_grid).clip(min=0, max=1) 366 | cv2.imshow('G_pred', vis2[:,:,::-1]) 367 | cv2.imshow('input', self.img_[:, :, ::-1]) 368 | cv2.waitKey(1) 369 | 370 | 371 | def _backward_x_sty(self): 372 | canvas = utils.patches2img( 373 | self.G_final_pred_canvas, self.m_grid, to_numpy=False).to(device) 374 | self.G_loss = self.args.beta_L1 * self._pxl_loss( 375 | canvas=self.G_final_pred_canvas, gt=self.img_batch, ignore_color=True) 376 | self.G_loss += self.args.beta_sty * self._style_loss(canvas, self.style_img) 377 | self.G_loss.backward() 378 | 379 | 380 | def _render_on_grids(self, v): 381 | 382 | rendered_imgs = [] 383 | 384 | self.rderr.create_empty_canvas() 385 | 386 | grid_idx = list(range(self.m_grid ** 2)) 387 | random.shuffle(grid_idx) 388 | for j in range(v.shape[1]): # for each group of stroke 389 | for i in range(len(grid_idx)): # for each random patch 390 | self.rderr.stroke_params = v[grid_idx[i], j, :] 391 | if self.rderr.check_stroke(): 392 | self.rderr.draw_stroke() 393 | rendered_imgs.append(self.rderr.canvas) 394 | 395 | return rendered_imgs 396 | 397 | 398 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torchvision import models 6 | import torch.nn.functional as F 7 | from torch.optim import lr_scheduler 8 | import math 9 | import utils 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | 13 | # Decide which device we want to run on 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | PI = math.pi 17 | ############################################################################### 18 | # Helper Functions 19 | ############################################################################### 20 | 21 | 22 | class Identity(nn.Module): 23 | def forward(self, x): 24 | return x 25 | 26 | def get_norm_layer(norm_type='instance'): 27 | """Return a normalization layer 28 | 29 | Parameters: 30 | norm_type (str) -- the name of the normalization layer: batch | instance | none 31 | 32 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 33 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 34 | """ 35 | if norm_type == 'batch': 36 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 37 | elif norm_type == 'instance': 38 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 39 | elif norm_type == 'none': 40 | norm_layer = lambda x: Identity() 41 | else: 42 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 43 | return norm_layer 44 | 45 | 46 | def init_weights(net, init_type='normal', init_gain=0.02): 47 | """Initialize network weights. 48 | 49 | Parameters: 50 | net (network) -- network to be initialized 51 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 52 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 53 | 54 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 55 | work better for some applications. Feel free to try yourself. 56 | """ 57 | def init_func(m): # define the initialization function 58 | classname = m.__class__.__name__ 59 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 60 | if init_type == 'normal': 61 | init.normal_(m.weight.data, 0.0, init_gain) 62 | elif init_type == 'xavier': 63 | init.xavier_normal_(m.weight.data, gain=init_gain) 64 | elif init_type == 'kaiming': 65 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 66 | elif init_type == 'orthogonal': 67 | init.orthogonal_(m.weight.data, gain=init_gain) 68 | else: 69 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 70 | if hasattr(m, 'bias') and m.bias is not None: 71 | init.constant_(m.bias.data, 0.0) 72 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 73 | init.normal_(m.weight.data, 1.0, init_gain) 74 | init.constant_(m.bias.data, 0.0) 75 | 76 | print('initialize network with %s' % init_type) 77 | net.apply(init_func) # apply the initialization function 78 | 79 | 80 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 81 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 82 | Parameters: 83 | net (network) -- the network to be initialized 84 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 85 | gain (float) -- scaling factor for normal, xavier and orthogonal. 86 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 87 | 88 | Return an initialized network. 89 | """ 90 | if len(gpu_ids) > 0: 91 | assert(torch.cuda.is_available()) 92 | net.to(gpu_ids[0]) 93 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 94 | init_weights(net, init_type, init_gain=init_gain) 95 | return net 96 | 97 | 98 | def define_G(rdrr, netG, init_type='normal', init_gain=0.02, gpu_ids=[]): 99 | net = None 100 | if netG == 'plain-dcgan': 101 | net = DCGAN(rdrr) 102 | elif netG == 'plain-unet': 103 | net = UNet(rdrr) 104 | elif netG == 'huang-net': 105 | net = HuangNet(rdrr) 106 | elif netG == 'zou-fusion-net': 107 | net = ZouFCNFusion(rdrr) 108 | else: 109 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 110 | return init_net(net, init_type, init_gain, gpu_ids) 111 | 112 | 113 | class DCGAN(nn.Module): 114 | def __init__(self, rdrr, ngf=64): 115 | super(DCGAN, self).__init__() 116 | input_nc = rdrr.d 117 | self.main = nn.Sequential( 118 | # input is Z, going into a convolution 119 | nn.ConvTranspose2d(input_nc, ngf * 8, 4, 1, 0, bias=False), 120 | nn.BatchNorm2d(ngf * 8), 121 | nn.ReLU(True), 122 | # state size. (ngf*8) x 4 x 4 123 | 124 | nn.ConvTranspose2d(ngf * 8, ngf * 8, 4, 2, 1, bias=False), 125 | nn.BatchNorm2d(ngf * 8), 126 | nn.ReLU(True), 127 | # state size. (ngf*4) x 8 x 8 128 | 129 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 130 | nn.BatchNorm2d(ngf * 4), 131 | nn.ReLU(True), 132 | # state size. (ngf*2) x 16 x 16 133 | 134 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 135 | nn.BatchNorm2d(ngf * 2), 136 | nn.ReLU(True), 137 | # state size. (ngf*2) x 32 x 32 138 | 139 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 140 | nn.BatchNorm2d(ngf), 141 | nn.ReLU(True), 142 | # state size. (ngf*2) x 64 x 64 143 | 144 | nn.ConvTranspose2d(ngf, 6, 4, 2, 1, bias=False), 145 | # state size. (nc) x 128 x 128 146 | ) 147 | 148 | def forward(self, input): 149 | output_tensor = self.main(input) 150 | return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:] 151 | 152 | 153 | 154 | class PixelShuffleNet(nn.Module): 155 | def __init__(self, input_nc): 156 | super(PixelShuffleNet, self).__init__() 157 | self.fc1 = (nn.Linear(input_nc, 512)) 158 | self.fc2 = (nn.Linear(512, 1024)) 159 | self.fc3 = (nn.Linear(1024, 2048)) 160 | self.fc4 = (nn.Linear(2048, 4096)) 161 | self.conv1 = (nn.Conv2d(16, 32, 3, 1, 1)) 162 | self.conv2 = (nn.Conv2d(32, 32, 3, 1, 1)) 163 | self.conv3 = (nn.Conv2d(8, 16, 3, 1, 1)) 164 | self.conv4 = (nn.Conv2d(16, 16, 3, 1, 1)) 165 | self.conv5 = (nn.Conv2d(4, 8, 3, 1, 1)) 166 | self.conv6 = (nn.Conv2d(8, 4*3, 3, 1, 1)) 167 | self.pixel_shuffle = nn.PixelShuffle(2) 168 | 169 | def forward(self, x): 170 | x = x.squeeze() 171 | x = F.relu(self.fc1(x)) 172 | x = F.relu(self.fc2(x)) 173 | x = F.relu(self.fc3(x)) 174 | x = F.relu(self.fc4(x)) 175 | x = x.view(-1, 16, 16, 16) 176 | x = F.relu(self.conv1(x)) 177 | x = self.pixel_shuffle(self.conv2(x)) 178 | x = F.relu(self.conv3(x)) 179 | x = self.pixel_shuffle(self.conv4(x)) 180 | x = F.relu(self.conv5(x)) 181 | x = self.pixel_shuffle(self.conv6(x)) 182 | x = x.view(-1, 3, 128, 128) 183 | return x 184 | 185 | 186 | 187 | class HuangNet(nn.Module): 188 | def __init__(self, rdrr): 189 | super(HuangNet, self).__init__() 190 | self.rdrr = rdrr 191 | self.fc1 = (nn.Linear(rdrr.d, 512)) 192 | self.fc2 = (nn.Linear(512, 1024)) 193 | self.fc3 = (nn.Linear(1024, 2048)) 194 | self.fc4 = (nn.Linear(2048, 4096)) 195 | self.conv1 = (nn.Conv2d(16, 32, 3, 1, 1)) 196 | self.conv2 = (nn.Conv2d(32, 32, 3, 1, 1)) 197 | self.conv3 = (nn.Conv2d(8, 16, 3, 1, 1)) 198 | self.conv4 = (nn.Conv2d(16, 16, 3, 1, 1)) 199 | self.conv5 = (nn.Conv2d(4, 8, 3, 1, 1)) 200 | self.conv6 = (nn.Conv2d(8, 4 * 6, 3, 1, 1)) 201 | self.pixel_shuffle = nn.PixelShuffle(2) 202 | 203 | 204 | def forward(self, x): 205 | x = x.squeeze() 206 | x = F.relu(self.fc1(x)) 207 | x = F.relu(self.fc2(x)) 208 | x = F.relu(self.fc3(x)) 209 | x = F.relu(self.fc4(x)) 210 | x = x.view(-1, 16, 16, 16) 211 | x = F.relu(self.conv1(x)) 212 | x = self.pixel_shuffle(self.conv2(x)) 213 | x = F.relu(self.conv3(x)) 214 | x = self.pixel_shuffle(self.conv4(x)) 215 | x = F.relu(self.conv5(x)) 216 | x = self.pixel_shuffle(self.conv6(x)) 217 | output_tensor = x.view(-1, 6, 128, 128) 218 | return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:] 219 | 220 | 221 | 222 | class ZouFCNFusion(nn.Module): 223 | def __init__(self, rdrr): 224 | super(ZouFCNFusion, self).__init__() 225 | self.rdrr = rdrr 226 | self.huangnet = PixelShuffleNet(rdrr.d_shape) 227 | self.dcgan = DCGAN(rdrr) 228 | 229 | def forward(self, x): 230 | x_shape = x[:, 0:self.rdrr.d_shape, :, :] 231 | x_alpha = x[:, [-1], :, :] 232 | if self.rdrr.renderer in ['oilpaintbrush', 'airbrush']: 233 | x_alpha = torch.tensor(1.0).to(device) 234 | 235 | mask = self.huangnet(x_shape) 236 | color, _ = self.dcgan(x) 237 | 238 | return color * mask, x_alpha * mask 239 | 240 | 241 | 242 | 243 | class UNet(torch.nn.Module): 244 | def __init__(self, rdrr): 245 | """ 246 | In the constructor we instantiate two nn.Linear modules and assign them as 247 | member variables. 248 | """ 249 | super(UNet, self).__init__() 250 | norm_layer = get_norm_layer(norm_type='batch') 251 | self.unet = UnetGenerator(rdrr.d, 6, 7, norm_layer=norm_layer, use_dropout=False) 252 | 253 | def forward(self, x): 254 | """ 255 | In the forward function we accept a Tensor of input data and we must return 256 | a Tensor of output data. We can use Modules defined in the constructor as 257 | well as arbitrary operators on Tensors. 258 | """ 259 | # resnet layers 260 | x = x.repeat(1, 1, 128, 128) 261 | output_tensor = self.unet(x) 262 | return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:] 263 | 264 | 265 | 266 | class UnetGenerator(nn.Module): 267 | """Create a Unet-based generator""" 268 | 269 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): 270 | """Construct a Unet generator 271 | Parameters: 272 | input_nc (int) -- the number of channels in input images 273 | output_nc (int) -- the number of channels in output images 274 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 275 | image of size 128x128 will become of size 1x1 # at the bottleneck 276 | ngf (int) -- the number of filters in the last conv layer 277 | norm_layer -- normalization layer 278 | 279 | We construct the U-Net from the innermost layer to the outermost layer. 280 | It is a recursive process. 281 | """ 282 | super(UnetGenerator, self).__init__() 283 | # construct unet structure 284 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer 285 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 286 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 287 | # gradually reduce the number of filters from ngf * 8 to ngf 288 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 289 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 290 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 291 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer 292 | 293 | def forward(self, input): 294 | """Standard forward""" 295 | return self.model(input) 296 | 297 | 298 | class UnetSkipConnectionBlock(nn.Module): 299 | """Defines the Unet submodule with skip connection. 300 | X -------------------identity---------------------- 301 | |-- downsampling -- |submodule| -- upsampling --| 302 | """ 303 | 304 | def __init__(self, outer_nc, inner_nc, input_nc=None, 305 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 306 | """Construct a Unet submodule with skip connections. 307 | 308 | Parameters: 309 | outer_nc (int) -- the number of filters in the outer conv layer 310 | inner_nc (int) -- the number of filters in the inner conv layer 311 | input_nc (int) -- the number of channels in input images/features 312 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 313 | outermost (bool) -- if this module is the outermost module 314 | innermost (bool) -- if this module is the innermost module 315 | norm_layer -- normalization layer 316 | user_dropout (bool) -- if use dropout layers. 317 | """ 318 | super(UnetSkipConnectionBlock, self).__init__() 319 | self.outermost = outermost 320 | if type(norm_layer) == functools.partial: 321 | use_bias = norm_layer.func == nn.InstanceNorm2d 322 | else: 323 | use_bias = norm_layer == nn.InstanceNorm2d 324 | if input_nc is None: 325 | input_nc = outer_nc 326 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 327 | stride=2, padding=1, bias=use_bias) 328 | downrelu = nn.LeakyReLU(0.2, True) 329 | downnorm = norm_layer(inner_nc) 330 | uprelu = nn.ReLU(True) 331 | upnorm = norm_layer(outer_nc) 332 | 333 | if outermost: 334 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 335 | kernel_size=4, stride=2, 336 | padding=1) 337 | down = [downconv] 338 | # up = [uprelu, upconv, nn.Tanh()] 339 | # up = [uprelu, upconv, nn.Sigmoid()] # ZZX 340 | up = [uprelu, upconv] # ZZX 341 | model = down + [submodule] + up 342 | elif innermost: 343 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 344 | kernel_size=4, stride=2, 345 | padding=1, bias=use_bias) 346 | down = [downrelu, downconv] 347 | up = [uprelu, upconv, upnorm] 348 | model = down + up 349 | else: 350 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 351 | kernel_size=4, stride=2, 352 | padding=1, bias=use_bias) 353 | down = [downrelu, downconv, downnorm] 354 | up = [uprelu, upconv, upnorm] 355 | 356 | if use_dropout: 357 | model = down + [submodule] + up + [nn.Dropout(0.5)] 358 | else: 359 | model = down + [submodule] + up 360 | 361 | self.model = nn.Sequential(*model) 362 | 363 | def forward(self, x): 364 | if self.outermost: 365 | return self.model(x) 366 | else: # add skip connections 367 | return torch.cat([x, self.model(x)], 1) 368 | --------------------------------------------------------------------------------