├── .DS_Store
├── gallery
├── nst.jpg
├── 8bitart.jpg
├── gif_teaser_1.gif
├── diamond_markerpen.jpg
└── apple_oilpaintbrush.jpg
├── test_images
├── yui.jpg
├── joker.jpg
├── kasumi.png
└── satomi.png
├── style_images
├── fire.jpg
├── mosaic.jpg
├── scream.jpg
└── picasso.jpg
├── brushes
├── brush_fromweb2_large_horizontal.png
├── brush_fromweb2_large_vertical.png
├── brush_fromweb2_small_horizontal.png
└── brush_fromweb2_small_vertical.png
├── Requirements.txt
├── README.md
├── morphology.py
├── train_imitator.py
├── pytorch_batch_sinkhorn.py
├── demo.py
├── demo_nst.py
├── demo_prog.py
├── demo_8bitart.py
├── loss.py
├── LICENSE
├── imitator.py
├── image_to_paint.ipynb
├── utils.py
├── renderer.py
├── painter.py
└── networks.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/.DS_Store
--------------------------------------------------------------------------------
/gallery/nst.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/gallery/nst.jpg
--------------------------------------------------------------------------------
/gallery/8bitart.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/gallery/8bitart.jpg
--------------------------------------------------------------------------------
/test_images/yui.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/test_images/yui.jpg
--------------------------------------------------------------------------------
/style_images/fire.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/style_images/fire.jpg
--------------------------------------------------------------------------------
/style_images/mosaic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/style_images/mosaic.jpg
--------------------------------------------------------------------------------
/style_images/scream.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/style_images/scream.jpg
--------------------------------------------------------------------------------
/test_images/joker.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/test_images/joker.jpg
--------------------------------------------------------------------------------
/test_images/kasumi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/test_images/kasumi.png
--------------------------------------------------------------------------------
/test_images/satomi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/test_images/satomi.png
--------------------------------------------------------------------------------
/gallery/gif_teaser_1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/gallery/gif_teaser_1.gif
--------------------------------------------------------------------------------
/style_images/picasso.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/style_images/picasso.jpg
--------------------------------------------------------------------------------
/gallery/diamond_markerpen.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/gallery/diamond_markerpen.jpg
--------------------------------------------------------------------------------
/gallery/apple_oilpaintbrush.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/gallery/apple_oilpaintbrush.jpg
--------------------------------------------------------------------------------
/brushes/brush_fromweb2_large_horizontal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/brushes/brush_fromweb2_large_horizontal.png
--------------------------------------------------------------------------------
/brushes/brush_fromweb2_large_vertical.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/brushes/brush_fromweb2_large_vertical.png
--------------------------------------------------------------------------------
/brushes/brush_fromweb2_small_horizontal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/brushes/brush_fromweb2_small_horizontal.png
--------------------------------------------------------------------------------
/brushes/brush_fromweb2_small_vertical.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/stylized-neural-painting/main/brushes/brush_fromweb2_small_vertical.png
--------------------------------------------------------------------------------
/Requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib
2 | scikit-image
3 | scikit-learn
4 | scipy
5 | numpy
6 | torch
7 | torchvision
8 | opencv-python
9 | opencv-contrib-python
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | image_to_paint.ipynbファイルを開くと、Jupyternotebook形式のファイルが表示されます。上部にある「OpeninColab」ボタンをクリックすると、Google Colab上で実行できます。使用環境は、Google Colabが動作すれば何でも構いません。
2 |
3 | # image_to_paint.ipynb
4 | Stylized Neural Painting とは、従来からあるピクセル単位で画像変換する方法とは異なり、ベクトル化されたデータを使って順次レンダリングするという、まるで筆で絵具を重ね塗りするような方法で油絵へ変換します。詳細は、[cedro-blog](http://cedro3.com/ai/image-to-paint/)を参照下さい。
5 |
--------------------------------------------------------------------------------
/morphology.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pdb
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class Erosion2d(nn.Module):
9 |
10 | def __init__(self, m=1):
11 |
12 | super(Erosion2d, self).__init__()
13 | self.m = m
14 | self.pad = [m, m, m, m]
15 | self.unfold = nn.Unfold(2*m+1, padding=0, stride=1)
16 |
17 | def forward(self, x):
18 | batch_size, c, h, w = x.shape
19 | x_pad = F.pad(x, pad=self.pad, mode='constant', value=1e9)
20 | for i in range(c):
21 | channel = self.unfold(x_pad[:, [i], :, :])
22 | channel = torch.min(channel, dim=1, keepdim=True)[0]
23 | channel = channel.view([batch_size, 1, h, w])
24 | x[:, [i], :, :] = channel
25 |
26 | return x
27 |
28 |
29 |
30 | class Dilation2d(nn.Module):
31 |
32 | def __init__(self, m=1):
33 |
34 | super(Dilation2d, self).__init__()
35 | self.m = m
36 | self.pad = [m, m, m, m]
37 | self.unfold = nn.Unfold(2*m+1, padding=0, stride=1)
38 |
39 | def forward(self, x):
40 | batch_size, c, h, w = x.shape
41 | x_pad = F.pad(x, pad=self.pad, mode='constant', value=-1e9)
42 | for i in range(c):
43 | channel = self.unfold(x_pad[:, [i], :, :])
44 | channel = torch.max(channel, dim=1, keepdim=True)[0]
45 | channel = channel.view([batch_size, 1, h, w])
46 | x[:, [i], :, :] = channel
47 |
48 | return x
49 |
--------------------------------------------------------------------------------
/train_imitator.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import utils
5 | from imitator import*
6 |
7 | # settings
8 | parser = argparse.ArgumentParser(description='ZZX TRAIN IMITATOR')
9 | parser.add_argument('--renderer', type=str, default='oilpaintbrush', metavar='str',
10 | help='renderer: [watercolor, markerpen, oilpaintbrush, rectangle'
11 | 'bezier, circle, square, rectangle] (default ...)')
12 | parser.add_argument('--batch_size', type=int, default=64, metavar='N',
13 | help='input batch size for training (default: 4)')
14 | parser.add_argument('--print_models', action='store_true', default=False,
15 | help='visualize and print networks')
16 | parser.add_argument('--net_G', type=str, default='zou-fusion-net', metavar='str',
17 | help='net_G: plain-dcgan or plain-unet or huang-net or zou-fusion-net')
18 | parser.add_argument('--checkpoint_dir', type=str, default=r'./checkpoints_G', metavar='str',
19 | help='dir to save checkpoints (default: ...)')
20 | parser.add_argument('--vis_dir', type=str, default=r'./val_out_G', metavar='str',
21 | help='dir to save results during training (default: ./val_out_G)')
22 | parser.add_argument('--lr', type=float, default=2e-4,
23 | help='learning rate (default: 0.0002)')
24 | parser.add_argument('--max_num_epochs', type=int, default=400, metavar='N',
25 | help='max number of training epochs (default 400)')
26 | args = parser.parse_args()
27 |
28 |
29 | if __name__ == '__main__':
30 |
31 | dataloaders = utils.get_renderer_loaders(args)
32 | imt = Imitator(args=args, dataloaders=dataloaders)
33 | imt.train_models()
34 |
35 | # # How to check if the data is loading correctly?
36 | # dataloaders = utils.get_renderer_loaders(args)
37 | # for i in range(100):
38 | # data = next(iter(dataloaders['train']))
39 | # vis_A = data['A']
40 | # vis_B = utils.make_numpy_grid(data['B'])
41 | # print(data['A'].cpu().numpy().shape[1])
42 | # print(data['B'].shape)
43 | # plt.imshow(vis_B)
44 | # plt.show()
45 |
46 |
47 |
--------------------------------------------------------------------------------
/pytorch_batch_sinkhorn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | sinkhorn_pointcloud.py
4 |
5 | Discrete OT : Sinkhorn algorithm for point cloud marginals.
6 |
7 | """
8 |
9 | import torch
10 | from torch.autograd import Variable
11 |
12 | # Decide which device we want to run on
13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14 |
15 |
16 | def sinkhorn_normalized(x, y, epsilon, niter, mass_x=None, mass_y=None):
17 |
18 | Wxy = sinkhorn_loss(x, y, epsilon, niter, mass_x, mass_y)
19 | Wxx = sinkhorn_loss(x, x, epsilon, niter, mass_x, mass_x)
20 | Wyy = sinkhorn_loss(y, y, epsilon, niter, mass_y, mass_y)
21 | return 2 * Wxy - Wxx - Wyy
22 |
23 | def sinkhorn_loss(x, y, epsilon, niter, mass_x=None, mass_y=None):
24 | """
25 | Given two emprical measures with n points each with locations x and y
26 | outputs an approximation of the OT cost with regularization parameter epsilon
27 | niter is the max. number of steps in sinkhorn loop
28 | """
29 |
30 | # The Sinkhorn algorithm takes as input three variables :
31 | C = cost_matrix(x, y) # Wasserstein cost function
32 |
33 | nx = x.shape[1]
34 | ny = y.shape[1]
35 | batch_size = x.shape[0]
36 |
37 | if mass_x is None:
38 | # assign marginal to fixed with equal weights
39 | mu = 1. / nx * torch.ones([batch_size, nx]).to(device)
40 | else: # normalize
41 | mass_x.data = torch.clamp(mass_x.data, min=0, max=1e9)
42 | mass_x = mass_x + 1e-9
43 | mu = (mass_x / mass_x.sum(dim=-1, keepdim=True)).to(device)
44 |
45 | if mass_y is None:
46 | # assign marginal to fixed with equal weights
47 | nu = 1. / ny * torch.ones([batch_size, ny]).to(device)
48 | else: # normalize
49 | mass_y.data = torch.clamp(mass_y.data, min=0, max=1e9)
50 | mass_y = mass_y + 1e-9
51 | nu = (mass_y / mass_y.sum(dim=-1, keepdim=True)).to(device)
52 |
53 | def M(u, v):
54 | "Modified cost for logarithmic updates"
55 | "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
56 | return (-C + u.unsqueeze(2) + v.unsqueeze(1)) / epsilon
57 |
58 | def lse(A):
59 | "log-sum-exp"
60 | return torch.log(torch.exp(A).sum(2, keepdim=True) + 1e-6) # add 10^-6 to prevent NaN
61 |
62 | # Actual Sinkhorn loop ......................................................................
63 | u, v, err = 0. * mu, 0. * nu, 0.
64 |
65 | for i in range(niter):
66 | u = epsilon * (torch.log(mu) - lse(M(u, v)).squeeze()) + u
67 | v = epsilon * (torch.log(nu) - lse(M(u, v).transpose(dim0=1, dim1=2)).squeeze()) + v
68 |
69 | U, V = u, v
70 | pi = torch.exp(M(U, V)) # Transport plan pi = diag(a)*K*diag(b)
71 | cost = torch.sum(pi * C, dim=[1, 2]) # Sinkhorn cost
72 |
73 | return torch.mean(cost)
74 |
75 |
76 | def cost_matrix(x, y, p=2):
77 | "Returns the matrix of $|x_i-y_j|^p$."
78 | x_col = x.unsqueeze(2)
79 | y_lin = y.unsqueeze(1)
80 | c = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
81 | return c
82 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | torch.cuda.current_device()
5 | import torch.optim as optim
6 |
7 | from painter import *
8 |
9 | # settings
10 | parser = argparse.ArgumentParser(description='STYLIZED NEURAL PAINTING')
11 | parser.add_argument('--img_path', type=str, default='./test_images/sunflowers.jpg', metavar='str',
12 | help='path to test image (default: ./test_images/sunflowers.jpg)')
13 | parser.add_argument('--renderer', type=str, default='oilpaintbrush', metavar='str',
14 | help='renderer: [watercolor, markerpen, oilpaintbrush, rectangle (default oilpaintbrush)')
15 | parser.add_argument('--canvas_color', type=str, default='black', metavar='str',
16 | help='canvas_color: [black, white] (default black)')
17 | parser.add_argument('--canvas_size', type=int, default=512, metavar='str',
18 | help='size of the canvas for stroke rendering')
19 | parser.add_argument('--max_m_strokes', type=int, default=500, metavar='str',
20 | help='max number of strokes (default 500)')
21 | parser.add_argument('--m_grid', type=int, default=5, metavar='N',
22 | help='divide an image to m_grid x m_grid patches (default 5)')
23 | parser.add_argument('--beta_L1', type=float, default=1.0,
24 | help='weight for L1 loss (default: 1.0)')
25 | parser.add_argument('--with_ot_loss', action='store_true', default=False,
26 | help='imporve the convergence by using optimal transportation loss')
27 | parser.add_argument('--beta_ot', type=float, default=0.1,
28 | help='weight for optimal transportation loss (default: 0.1)')
29 | parser.add_argument('--net_G', type=str, default='zou-fusion-net', metavar='str',
30 | help='net_G: plain-dcgan, plain-unet, huang-net, or zou-fusion-net (default: zou-fusion-net)')
31 | parser.add_argument('--renderer_checkpoint_dir', type=str, default=r'./checkpoints_G_oilpaintbrush', metavar='str',
32 | help='dir to load neu-renderer (default: ./checkpoints_G_oilpaintbrush)')
33 | parser.add_argument('--lr', type=float, default=0.005,
34 | help='learning rate for stroke searching (default: 0.005)')
35 | parser.add_argument('--output_dir', type=str, default=r'./output', metavar='str',
36 | help='dir to save painting results (default: ./output)')
37 | args = parser.parse_args()
38 |
39 |
40 | # Decide which device we want to run on
41 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42 |
43 |
44 | def optimize_x(pt):
45 |
46 | pt._load_checkpoint()
47 | pt.net_G.eval()
48 |
49 | pt.initialize_params()
50 | pt.x_ctt.requires_grad = True
51 | pt.x_color.requires_grad = True
52 | pt.x_alpha.requires_grad = True
53 | utils.set_requires_grad(pt.net_G, False)
54 |
55 | pt.optimizer_x = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr)
56 |
57 | print('begin to draw...')
58 | pt.step_id = 0
59 | for pt.anchor_id in range(0, pt.m_strokes_per_block):
60 | pt.stroke_sampler(pt.anchor_id)
61 | iters_per_stroke = 20
62 | if pt.anchor_id == pt.m_strokes_per_block - 1:
63 | iters_per_stroke = 40
64 | for i in range(iters_per_stroke):
65 |
66 | pt.optimizer_x.zero_grad()
67 |
68 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
69 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
70 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)
71 |
72 | if args.canvas_color == 'white':
73 | pt.G_pred_canvas = torch.ones([args.m_grid ** 2, 3, 128, 128]).to(device)
74 | else:
75 | pt.G_pred_canvas = torch.zeros(args.m_grid ** 2, 3, 128, 128).to(device)
76 |
77 | pt._forward_pass()
78 | pt._drawing_step_states()
79 | pt._backward_x()
80 | pt.optimizer_x.step()
81 |
82 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
83 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
84 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)
85 |
86 | pt.step_id += 1
87 |
88 | v = pt.x.detach().cpu().numpy()
89 | pt._save_stroke_params(v)
90 | v_n = pt._normalize_strokes(pt.x)
91 | pt.final_rendered_images = pt._render_on_grids(v_n)
92 | pt._save_rendered_images()
93 |
94 |
95 |
96 | if __name__ == '__main__':
97 |
98 | pt = Painter(args=args)
99 | optimize_x(pt)
100 |
101 |
--------------------------------------------------------------------------------
/demo_nst.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | torch.cuda.current_device()
5 | import torch.optim as optim
6 |
7 | from painter import *
8 |
9 | # settings
10 | parser = argparse.ArgumentParser(description='STYLIZED NEURAL PAINTING')
11 | parser.add_argument('--renderer', type=str, default='oilpaintbrush', metavar='str',
12 | help='renderer: [watercolor, markerpen, oilpaintbrush, rectangle (default oilpaintbrush)')
13 | parser.add_argument('--vector_file', type=str, default='./output/sunflowers_strokes.npz', metavar='str',
14 | help='path to pre-generated stroke vector file (default: ...)')
15 | parser.add_argument('--style_img_path', type=str, default='./style_images/fire.jpg', metavar='str',
16 | help='path to style image (default: ...)')
17 | parser.add_argument('--content_img_path', type=str, default='./test_images/sunflowers.jpg', metavar='str',
18 | help='path to content image (default: ...)')
19 | parser.add_argument('--transfer_mode', type=int, default=1, metavar='N',
20 | help='style transfer mode, 0: transfer color only, 1: transfer both color and texture, '
21 | 'defalt: 1')
22 | parser.add_argument('--canvas_color', type=str, default='black', metavar='str',
23 | help='canvas_color: [black, white] (default black)')
24 | parser.add_argument('--canvas_size', type=int, default=512, metavar='str',
25 | help='size of the canvas for stroke rendering')
26 | parser.add_argument('--beta_L1', type=float, default=1.0,
27 | help='weight for L1 loss (default: 1.0)')
28 | parser.add_argument('--beta_sty', type=float, default=0.5,
29 | help='weight for vgg style loss (default: 0.5)')
30 | parser.add_argument('--net_G', type=str, default='zou-fusion-net', metavar='str',
31 | help='net_G: plain-dcgan, plain-unet, huang-net, or zou-fusion-net (default: zou-fusion-net)')
32 | parser.add_argument('--renderer_checkpoint_dir', type=str, default=r'./checkpoints_G_oilpaintbrush', metavar='str',
33 | help='dir to load neu-renderer (default: ./checkpoints_G_oilpaintbrush)')
34 | parser.add_argument('--lr', type=float, default=0.005,
35 | help='learning rate for stroke searching (default: 0.005)')
36 | parser.add_argument('--output_dir', type=str, default=r'./output', metavar='str',
37 | help='dir to save style transfer results (default: ./output)')
38 | args = parser.parse_args()
39 |
40 |
41 | # Decide which device we want to run on
42 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
43 |
44 |
45 | def optimize_x(pt):
46 |
47 | pt._load_checkpoint()
48 | pt.net_G.eval()
49 |
50 | if args.transfer_mode == 0: # transfer color only
51 | pt.x_ctt.requires_grad = False
52 | pt.x_color.requires_grad = True
53 | pt.x_alpha.requires_grad = False
54 | else: # transfer both color and texture
55 | pt.x_ctt.requires_grad = True
56 | pt.x_color.requires_grad = True
57 | pt.x_alpha.requires_grad = True
58 |
59 | pt.optimizer_x_sty = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr)
60 |
61 | iters_per_stroke = 100
62 | for i in range(iters_per_stroke):
63 | pt.optimizer_x_sty.zero_grad()
64 |
65 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
66 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
67 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)
68 |
69 | if args.canvas_color == 'white':
70 | pt.G_pred_canvas = torch.ones([pt.m_grid*pt.m_grid, 3, 128, 128]).to(device)
71 | else:
72 | pt.G_pred_canvas = torch.zeros(pt.m_grid*pt.m_grid, 3, 128, 128).to(device)
73 |
74 | pt._forward_pass()
75 | pt._style_transfer_step_states()
76 | pt._backward_x_sty()
77 | pt.optimizer_x_sty.step()
78 |
79 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
80 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
81 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)
82 |
83 | pt.step_id += 1
84 |
85 | print('saving style transfer result...')
86 | v_n = pt._normalize_strokes(pt.x)
87 | pt.final_rendered_images = pt._render_on_grids(v_n)
88 |
89 | file_dir = os.path.join(
90 | args.output_dir, args.content_img_path.split('/')[-1][:-4])
91 | plt.imsave(file_dir + '_style_img_' +
92 | args.style_img_path.split('/')[-1][:-4] + '.png', pt.style_img_)
93 | plt.imsave(file_dir + '_style_transfer_' +
94 | args.style_img_path.split('/')[-1][:-4] + '.png', pt.final_rendered_images[-1])
95 |
96 |
97 | if __name__ == '__main__':
98 |
99 | pt = NeuralStyleTransfer(args=args)
100 | optimize_x(pt)
101 |
102 |
--------------------------------------------------------------------------------
/demo_prog.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | torch.cuda.current_device()
5 | import torch.optim as optim
6 |
7 | from painter import *
8 |
9 | # settings
10 | parser = argparse.ArgumentParser(description='STYLIZED NEURAL PAINTING')
11 | parser.add_argument('--img_path', type=str, default='./test_images/sunflowers.jpg', metavar='str',
12 | help='path to test image (default: ./test_images/sunflowers.jpg)')
13 | parser.add_argument('--renderer', type=str, default='oilpaintbrush', metavar='str',
14 | help='renderer: [watercolor, markerpen, oilpaintbrush, rectangle (default oilpaintbrush)')
15 | parser.add_argument('--canvas_color', type=str, default='black', metavar='str',
16 | help='canvas_color: [black, white] (default black)')
17 | parser.add_argument('--canvas_size', type=int, default=512, metavar='str',
18 | help='size of the canvas for stroke rendering')
19 | parser.add_argument('--max_m_strokes', type=int, default=500, metavar='str',
20 | help='max number of strokes (default 500)')
21 | parser.add_argument('--max_divide', type=int, default=5, metavar='N',
22 | help='divide an image up-to max_divide x max_divide patches (default 5)')
23 | parser.add_argument('--beta_L1', type=float, default=1.0,
24 | help='weight for L1 loss (default: 1.0)')
25 | parser.add_argument('--with_ot_loss', action='store_true', default=False,
26 | help='imporve the convergence by using optimal transportation loss')
27 | parser.add_argument('--beta_ot', type=float, default=0.1,
28 | help='weight for optimal transportation loss (default: 0.1)')
29 | parser.add_argument('--net_G', type=str, default='zou-fusion-net', metavar='str',
30 | help='net_G: plain-dcgan, plain-unet, huang-net, or zou-fusion-net (default: zou-fusion-net)')
31 | parser.add_argument('--renderer_checkpoint_dir', type=str, default=r'./checkpoints_G_oilpaintbrush', metavar='str',
32 | help='dir to load neu-renderer (default: ./checkpoints_G_oilpaintbrush)')
33 | parser.add_argument('--lr', type=float, default=0.005,
34 | help='learning rate for stroke searching (default: 0.005)')
35 | parser.add_argument('--output_dir', type=str, default=r'./output', metavar='str',
36 | help='dir to save painting results (default: ./output)')
37 | args = parser.parse_args()
38 |
39 |
40 | # Decide which device we want to run on
41 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42 |
43 |
44 | def optimize_x(pt):
45 |
46 | pt._load_checkpoint()
47 | pt.net_G.eval()
48 |
49 | print('begin drawing...')
50 |
51 | PARAMS = np.zeros([1, 0, pt.rderr.d], np.float32)
52 |
53 | if pt.rderr.canvas_color == 'white':
54 | CANVAS_tmp = torch.ones([1, 3, 128, 128]).to(device)
55 | else:
56 | CANVAS_tmp = torch.zeros([1, 3, 128, 128]).to(device)
57 |
58 | for pt.m_grid in range(1, pt.max_divide + 1):
59 |
60 | pt.img_batch = utils.img2patches(pt.img_, pt.m_grid).to(device)
61 | pt.G_final_pred_canvas = CANVAS_tmp
62 |
63 | pt.initialize_params()
64 | pt.x_ctt.requires_grad = True
65 | pt.x_color.requires_grad = True
66 | pt.x_alpha.requires_grad = True
67 | utils.set_requires_grad(pt.net_G, False)
68 |
69 | pt.optimizer_x = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr, centered=True)
70 |
71 | pt.step_id = 0
72 | for pt.anchor_id in range(0, pt.m_strokes_per_block):
73 | pt.stroke_sampler(pt.anchor_id)
74 | iters_per_stroke = 40
75 | for i in range(iters_per_stroke):
76 | pt.G_pred_canvas = CANVAS_tmp
77 |
78 | # update x
79 | pt.optimizer_x.zero_grad()
80 |
81 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
82 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
83 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)
84 |
85 | pt._forward_pass()
86 | pt._drawing_step_states()
87 | pt._backward_x()
88 |
89 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
90 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
91 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)
92 |
93 | pt.optimizer_x.step()
94 | pt.step_id += 1
95 |
96 | v = pt._normalize_strokes(pt.x)
97 | PARAMS = np.concatenate([PARAMS, np.reshape(v, [1, -1, pt.rderr.d])], axis=1)
98 | CANVAS_tmp = pt._render(PARAMS)[-1]
99 | CANVAS_tmp = utils.img2patches(CANVAS_tmp, pt.m_grid + 1, to_tensor=True).to(device)
100 |
101 | pt._save_stroke_params(PARAMS)
102 | pt.final_rendered_images = pt._render(PARAMS)
103 | pt._save_rendered_images()
104 |
105 |
106 |
107 | if __name__ == '__main__':
108 |
109 | pt = ProgressivePainter(args=args)
110 | optimize_x(pt)
111 |
112 |
--------------------------------------------------------------------------------
/demo_8bitart.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | torch.cuda.current_device()
5 | import torch.optim as optim
6 |
7 | from painter import *
8 |
9 | # settings
10 | parser = argparse.ArgumentParser(description='STYLIZED NEURAL PAINTING')
11 | parser.add_argument('--img_path', type=str, default='./test_images/sunflowers.jpg', metavar='str',
12 | help='path to test image (default: ./test_images/sunflowers.jpg)')
13 | parser.add_argument('--renderer', type=str, default='rectangle', metavar='str',
14 | help='renderer: [watercolor, markerpen, oilpaintbrush, rectangle (default oilpaintbrush)')
15 | parser.add_argument('--canvas_color', type=str, default='black', metavar='str',
16 | help='canvas_color: [black, white] (default black)')
17 | parser.add_argument('--canvas_size', type=int, default=512, metavar='str',
18 | help='size of the canvas for stroke rendering')
19 | parser.add_argument('--max_m_strokes', type=int, default=500, metavar='str',
20 | help='max number of strokes (default 500)')
21 | parser.add_argument('--max_divide', type=int, default=5, metavar='N',
22 | help='divide an image up-to max_divide x max_divide patches (default 5)')
23 | parser.add_argument('--beta_L1', type=float, default=1.0,
24 | help='weight for L1 loss (default: 1.0)')
25 | parser.add_argument('--with_ot_loss', action='store_true', default=False,
26 | help='imporve the convergence by using optimal transportation loss')
27 | parser.add_argument('--beta_ot', type=float, default=0.1,
28 | help='weight for optimal transportation loss (default: 0.1)')
29 | parser.add_argument('--net_G', type=str, default='zou-fusion-net', metavar='str',
30 | help='net_G: plain-dcgan, plain-unet, huang-net, or zou-fusion-net (default: zou-fusion-net)')
31 | parser.add_argument('--renderer_checkpoint_dir', type=str, default=r'./checkpoints_G_rectangle', metavar='str',
32 | help='dir to load neu-renderer (default: ./checkpoints_G_rectangle)')
33 | parser.add_argument('--lr', type=float, default=0.005,
34 | help='learning rate for stroke searching (default: 0.005)')
35 | parser.add_argument('--output_dir', type=str, default=r'./output', metavar='str',
36 | help='dir to save painting results (default: ./output)')
37 | args = parser.parse_args()
38 |
39 |
40 | # Decide which device we want to run on
41 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42 |
43 |
44 | def optimize_x(pt):
45 |
46 | pt._load_checkpoint()
47 | pt.net_G.eval()
48 |
49 | print('begin drawing...')
50 |
51 | PARAMS = np.zeros([1, 0, pt.rderr.d], np.float32)
52 |
53 | if pt.rderr.canvas_color == 'white':
54 | CANVAS_tmp = torch.ones([1, 3, 128, 128]).to(device)
55 | else:
56 | CANVAS_tmp = torch.zeros([1, 3, 128, 128]).to(device)
57 |
58 | for pt.m_grid in range(1, pt.max_divide + 1):
59 |
60 | pt.img_batch = utils.img2patches(pt.img_, pt.m_grid).to(device)
61 | pt.G_final_pred_canvas = CANVAS_tmp
62 |
63 | pt.initialize_params()
64 | pt.x_ctt.requires_grad = True
65 | pt.x_color.requires_grad = True
66 | pt.x_alpha.requires_grad = True
67 | utils.set_requires_grad(pt.net_G, False)
68 |
69 | pt.optimizer_x = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr, centered=True)
70 |
71 | pt.step_id = 0
72 | for pt.anchor_id in range(0, pt.m_strokes_per_block):
73 | pt.stroke_sampler(pt.anchor_id)
74 | iters_per_stroke = 20
75 | for i in range(iters_per_stroke):
76 | pt.G_pred_canvas = CANVAS_tmp
77 |
78 | # update x
79 | pt.optimizer_x.zero_grad()
80 |
81 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0, 1)
82 | pt.x_ctt.data[:, :, -1] = torch.clamp(pt.x_ctt.data[:, :, -1], 0, 0)
83 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
84 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 1, 1)
85 |
86 | pt._forward_pass()
87 | pt._backward_x()
88 |
89 | pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0, 1)
90 | pt.x_ctt.data[:, :, -1] = torch.clamp(pt.x_ctt.data[:, :, -1], 0, 0)
91 | pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
92 | pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 1, 1)
93 |
94 | pt._drawing_step_states()
95 |
96 | pt.optimizer_x.step()
97 | pt.step_id += 1
98 |
99 | v = pt._normalize_strokes(pt.x)
100 | PARAMS = np.concatenate([PARAMS, np.reshape(v, [1, -1, pt.rderr.d])], axis=1)
101 | CANVAS_tmp = pt._render(PARAMS)[-1]
102 | CANVAS_tmp = utils.img2patches(CANVAS_tmp, pt.m_grid + 1, to_tensor=True).to(device)
103 |
104 | pt._save_stroke_params(PARAMS)
105 | pt.final_rendered_images = pt._render(PARAMS)
106 | pt._save_rendered_images()
107 |
108 |
109 |
110 | if __name__ == '__main__':
111 |
112 | pt = ProgressivePainter(args=args)
113 | optimize_x(pt)
114 |
115 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.cuda.current_device()
3 | import torch.nn as nn
4 | import torchvision
5 | import utils
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import random
9 | import pytorch_batch_sinkhorn as spc
10 |
11 | # Decide which device we want to run on
12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13 |
14 |
15 | class PixelLoss(nn.Module):
16 |
17 | def __init__(self, p=1):
18 | super(PixelLoss, self).__init__()
19 | self.p = p
20 |
21 | def forward(self, canvas, gt, ignore_color=False):
22 | if ignore_color:
23 | canvas = torch.mean(canvas, dim=1)
24 | gt = torch.mean(gt, dim=1)
25 | loss = torch.mean(torch.abs(canvas-gt)**self.p)
26 | return loss
27 |
28 |
29 | class VGGPerceptualLoss(torch.nn.Module):
30 | def __init__(self, resize=True):
31 | super(VGGPerceptualLoss, self).__init__()
32 | vgg = torchvision.models.vgg16(pretrained=True).to(device)
33 | blocks = []
34 | blocks.append(vgg.features[:4].eval())
35 | blocks.append(vgg.features[4:9].eval())
36 | blocks.append(vgg.features[9:16].eval())
37 | blocks.append(vgg.features[16:23].eval())
38 | for bl in blocks:
39 | for p in bl:
40 | p.requires_grad = False
41 | self.blocks = torch.nn.ModuleList(blocks)
42 | self.transform = torch.nn.functional.interpolate
43 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
44 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
45 | self.resize = resize
46 |
47 | def forward(self, input, target, ignore_color=False):
48 | self.mean = self.mean.type_as(input)
49 | self.std = self.std.type_as(input)
50 | if ignore_color:
51 | input = torch.mean(input, dim=1, keepdim=True)
52 | target = torch.mean(target, dim=1, keepdim=True)
53 | if input.shape[1] != 3:
54 | input = input.repeat(1, 3, 1, 1)
55 | target = target.repeat(1, 3, 1, 1)
56 | input = (input-self.mean) / self.std
57 | target = (target-self.mean) / self.std
58 | if self.resize:
59 | input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
60 | target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
61 | loss = 0.0
62 | x = input
63 | y = target
64 | for block in self.blocks:
65 | x = block(x)
66 | y = block(y)
67 | loss += torch.nn.functional.l1_loss(x, y)
68 | return loss
69 |
70 |
71 |
72 | class VGGStyleLoss(torch.nn.Module):
73 | def __init__(self, transfer_mode, resize=True):
74 | super(VGGStyleLoss, self).__init__()
75 | vgg = torchvision.models.vgg16(pretrained=True).to(device)
76 | for i, layer in enumerate(vgg.features):
77 | if isinstance(layer, torch.nn.MaxPool2d):
78 | vgg.features[i] = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
79 |
80 | blocks = []
81 | if transfer_mode == 0: # transfer color only
82 | blocks.append(vgg.features[:4].eval())
83 | blocks.append(vgg.features[4:9].eval())
84 | else: # transfer both color and texture
85 | blocks.append(vgg.features[:4].eval())
86 | blocks.append(vgg.features[4:9].eval())
87 | blocks.append(vgg.features[9:16].eval())
88 | blocks.append(vgg.features[16:23].eval())
89 |
90 | for bl in blocks:
91 | for p in bl:
92 | p.requires_grad = False
93 | self.blocks = torch.nn.ModuleList(blocks)
94 |
95 | self.transform = torch.nn.functional.interpolate
96 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
97 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
98 | self.resize = resize
99 |
100 | def gram_matrix(self, y):
101 | (b, ch, h, w) = y.size()
102 | features = y.view(b, ch, w * h)
103 | features_t = features.transpose(1, 2)
104 | gram = features.bmm(features_t) / (ch * w * h)
105 | return gram
106 |
107 | def forward(self, input, target):
108 | if input.shape[1] != 3:
109 | input = input.repeat(1, 3, 1, 1)
110 | target = target.repeat(1, 3, 1, 1)
111 | input = (input - self.mean) / self.std
112 | target = (target - self.mean) / self.std
113 | if self.resize:
114 | input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
115 | target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
116 | loss = 0.0
117 | x = input
118 | y = target
119 | for block in self.blocks:
120 | x = block(x)
121 | y = block(y)
122 | gm_x = self.gram_matrix(x)
123 | gm_y = self.gram_matrix(y)
124 | loss += torch.sum((gm_x-gm_y)**2)
125 | return loss
126 |
127 |
128 |
129 | class SinkhornLoss(nn.Module):
130 |
131 | def __init__(self, epsilon=0.01, niter=5, normalize=False):
132 | super(SinkhornLoss, self).__init__()
133 | self.epsilon = epsilon
134 | self.niter = niter
135 | self.normalize = normalize
136 |
137 | def _mesh_grids(self, batch_size, h, w):
138 |
139 | a = torch.linspace(0.0, h - 1.0, h).to(device)
140 | b = torch.linspace(0.0, w - 1.0, w).to(device)
141 | y_grid = a.view(-1, 1).repeat(batch_size, 1, w) / h
142 | x_grid = b.view(1, -1).repeat(batch_size, h, 1) / w
143 | grids = torch.cat([y_grid.view(batch_size, -1, 1), x_grid.view(batch_size, -1, 1)], dim=-1)
144 | return grids
145 |
146 | def forward(self, canvas, gt):
147 |
148 | batch_size, c, h, w = gt.shape
149 | if h > 24:
150 | canvas = nn.functional.interpolate(canvas, [24, 24], mode='area')
151 | gt = nn.functional.interpolate(gt, [24, 24], mode='area')
152 | batch_size, c, h, w = gt.shape
153 |
154 | canvas_grids = self._mesh_grids(batch_size, h, w)
155 | gt_grids = torch.clone(canvas_grids)
156 |
157 | # randomly select a color channel, to speedup and consume memory
158 | i = random.randint(0, 2)
159 |
160 | img_1 = canvas[:, [i], :, :]
161 | img_2 = gt[:, [i], :, :]
162 |
163 | mass_x = img_1.reshape(batch_size, -1)
164 | mass_y = img_2.reshape(batch_size, -1)
165 | if self.normalize:
166 | loss = spc.sinkhorn_normalized(
167 | canvas_grids, gt_grids, epsilon=self.epsilon, niter=self.niter,
168 | mass_x=mass_x, mass_y=mass_y)
169 | else:
170 | loss = spc.sinkhorn_loss(
171 | canvas_grids, gt_grids, epsilon=self.epsilon, niter=self.niter,
172 | mass_x=mass_x, mass_y=mass_y)
173 |
174 |
175 | return loss
176 |
177 |
178 |
179 |
180 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Creative Commons Legal Code
2 |
3 | CC0 1.0 Universal
4 |
5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
12 | HEREUNDER.
13 |
14 | Statement of Purpose
15 |
16 | The laws of most jurisdictions throughout the world automatically confer
17 | exclusive Copyright and Related Rights (defined below) upon the creator
18 | and subsequent owner(s) (each and all, an "owner") of an original work of
19 | authorship and/or a database (each, a "Work").
20 |
21 | Certain owners wish to permanently relinquish those rights to a Work for
22 | the purpose of contributing to a commons of creative, cultural and
23 | scientific works ("Commons") that the public can reliably and without fear
24 | of later claims of infringement build upon, modify, incorporate in other
25 | works, reuse and redistribute as freely as possible in any form whatsoever
26 | and for any purposes, including without limitation commercial purposes.
27 | These owners may contribute to the Commons to promote the ideal of a free
28 | culture and the further production of creative, cultural and scientific
29 | works, or to gain reputation or greater distribution for their Work in
30 | part through the use and efforts of others.
31 |
32 | For these and/or other purposes and motivations, and without any
33 | expectation of additional consideration or compensation, the person
34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she
35 | is an owner of Copyright and Related Rights in the Work, voluntarily
36 | elects to apply CC0 to the Work and publicly distribute the Work under its
37 | terms, with knowledge of his or her Copyright and Related Rights in the
38 | Work and the meaning and intended legal effect of CC0 on those rights.
39 |
40 | 1. Copyright and Related Rights. A Work made available under CC0 may be
41 | protected by copyright and related or neighboring rights ("Copyright and
42 | Related Rights"). Copyright and Related Rights include, but are not
43 | limited to, the following:
44 |
45 | i. the right to reproduce, adapt, distribute, perform, display,
46 | communicate, and translate a Work;
47 | ii. moral rights retained by the original author(s) and/or performer(s);
48 | iii. publicity and privacy rights pertaining to a person's image or
49 | likeness depicted in a Work;
50 | iv. rights protecting against unfair competition in regards to a Work,
51 | subject to the limitations in paragraph 4(a), below;
52 | v. rights protecting the extraction, dissemination, use and reuse of data
53 | in a Work;
54 | vi. database rights (such as those arising under Directive 96/9/EC of the
55 | European Parliament and of the Council of 11 March 1996 on the legal
56 | protection of databases, and under any national implementation
57 | thereof, including any amended or successor version of such
58 | directive); and
59 | vii. other similar, equivalent or corresponding rights throughout the
60 | world based on applicable law or treaty, and any national
61 | implementations thereof.
62 |
63 | 2. Waiver. To the greatest extent permitted by, but not in contravention
64 | of, applicable law, Affirmer hereby overtly, fully, permanently,
65 | irrevocably and unconditionally waives, abandons, and surrenders all of
66 | Affirmer's Copyright and Related Rights and associated claims and causes
67 | of action, whether now known or unknown (including existing as well as
68 | future claims and causes of action), in the Work (i) in all territories
69 | worldwide, (ii) for the maximum duration provided by applicable law or
70 | treaty (including future time extensions), (iii) in any current or future
71 | medium and for any number of copies, and (iv) for any purpose whatsoever,
72 | including without limitation commercial, advertising or promotional
73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
74 | member of the public at large and to the detriment of Affirmer's heirs and
75 | successors, fully intending that such Waiver shall not be subject to
76 | revocation, rescission, cancellation, termination, or any other legal or
77 | equitable action to disrupt the quiet enjoyment of the Work by the public
78 | as contemplated by Affirmer's express Statement of Purpose.
79 |
80 | 3. Public License Fallback. Should any part of the Waiver for any reason
81 | be judged legally invalid or ineffective under applicable law, then the
82 | Waiver shall be preserved to the maximum extent permitted taking into
83 | account Affirmer's express Statement of Purpose. In addition, to the
84 | extent the Waiver is so judged Affirmer hereby grants to each affected
85 | person a royalty-free, non transferable, non sublicensable, non exclusive,
86 | irrevocable and unconditional license to exercise Affirmer's Copyright and
87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the
88 | maximum duration provided by applicable law or treaty (including future
89 | time extensions), (iii) in any current or future medium and for any number
90 | of copies, and (iv) for any purpose whatsoever, including without
91 | limitation commercial, advertising or promotional purposes (the
92 | "License"). The License shall be deemed effective as of the date CC0 was
93 | applied by Affirmer to the Work. Should any part of the License for any
94 | reason be judged legally invalid or ineffective under applicable law, such
95 | partial invalidity or ineffectiveness shall not invalidate the remainder
96 | of the License, and in such case Affirmer hereby affirms that he or she
97 | will not (i) exercise any of his or her remaining Copyright and Related
98 | Rights in the Work or (ii) assert any associated claims and causes of
99 | action with respect to the Work, in either case contrary to Affirmer's
100 | express Statement of Purpose.
101 |
102 | 4. Limitations and Disclaimers.
103 |
104 | a. No trademark or patent rights held by Affirmer are waived, abandoned,
105 | surrendered, licensed or otherwise affected by this document.
106 | b. Affirmer offers the Work as-is and makes no representations or
107 | warranties of any kind concerning the Work, express, implied,
108 | statutory or otherwise, including without limitation warranties of
109 | title, merchantability, fitness for a particular purpose, non
110 | infringement, or the absence of latent or other defects, accuracy, or
111 | the present or absence of errors, whether or not discoverable, all to
112 | the greatest extent permissible under applicable law.
113 | c. Affirmer disclaims responsibility for clearing rights of other persons
114 | that may apply to the Work or any use thereof, including without
115 | limitation any person's Copyright and Related Rights in the Work.
116 | Further, Affirmer disclaims responsibility for obtaining any necessary
117 | consents, permissions or other rights required for any use of the
118 | Work.
119 | d. Affirmer understands and acknowledges that Creative Commons is not a
120 | party to this document and has no duty or obligation with respect to
121 | this CC0 or use of the Work.
122 |
--------------------------------------------------------------------------------
/imitator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import os
4 |
5 | import utils
6 | import loss
7 | from networks import *
8 |
9 | import torch
10 | torch.cuda.current_device()
11 | import torch.optim as optim
12 | from torch.optim import lr_scheduler
13 | import torch.nn as nn
14 |
15 | import renderer
16 |
17 |
18 | # Decide which device we want to run on
19 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20 |
21 |
22 | class Imitator():
23 |
24 | def __init__(self, args, dataloaders):
25 |
26 | self.dataloaders = dataloaders
27 |
28 | self.rderr = renderer.Renderer(renderer=args.renderer)
29 |
30 | # define G
31 | self.net_G = define_G(rdrr=self.rderr, netG=args.net_G).to(device)
32 |
33 | # Learning rate
34 | self.lr = args.lr
35 |
36 | # define optimizers
37 | self.optimizer_G = optim.Adam(
38 | self.net_G.parameters(), lr=self.lr, betas=(0.9, 0.999))
39 |
40 | # define lr schedulers
41 | self.exp_lr_scheduler_G = lr_scheduler.StepLR(
42 | self.optimizer_G, step_size=100, gamma=0.1)
43 |
44 | # define some other vars to record the training states
45 | self.running_acc = []
46 | self.epoch_acc = 0
47 | self.best_val_acc = 0.0
48 | self.best_epoch_id = 0
49 | self.epoch_to_start = 0
50 | self.max_num_epochs = args.max_num_epochs
51 | self.G_pred_foreground = None
52 | self.G_pred_alpha = None
53 | self.batch = None
54 | self.G_loss = None
55 | self.is_training = False
56 | self.batch_id = 0
57 | self.epoch_id = 0
58 | self.checkpoint_dir = args.checkpoint_dir
59 | self.vis_dir = args.vis_dir
60 |
61 | # define the loss functions
62 | self._pxl_loss = loss.PixelLoss(p=2)
63 |
64 | self.VAL_ACC = np.array([], np.float32)
65 | if os.path.exists(os.path.join(self.checkpoint_dir, 'val_acc.npy')):
66 | self.VAL_ACC = np.load(os.path.join(self.checkpoint_dir, 'val_acc.npy'))
67 |
68 | # check and create model dir
69 | if os.path.exists(self.checkpoint_dir) is False:
70 | os.mkdir(self.checkpoint_dir)
71 | if os.path.exists(self.vis_dir) is False:
72 | os.mkdir(self.vis_dir)
73 |
74 | # visualize model
75 | if args.print_models:
76 | self._visualize_models()
77 |
78 |
79 | def _visualize_models(self):
80 |
81 | from torchviz import make_dot
82 |
83 | # visualize models with the package torchviz
84 | data = next(iter(self.dataloaders['train']))
85 | y = self.net_G(data['A'].to(device))
86 | mygraph = make_dot(y.mean(), params=dict(self.net_G.named_parameters()))
87 | mygraph.render('G')
88 |
89 |
90 | def _load_checkpoint(self):
91 |
92 | if os.path.exists(os.path.join(self.checkpoint_dir, 'last_ckpt.pt')):
93 | print('loading last checkpoint...')
94 | # load the entire checkpoint
95 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, 'last_ckpt.pt'))
96 |
97 | # update net_G states
98 | self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
99 | self.optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
100 | self.exp_lr_scheduler_G.load_state_dict(
101 | checkpoint['exp_lr_scheduler_G_state_dict'])
102 | self.net_G.to(device)
103 |
104 | # update some other states
105 | self.epoch_to_start = checkpoint['epoch_id'] + 1
106 | self.best_val_acc = checkpoint['best_val_acc']
107 | self.best_epoch_id = checkpoint['best_epoch_id']
108 |
109 | print('Epoch_to_start = %d, Historical_best_acc = %.4f (at epoch %d)' %
110 | (self.epoch_to_start, self.best_val_acc, self.best_epoch_id))
111 | print()
112 |
113 | else:
114 | print('training from scratch...')
115 |
116 |
117 | def _save_checkpoint(self, ckpt_name):
118 | torch.save({
119 | 'epoch_id': self.epoch_id,
120 | 'best_val_acc': self.best_val_acc,
121 | 'best_epoch_id': self.best_epoch_id,
122 | 'model_G_state_dict': self.net_G.state_dict(),
123 | 'optimizer_G_state_dict': self.optimizer_G.state_dict(),
124 | 'exp_lr_scheduler_G_state_dict': self.exp_lr_scheduler_G.state_dict()
125 | }, os.path.join(self.checkpoint_dir, ckpt_name))
126 |
127 |
128 | def _update_lr_schedulers(self):
129 | self.exp_lr_scheduler_G.step()
130 |
131 |
132 | def _compute_acc(self):
133 |
134 | target_foreground = self.batch['B'].to(device).detach()
135 | target_alpha_map = self.batch['ALPHA'].to(device).detach()
136 | foreground = self.G_pred_foreground.detach()
137 | alpha_map = self.G_pred_alpha.detach()
138 | psnr1 = utils.cpt_batch_psnr(foreground, target_foreground, PIXEL_MAX=1.0)
139 | psnr2 = utils.cpt_batch_psnr(alpha_map, target_alpha_map, PIXEL_MAX=1.0)
140 | return (psnr1 + psnr2)/2.0
141 |
142 |
143 | def _collect_running_batch_states(self):
144 | self.running_acc.append(self._compute_acc().item())
145 |
146 | m = len(self.dataloaders['train'])
147 | if self.is_training is False:
148 | m = len(self.dataloaders['val'])
149 |
150 | if np.mod(self.batch_id, 100) == 1:
151 | print('Is_training: %s. [%d,%d][%d,%d], G_loss: %.5f, running_acc: %.5f'
152 | % (self.is_training, self.epoch_id, self.max_num_epochs-1, self.batch_id, m,
153 | self.G_loss.item(), np.mean(self.running_acc)))
154 |
155 | if np.mod(self.batch_id, 1000) == 1:
156 | vis_pred_foreground = utils.make_numpy_grid(self.G_pred_foreground)
157 | vis_gt_foreground = utils.make_numpy_grid(self.batch['B'])
158 | vis_pred_alpha = utils.make_numpy_grid(self.G_pred_alpha)
159 | vis_gt_alpha = utils.make_numpy_grid(self.batch['ALPHA'])
160 | vis = np.concatenate([vis_pred_foreground, vis_gt_foreground,
161 | vis_pred_alpha, vis_gt_alpha], axis=0)
162 | vis = np.clip(vis, a_min=0.0, a_max=1.0)
163 | file_name = os.path.join(
164 | self.vis_dir, 'istrain_'+str(self.is_training)+'_'+
165 | str(self.epoch_id)+'_'+str(self.batch_id)+'.jpg')
166 | plt.imsave(file_name, vis)
167 |
168 |
169 |
170 | def _collect_epoch_states(self):
171 |
172 | self.epoch_acc = np.mean(self.running_acc)
173 | print('Is_training: %s. Epoch %d / %d, epoch_acc= %.5f' %
174 | (self.is_training, self.epoch_id, self.max_num_epochs-1, self.epoch_acc))
175 | print()
176 |
177 |
178 | def _update_checkpoints(self):
179 |
180 | # save current model
181 | self._save_checkpoint(ckpt_name='last_ckpt.pt')
182 | print('Lastest model updated. Epoch_acc=%.4f, Historical_best_acc=%.4f (at epoch %d)'
183 | % (self.epoch_acc, self.best_val_acc, self.best_epoch_id))
184 | print()
185 |
186 | self.VAL_ACC = np.append(self.VAL_ACC, [self.epoch_acc])
187 | np.save(os.path.join(self.checkpoint_dir, 'val_acc.npy'), self.VAL_ACC)
188 |
189 | # update the best model (based on eval acc)
190 | if self.epoch_acc > self.best_val_acc:
191 | self.best_val_acc = self.epoch_acc
192 | self.best_epoch_id = self.epoch_id
193 | self._save_checkpoint(ckpt_name='best_ckpt.pt')
194 | print('*' * 10 + 'Best model updated!')
195 | print()
196 |
197 |
198 | def _clear_cache(self):
199 | self.running_acc = []
200 |
201 |
202 | def _forward_pass(self, batch):
203 | self.batch = batch
204 | z_in = batch['A'].to(device)
205 | self.G_pred_foreground, self.G_pred_alpha = self.net_G(z_in)
206 |
207 |
208 | def _backward_G(self):
209 |
210 | gt_foreground = self.batch['B'].to(device)
211 | gt_alpha = self.batch['ALPHA'].to(device)
212 |
213 | pixel_loss1 = self._pxl_loss(self.G_pred_foreground, gt_foreground)
214 | pixel_loss2 = self._pxl_loss(self.G_pred_alpha, gt_alpha)
215 | self.G_loss = 100 * (pixel_loss1 + pixel_loss2) / 2.0
216 | self.G_loss.backward()
217 |
218 |
219 | def train_models(self):
220 |
221 | self._load_checkpoint()
222 |
223 | # loop over the dataset multiple times
224 | for self.epoch_id in range(self.epoch_to_start, self.max_num_epochs):
225 |
226 | ################## train #################
227 | ##########################################
228 | self._clear_cache()
229 | self.is_training = True
230 | self.net_G.train() # Set model to training mode
231 | # Iterate over data.
232 | for self.batch_id, batch in enumerate(self.dataloaders['train'], 0):
233 | self._forward_pass(batch)
234 | # update G
235 | self.optimizer_G.zero_grad()
236 | self._backward_G()
237 | self.optimizer_G.step()
238 | self._collect_running_batch_states()
239 | self._collect_epoch_states()
240 | self._update_lr_schedulers()
241 |
242 | ########### Update_Checkpoints ###########
243 | ##########################################
244 | self._update_checkpoints()
245 |
246 |
--------------------------------------------------------------------------------
/image_to_paint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "accelerator": "GPU",
6 | "colab": {
7 | "name": "image-to-paint",
8 | "provenance": [],
9 | "collapsed_sections": [],
10 | "machine_shape": "hm",
11 | "include_colab_link": true
12 | },
13 | "kernelspec": {
14 | "display_name": "Python 3",
15 | "name": "python3"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "LFLeFyHW7AfS"
33 | },
34 | "source": [
35 | "# Githubからコードをコピー\n",
36 | "\n"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "metadata": {
42 | "id": "h72s1Mk26j4V"
43 | },
44 | "source": [
45 | "# githubのコードをコピー\n",
46 | "!git clone https://github.com/cedro3/stylized-neural-painting.git\n",
47 | "%cd stylized-neural-painting"
48 | ],
49 | "execution_count": null,
50 | "outputs": []
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "source": [
55 | "# 学習済みモデルのダウンロード"
56 | ],
57 | "metadata": {
58 | "id": "FDboP9pR76ma"
59 | }
60 | },
61 | {
62 | "cell_type": "code",
63 | "source": [
64 | "# 学習済みパラメータのダウンロード\n",
65 | "! pip install --upgrade gdown\n",
66 | "import gdown\n",
67 | "gdown.download('https://drive.google.com/uc?id=1sqWhgBKqaBJggl2A8sD1bLSq2_B1ScMG', './checkpoints_G_oilpaintbrush.zip', quiet=False)\n",
68 | "! unzip checkpoints_G_oilpaintbrush.zip"
69 | ],
70 | "metadata": {
71 | "id": "3EWAl4aq73JT"
72 | },
73 | "execution_count": null,
74 | "outputs": []
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "metadata": {
79 | "id": "RwEdnkcI5fFE"
80 | },
81 | "source": [
82 | "# コード本体"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "metadata": {
88 | "id": "eD8de5Il7JdO"
89 | },
90 | "source": [
91 | "import argparse\n",
92 | "import torch\n",
93 | "torch.cuda.current_device()\n",
94 | "import torch.optim as optim\n",
95 | "from painter import *\n",
96 | "\n",
97 | "# Decide which device we want to run on\n",
98 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
99 | "\n",
100 | "# settings\n",
101 | "parser = argparse.ArgumentParser(description='STYLIZED NEURAL PAINTING')\n",
102 | "args = parser.parse_args(args=[])\n",
103 | "args.img_path = './test_images/kasumi.png' # path to input photo\n",
104 | "args.renderer = 'oilpaintbrush' # [watercolor, markerpen, oilpaintbrush, rectangle]\n",
105 | "args.canvas_color = 'black' # [black, white]\n",
106 | "args.canvas_size = 512 # size of the canvas for stroke rendering'\n",
107 | "args.max_m_strokes = 500 # max number of strokes\n",
108 | "args.max_divide = 5 # divide an image up-to max_divide x max_divide patches\n",
109 | "args.beta_L1 = 1.0 # weight for L1 loss\n",
110 | "args.with_ot_loss = False # set True for imporving the convergence by using optimal transportation loss, but will slow-down the speed\n",
111 | "args.beta_ot = 0.1 # weight for optimal transportation loss\n",
112 | "args.net_G = 'zou-fusion-net' # renderer architecture\n",
113 | "args.renderer_checkpoint_dir = './checkpoints_G_oilpaintbrush' # dir to load the pretrained neu-renderer\n",
114 | "args.lr = 0.005 # learning rate for stroke searching\n",
115 | "args.output_dir = './output' # dir to save painting results\n",
116 | "\n",
117 | "\n",
118 | "def _drawing_step_states(pt):\n",
119 | " acc = pt._compute_acc().item()\n",
120 | " print('iteration step %d, G_loss: %.5f, step_acc: %.5f, grid_scale: %d / %d, strokes: %d / %d'\n",
121 | " % (pt.step_id, pt.G_loss.item(), acc,\n",
122 | " pt.m_grid, pt.max_divide,\n",
123 | " pt.anchor_id, pt.m_strokes_per_block))\n",
124 | " vis2 = utils.patches2img(pt.G_final_pred_canvas, pt.m_grid).clip(min=0, max=1)\n",
125 | "\n",
126 | "\n",
127 | "def optimize_x(pt):\n",
128 | " pt._load_checkpoint()\n",
129 | " pt.net_G.eval()\n",
130 | " print('begin drawing...')\n",
131 | "\n",
132 | " PARAMS = np.zeros([1, 0, pt.rderr.d], np.float32)\n",
133 | "\n",
134 | " if pt.rderr.canvas_color == 'white':\n",
135 | " CANVAS_tmp = torch.ones([1, 3, 128, 128]).to(device)\n",
136 | " else:\n",
137 | " CANVAS_tmp = torch.zeros([1, 3, 128, 128]).to(device)\n",
138 | "\n",
139 | " for pt.m_grid in range(1, pt.max_divide + 1):\n",
140 | "\n",
141 | " pt.img_batch = utils.img2patches(pt.img_, pt.m_grid).to(device)\n",
142 | " pt.G_final_pred_canvas = CANVAS_tmp\n",
143 | "\n",
144 | " pt.initialize_params()\n",
145 | " pt.x_ctt.requires_grad = True\n",
146 | " pt.x_color.requires_grad = True\n",
147 | " pt.x_alpha.requires_grad = True\n",
148 | " utils.set_requires_grad(pt.net_G, False)\n",
149 | "\n",
150 | " pt.optimizer_x = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr, centered=True)\n",
151 | "\n",
152 | " pt.step_id = 0\n",
153 | " for pt.anchor_id in range(0, pt.m_strokes_per_block):\n",
154 | " pt.stroke_sampler(pt.anchor_id)\n",
155 | " iters_per_stroke = 80\n",
156 | " for i in range(iters_per_stroke):\n",
157 | " pt.G_pred_canvas = CANVAS_tmp\n",
158 | "\n",
159 | " # update x\n",
160 | " pt.optimizer_x.zero_grad()\n",
161 | "\n",
162 | " pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)\n",
163 | " pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)\n",
164 | " pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)\n",
165 | "\n",
166 | " pt._forward_pass()\n",
167 | " _drawing_step_states(pt)\n",
168 | " pt._backward_x()\n",
169 | "\n",
170 | " pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)\n",
171 | " pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)\n",
172 | " pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)\n",
173 | "\n",
174 | " pt.optimizer_x.step()\n",
175 | " pt.step_id += 1\n",
176 | "\n",
177 | " v = pt._normalize_strokes(pt.x)\n",
178 | " PARAMS = np.concatenate([PARAMS, np.reshape(v, [1, -1, pt.rderr.d])], axis=1)\n",
179 | " CANVAS_tmp = pt._render(PARAMS)[-1]\n",
180 | " CANVAS_tmp = utils.img2patches(CANVAS_tmp, pt.m_grid + 1, to_tensor=True).to(device)\n",
181 | "\n",
182 | " pt._save_stroke_params(PARAMS)\n",
183 | " pt.final_rendered_images = pt._render(PARAMS)\n",
184 | " pt._save_rendered_images()"
185 | ],
186 | "execution_count": null,
187 | "outputs": []
188 | },
189 | {
190 | "cell_type": "markdown",
191 | "metadata": {
192 | "id": "b_QxLKdc-7nr"
193 | },
194 | "source": [
195 | "# レンダリング"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "metadata": {
201 | "id": "7yGUv65G-UCm"
202 | },
203 | "source": [
204 | " pt = ProgressivePainter(args=args)\n",
205 | " optimize_x(pt)"
206 | ],
207 | "execution_count": null,
208 | "outputs": []
209 | },
210 | {
211 | "cell_type": "markdown",
212 | "metadata": {
213 | "id": "-p4XHLUh_ppN"
214 | },
215 | "source": [
216 | "# 画像表示"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "metadata": {
222 | "id": "lmYpxdcW_Azv"
223 | },
224 | "source": [
225 | "# show picture\n",
226 | "fig = plt.figure(figsize=(8,4))\n",
227 | "plt.subplot(1,2,1)\n",
228 | "plt.imshow(pt.img_), plt.title('input')\n",
229 | "plt.subplot(1,2,2)\n",
230 | "plt.imshow(pt.final_rendered_images[-1]), plt.title('generated')\n",
231 | "plt.show()"
232 | ],
233 | "execution_count": null,
234 | "outputs": []
235 | },
236 | {
237 | "cell_type": "code",
238 | "metadata": {
239 | "id": "5LAIizo7_rxz"
240 | },
241 | "source": [
242 | "# make animation\n",
243 | "import matplotlib.animation as animation\n",
244 | "from IPython.display import HTML\n",
245 | "\n",
246 | "fig = plt.figure(figsize=(8,8))\n",
247 | "plt.axis('off')\n",
248 | "ims = [[plt.imshow(img, animated=True)] for img in pt.final_rendered_images[::10]]\n",
249 | "ani = animation.ArtistAnimation(fig, ims, interval=100)\n",
250 | "\n",
251 | "HTML(ani.to_jshtml())"
252 | ],
253 | "execution_count": null,
254 | "outputs": []
255 | },
256 | {
257 | "cell_type": "code",
258 | "metadata": {
259 | "id": "AlDC5MmeYuYL"
260 | },
261 | "source": [
262 | "# save animation\n",
263 | "ani.save('anime.mp4', writer='ffmpeg')"
264 | ],
265 | "execution_count": null,
266 | "outputs": []
267 | }
268 | ]
269 | }
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import cv2
4 | #from skimage.measure import compare_ssim as sk_cpt_ssim
5 | from skimage.metrics import structural_similarity as sk_cpt_ssim
6 |
7 |
8 | import os
9 | import glob
10 | import random
11 |
12 | import torch
13 | torch.cuda.current_device()
14 | import torchvision.transforms.functional as TF
15 | from torch.utils.data import Dataset, DataLoader, Subset
16 | from torchvision import transforms, utils
17 | import renderer
18 |
19 |
20 |
21 | M_RENDERING_SAMPLES_PER_EPOCH = 50000
22 |
23 | class PairedDataAugmentation:
24 |
25 | def __init__(
26 | self,
27 | img_size,
28 | with_random_hflip=False,
29 | with_random_vflip=False,
30 | with_random_rot90=False,
31 | with_random_rot180=False,
32 | with_random_rot270=False,
33 | with_random_crop=False,
34 | with_random_patch=False
35 | ):
36 | self.img_size = img_size
37 | self.with_random_hflip = with_random_hflip
38 | self.with_random_vflip = with_random_vflip
39 | self.with_random_rot90 = with_random_rot90
40 | self.with_random_rot180 = with_random_rot180
41 | self.with_random_rot270 = with_random_rot270
42 | self.with_random_crop = with_random_crop
43 | self.with_random_patch = with_random_patch
44 |
45 | def transform(self, img1, img2):
46 |
47 | # resize image and covert to tensor
48 | img1 = TF.to_pil_image(img1)
49 | img1 = TF.resize(img1, [self.img_size, self.img_size], interpolation=3)
50 | img2 = TF.to_pil_image(img2)
51 | img2 = TF.resize(img2, [self.img_size, self.img_size], interpolation=3)
52 |
53 | if self.with_random_hflip and random.random() > 0.5:
54 | img1 = TF.hflip(img1)
55 | img2 = TF.hflip(img2)
56 |
57 | if self.with_random_vflip and random.random() > 0.5:
58 | img1 = TF.vflip(img1)
59 | img2 = TF.vflip(img2)
60 |
61 | if self.with_random_rot90 and random.random() > 0.5:
62 | img1 = TF.rotate(img1, 90)
63 | img2 = TF.rotate(img2, 90)
64 |
65 | if self.with_random_rot180 and random.random() > 0.5:
66 | img1 = TF.rotate(img1, 180)
67 | img2 = TF.rotate(img2, 180)
68 |
69 | if self.with_random_rot270 and random.random() > 0.5:
70 | img1 = TF.rotate(img1, 270)
71 | img2 = TF.rotate(img2, 270)
72 |
73 | if self.with_random_crop and random.random() > 0.5:
74 | i, j, h, w = transforms.RandomResizedCrop(size=self.img_size). \
75 | get_params(img=img1, scale=(0.5, 1.0), ratio=(0.9, 1.1))
76 | img1 = TF.resized_crop(
77 | img1, i, j, h, w, size=(self.img_size, self.img_size))
78 | img2 = TF.resized_crop(
79 | img2, i, j, h, w, size=(self.img_size, self.img_size))
80 |
81 | if self.with_random_patch:
82 | i, j, h, w = transforms.RandomResizedCrop(size=self.img_size). \
83 | get_params(img=img1, scale=(1/16.0, 1/9.0), ratio=(0.9, 1.1))
84 | img1 = TF.resized_crop(
85 | img1, i, j, h, w, size=(self.img_size, self.img_size))
86 | img2 = TF.resized_crop(
87 | img2, i, j, h, w, size=(self.img_size, self.img_size))
88 |
89 | # to tensor
90 | img1 = TF.to_tensor(img1)
91 | img2 = TF.to_tensor(img2)
92 |
93 | return img1, img2
94 |
95 |
96 |
97 | class StrokeDataset(Dataset):
98 |
99 | def __init__(self, renderer_type, is_train=True):
100 | self.rderr = renderer.Renderer(renderer=renderer_type, CANVAS_WIDTH=128, train=True)
101 | self.is_train = is_train
102 |
103 | def __len__(self):
104 | if self.is_train:
105 | return M_RENDERING_SAMPLES_PER_EPOCH
106 | else:
107 | return int(M_RENDERING_SAMPLES_PER_EPOCH / 20)
108 |
109 | def __getitem__(self, idx):
110 |
111 | self.rderr.foreground = None
112 | self.rderr.stroke_alpha_map = None
113 |
114 | self.rderr.random_stroke_params()
115 | self.rderr.draw_stroke()
116 |
117 | # to tensor
118 | params = torch.tensor(np.array(self.rderr.stroke_params, dtype=np.float32))
119 | params = torch.reshape(params, [-1, 1, 1])
120 | foreground = TF.to_tensor(np.array(self.rderr.foreground, dtype=np.float32))
121 | stroke_alpha_map = TF.to_tensor(np.array(self.rderr.stroke_alpha_map, dtype=np.float32))
122 |
123 | data = {'A': params, 'B': foreground, 'ALPHA': stroke_alpha_map}
124 |
125 | return data
126 |
127 |
128 |
129 |
130 | def get_renderer_loaders(args):
131 |
132 | training_set = StrokeDataset(renderer_type=args.renderer, is_train=True)
133 | val_set = StrokeDataset(renderer_type=args.renderer, is_train=False)
134 |
135 | datasets = {'train': training_set, 'val': val_set}
136 | dataloaders = {x: DataLoader(datasets[x], batch_size=args.batch_size,
137 | shuffle=True, num_workers=4)
138 | for x in ['train', 'val']}
139 |
140 | return dataloaders
141 |
142 |
143 |
144 | def set_requires_grad(nets, requires_grad=False):
145 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
146 | Parameters:
147 | nets (network list) -- a list of networks
148 | requires_grad (bool) -- whether the networks require gradients or not
149 | """
150 | if not isinstance(nets, list):
151 | nets = [nets]
152 | for net in nets:
153 | if net is not None:
154 | for param in net.parameters():
155 | param.requires_grad = requires_grad
156 |
157 |
158 |
159 | def make_numpy_grid(tensor_data):
160 | tensor_data = tensor_data.detach()
161 | vis = utils.make_grid(tensor_data)
162 | vis = np.array(vis.cpu()).transpose((1,2,0))
163 | if vis.shape[2] == 1:
164 | vis = np.stack([vis, vis, vis], axis=-1)
165 | return vis.clip(min=0, max=1)
166 |
167 |
168 |
169 | def tensor2img(tensor_data):
170 | if tensor_data.shape[0] > 1:
171 | raise NotImplementedError('batch size > 1, please use make_numpy_grid')
172 | tensor_data = tensor_data.detach()[0, :]
173 | img = np.array(tensor_data.cpu()).transpose((1, 2, 0))
174 | if img.shape[2] == 1:
175 | img = np.stack([img, img, img], axis=-1)
176 | return img.clip(min=0, max=1)
177 |
178 |
179 |
180 | def cpt_ssim(img, img_gt, normalize=False):
181 |
182 | if normalize:
183 | img = (img - img.min()) / (img.max() - img.min() + 1e-9)
184 | img_gt = (img_gt - img_gt.min()) / (img_gt.max() - img_gt.min() + 1e-9)
185 |
186 | SSIM = sk_cpt_ssim(img, img_gt, data_range=img_gt.max() - img_gt.min())
187 |
188 | return SSIM
189 |
190 |
191 | def cpt_psnr(img, img_gt, PIXEL_MAX=1.0, normalize=False):
192 |
193 | if normalize:
194 | img = (img - img.min()) / (img.max() - img.min() + 1e-9)
195 | img_gt = (img_gt - img_gt.min()) / (img_gt.max() - img_gt.min() + 1e-9)
196 |
197 | mse = np.mean((img - img_gt) ** 2)
198 | psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse))
199 |
200 | return psnr
201 |
202 |
203 | def cpt_cos_similarity(img, img_gt, normalize=False):
204 |
205 | if normalize:
206 | img = (img - img.min()) / (img.max() - img.min() + 1e-9)
207 | img_gt = (img_gt - img_gt.min()) / (img_gt.max() - img_gt.min() + 1e-9)
208 |
209 | cos_dist = np.sum(img*img_gt) / np.sqrt(np.sum(img**2)*np.sum(img_gt**2) + 1e-9)
210 |
211 | return cos_dist
212 |
213 |
214 | def cpt_batch_psnr(img, img_gt, PIXEL_MAX):
215 | mse = torch.mean((img - img_gt) ** 2)
216 | psnr = 20 * torch.log10(PIXEL_MAX / torch.sqrt(mse))
217 | return psnr
218 |
219 |
220 | def rotate_pt(pt, rotate_center, theta, return_int=True):
221 |
222 | # theta in [0, pi]
223 | x, y = pt[0], pt[1]
224 | xc, yc = rotate_center[0], rotate_center[1]
225 |
226 | x_ = (x-xc) * np.cos(theta) + (y-yc) * np.sin(theta) + xc
227 | y_ = -1 * (x-xc) * np.sin(theta) + (y-yc) * np.cos(theta) + yc
228 |
229 | if return_int:
230 | x_, y_ = int(x_), int(y_)
231 |
232 | pt_ = (x_, y_)
233 |
234 | return pt_
235 |
236 |
237 | def img2patches(img, m_grid, to_tensor=True):
238 | # input img: h, w, 3 (np.float32)
239 | # output patches: N, 3, 128, 128 (tensor)
240 |
241 | img = cv2.resize(img, (m_grid * 128, m_grid * 128))
242 | img_batch = np.zeros([m_grid ** 2, 3, 128, 128])
243 | for y_id in range(m_grid):
244 | for x_id in range(m_grid):
245 | patch = img[y_id * 128:y_id * 128 + 128,
246 | x_id * 128:x_id * 128 + 128, :].transpose([2, 0, 1])
247 | img_batch[y_id * m_grid + x_id, :, :, :] = patch
248 |
249 | if to_tensor:
250 | img_batch = torch.tensor(img_batch)
251 |
252 | return img_batch
253 |
254 |
255 |
256 | def patches2img(img_batch, m_grid, to_numpy=True):
257 | # input patches: m_grid**2, 3, 128, 128 (tensor)
258 | # output img: 128*m_grid, 128*m_grid, 3 (np.float32)
259 |
260 | img = torch.zeros([128*m_grid, 128*m_grid, 3])
261 |
262 | for y_id in range(m_grid):
263 | for x_id in range(m_grid):
264 | patch = img_batch[y_id * m_grid + x_id, :, :, :]
265 | img[y_id * 128:y_id * 128 + 128, x_id * 128:x_id * 128 + 128, :] \
266 | = patch.permute([1, 2, 0])
267 | if to_numpy:
268 | img = img.detach().numpy()
269 | else:
270 | img = img.permute([2,0,1]).unsqueeze(0)
271 |
272 | return img
273 |
274 |
275 |
276 |
277 | def create_transformed_brush(brush, canvas_w, canvas_h,
278 | x0, y0, w, h, theta, R0, G0, B0, R2, G2, B2):
279 |
280 | brush_alpha = np.stack([brush, brush, brush], axis=-1)
281 | brush_alpha = (brush_alpha > 0).astype(np.float32)
282 | brush_alpha = (brush_alpha*255).astype(np.uint8)
283 | colormap = np.zeros([brush.shape[0], brush.shape[1], 3], np.float32)
284 | for ii in range(brush.shape[0]):
285 | t = ii / brush.shape[0]
286 | this_color = [(1 - t) * R0 + t * R2,
287 | (1 - t) * G0 + t * G2,
288 | (1 - t) * B0 + t * B2]
289 | colormap[ii, :, :] = np.expand_dims(this_color, axis=0)
290 |
291 | brush = np.expand_dims(brush, axis=-1).astype(np.float32) / 255.
292 | brush = (brush * colormap * 255).astype(np.uint8)
293 | # plt.imshow(brush), plt.show()
294 |
295 | M1 = build_transformation_matrix([-brush.shape[1]/2, -brush.shape[0]/2, 0])
296 | M2 = build_scale_matrix(sx=w/brush.shape[1], sy=h/brush.shape[0])
297 | M3 = build_transformation_matrix([0,0,theta])
298 | M4 = build_transformation_matrix([x0, y0, 0])
299 |
300 | M = update_transformation_matrix(M1, M2)
301 | M = update_transformation_matrix(M, M3)
302 | M = update_transformation_matrix(M, M4)
303 |
304 | brush = cv2.warpAffine(
305 | brush, M, (canvas_w, canvas_h),
306 | borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_AREA)
307 | brush_alpha = cv2.warpAffine(
308 | brush_alpha, M, (canvas_w, canvas_h),
309 | borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_AREA)
310 |
311 | return brush, brush_alpha
312 |
313 |
314 | def build_scale_matrix(sx, sy):
315 | transform_matrix = np.zeros((2, 3))
316 | transform_matrix[0, 0] = sx
317 | transform_matrix[1, 1] = sy
318 | return transform_matrix
319 |
320 |
321 | def update_transformation_matrix(M, m):
322 |
323 | # extend M and m to 3x3 by adding an [0,0,1] to their 3rd row
324 | M_ = np.concatenate([M, np.zeros([1,3])], axis=0)
325 | M_[-1, -1] = 1
326 | m_ = np.concatenate([m, np.zeros([1,3])], axis=0)
327 | m_[-1, -1] = 1
328 |
329 | M_new = np.matmul(m_, M_)
330 | return M_new[0:2, :]
331 | #
332 |
333 | def build_transformation_matrix(transform):
334 | """Convert transform list to transformation matrix
335 |
336 | :param transform: transform list as [dx, dy, da]
337 | :return: transform matrix as 2d (2, 3) numpy array
338 | """
339 | transform_matrix = np.zeros((2, 3))
340 |
341 | transform_matrix[0, 0] = np.cos(transform[2])
342 | transform_matrix[0, 1] = -np.sin(transform[2])
343 | transform_matrix[1, 0] = np.sin(transform[2])
344 | transform_matrix[1, 1] = np.cos(transform[2])
345 | transform_matrix[0, 2] = transform[0]
346 | transform_matrix[1, 2] = transform[1]
347 |
348 | return transform_matrix
349 |
350 |
351 |
352 |
--------------------------------------------------------------------------------
/renderer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import random
4 | import utils
5 |
6 | import matplotlib.pyplot as plt
7 |
8 |
9 | def _random_floats(low, high, size):
10 | return [random.uniform(low, high) for _ in range(size)]
11 |
12 |
13 | def _normalize(x, width):
14 | return (int)(x * (width - 1) + 0.5)
15 |
16 |
17 | class Renderer():
18 |
19 | def __init__(self, renderer='bezier', CANVAS_WIDTH=128, train=False, canvas_color='black'):
20 |
21 | self.CANVAS_WIDTH = CANVAS_WIDTH
22 | self.renderer = renderer
23 | self.stroke_params = None
24 | self.canvas_color = canvas_color
25 |
26 | self.canvas = None
27 | self.create_empty_canvas()
28 |
29 | self.train = train
30 |
31 | if self.renderer in ['markerpen']:
32 | self.d = 12 # x0, y0, x1, y1, x2, y2, radius0, radius2, R, G, B, A
33 | self.d_shape = 8
34 | self.d_color = 3
35 | self.d_alpha = 1
36 | elif self.renderer in ['watercolor']:
37 | self.d = 15 # x0, y0, x1, y1, x2, y2, radius0, radius2, R0, G0, B0, R2, G2, B2, A
38 | self.d_shape = 8
39 | self.d_color = 6
40 | self.d_alpha = 1
41 | elif self.renderer in ['oilpaintbrush']:
42 | self.d = 12 # xc, yc, w, h, theta, R0, G0, B0, R2, G2, B2, A
43 | self.d_shape = 5
44 | self.d_color = 6
45 | self.d_alpha = 1
46 | self.brush_small_vertical = cv2.imread(
47 | r'./brushes/brush_fromweb2_small_vertical.png', cv2.IMREAD_GRAYSCALE)
48 | self.brush_small_horizontal = cv2.imread(
49 | r'./brushes/brush_fromweb2_small_horizontal.png', cv2.IMREAD_GRAYSCALE)
50 | self.brush_large_vertical = cv2.imread(
51 | r'./brushes/brush_fromweb2_large_vertical.png', cv2.IMREAD_GRAYSCALE)
52 | self.brush_large_horizontal = cv2.imread(
53 | r'./brushes/brush_fromweb2_large_horizontal.png', cv2.IMREAD_GRAYSCALE)
54 | elif self.renderer in ['rectangle']:
55 | self.d = 9 # xc, yc, w, h, theta, R, G, B, A
56 | self.d_shape = 5
57 | self.d_color = 3
58 | self.d_alpha = 1
59 | else:
60 | raise NotImplementedError(
61 | 'Wrong renderer name %s (choose one from [watercolor, markerpen, oilpaintbrush, rectangle] ...)'
62 | % self.renderer)
63 |
64 | def create_empty_canvas(self):
65 | if self.canvas_color == 'white':
66 | self.canvas = np.ones(
67 | [self.CANVAS_WIDTH, self.CANVAS_WIDTH, 3]).astype('float32')
68 | else:
69 | self.canvas = np.zeros(
70 | [self.CANVAS_WIDTH, self.CANVAS_WIDTH, 3]).astype('float32')
71 |
72 |
73 | def random_stroke_params(self):
74 | self.stroke_params = np.array(_random_floats(0, 1, self.d), dtype=np.float32)
75 |
76 | def random_stroke_params_sampler(self, err_map, img):
77 |
78 | map_h, map_w, c = img.shape
79 |
80 | err_map = cv2.resize(err_map, (self.CANVAS_WIDTH, self.CANVAS_WIDTH))
81 | err_map[err_map < 0] = 0
82 | if np.all((err_map == 0)):
83 | err_map = np.ones_like(err_map)
84 | err_map = err_map / (np.sum(err_map) + 1e-99)
85 |
86 | index = np.random.choice(range(err_map.size), size=1, p=err_map.ravel())[0]
87 |
88 | cy = (index // self.CANVAS_WIDTH) / self.CANVAS_WIDTH
89 | cx = (index % self.CANVAS_WIDTH) / self.CANVAS_WIDTH
90 |
91 | if self.renderer in ['markerpen']:
92 | # x0, y0, x1, y1, x2, y2, radius0, radius2, R, G, B, A
93 | x0, y0, x1, y1, x2, y2 = cx, cy, cx, cy, cx, cy
94 | x = [x0, y0, x1, y1, x2, y2]
95 | r = _random_floats(0.1, 0.5, 2)
96 | color = img[int(cy*map_h), int(cx*map_w), :].tolist()
97 | alpha = _random_floats(0.8, 0.98, 1)
98 | self.stroke_params = np.array(x + r + color + alpha, dtype=np.float32)
99 | elif self.renderer in ['watercolor']:
100 | # x0, y0, x1, y1, x2, y2, radius0, radius2, R0, G0, B0, R2, G2, B2, A
101 | x0, y0, x1, y1, x2, y2 = cx, cy, cx, cy, cx, cy
102 | x = [x0, y0, x1, y1, x2, y2]
103 | r = _random_floats(0.1, 0.5, 2)
104 | color = img[int(cy*map_h), int(cx*map_w), :].tolist()
105 | color = color + color
106 | alpha = _random_floats(0.98, 1.0, 1)
107 | self.stroke_params = np.array(x + r + color + alpha, dtype=np.float32)
108 | elif self.renderer in ['oilpaintbrush']:
109 | # xc, yc, w, h, theta, R0, G0, B0, R2, G2, B2, A
110 | x = [cx, cy]
111 | wh = _random_floats(0.1, 0.5, 2)
112 | theta = _random_floats(0, 1, 1)
113 | color = img[int(cy*map_h), int(cx*map_w), :].tolist()
114 | color = color + color
115 | alpha = _random_floats(0.98, 1.0, 1)
116 | self.stroke_params = np.array(x + wh + theta + color + alpha, dtype=np.float32)
117 | elif self.renderer in ['rectangle']:
118 | # xc, yc, w, h, theta, R, G, B, A
119 | x = [cx, cy]
120 | wh = _random_floats(0.1, 0.5, 2)
121 | theta = [0]
122 | color = img[int(cy*map_h), int(cx*map_w), :].tolist()
123 | alpha = _random_floats(0.8, 0.98, 1)
124 | self.stroke_params = np.array(x + wh + theta + color + alpha, dtype=np.float32)
125 |
126 |
127 | def check_stroke(self):
128 | r_ = 1.0
129 | if self.renderer in ['markerpen', 'watercolor']:
130 | r_ = max(self.stroke_params[6], self.stroke_params[7])
131 | elif self.renderer in ['oilpaintbrush']:
132 | r_ = max(self.stroke_params[2], self.stroke_params[3])
133 | elif self.renderer in ['rectangle']:
134 | r_ = max(self.stroke_params[2], self.stroke_params[3])
135 | if r_ > 3/128.:
136 | return True
137 | else:
138 | return False
139 |
140 |
141 | def draw_stroke(self):
142 |
143 | if self.renderer == 'watercolor':
144 | return self._draw_watercolor()
145 | elif self.renderer == 'markerpen':
146 | return self._draw_markerpen()
147 | elif self.renderer == 'oilpaintbrush':
148 | return self._draw_oilpaintbrush()
149 | elif self.renderer == 'rectangle':
150 | return self._draw_rectangle()
151 |
152 |
153 | def _draw_watercolor(self):
154 |
155 | # x0, y0, x1, y1, x2, y2, radius0, radius2, R0, G0, B0, R2, G2, B2, A
156 | x0, y0, x1, y1, x2, y2, radius0, radius2 = self.stroke_params[0:8]
157 | R0, G0, B0, R2, G2, B2, ALPHA = self.stroke_params[8:]
158 | x1 = x0 + (x2 - x0) * x1
159 | y1 = y0 + (y2 - y0) * y1
160 | x0 = _normalize(x0, self.CANVAS_WIDTH)
161 | x1 = _normalize(x1, self.CANVAS_WIDTH)
162 | x2 = _normalize(x2, self.CANVAS_WIDTH)
163 | y0 = _normalize(y0, self.CANVAS_WIDTH)
164 | y1 = _normalize(y1, self.CANVAS_WIDTH)
165 | y2 = _normalize(y2, self.CANVAS_WIDTH)
166 | radius0 = (int)(1 + radius0 * self.CANVAS_WIDTH // 4)
167 | radius2 = (int)(1 + radius2 * self.CANVAS_WIDTH // 4)
168 |
169 | stroke_alpha_value = self.stroke_params[-1]
170 |
171 | self.foreground = np.zeros_like(
172 | self.canvas, dtype=np.uint8) # uint8 for antialiasing
173 | self.stroke_alpha_map = np.zeros_like(
174 | self.canvas, dtype=np.uint8) # uint8 for antialiasing
175 |
176 | alpha = (stroke_alpha_value * 255,
177 | stroke_alpha_value * 255,
178 | stroke_alpha_value * 255)
179 | tmp = 1. / 100
180 | for i in range(100):
181 | t = i * tmp
182 | x = (int)((1 - t) * (1 - t) * x0 + 2 * t * (1 - t) * x1 + t * t * x2)
183 | y = (int)((1 - t) * (1 - t) * y0 + 2 * t * (1 - t) * y1 + t * t * y2)
184 | radius = (int)((1 - t) * radius0 + t * radius2)
185 | color = ((1-t)*R0*255 + t*R2*255,
186 | (1-t)*G0*255 + t*G2*255,
187 | (1-t)*B0*255 + t*B2*255)
188 | cv2.circle(self.foreground, (x, y), radius, color, -1, lineType=cv2.LINE_AA)
189 | cv2.circle(self.stroke_alpha_map, (x, y), radius, alpha, -1, lineType=cv2.LINE_AA)
190 |
191 | if not self.train:
192 | self.foreground = cv2.dilate(self.foreground, np.ones([2, 2]))
193 | self.stroke_alpha_map = cv2.erode(self.stroke_alpha_map, np.ones([2, 2]))
194 |
195 | self.foreground = np.array(self.foreground, dtype=np.float32)/255.
196 | self.stroke_alpha_map = np.array(self.stroke_alpha_map, dtype=np.float32)/255.
197 | self.canvas = self._update_canvas()
198 |
199 |
200 | def _draw_rectangle(self):
201 |
202 | # xc, yc, w, h, theta, R, G, B, A
203 | x0, y0, w, h, theta = self.stroke_params[0:5]
204 | R0, G0, B0, ALPHA = self.stroke_params[5:]
205 | x0 = _normalize(x0, self.CANVAS_WIDTH)
206 | y0 = _normalize(y0, self.CANVAS_WIDTH)
207 | w = (int)(1 + w * self.CANVAS_WIDTH // 4)
208 | h = (int)(1 + h * self.CANVAS_WIDTH // 4)
209 | theta = np.pi*theta
210 | stroke_alpha_value = self.stroke_params[-1]
211 |
212 | self.foreground = np.zeros_like(
213 | self.canvas, dtype=np.uint8) # uint8 for antialiasing
214 | self.stroke_alpha_map = np.zeros_like(
215 | self.canvas, dtype=np.uint8) # uint8 for antialiasing
216 |
217 | color = (R0 * 255, G0 * 255, B0 * 255)
218 | alpha = (stroke_alpha_value * 255,
219 | stroke_alpha_value * 255,
220 | stroke_alpha_value * 255)
221 | ptc = (x0, y0)
222 | pt0 = utils.rotate_pt(pt=(x0 - w, y0 - h), rotate_center=ptc, theta=theta)
223 | pt1 = utils.rotate_pt(pt=(x0 + w, y0 - h), rotate_center=ptc, theta=theta)
224 | pt2 = utils.rotate_pt(pt=(x0 + w, y0 + h), rotate_center=ptc, theta=theta)
225 | pt3 = utils.rotate_pt(pt=(x0 - w, y0 + h), rotate_center=ptc, theta=theta)
226 |
227 | ppt = np.array([pt0, pt1, pt2, pt3], np.int32)
228 | ppt = ppt.reshape((-1, 1, 2))
229 | cv2.fillPoly(self.foreground, [ppt], color, lineType=cv2.LINE_AA)
230 | cv2.fillPoly(self.stroke_alpha_map, [ppt], alpha, lineType=cv2.LINE_AA)
231 |
232 | if not self.train:
233 | self.foreground = cv2.dilate(self.foreground, np.ones([2, 2]))
234 | self.stroke_alpha_map = cv2.erode(self.stroke_alpha_map, np.ones([2, 2]))
235 |
236 | self.foreground = np.array(self.foreground, dtype=np.float32)/255.
237 | self.stroke_alpha_map = np.array(self.stroke_alpha_map, dtype=np.float32)/255.
238 | self.canvas = self._update_canvas()
239 |
240 |
241 | def _draw_markerpen(self):
242 |
243 | # x0, y0, x1, y1, x2, y2, radius0, radius2, R, G, B, A
244 | x0, y0, x1, y1, x2, y2, radius, _ = self.stroke_params[0:8]
245 | R0, G0, B0, ALPHA = self.stroke_params[8:]
246 | x1 = x0 + (x2 - x0) * x1
247 | y1 = y0 + (y2 - y0) * y1
248 | x0 = _normalize(x0, self.CANVAS_WIDTH)
249 | x1 = _normalize(x1, self.CANVAS_WIDTH)
250 | x2 = _normalize(x2, self.CANVAS_WIDTH)
251 | y0 = _normalize(y0, self.CANVAS_WIDTH)
252 | y1 = _normalize(y1, self.CANVAS_WIDTH)
253 | y2 = _normalize(y2, self.CANVAS_WIDTH)
254 | radius = (int)(1 + radius * self.CANVAS_WIDTH // 4)
255 |
256 | stroke_alpha_value = self.stroke_params[-1]
257 |
258 | self.foreground = np.zeros_like(
259 | self.canvas, dtype=np.uint8) # uint8 for antialiasing
260 | self.stroke_alpha_map = np.zeros_like(
261 | self.canvas, dtype=np.uint8) # uint8 for antialiasing
262 |
263 | if abs(x0-x2) + abs(y0-y2) < 4: # too small, dont draw
264 | self.foreground = np.array(self.foreground, dtype=np.float32) / 255.
265 | self.stroke_alpha_map = np.array(self.stroke_alpha_map, dtype=np.float32) / 255.
266 | self.canvas = self._update_canvas()
267 | return
268 |
269 | color = (R0 * 255, G0 * 255, B0 * 255)
270 | alpha = (stroke_alpha_value * 255,
271 | stroke_alpha_value * 255,
272 | stroke_alpha_value * 255)
273 | tmp = 1. / 100
274 | for i in range(100):
275 | t = i * tmp
276 | x = (1 - t) * (1 - t) * x0 + 2 * t * (1 - t) * x1 + t * t * x2
277 | y = (1 - t) * (1 - t) * y0 + 2 * t * (1 - t) * y1 + t * t * y2
278 |
279 | ptc = (x, y)
280 | dx = 2 * (t - 1) * x0 + 2 * (1 - 2 * t) * x1 + 2 * t * x2
281 | dy = 2 * (t - 1) * y0 + 2 * (1 - 2 * t) * y1 + 2 * t * y2
282 |
283 | theta = np.arctan2(dx, dy) - np.pi/2
284 | pt0 = utils.rotate_pt(pt=(x - radius, y - radius), rotate_center=ptc, theta=theta)
285 | pt1 = utils.rotate_pt(pt=(x + radius, y - radius), rotate_center=ptc, theta=theta)
286 | pt2 = utils.rotate_pt(pt=(x + radius, y + radius), rotate_center=ptc, theta=theta)
287 | pt3 = utils.rotate_pt(pt=(x - radius, y + radius), rotate_center=ptc, theta=theta)
288 | ppt = np.array([pt0, pt1, pt2, pt3], np.int32)
289 | ppt = ppt.reshape((-1, 1, 2))
290 | cv2.fillPoly(self.foreground, [ppt], color, lineType=cv2.LINE_AA)
291 | cv2.fillPoly(self.stroke_alpha_map, [ppt], alpha, lineType=cv2.LINE_AA)
292 |
293 | if not self.train:
294 | self.foreground = cv2.dilate(self.foreground, np.ones([2, 2]))
295 | self.stroke_alpha_map = cv2.erode(self.stroke_alpha_map, np.ones([2, 2]))
296 |
297 | self.foreground = np.array(self.foreground, dtype=np.float32)/255.
298 | self.stroke_alpha_map = np.array(self.stroke_alpha_map, dtype=np.float32)/255.
299 | self.canvas = self._update_canvas()
300 |
301 |
302 |
303 | def _draw_oilpaintbrush(self):
304 |
305 | # xc, yc, w, h, theta, R0, G0, B0, R2, G2, B2, A
306 | x0, y0, w, h, theta = self.stroke_params[0:5]
307 | R0, G0, B0, R2, G2, B2, ALPHA = self.stroke_params[5:]
308 | x0 = _normalize(x0, self.CANVAS_WIDTH)
309 | y0 = _normalize(y0, self.CANVAS_WIDTH)
310 | w = (int)(1 + w * self.CANVAS_WIDTH)
311 | h = (int)(1 + h * self.CANVAS_WIDTH)
312 | theta = np.pi*theta
313 |
314 | if w * h / (self.CANVAS_WIDTH**2) > 0.1:
315 | if h > w:
316 | brush = self.brush_large_vertical
317 | else:
318 | brush = self.brush_large_horizontal
319 | else:
320 | if h > w:
321 | brush = self.brush_small_vertical
322 | else:
323 | brush = self.brush_small_horizontal
324 | self.foreground, self.stroke_alpha_map = utils.create_transformed_brush(
325 | brush, self.CANVAS_WIDTH, self.CANVAS_WIDTH,
326 | x0, y0, w, h, theta, R0, G0, B0, R2, G2, B2)
327 |
328 | if not self.train:
329 | self.foreground = cv2.dilate(self.foreground, np.ones([2, 2]))
330 | self.stroke_alpha_map = cv2.erode(self.stroke_alpha_map, np.ones([2, 2]))
331 |
332 | self.foreground = np.array(self.foreground, dtype=np.float32)/255.
333 | self.stroke_alpha_map = np.array(self.stroke_alpha_map, dtype=np.float32)/255.
334 | self.canvas = self._update_canvas()
335 |
336 |
337 | def _update_canvas(self):
338 | return self.foreground * self.stroke_alpha_map + \
339 | self.canvas * (1 - self.stroke_alpha_map)
340 |
--------------------------------------------------------------------------------
/painter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import random
4 |
5 | import utils
6 | import loss
7 | from networks import *
8 | import morphology
9 |
10 | import renderer
11 |
12 | import torch
13 | torch.cuda.current_device()
14 |
15 | # Decide which device we want to run on
16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17 |
18 |
19 | class PainterBase():
20 | def __init__(self, args):
21 | self.args = args
22 |
23 | self.rderr = renderer.Renderer(renderer=args.renderer,
24 | CANVAS_WIDTH=args.canvas_size, canvas_color=args.canvas_color)
25 |
26 | # define G
27 | self.net_G = define_G(rdrr=self.rderr, netG=args.net_G).to(device)
28 |
29 | # define some other vars to record the training states
30 | self.x_ctt = None
31 | self.x_color = None
32 | self.x_alpha = None
33 |
34 | self.G_pred_foreground = None
35 | self.G_pred_alpha = None
36 | self.G_final_pred_canvas = torch.zeros([1, 3, 128, 128]).to(device)
37 |
38 | self.G_loss = torch.tensor(0.0)
39 | self.step_id = 0
40 | self.anchor_id = 0
41 | self.renderer_checkpoint_dir = args.renderer_checkpoint_dir
42 | self.output_dir = args.output_dir
43 | self.lr = args.lr
44 |
45 | # define the loss functions
46 | self._pxl_loss = loss.PixelLoss(p=1)
47 | self._sinkhorn_loss = loss.SinkhornLoss(epsilon=0.01, niter=5, normalize=False)
48 |
49 | # some other vars to be initialized in child classes
50 | self.img_path = None
51 | self.img_batch = None
52 | self.img_ = None
53 | self.final_rendered_images = None
54 | self.m_grid = None
55 | self.m_strokes_per_block = None
56 |
57 | if os.path.exists(self.output_dir) is False:
58 | os.mkdir(self.output_dir)
59 |
60 | def _load_checkpoint(self):
61 |
62 | # load renderer G
63 | if os.path.exists((os.path.join(
64 | self.renderer_checkpoint_dir, 'last_ckpt.pt'))):
65 | print('loading renderer from pre-trained checkpoint...')
66 | # load the entire checkpoint
67 | checkpoint = torch.load(os.path.join(
68 | self.renderer_checkpoint_dir, 'last_ckpt.pt'))
69 | # update net_G states
70 | self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
71 | self.net_G.to(device)
72 | self.net_G.eval()
73 | else:
74 | print('pre-trained renderer does not exist...')
75 | exit()
76 |
77 |
78 | def _compute_acc(self):
79 |
80 | target = self.img_batch.detach()
81 | canvas = self.G_pred_canvas.detach()
82 | psnr = utils.cpt_batch_psnr(canvas, target, PIXEL_MAX=1.0)
83 |
84 | return psnr
85 |
86 | def _save_stroke_params(self, v):
87 |
88 | d_shape = self.rderr.d_shape
89 | d_color = self.rderr.d_color
90 | d_alpha = self.rderr.d_alpha
91 |
92 | x_ctt = v[:, :, 0:d_shape]
93 | x_color = v[:, :, d_shape:d_shape+d_color]
94 | x_alpha = v[:, :, d_shape+d_color:d_shape+d_color+d_alpha]
95 | print('saving stroke parameters...')
96 | file_name = os.path.join(
97 | self.output_dir, self.img_path.split('/')[-1][:-4])
98 | np.savez(file_name + '_strokes.npz', x_ctt=x_ctt,
99 | x_color=x_color, x_alpha=x_alpha)
100 |
101 |
102 | def _save_rendered_images(self):
103 | print('saving rendered images...')
104 | file_name = os.path.join(
105 | self.output_dir, self.img_path.split('/')[-1][:-4])
106 | plt.imsave(file_name+'_input.png', self.img_)
107 | for i in range(len(self.final_rendered_images)):
108 | plt.imsave(file_name + '_rendered_stroke_' + str((i+1)).zfill(4) +
109 | '.png', self.final_rendered_images[i])
110 |
111 |
112 | def _normalize_strokes(self, v):
113 |
114 | v = np.array(v.detach().cpu())
115 |
116 | if self.rderr.renderer in ['watercolor', 'markerpen']:
117 | # x0, y0, x1, y1, x2, y2, radius0, radius2, ...
118 | xs = np.array([0, 4])
119 | ys = np.array([1, 5])
120 | rs = np.array([6, 7])
121 | elif self.rderr.renderer in ['oilpaintbrush', 'rectangle']:
122 | # xc, yc, w, h, theta ...
123 | xs = np.array([0])
124 | ys = np.array([1])
125 | rs = np.array([2, 3])
126 | else:
127 | raise NotImplementedError('renderer [%s] is not implemented' % self.rderr.renderer)
128 |
129 | for y_id in range(self.m_grid):
130 | for x_id in range(self.m_grid):
131 | y_bias = y_id / self.m_grid
132 | x_bias = x_id / self.m_grid
133 | v[y_id * self.m_grid + x_id, :, ys] = \
134 | y_bias + v[y_id * self.m_grid + x_id, :, ys] / self.m_grid
135 | v[y_id * self.m_grid + x_id, :, xs] = \
136 | x_bias + v[y_id * self.m_grid + x_id, :, xs] / self.m_grid
137 | v[y_id * self.m_grid + x_id, :, rs] /= self.m_grid
138 |
139 | return v
140 |
141 |
142 | def initialize_params(self):
143 |
144 | self.x_ctt = np.random.rand(
145 | self.m_grid*self.m_grid, self.m_strokes_per_block,
146 | self.rderr.d_shape).astype(np.float32)
147 | self.x_ctt = torch.tensor(self.x_ctt).to(device)
148 |
149 | self.x_color = np.random.rand(
150 | self.m_grid*self.m_grid, self.m_strokes_per_block,
151 | self.rderr.d_color).astype(np.float32)
152 | self.x_color = torch.tensor(self.x_color).to(device)
153 |
154 | self.x_alpha = np.random.rand(
155 | self.m_grid*self.m_grid, self.m_strokes_per_block,
156 | self.rderr.d_alpha).astype(np.float32)
157 | self.x_alpha = torch.tensor(self.x_alpha).to(device)
158 |
159 |
160 | def stroke_sampler(self, anchor_id):
161 |
162 | if anchor_id == self.m_strokes_per_block:
163 | return
164 |
165 | err_maps = torch.sum(
166 | torch.abs(self.img_batch - self.G_final_pred_canvas),
167 | dim=1, keepdim=True).detach()
168 |
169 | for i in range(self.m_grid*self.m_grid):
170 | this_err_map = err_maps[i,0,:,:].cpu().numpy()
171 | this_err_map = cv2.blur(this_err_map, (20, 20))
172 | this_err_map = this_err_map ** 4
173 | this_img = self.img_batch[i, :, :, :].detach().permute([1, 2, 0]).cpu().numpy()
174 |
175 | self.rderr.random_stroke_params_sampler(
176 | err_map=this_err_map, img=this_img)
177 |
178 | self.x_ctt.data[i, anchor_id, :] = torch.tensor(
179 | self.rderr.stroke_params[0:self.rderr.d_shape])
180 | self.x_color.data[i, anchor_id, :] = torch.tensor(
181 | self.rderr.stroke_params[self.rderr.d_shape:self.rderr.d_shape+self.rderr.d_color])
182 | self.x_alpha.data[i, anchor_id, :] = torch.tensor(self.rderr.stroke_params[-1])
183 |
184 |
185 | def _backward_x(self):
186 |
187 | self.G_loss = 0
188 | self.G_loss += self.args.beta_L1 * self._pxl_loss(
189 | canvas=self.G_final_pred_canvas, gt=self.img_batch)
190 | if self.args.with_ot_loss:
191 | self.G_loss += self.args.beta_ot * self._sinkhorn_loss(
192 | self.G_final_pred_canvas, self.img_batch)
193 | self.G_loss.backward()
194 |
195 |
196 | def _forward_pass(self):
197 |
198 | self.x = torch.cat([self.x_ctt, self.x_color, self.x_alpha], dim=-1)
199 |
200 | v = torch.reshape(self.x[:, 0:self.anchor_id+1, :],
201 | [self.m_grid*self.m_grid*(self.anchor_id+1), -1, 1, 1])
202 | self.G_pred_foregrounds, self.G_pred_alphas = self.net_G(v)
203 |
204 | self.G_pred_foregrounds = morphology.Dilation2d(m=1)(self.G_pred_foregrounds)
205 | self.G_pred_alphas = morphology.Erosion2d(m=1)(self.G_pred_alphas)
206 |
207 | self.G_pred_foregrounds = torch.reshape(
208 | self.G_pred_foregrounds, [self.m_grid*self.m_grid, self.anchor_id+1, 3, 128, 128])
209 | self.G_pred_alphas = torch.reshape(
210 | self.G_pred_alphas, [self.m_grid*self.m_grid, self.anchor_id+1, 3, 128, 128])
211 |
212 | for i in range(self.anchor_id+1):
213 | G_pred_foreground = self.G_pred_foregrounds[:, i]
214 | G_pred_alpha = self.G_pred_alphas[:, i]
215 | self.G_pred_canvas = G_pred_foreground * G_pred_alpha \
216 | + self.G_pred_canvas * (1 - G_pred_alpha)
217 |
218 | self.G_final_pred_canvas = self.G_pred_canvas
219 |
220 |
221 |
222 |
223 | class Painter(PainterBase):
224 |
225 | def __init__(self, args):
226 | super(Painter, self).__init__(args=args)
227 |
228 | self.m_grid = args.m_grid
229 |
230 | self.max_m_strokes = args.max_m_strokes
231 |
232 | self.img_path = args.img_path
233 | self.img_ = cv2.imread(args.img_path, cv2.IMREAD_COLOR)
234 | self.img_ = cv2.cvtColor(self.img_, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
235 | self.img_ = cv2.resize(self.img_, (128 * args.m_grid, 128 * args.m_grid))
236 |
237 | self.m_strokes_per_block = int(args.max_m_strokes / (args.m_grid * args.m_grid))
238 |
239 | self.img_batch = utils.img2patches(self.img_, args.m_grid).to(device)
240 |
241 | self.final_rendered_images = None
242 |
243 |
244 | def _drawing_step_states(self):
245 | acc = self._compute_acc().item()
246 | print('iteration step %d, G_loss: %.5f, step_psnr: %.5f, strokes: %d / %d'
247 | % (self.step_id, self.G_loss.item(), acc,
248 | (self.anchor_id+1)*self.m_grid*self.m_grid,
249 | self.max_m_strokes))
250 | vis2 = utils.patches2img(self.G_final_pred_canvas, self.m_grid).clip(min=0, max=1)
251 | cv2.imshow('G_pred', vis2[:,:,::-1])
252 | cv2.imshow('input', self.img_[:, :, ::-1])
253 | cv2.waitKey(1)
254 |
255 | def _render_on_grids(self, v):
256 |
257 | rendered_imgs = []
258 |
259 | self.rderr.create_empty_canvas()
260 |
261 | grid_idx = list(range(self.m_grid ** 2))
262 | random.shuffle(grid_idx)
263 | for j in range(v.shape[1]): # for each group of stroke
264 | for i in range(len(grid_idx)): # for each random patch
265 | self.rderr.stroke_params = v[grid_idx[i], j, :]
266 | if self.rderr.check_stroke():
267 | self.rderr.draw_stroke()
268 | rendered_imgs.append(self.rderr.canvas)
269 |
270 | return rendered_imgs
271 |
272 |
273 |
274 |
275 | class ProgressivePainter(PainterBase):
276 |
277 | def __init__(self, args):
278 | super(ProgressivePainter, self).__init__(args=args)
279 |
280 | self.max_divide = args.max_divide
281 |
282 | self.max_m_strokes = args.max_m_strokes
283 |
284 | self.m_strokes_per_block = self.stroke_parser()
285 |
286 | self.m_grid = 1
287 |
288 | self.img_path = args.img_path
289 | self.img_ = cv2.imread(args.img_path, cv2.IMREAD_COLOR)
290 | self.img_ = cv2.cvtColor(self.img_, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
291 | self.img_ = cv2.resize(self.img_, (128 * args.max_divide, 128 * args.max_divide))
292 |
293 |
294 |
295 | def _render(self, v):
296 |
297 | v = v[0,:,:]
298 |
299 | rendered_imgs = []
300 |
301 | self.rderr.create_empty_canvas()
302 |
303 | for i in range(v.shape[0]): # for each stroke
304 | self.rderr.stroke_params = v[i, :]
305 | if self.rderr.check_stroke():
306 | self.rderr.draw_stroke()
307 | rendered_imgs.append(self.rderr.canvas)
308 |
309 | return rendered_imgs
310 |
311 |
312 | def stroke_parser(self):
313 |
314 | total_blocks = 0
315 | for i in range(0, self.max_divide + 1):
316 | total_blocks += i ** 2
317 |
318 | return int(self.max_m_strokes / total_blocks)
319 |
320 |
321 | def _drawing_step_states(self):
322 | acc = self._compute_acc().item()
323 | print('iteration step %d, G_loss: %.5f, step_acc: %.5f, grid_scale: %d / %d, strokes: %d / %d'
324 | % (self.step_id, self.G_loss.item(), acc,
325 | self.m_grid, self.max_divide,
326 | self.anchor_id, self.m_strokes_per_block))
327 | vis2 = utils.patches2img(self.G_final_pred_canvas, self.m_grid).clip(min=0, max=1)
328 | cv2.imshow('G_pred', vis2[:,:,::-1])
329 | cv2.imshow('input', self.img_[:, :, ::-1])
330 | cv2.waitKey(1)
331 |
332 |
333 |
334 | class NeuralStyleTransfer(PainterBase):
335 |
336 | def __init__(self, args):
337 | super(NeuralStyleTransfer, self).__init__(args=args)
338 |
339 | self._style_loss = loss.VGGStyleLoss(transfer_mode=args.transfer_mode, resize=True)
340 |
341 | npzfile = np.load(args.vector_file)
342 |
343 | self.x_ctt = torch.tensor(npzfile['x_ctt']).to(device)
344 | self.x_color = torch.tensor(npzfile['x_color']).to(device)
345 | self.x_alpha = torch.tensor(npzfile['x_alpha']).to(device)
346 | self.m_grid = int(np.sqrt(self.x_ctt.shape[0]))
347 |
348 | self.anchor_id = self.x_ctt.shape[1] - 1
349 |
350 | img_ = cv2.imread(args.content_img_path, cv2.IMREAD_COLOR)
351 | img_ = cv2.cvtColor(img_, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
352 | self.img_ = cv2.resize(img_, (128*self.m_grid, 128*self.m_grid))
353 | self.img_batch = utils.img2patches(self.img_, self.m_grid).to(device)
354 |
355 | style_img = cv2.imread(args.style_img_path, cv2.IMREAD_COLOR)
356 | self.style_img_ = cv2.cvtColor(style_img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
357 | self.style_img = cv2.blur(cv2.resize(self.style_img_, (128, 128)), (2, 2))
358 | self.style_img = torch.tensor(self.style_img).permute([2, 0, 1]).unsqueeze(0).to(device)
359 |
360 |
361 | def _style_transfer_step_states(self):
362 | acc = self._compute_acc().item()
363 | print('running style transfer... iteration step %d, G_loss: %.5f, step_psnr: %.5f'
364 | % (self.step_id, self.G_loss.item(), acc))
365 | vis2 = utils.patches2img(self.G_final_pred_canvas, self.m_grid).clip(min=0, max=1)
366 | cv2.imshow('G_pred', vis2[:,:,::-1])
367 | cv2.imshow('input', self.img_[:, :, ::-1])
368 | cv2.waitKey(1)
369 |
370 |
371 | def _backward_x_sty(self):
372 | canvas = utils.patches2img(
373 | self.G_final_pred_canvas, self.m_grid, to_numpy=False).to(device)
374 | self.G_loss = self.args.beta_L1 * self._pxl_loss(
375 | canvas=self.G_final_pred_canvas, gt=self.img_batch, ignore_color=True)
376 | self.G_loss += self.args.beta_sty * self._style_loss(canvas, self.style_img)
377 | self.G_loss.backward()
378 |
379 |
380 | def _render_on_grids(self, v):
381 |
382 | rendered_imgs = []
383 |
384 | self.rderr.create_empty_canvas()
385 |
386 | grid_idx = list(range(self.m_grid ** 2))
387 | random.shuffle(grid_idx)
388 | for j in range(v.shape[1]): # for each group of stroke
389 | for i in range(len(grid_idx)): # for each random patch
390 | self.rderr.stroke_params = v[grid_idx[i], j, :]
391 | if self.rderr.check_stroke():
392 | self.rderr.draw_stroke()
393 | rendered_imgs.append(self.rderr.canvas)
394 |
395 | return rendered_imgs
396 |
397 |
398 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torchvision import models
6 | import torch.nn.functional as F
7 | from torch.optim import lr_scheduler
8 | import math
9 | import utils
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 |
13 | # Decide which device we want to run on
14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15 |
16 | PI = math.pi
17 | ###############################################################################
18 | # Helper Functions
19 | ###############################################################################
20 |
21 |
22 | class Identity(nn.Module):
23 | def forward(self, x):
24 | return x
25 |
26 | def get_norm_layer(norm_type='instance'):
27 | """Return a normalization layer
28 |
29 | Parameters:
30 | norm_type (str) -- the name of the normalization layer: batch | instance | none
31 |
32 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
33 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
34 | """
35 | if norm_type == 'batch':
36 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
37 | elif norm_type == 'instance':
38 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
39 | elif norm_type == 'none':
40 | norm_layer = lambda x: Identity()
41 | else:
42 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
43 | return norm_layer
44 |
45 |
46 | def init_weights(net, init_type='normal', init_gain=0.02):
47 | """Initialize network weights.
48 |
49 | Parameters:
50 | net (network) -- network to be initialized
51 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
52 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
53 |
54 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
55 | work better for some applications. Feel free to try yourself.
56 | """
57 | def init_func(m): # define the initialization function
58 | classname = m.__class__.__name__
59 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
60 | if init_type == 'normal':
61 | init.normal_(m.weight.data, 0.0, init_gain)
62 | elif init_type == 'xavier':
63 | init.xavier_normal_(m.weight.data, gain=init_gain)
64 | elif init_type == 'kaiming':
65 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
66 | elif init_type == 'orthogonal':
67 | init.orthogonal_(m.weight.data, gain=init_gain)
68 | else:
69 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
70 | if hasattr(m, 'bias') and m.bias is not None:
71 | init.constant_(m.bias.data, 0.0)
72 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
73 | init.normal_(m.weight.data, 1.0, init_gain)
74 | init.constant_(m.bias.data, 0.0)
75 |
76 | print('initialize network with %s' % init_type)
77 | net.apply(init_func) # apply the initialization function
78 |
79 |
80 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
81 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
82 | Parameters:
83 | net (network) -- the network to be initialized
84 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
85 | gain (float) -- scaling factor for normal, xavier and orthogonal.
86 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
87 |
88 | Return an initialized network.
89 | """
90 | if len(gpu_ids) > 0:
91 | assert(torch.cuda.is_available())
92 | net.to(gpu_ids[0])
93 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
94 | init_weights(net, init_type, init_gain=init_gain)
95 | return net
96 |
97 |
98 | def define_G(rdrr, netG, init_type='normal', init_gain=0.02, gpu_ids=[]):
99 | net = None
100 | if netG == 'plain-dcgan':
101 | net = DCGAN(rdrr)
102 | elif netG == 'plain-unet':
103 | net = UNet(rdrr)
104 | elif netG == 'huang-net':
105 | net = HuangNet(rdrr)
106 | elif netG == 'zou-fusion-net':
107 | net = ZouFCNFusion(rdrr)
108 | else:
109 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
110 | return init_net(net, init_type, init_gain, gpu_ids)
111 |
112 |
113 | class DCGAN(nn.Module):
114 | def __init__(self, rdrr, ngf=64):
115 | super(DCGAN, self).__init__()
116 | input_nc = rdrr.d
117 | self.main = nn.Sequential(
118 | # input is Z, going into a convolution
119 | nn.ConvTranspose2d(input_nc, ngf * 8, 4, 1, 0, bias=False),
120 | nn.BatchNorm2d(ngf * 8),
121 | nn.ReLU(True),
122 | # state size. (ngf*8) x 4 x 4
123 |
124 | nn.ConvTranspose2d(ngf * 8, ngf * 8, 4, 2, 1, bias=False),
125 | nn.BatchNorm2d(ngf * 8),
126 | nn.ReLU(True),
127 | # state size. (ngf*4) x 8 x 8
128 |
129 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
130 | nn.BatchNorm2d(ngf * 4),
131 | nn.ReLU(True),
132 | # state size. (ngf*2) x 16 x 16
133 |
134 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
135 | nn.BatchNorm2d(ngf * 2),
136 | nn.ReLU(True),
137 | # state size. (ngf*2) x 32 x 32
138 |
139 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
140 | nn.BatchNorm2d(ngf),
141 | nn.ReLU(True),
142 | # state size. (ngf*2) x 64 x 64
143 |
144 | nn.ConvTranspose2d(ngf, 6, 4, 2, 1, bias=False),
145 | # state size. (nc) x 128 x 128
146 | )
147 |
148 | def forward(self, input):
149 | output_tensor = self.main(input)
150 | return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]
151 |
152 |
153 |
154 | class PixelShuffleNet(nn.Module):
155 | def __init__(self, input_nc):
156 | super(PixelShuffleNet, self).__init__()
157 | self.fc1 = (nn.Linear(input_nc, 512))
158 | self.fc2 = (nn.Linear(512, 1024))
159 | self.fc3 = (nn.Linear(1024, 2048))
160 | self.fc4 = (nn.Linear(2048, 4096))
161 | self.conv1 = (nn.Conv2d(16, 32, 3, 1, 1))
162 | self.conv2 = (nn.Conv2d(32, 32, 3, 1, 1))
163 | self.conv3 = (nn.Conv2d(8, 16, 3, 1, 1))
164 | self.conv4 = (nn.Conv2d(16, 16, 3, 1, 1))
165 | self.conv5 = (nn.Conv2d(4, 8, 3, 1, 1))
166 | self.conv6 = (nn.Conv2d(8, 4*3, 3, 1, 1))
167 | self.pixel_shuffle = nn.PixelShuffle(2)
168 |
169 | def forward(self, x):
170 | x = x.squeeze()
171 | x = F.relu(self.fc1(x))
172 | x = F.relu(self.fc2(x))
173 | x = F.relu(self.fc3(x))
174 | x = F.relu(self.fc4(x))
175 | x = x.view(-1, 16, 16, 16)
176 | x = F.relu(self.conv1(x))
177 | x = self.pixel_shuffle(self.conv2(x))
178 | x = F.relu(self.conv3(x))
179 | x = self.pixel_shuffle(self.conv4(x))
180 | x = F.relu(self.conv5(x))
181 | x = self.pixel_shuffle(self.conv6(x))
182 | x = x.view(-1, 3, 128, 128)
183 | return x
184 |
185 |
186 |
187 | class HuangNet(nn.Module):
188 | def __init__(self, rdrr):
189 | super(HuangNet, self).__init__()
190 | self.rdrr = rdrr
191 | self.fc1 = (nn.Linear(rdrr.d, 512))
192 | self.fc2 = (nn.Linear(512, 1024))
193 | self.fc3 = (nn.Linear(1024, 2048))
194 | self.fc4 = (nn.Linear(2048, 4096))
195 | self.conv1 = (nn.Conv2d(16, 32, 3, 1, 1))
196 | self.conv2 = (nn.Conv2d(32, 32, 3, 1, 1))
197 | self.conv3 = (nn.Conv2d(8, 16, 3, 1, 1))
198 | self.conv4 = (nn.Conv2d(16, 16, 3, 1, 1))
199 | self.conv5 = (nn.Conv2d(4, 8, 3, 1, 1))
200 | self.conv6 = (nn.Conv2d(8, 4 * 6, 3, 1, 1))
201 | self.pixel_shuffle = nn.PixelShuffle(2)
202 |
203 |
204 | def forward(self, x):
205 | x = x.squeeze()
206 | x = F.relu(self.fc1(x))
207 | x = F.relu(self.fc2(x))
208 | x = F.relu(self.fc3(x))
209 | x = F.relu(self.fc4(x))
210 | x = x.view(-1, 16, 16, 16)
211 | x = F.relu(self.conv1(x))
212 | x = self.pixel_shuffle(self.conv2(x))
213 | x = F.relu(self.conv3(x))
214 | x = self.pixel_shuffle(self.conv4(x))
215 | x = F.relu(self.conv5(x))
216 | x = self.pixel_shuffle(self.conv6(x))
217 | output_tensor = x.view(-1, 6, 128, 128)
218 | return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]
219 |
220 |
221 |
222 | class ZouFCNFusion(nn.Module):
223 | def __init__(self, rdrr):
224 | super(ZouFCNFusion, self).__init__()
225 | self.rdrr = rdrr
226 | self.huangnet = PixelShuffleNet(rdrr.d_shape)
227 | self.dcgan = DCGAN(rdrr)
228 |
229 | def forward(self, x):
230 | x_shape = x[:, 0:self.rdrr.d_shape, :, :]
231 | x_alpha = x[:, [-1], :, :]
232 | if self.rdrr.renderer in ['oilpaintbrush', 'airbrush']:
233 | x_alpha = torch.tensor(1.0).to(device)
234 |
235 | mask = self.huangnet(x_shape)
236 | color, _ = self.dcgan(x)
237 |
238 | return color * mask, x_alpha * mask
239 |
240 |
241 |
242 |
243 | class UNet(torch.nn.Module):
244 | def __init__(self, rdrr):
245 | """
246 | In the constructor we instantiate two nn.Linear modules and assign them as
247 | member variables.
248 | """
249 | super(UNet, self).__init__()
250 | norm_layer = get_norm_layer(norm_type='batch')
251 | self.unet = UnetGenerator(rdrr.d, 6, 7, norm_layer=norm_layer, use_dropout=False)
252 |
253 | def forward(self, x):
254 | """
255 | In the forward function we accept a Tensor of input data and we must return
256 | a Tensor of output data. We can use Modules defined in the constructor as
257 | well as arbitrary operators on Tensors.
258 | """
259 | # resnet layers
260 | x = x.repeat(1, 1, 128, 128)
261 | output_tensor = self.unet(x)
262 | return output_tensor[:,0:3,:,:], output_tensor[:,3:6,:,:]
263 |
264 |
265 |
266 | class UnetGenerator(nn.Module):
267 | """Create a Unet-based generator"""
268 |
269 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
270 | """Construct a Unet generator
271 | Parameters:
272 | input_nc (int) -- the number of channels in input images
273 | output_nc (int) -- the number of channels in output images
274 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
275 | image of size 128x128 will become of size 1x1 # at the bottleneck
276 | ngf (int) -- the number of filters in the last conv layer
277 | norm_layer -- normalization layer
278 |
279 | We construct the U-Net from the innermost layer to the outermost layer.
280 | It is a recursive process.
281 | """
282 | super(UnetGenerator, self).__init__()
283 | # construct unet structure
284 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
285 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
286 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
287 | # gradually reduce the number of filters from ngf * 8 to ngf
288 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
289 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
290 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
291 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
292 |
293 | def forward(self, input):
294 | """Standard forward"""
295 | return self.model(input)
296 |
297 |
298 | class UnetSkipConnectionBlock(nn.Module):
299 | """Defines the Unet submodule with skip connection.
300 | X -------------------identity----------------------
301 | |-- downsampling -- |submodule| -- upsampling --|
302 | """
303 |
304 | def __init__(self, outer_nc, inner_nc, input_nc=None,
305 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
306 | """Construct a Unet submodule with skip connections.
307 |
308 | Parameters:
309 | outer_nc (int) -- the number of filters in the outer conv layer
310 | inner_nc (int) -- the number of filters in the inner conv layer
311 | input_nc (int) -- the number of channels in input images/features
312 | submodule (UnetSkipConnectionBlock) -- previously defined submodules
313 | outermost (bool) -- if this module is the outermost module
314 | innermost (bool) -- if this module is the innermost module
315 | norm_layer -- normalization layer
316 | user_dropout (bool) -- if use dropout layers.
317 | """
318 | super(UnetSkipConnectionBlock, self).__init__()
319 | self.outermost = outermost
320 | if type(norm_layer) == functools.partial:
321 | use_bias = norm_layer.func == nn.InstanceNorm2d
322 | else:
323 | use_bias = norm_layer == nn.InstanceNorm2d
324 | if input_nc is None:
325 | input_nc = outer_nc
326 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
327 | stride=2, padding=1, bias=use_bias)
328 | downrelu = nn.LeakyReLU(0.2, True)
329 | downnorm = norm_layer(inner_nc)
330 | uprelu = nn.ReLU(True)
331 | upnorm = norm_layer(outer_nc)
332 |
333 | if outermost:
334 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
335 | kernel_size=4, stride=2,
336 | padding=1)
337 | down = [downconv]
338 | # up = [uprelu, upconv, nn.Tanh()]
339 | # up = [uprelu, upconv, nn.Sigmoid()] # ZZX
340 | up = [uprelu, upconv] # ZZX
341 | model = down + [submodule] + up
342 | elif innermost:
343 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
344 | kernel_size=4, stride=2,
345 | padding=1, bias=use_bias)
346 | down = [downrelu, downconv]
347 | up = [uprelu, upconv, upnorm]
348 | model = down + up
349 | else:
350 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
351 | kernel_size=4, stride=2,
352 | padding=1, bias=use_bias)
353 | down = [downrelu, downconv, downnorm]
354 | up = [uprelu, upconv, upnorm]
355 |
356 | if use_dropout:
357 | model = down + [submodule] + up + [nn.Dropout(0.5)]
358 | else:
359 | model = down + [submodule] + up
360 |
361 | self.model = nn.Sequential(*model)
362 |
363 | def forward(self, x):
364 | if self.outermost:
365 | return self.model(x)
366 | else: # add skip connections
367 | return torch.cat([x, self.model(x)], 1)
368 |
--------------------------------------------------------------------------------