├── .gitignore ├── .gitmodules ├── beziermatrix.py ├── README.md ├── bezierloss.py ├── npz.py ├── infer_bezierae.py ├── beziercurve.py ├── train_bezierae.py ├── infer_beziersketch.py ├── train_beziersketch.py └── bezierae.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | logs/ 3 | junks/ 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "quickdraw"] 2 | path = quickdraw 3 | url = git@github.com:dasayan05/quickdraw_nn_dataset.git 4 | -------------------------------------------------------------------------------- /beziermatrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import comb as choose 3 | 4 | def bezier_matrix(degree): 5 | m = degree 6 | Q = np.zeros((degree + 1, degree + 1)) 7 | for i in range(degree + 1): 8 | for j in range(degree + 1): 9 | if (0 <= (i+j)) and ((i+j) <= degree): 10 | Q[i,j] = choose(m, j) * choose(m-j, m-i-j) * ((-1)**(m-i-j)) 11 | return Q -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Official repository for "BézierSketch: A generative model for scalable vector sketches" accepted @ ECCV 2020 2 | 3 | arXiv link: [https://arxiv.org/abs/2007.02190](https://arxiv.org/abs/2007.02190) 4 | 5 | **Abstract:** The study of neural generative models of human sketches is a fascinating contemporary modeling problem due to the links between sketch image generation and the human drawing process. The landmark SketchRNN provided breakthrough by sequentially generating sketches as a sequence of waypoints. However this leads to low-resolution image generation, and failure to model long sketches. In this paper we present BézierSketch, a novel generative model for fully vector sketches that are automatically scalable and high-resolution. To this end, we first introduce a novel inverse graphics approach to stroke embedding that trains an encoder to embed each stroke to its best fit Bézier curve. This enables us to treat sketches as short sequences of paramaterized strokes and thus train a recurrent sketch generator with greater capacity for longer sketches, while producing scalable high-resolution results. We report qualitative and quantitative results on the Quick, Draw! benchmark. -------------------------------------------------------------------------------- /bezierloss.py: -------------------------------------------------------------------------------- 1 | import torch, numpy as np 2 | import torch.nn as nn 3 | from beziermatrix import bezier_matrix 4 | 5 | class BezierLoss(nn.Module): 6 | def __init__(self, degree, reg_weight_p = 1e-2, reg_weight_r = 1e-2): 7 | super().__init__() 8 | self.degree = degree 9 | self.M = self._M(self.degree) 10 | if torch.cuda.is_available(): 11 | self.M = self.M.cuda() 12 | self.reg_weight_p = reg_weight_p 13 | self.reg_weight_r = reg_weight_r 14 | 15 | def _consecutive_dist(self, XY): 16 | return (((XY[1:,:] - XY[0:-1,:])**2).sum(axis=1))**0.5 17 | 18 | def _heuristic_ts(self, XY): 19 | ds = self._consecutive_dist(XY) 20 | ds = ds / ds.sum() 21 | return torch.cumsum(torch.tensor([0., *ds]), 0) 22 | 23 | def _T(self, ts, d, dtype=torch.float32): 24 | ts = ts[..., np.newaxis] 25 | Q = [ts**n for n in range(d, -1, -1)] 26 | Q = torch.cat(Q, 1) 27 | if torch.cuda.is_available(): 28 | Q = Q.cuda() 29 | return Q 30 | 31 | def _M(self, d: 'degree'): 32 | return torch.tensor(bezier_matrix(d), dtype=torch.float32) 33 | 34 | def forward(self, P, R, XY, ts=None): 35 | # breakpoint() 36 | if R is not None: 37 | C = torch.mm(self._T(ts, self.degree), torch.mm(self.M, torch.diag(R))) 38 | C = C / C.sum(1).unsqueeze(1) 39 | C = torch.mm(C, P) 40 | else: 41 | C = torch.mm(self._T(ts, self.degree), torch.mm(self.M, P)) 42 | 43 | if XY is None: 44 | return C 45 | else: 46 | l = ((C - XY)**2).mean() + self.reg_weight_p * (self._consecutive_dist(P)**2).mean() 47 | return l -------------------------------------------------------------------------------- /npz.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class NPZWriter(object): 5 | def __init__(self, filepath): 6 | super().__init__() 7 | # Track parameters 8 | self.filepath = filepath 9 | # internal list 10 | self.tr, self.ts, self.vl = [], [], [] 11 | 12 | def add(self, ctrlpt_batch, start_batch, n_strokes): 13 | for ctrlpt, start, n_stroke in zip(ctrlpt_batch, start_batch, n_strokes): 14 | # ctrlpts = torch.unbind(ctrlpt[:n_stroke.item()], dim=0) 15 | # starts = torch.unbind(start[:n_stroke.item()], dim=0) 16 | 17 | # populate this 18 | sketch = np.empty((0, 3), dtype=np.float32) 19 | 20 | for c, s in zip(ctrlpt, start): 21 | c = c.detach().cpu().numpy().reshape((-1, 2)) 22 | s = s.detach().cpu().numpy() 23 | 24 | P0 = np.array([[0., 0.]]) # start P 25 | c = np.cumsum(np.concatenate((P0, c), 0), 0) 26 | c = c + s 27 | 28 | q = np.zeros((c.shape[0], 1), dtype=np.float32); q[-1, 0] = 1. 29 | 30 | sketch = np.vstack((sketch, np.hstack((c, q)))) 31 | 32 | sketch[:,:2] *= 255. 33 | sketch = sketch.astype(np.int16) 34 | sketch[:,:2] -= sketch[0,:2] 35 | sketch[1:,:2] -= sketch[:-1,:2] 36 | 37 | R = np.random.rand() 38 | 39 | if R < 0.9: 40 | self.tr.append( sketch[1:, :] ) 41 | elif R >= 0.9 and R < 0.95: 42 | self.ts.append( sketch[1:, :] ) 43 | else: 44 | self.vl.append( sketch[1:, :]) 45 | 46 | def flush(self): 47 | tr = np.array(self.tr, dtype=np.object) 48 | ts = np.array(self.ts, dtype=np.object) 49 | vl = np.array(self.vl, dtype=np.object) 50 | 51 | with open(self.filepath, 'wb') as f: 52 | np.savez(f, train=tr, test=ts, valid=vl) -------------------------------------------------------------------------------- /infer_bezierae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch, numpy as np 3 | import matplotlib.pyplot as plt 4 | from torch.nn.utils.rnn import pad_packed_sequence 5 | 6 | from quickdraw.quickdraw import QuickDraw 7 | from beziercurve import draw_bezier 8 | 9 | def inference(qdl, model, layers, hidden, nsamples, bezier_degree_low, bezier_degree_high, savefile): 10 | with torch.no_grad(): 11 | rsamples = bezier_degree_high - bezier_degree_low + 1 12 | fig, ax = plt.subplots(nsamples, (rsamples + 1), figsize=((rsamples + 1) * 2, nsamples * 2)) 13 | for i, (X, _) in enumerate(qdl): 14 | if i >= nsamples: 15 | break 16 | 17 | h_initial = torch.zeros(layers * 2, 1, hidden, dtype=torch.float32) 18 | c_initial = torch.zeros(layers * 2, 1, hidden, dtype=torch.float32) 19 | if torch.cuda.is_available(): 20 | X, h_initial, c_initial = X.cuda(), h_initial.cuda(), c_initial.cuda() 21 | 22 | X_, l_ = pad_packed_sequence(X) 23 | 24 | if torch.cuda.is_available(): 25 | X_numpy = X_.squeeze().cpu().numpy() # get rid of the obvious 1-th dimension which is 1 (because batch_size == 1) 26 | else: 27 | X_numpy = X_.squeeze().numpy() 28 | 29 | if model.rational: 30 | ctrlpt, ratw = model(X, h_initial, c_initial) 31 | else: 32 | ctrlpt = model(X, h_initial, c_initial) 33 | 34 | # normal = torch.distributions.Normal(ctrlpt.squeeze(), torch.zeros_like(ctrlpt.squeeze())) 35 | 36 | ax[i, 0].scatter(X_numpy[:, 0], X_numpy[:,1]) 37 | ax[i, 0].plot(X_numpy[:,0], X_numpy[:,1]) 38 | ax[i, 0].set_xticks([]); ax[i, 0].set_yticks([]) 39 | 40 | for z in range(bezier_degree_low, bezier_degree_high + 1): 41 | ctrlpt_ = ctrlpt[z - bezier_degree_low].squeeze() 42 | 43 | if model.rational: 44 | ratw_ = ratw[z - bezier_degree_low].squeeze() 45 | ratw_ = torch.cat([torch.tensor([5.,], device=ratw_.device), ratw_, torch.tensor([5.,], device=ratw_.device)], 0) 46 | ratw_ = torch.sigmoid(ratw_) 47 | else: 48 | ratw_ = None 49 | 50 | # Decode the encoded DelP1..DelPn 51 | P0 = torch.zeros(1, ctrlpt_.shape[1], device=ctrlpt_.device) 52 | ctrlpt_ = torch.cat([P0, ctrlpt_], 0) 53 | ctrlpt_ = torch.cumsum(ctrlpt_, 0) 54 | draw_bezier(ctrlpt_.cpu().numpy(), ratw_.cpu().numpy() if ratw_ is not None else ratw_, 55 | annotate=False, draw_axis=ax[i, z - bezier_degree_low + 1], 56 | ctrlPointPlotKwargs=dict(color='g', linestyle='--', marker='X', alpha=0.4), 57 | curvePlotKwagrs=dict(color='r')) 58 | ax[i, z - bezier_degree_low + 1].set_xticks([]) 59 | ax[i, z - bezier_degree_low + 1].set_yticks([]) 60 | if i == 0: 61 | ax[i, z - bezier_degree_low + 1].set_title(f'$n={z}$') 62 | 63 | plt.savefig(savefile, bbox_inches='tight', inches=0) 64 | plt.close() -------------------------------------------------------------------------------- /beziercurve.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from scipy.special import comb 4 | from beziermatrix import bezier_matrix 5 | 6 | def bij(t, i, n): 7 | # binomial coefficients 8 | return comb(n, i) * (t ** i) * ((1-t) ** (n-i)) 9 | 10 | def draw_bezier(ctrlPoints, rWeights = None, nCtrlPoints = 0, nPointsCurve = 100, start_xy = None, annotate = True, return_curve=False, 11 | ctrlPointPlotKwargs = dict(marker='X', color='r', linestyle='--'), curvePlotKwagrs = dict(color='g'), 12 | draw_axis = plt): 13 | ''' 14 | Draws a Bezier curve with given control points 15 | 16 | ctrlPoints: shape (n+1, 2) matrix containing all control points 17 | nCtrlPoints: No. of control points. If 0, infered from 'ctrlPoints', otherwise consideres first 'nCtrlPoints' points from 'ctrlPoints' 18 | nPointsCurve: granularity of the Bezier curve 19 | return_curve: returns the points on the curve rather than drawing them 20 | 21 | ctrlPointPlotKwargs: The **kwargs for control point's plot() function 22 | curvePlotKwagrs: The **kwargs for curve's plot() function 23 | ''' 24 | 25 | def T(ts: 'time points', d: 'degree'): 26 | # 'ts' is a vector (np.array) of time points 27 | ts = ts[..., np.newaxis] 28 | Q = tuple(ts**n for n in range(d, -1, -1)) 29 | return np.concatenate(Q, 1) 30 | 31 | if nCtrlPoints == 0: 32 | # Infer the no. of control points 33 | nCtrlPoints, _ = ctrlPoints.shape 34 | else: 35 | # If given, pick first `nCtrlPoints` control points 36 | ctrlPoints = ctrlPoints[0:nCtrlPoints, :] 37 | 38 | # curve = np.zeros((nPointsCurve, 2)) 39 | # for step, t in enumerate(np.linspace(0.0, 1.0, num = nPointsCurve)): 40 | # s = np.zeros_like(ctrlPoints[0]) # Basically [0., 0.] 41 | # for pointID, point in enumerate(ctrlPoints): 42 | # # 'point' has shape (2,) 43 | # s += bij(t, pointID, nCtrlPoints-1) * point 44 | # curve[step] = s 45 | ts = np.linspace(0., 1., num = nPointsCurve) 46 | 47 | if rWeights is None: 48 | curve = np.matmul( 49 | T(ts, nCtrlPoints - 1), 50 | np.matmul( 51 | bezier_matrix(nCtrlPoints-1), 52 | ctrlPoints 53 | ) 54 | ) 55 | else: 56 | curve = np.matmul( 57 | T(ts, nCtrlPoints - 1), 58 | np.matmul( 59 | bezier_matrix(nCtrlPoints-1), 60 | np.diag(rWeights) 61 | ) 62 | ) 63 | curve = curve / np.expand_dims(curve.sum(1), 1) 64 | curve = np.matmul(curve, ctrlPoints) 65 | 66 | if return_curve: # Return the points of the curve as 'np.array' 67 | return curve 68 | 69 | if start_xy is not None: 70 | curve = curve + start_xy 71 | ctrlPoints = ctrlPoints + start_xy 72 | 73 | # Plot the curve 74 | draw_axis.plot(ctrlPoints[:,0], ctrlPoints[:,1], **ctrlPointPlotKwargs) 75 | for n, ctrlPoint in enumerate(ctrlPoints): 76 | if annotate: 77 | draw_axis.annotate(str(n), (ctrlPoint[0], ctrlPoint[1]), color=ctrlPointPlotKwargs['color']) 78 | 79 | 80 | draw_axis.plot(curve[:,0], curve[:,1], **curvePlotKwagrs) 81 | for n, curvePoint in enumerate(curve): 82 | if n % 10 == 0 and annotate: 83 | draw_axis.annotate(str(n), (curvePoint[0], curvePoint[1]), color=curvePlotKwagrs['color']) 84 | 85 | if __name__ == '__main__': 86 | ## Sample usage of the 'draw_bezier()' function 87 | 88 | # few definitions 89 | degree = 4 90 | 91 | # random control points over [-30,30] range 92 | ctrlPoints = np.random.randint(-30, 30, (degree + 1, 2)).astype(np.float_) 93 | 94 | fig = plt.figure() 95 | draw_bezier(ctrlPoints, draw_axis=plt.gca()) 96 | plt.show() -------------------------------------------------------------------------------- /train_bezierae.py: -------------------------------------------------------------------------------- 1 | import sys, os, random 2 | import torch, numpy as np 3 | import torch.utils.tensorboard as tb 4 | from torch.nn.utils.rnn import pad_packed_sequence 5 | 6 | from quickdraw.quickdraw import QuickDraw 7 | from bezierae import RNNBezierAE 8 | from infer_bezierae import inference 9 | 10 | def length_gt(s, f): 11 | if len(s[0]) > f: 12 | return True, s 13 | else: 14 | return False, None 15 | 16 | def main( args ): 17 | chosen_classes = [ 'cat', 'chair', 'face' , 'firetruck', 'mosquito', 'owl', 'pig', 'purse', 'shoe' ] 18 | if args.iam: 19 | chosen_classes = ['iam'] 20 | 21 | qds = QuickDraw(args.root, categories=[chosen_classes[args.n_class],], raw=args.raw, npz=args.npz, 22 | max_sketches_each_cat=args.max_sketches_each_cat, mode=QuickDraw.STROKE, start_from_zero=True, verbose=True, problem=QuickDraw.ENCDEC) 23 | qdl = qds.get_dataloader(args.batch_size) 24 | 25 | qds_infer = QuickDraw(args.root, categories=[chosen_classes[args.n_class],], filter_func=lambda s: length_gt(s, 5), 26 | raw=args.raw, npz=args.npz, max_sketches_each_cat=100, mode=QuickDraw.STROKE, start_from_zero=True, verbose=True, problem=QuickDraw.ENCDEC) 27 | 28 | # chosen device 29 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 30 | 31 | model = RNNBezierAE(2, args.hidden, args.layers, args.latent, args.bezier_degree_low, args.bezier_degree_high, 32 | bidirectional=True, dropout=args.dropout, rational=args.rational) 33 | 34 | model = model.float() 35 | if torch.cuda.is_available(): 36 | model = model.cuda() 37 | 38 | mseloss = torch.nn.MSELoss() 39 | optim = torch.optim.Adam(model.parameters(), lr=args.lr) 40 | sched = torch.optim.lr_scheduler.StepLR(optim, step_size=10, gamma=0.8) 41 | 42 | writer = tb.SummaryWriter(os.path.join(args.base, 'logs', args.tag)) 43 | 44 | count = 0 45 | for e in range(args.epochs): 46 | 47 | model.train() 48 | for i, (X, _) in enumerate(qdl): 49 | # break 50 | h_initial = torch.zeros(args.layers * 2, args.batch_size, args.hidden, dtype=torch.float32) 51 | c_initial = torch.zeros(args.layers * 2, args.batch_size, args.hidden, dtype=torch.float32) 52 | 53 | if torch.cuda.is_available(): 54 | X, h_initial, c_initial = X.cuda(), h_initial.cuda(), c_initial.cuda() 55 | 56 | # Unpacking the X, nothing more 57 | X_, L_ = pad_packed_sequence(X, batch_first=True) 58 | 59 | out, regu = model(X, h_initial, c_initial) 60 | 61 | batch_losses = [] 62 | for z_out in out: 63 | for o, x_, l_ in zip(z_out, X_, L_): 64 | # per sample iteration 65 | batch_losses.append( mseloss(o[:l_, :], x_[:l_, :]) ) 66 | 67 | # breakpoint() 68 | REC_loss = sum(batch_losses) / len(batch_losses) 69 | REC_loss = REC_loss + regu * args.regp 70 | 71 | loss = REC_loss 72 | 73 | optim.zero_grad() 74 | loss.backward() 75 | optim.step() 76 | 77 | if i % args.interval == 0: 78 | count += 1 79 | print(f'[Training: {i}/{e}/{args.epochs}] -> Loss: {REC_loss:.4f}') 80 | writer.add_scalar('train/loss/total', loss.item(), global_step=count) 81 | 82 | # save after every epoch 83 | torch.save(model.state_dict(), os.path.join(args.base, args.modelname)) 84 | 85 | model.eval() 86 | savefile = os.path.join(args.base, 'logs', args.tag, str(e) + '.pdf') 87 | inference(qds_infer.get_dataloader(1), model, layers=args.layers, hidden=args.hidden, 88 | bezier_degree_low=args.bezier_degree_low, bezier_degree_high=args.bezier_degree_high, 89 | savefile=savefile, nsamples=args.nsample) 90 | 91 | # invoke scheduler 92 | sched.step() 93 | 94 | if __name__ == '__main__': 95 | import argparse 96 | parser = argparse.ArgumentParser() 97 | 98 | parser.add_argument('--root', type=str, required=True, help='quickdraw binary file') 99 | parser.add_argument('--iam', action='store_true', help='Use IAM dataset') 100 | parser.add_argument('--base', type=str, required=False, default='.', help='base folder of operation (needed for condor)') 101 | parser.add_argument('--n_class', '-c', type=int, required=False, default=0, help='no. of classes') 102 | parser.add_argument('--raw', action='store_true', help='Use raw QuickDraw data') 103 | parser.add_argument('--npz', action='store_true', help='Use .npz QuickDraw data') 104 | parser.add_argument('--max_sketches_each_cat', '-n', type=int, required=False, default=25000, help='Max no. of sketches each category') 105 | 106 | parser.add_argument('-R', '--rational', action='store_true', help='Rational bezier curve ?') 107 | parser.add_argument('--hidden', type=int, required=False, default=16, help='no. of hidden neurons') 108 | parser.add_argument('--layers', type=int, required=False, default=1, help='no of layers in encoder RNN') 109 | parser.add_argument('--latent', type=int, required=False, default=256, help='length of the degree agnostic latent vector') 110 | parser.add_argument('-y', '--bezier_degree_low', type=int, required=False, default=9, help='lowest degree of the bezier') 111 | parser.add_argument('-z', '--bezier_degree_high', type=int, required=False, default=9, help='highest degree of the bezier') 112 | 113 | parser.add_argument('-b','--batch_size', type=int, required=False, default=128, help='batch size') 114 | parser.add_argument('--dropout', type=float, required=False, default=0.8, help='Dropout rate') 115 | parser.add_argument('--lr', type=float, required=False, default=1e-4, help='learning rate') 116 | parser.add_argument('-e', '--epochs', type=int, required=False, default=40, help='no of epochs') 117 | parser.add_argument('--regp', type=float, required=False, default=1e-2, help='Regularizer weight on control points') 118 | 119 | parser.add_argument('--tag', type=str, required=False, default='main', help='run identifier') 120 | parser.add_argument('-m', '--modelname', type=str, required=False, default='model', help='name of saved model') 121 | parser.add_argument('-i', '--interval', type=int, required=False, default=100, help='logging interval') 122 | parser.add_argument('--nsample', type=int, required=False, default=6, help='no. of data samples for inference') 123 | args = parser.parse_args() 124 | 125 | main( args ) -------------------------------------------------------------------------------- /infer_beziersketch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch, numpy as np 3 | dist = torch.distributions 4 | import matplotlib.pyplot as plt 5 | from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence 6 | 7 | from quickdraw.quickdraw import QuickDraw 8 | from beziercurve import draw_bezier 9 | 10 | def drawsketch(ctrlpts, ratws, st_starts, n_stroke, draw_axis=plt.gca(), invert_y=True): 11 | ctrlpts, ratws, st_starts = ctrlpts[:n_stroke], ratws[:n_stroke], st_starts[:n_stroke] 12 | # ctrlpts = ctrlpts.view(-1, ctrlpts.shape[-1] // 2, 2) 13 | 14 | z_ = torch.ones((ratws.shape[0], 1), device=ratws.device) * 5. # sigmoid(5.) is close to 1 15 | ratws = torch.cat([z_, ratws, z_], 1) 16 | for ctrlpt, ratw, st_start in zip(ctrlpts, torch.sigmoid(ratws), st_starts): 17 | 18 | if len(ctrlpt.shape) == 1: 19 | ctrlpt = ctrlpt.view(-1, 2) 20 | 21 | # Decode the DelP1..DelPn 22 | P0 = torch.zeros(1, 2, device=ctrlpts[0].device) 23 | # breakpoint() 24 | ctrlpt = torch.cat([P0, ctrlpt], 0) 25 | ctrlpt = torch.cumsum(ctrlpt, 0) 26 | 27 | ctrlpt = ctrlpt.detach().cpu().numpy() 28 | ratw = ratw.detach().cpu().numpy() 29 | st_start = st_start.detach().cpu().numpy() 30 | # over-writing this for now 31 | 32 | draw_bezier(ctrlpt, rWeights=None, start_xy=st_start, draw_axis=draw_axis, annotate=False, 33 | ctrlPointPlotKwargs=dict(color='g', linestyle='--', marker='X', alpha=0.4), 34 | curvePlotKwagrs=dict(color='r')) 35 | if invert_y: 36 | draw_axis.invert_yaxis() 37 | 38 | def stroke_embed(batch, initials, embedder, bezier_degree, bezier_degree_low, variational=False, inf_loss=False): 39 | h_initial, c_initial = initials 40 | # Redundant, but thats fine 41 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 42 | 43 | # accumulate all info into these empty lists 44 | sketches_ctrlpt, sketches_ratw, sketches_st_starts, sketches_stopbits = [], [], [], [] 45 | deg_losses = [] 46 | n_strokes = [] 47 | 48 | for sk, _ in batch: 49 | # for each sketch in the batch 50 | st_starts = torch.tensor([st[0,:2] for st in sk], device=device) 51 | sk = [torch.tensor(st[:,:-1], device=device) - st_start for st, st_start in zip(sk, st_starts)] 52 | ls = [st.shape[0] for st in sk] 53 | sk = pad_sequence(sk, batch_first=True) 54 | sk = pack_padded_sequence(sk, ls, batch_first=True, enforce_sorted=False) 55 | 56 | if embedder.rational: 57 | emb_ctrlpt, emb_ratw = embedder(sk, h_initial, c_initial) 58 | else: 59 | if not inf_loss: 60 | emb_ctrlpt = embedder(sk, h_initial, c_initial, inf_loss=False) 61 | else: 62 | emb_ctrlpt, deg_loss = embedder(sk, h_initial, c_initial, inf_loss=True) 63 | # breakpoint() 64 | 65 | if not inf_loss: 66 | emb_ctrlpt = emb_ctrlpt[bezier_degree - bezier_degree_low] 67 | sketches_ctrlpt.append(emb_ctrlpt.view(len(ls), -1)) 68 | else: 69 | sketches_ctrlpt.append(emb_ctrlpt) 70 | deg_losses.append(deg_loss) 71 | # breakpoint() 72 | 73 | if embedder.rational: 74 | sketches_ratw.append(emb_ratw) 75 | sketches_st_starts.append(st_starts) 76 | # create stopbits 77 | stopbit = torch.zeros(len(ls), 1, device=device); stopbit[-1, 0] = 1. 78 | sketches_stopbits.append(stopbit) 79 | n_strokes.append(len(ls)) 80 | 81 | n_strokes = torch.tensor(n_strokes, device=device) 82 | if not inf_loss: 83 | sketches_ctrlpt = pad_sequence(sketches_ctrlpt, batch_first=True) 84 | 85 | if embedder.rational: 86 | sketches_ratw = pad_sequence(sketches_ratw, batch_first=True) 87 | sketches_st_starts = pad_sequence(sketches_st_starts, batch_first=True) 88 | sketches_stopbits = pad_sequence(sketches_stopbits, batch_first=True, padding_value=1.0) 89 | 90 | # For every sketch in a batch: 91 | # For every stroke in the sketch: 92 | # 1. (Control Point, Rational Weights) pair 93 | # 2. Start location of the stroke with respect to a global reference (of the sketch) 94 | if embedder.rational: 95 | return sketches_ctrlpt, sketches_ratw, sketches_st_starts, sketches_stopbits, n_strokes 96 | else: 97 | if not inf_loss: 98 | return sketches_ctrlpt, sketches_st_starts, sketches_stopbits, n_strokes 99 | else: 100 | return (sketches_ctrlpt, deg_losses), sketches_st_starts, sketches_stopbits, n_strokes 101 | 102 | def inference(qdl, model, embedder, emblayers, embhidden, layers, hidden, n_mix, 103 | nsamples, rsamples, variational, bezier_degree, bezier_degree_low, savefile, device, invert_y): 104 | with torch.no_grad(): 105 | fig, ax = plt.subplots(nsamples, (rsamples + 1), figsize=(rsamples * 8, nsamples * 4)) 106 | for i, B in enumerate(qdl): 107 | 108 | h_initial_emb = torch.zeros(emblayers * 2, 256, embhidden, dtype=torch.float32) 109 | c_initial_emb = torch.zeros(emblayers * 2, 256, embhidden, dtype=torch.float32) 110 | h_initial = torch.zeros(layers * 2, 1, hidden, dtype=torch.float32) 111 | c_initial = torch.zeros(layers * 2, 1, hidden, dtype=torch.float32) 112 | if torch.cuda.is_available(): 113 | h_initial, h_initial_emb, c_initial, c_initial_emb = h_initial.cuda(), h_initial_emb.cuda(), c_initial.cuda(), c_initial_emb.cuda() 114 | 115 | with torch.no_grad(): 116 | if model.rational: 117 | ctrlpts, ratws, starts, _, n_strokes = stroke_embed(B, (h_initial_emb, c_initial_emb), embedder, bezier_degree, bezier_degree_low) 118 | else: 119 | ctrlpts, starts, _, n_strokes = stroke_embed(B, (h_initial_emb, c_initial_emb), embedder, bezier_degree, bezier_degree_low) 120 | ratws = torch.ones(ctrlpts.shape[0], ctrlpts.shape[1], model.n_ratw, device=ctrlpts.device) 121 | 122 | _cpad = torch.zeros(ctrlpts.shape[0], 1, ctrlpts.shape[2], device=device) 123 | _rpad = torch.zeros(ratws.shape[0], 1, ratws.shape[2], device=device) 124 | _spad = torch.zeros(starts.shape[0], 1, starts.shape[2], device=device) 125 | ctrlpts = torch.cat([_cpad, ctrlpts], dim=1) 126 | ratws = torch.cat([_rpad, ratws], dim=1) 127 | starts = torch.cat([_spad, starts], dim=1) 128 | 129 | for i in range(256): 130 | if i == nsamples: 131 | break 132 | 133 | n_stroke = n_strokes[i] 134 | drawsketch(ctrlpts[i,1:n_stroke+1,:], ratws[i,1:n_stroke+1,:], starts[i,1:n_stroke+1,:], 135 | n_stroke, ax[i, 0], invert_y=invert_y) 136 | 137 | for r in range(rsamples): 138 | if model.rational: 139 | out_param_mu, out_param_std, out_param_mix, _ = model((h_initial, c_initial), 140 | ctrlpts[i,:n_stroke,:].unsqueeze(0), ratws[i,:n_stroke,:].unsqueeze(0), starts[i,:n_stroke,:].unsqueeze(0)) 141 | n_stroke = out_ctrlpts.shape[0] 142 | drawsketch(out_ctrlpts, out_ratws, out_starts, n_stroke, ax[i, 1+r], invert_y=invert_y) 143 | else: 144 | if model.variational: 145 | out_ctrlpts, out_starts = model((h_initial, c_initial), ctrlpts[i,:n_stroke,:].unsqueeze(0), None, starts[i,:n_stroke,:].unsqueeze(0), inference=True) 146 | else: 147 | out_ctrlpts, out_starts= model((h_initial, c_initial), ctrlpts[i,:n_stroke,:].unsqueeze(0), None, starts[i,:n_stroke,:].unsqueeze(0), inference=True) 148 | out_ratws = torch.ones(out_ctrlpts.shape[1], model.n_ratw) # FAKE IT 149 | 150 | n_stroke = out_ctrlpts.shape[0] 151 | drawsketch(out_ctrlpts, out_ratws, out_starts, n_stroke, ax[i, 1+r], invert_y=invert_y) 152 | 153 | break # just one batch enough 154 | 155 | plt.xticks([]); plt.yticks([]) 156 | plt.savefig(savefile) 157 | plt.close() -------------------------------------------------------------------------------- /train_beziersketch.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch, os, numpy as np 3 | from torch.distributions import Normal 4 | from torch.utils import tensorboard as tb 5 | 6 | from quickdraw.quickdraw import QuickDraw 7 | from bezierae import RNNBezierAE, RNNSketchAE, gmm_loss 8 | from infer_beziersketch import inference, drawsketch, stroke_embed 9 | from npz import NPZWriter 10 | 11 | def select_degree(ctrlpts, deg_loss): 12 | batch = [] 13 | for cpts, degloss in zip(ctrlpts, deg_loss): 14 | sketch = [] 15 | for i_stroke, dloss in enumerate(degloss): 16 | opt_degree = (dloss < 5e-5).nonzero()[0] 17 | if opt_degree.size != 0: 18 | opt_degree = opt_degree[0] 19 | else: 20 | opt_degree = len(dloss) - 1 21 | 22 | t = cpts[opt_degree][i_stroke,:] 23 | sketch.append( t ) 24 | batch.append(sketch) 25 | return batch 26 | 27 | def main( args ): 28 | chosen_classes = [ 'cat', 'chair', 'mosquito', 'firetruck', 'owl', 'pig', 'face', 'purse', 'shoe' ] 29 | if args.iam: 30 | chosen_classes = ['iam'] 31 | 32 | qd = QuickDraw(args.root, categories=[chosen_classes[args.n_class],], max_sketches_each_cat=args.max_sketches_each_cat, 33 | verbose=True, normalize_xy=True, start_from_zero=False, mode=QuickDraw.STROKESET, raw=args.raw, npz=args.npz) 34 | 35 | qdtrain, qdtest = qd.split(0.8) 36 | qdltrain = qdtrain.get_dataloader(args.batch_size) 37 | 38 | # chosen device 39 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 40 | 41 | # Embedder model (pretrained and freezed) 42 | embedder = RNNBezierAE(2, args.embhidden, args.emblayers, args.emblatent, args.bezier_degree_low, args.bezier_degree_high, 43 | bidirectional=True, rational=args.rational) 44 | embmodel = os.path.join(args.base, args.embmodel) 45 | if os.path.exists(embmodel): 46 | embedder.load_state_dict(torch.load(embmodel)) 47 | else: 48 | raise FileNotFoundError('Embedding model not found') 49 | h_initial_emb = torch.zeros(args.emblayers * 2, args.batch_size, args.embhidden, dtype=torch.float32) 50 | c_initial_emb = torch.zeros(args.emblayers * 2, args.batch_size, args.embhidden, dtype=torch.float32) 51 | if torch.cuda.is_available(): 52 | embedder, h_initial_emb, c_initial_emb = embedder.cuda(), h_initial_emb.cuda(), c_initial_emb.cuda() 53 | embedder.eval() 54 | 55 | # RNN Sketch model 56 | n_ratw = args.bezier_degree + 1 - 2 57 | n_ctrlpt = (args.bezier_degree + 1 - 1) * 2 58 | model = RNNSketchAE((n_ctrlpt, n_ratw, 2), args.hidden, dropout=args.dropout, n_mixture=args.n_mix, 59 | rational=args.rational, variational=args.variational, concatz=args.concatz) 60 | 61 | h_initial = torch.zeros(args.layers * 2, args.batch_size, args.hidden, dtype=torch.float32) 62 | c_initial = torch.zeros(args.layers * 2, args.batch_size, args.hidden, dtype=torch.float32) 63 | if torch.cuda.is_available(): 64 | model, h_initial, c_initial = model.cuda(), h_initial.cuda(), c_initial.cuda() 65 | 66 | optim = torch.optim.Adam(model.parameters(), lr=args.lr) 67 | 68 | writer = tb.SummaryWriter(os.path.join(args.base, 'logs', args.tag)) 69 | npzwriter = NPZWriter(args.npzfile) 70 | linear = lambda e, e0, T: max(min((e - e0) / float(T), 1.), 0.) 71 | 72 | count, best_loss = 0, np.inf 73 | for e in range(args.epochs): 74 | model.train() 75 | for i, B in enumerate(qdltrain): 76 | with torch.no_grad(): 77 | if args.rational: 78 | ctrlpts, ratws, starts, stopbits, n_strokes = stroke_embed(B, (h_initial_emb, c_initial_emb), embedder, args.bezier_degree, args.bezier_degree_low) 79 | else: 80 | ctrlpts, starts, stopbits, n_strokes = stroke_embed(B, (h_initial_emb, c_initial_emb), embedder, args.bezier_degree, args.bezier_degree_low, 81 | inf_loss=args.producenpz or args.rendersketch) 82 | ratws = torch.ones(args.batch_size, starts.shape[1], n_ratw, device=device) # FAKE IT 83 | if e == 0 and args.producenpz: 84 | ctrlpts = select_degree(*ctrlpts) 85 | ## DO THIS 86 | npzwriter.add(ctrlpts, starts, n_strokes) 87 | if i % 10 == 0: 88 | npzwriter.flush() 89 | continue 90 | 91 | if args.rendersketch: 92 | ctrlpts = select_degree(*ctrlpts) 93 | for b in range(len(ctrlpts)): 94 | fig, ax = plt.subplots(1, 1, figsize=(4, 4)) 95 | if b > 20: 96 | break 97 | drawsketch(ctrlpts[b], ratws[b], starts[b], n_strokes[b], draw_axis=ax, invert_y=not args.raw) 98 | ax.set_xticks([]); ax.set_yticks([]) 99 | plt.savefig(f'junks/{e}_{i}_{b}.png', bbox_inches='tight', inches=0) 100 | plt.close() 101 | continue 102 | 103 | 104 | _cpad = torch.zeros(ctrlpts.shape[0], 1, ctrlpts.shape[2], device=device) 105 | _rpad = torch.zeros(ratws.shape[0], 1, ratws.shape[2], device=device) 106 | _spad = torch.zeros(starts.shape[0], 1, starts.shape[2], device=device) 107 | _bpad = torch.zeros(stopbits.shape[0], 1, stopbits.shape[2], device=device) 108 | ctrlpts, ratws, starts, stopbits = torch.cat([_cpad, ctrlpts], dim=1), \ 109 | torch.cat([_rpad, ratws], dim=1), \ 110 | torch.cat([_spad, starts], dim=1), \ 111 | torch.cat([_bpad, stopbits], dim=1) 112 | 113 | if args.rational: 114 | out_param_mu, out_param_std, out_param_mix, out_stopbits = model((h_initial, c_initial), ctrlpts, ratws, starts) 115 | else: 116 | if args.variational: 117 | out_param_mu, out_param_std, out_param_mix, out_stopbits, KLD = model((h_initial, c_initial), ctrlpts, None, starts) 118 | ann = linear(e, 0, 10) 119 | else: 120 | out_param_mu, out_param_std, out_param_mix, out_stopbits = model((h_initial, c_initial), ctrlpts, None, starts) 121 | KLD = 0. 122 | ann = 0. 123 | 124 | loss = [] 125 | for mu_, std_, mix_, b_, c, r, s, b, l in zip(out_param_mu, out_param_std, out_param_mix, out_stopbits, 126 | ctrlpts, ratws, starts, stopbits, n_strokes): 127 | if l >= 1: 128 | c, r, s, b = c[1:l.item()+1, ...], r[1:l.item()+1, ...], s[1:l.item()+1, ...], b[1:l.item()+1, ...] 129 | mu_, std_, mix_, b_ = mu_[:l.item(), ...], std_[:l.item(), ...], mix_[:l.item(), ...], b_[:l.item(), ...] 130 | # preparing for mdn loss calc 131 | mu_ = mu_.view(1, l.item(), args.n_mix, -1) 132 | std_ = std_.view(1, l.item(), args.n_mix, -1) 133 | if args.rational: 134 | param_ = torch.cat([c, r, s], -1).view(1, l.item(), -1) 135 | else: 136 | param_ = torch.cat([c, s], -1).view(1, l.item(), -1) 137 | mix_ = mix_.log().view(1, l.item(), args.n_mix) 138 | gmml = gmm_loss(param_, mu_, std_, mix_, reduce=True) 139 | stopbitloss = (b - b_).pow(2).mean() 140 | loss.append( gmml + stopbitloss ) 141 | 142 | recon = sum(loss) / len(loss) 143 | loss = recon + KLD * args.wkl * ann 144 | 145 | if i % args.interval == 0: 146 | print(f'[Training: {i}/{e}/{args.epochs}] -> Loss: {recon:.4f} + {ann:.4f} x {KLD:.4f} = {loss:.4f}') 147 | writer.add_scalar('train-loss', loss.item(), global_step=count) 148 | count += 1 149 | 150 | optim.zero_grad() 151 | loss.backward() 152 | optim.step() 153 | 154 | # flush the npz 155 | if e == 0 and args.producenpz: 156 | npzwriter.flush() 157 | exit() 158 | 159 | # # evaluation phase 160 | # avg_loss = 0. 161 | model.eval() 162 | # for i, B in enumerate(qdltest): 163 | # with torch.no_grad(): 164 | # if args.rational: 165 | # ctrlpts, ratws, starts, stopbits, n_strokes = stroke_embed(B, (h_initial_emb, c_initial_emb), embedder) 166 | # else: 167 | # ctrlpts, starts, stopbits, n_strokes = stroke_embed(B, (h_initial_emb, c_initial_emb), embedder) 168 | # ratws = torch.ones(args.batch_size, ctrlpts.shape[1], n_ratw, device=device) # FAKE IT 169 | 170 | # if args.rational: 171 | # out_param_mu, out_param_std, out_param_mix, out_stopbits = model((h_initial, c_initial), ctrlpts, ratws, starts) 172 | # else: 173 | # if args.variational: 174 | # out_param_mu, out_param_std, out_param_mix, out_stopbits, KLD = model((h_initial, c_initial), ctrlpts, None, starts) 175 | # ann = linear(e, 0, 10) 176 | # else: 177 | # out_param_mu, out_param_std, out_param_mix, out_stopbits = model((h_initial, c_initial), ctrlpts, None, starts) 178 | # KLD = 0. 179 | # ann = 0. 180 | 181 | # _cpad = torch.zeros(ctrlpts.shape[0], 1, ctrlpts.shape[2], device=device) 182 | # _rpad = torch.zeros(ratws.shape[0], 1, ratws.shape[2], device=device) 183 | # _spad = torch.zeros(starts.shape[0], 1, starts.shape[2], device=device) 184 | # _bpad = torch.zeros(stopbits.shape[0], 1, stopbits.shape[2], device=device) 185 | # ctrlpts, ratws, starts, stopbits = torch.cat([_cpad, ctrlpts], dim=1), \ 186 | # torch.cat([_rpad, ratws], dim=1), \ 187 | # torch.cat([_spad, starts], dim=1), \ 188 | # torch.cat([_bpad, stopbits], dim=1) 189 | 190 | # loss = [] 191 | # for mu_, std_, mix_, b_, c, r, s, b, l in zip(out_param_mu, out_param_std, out_param_mix, out_stopbits, 192 | # ctrlpts, ratws, starts, stopbits, n_strokes): 193 | # if l >= 1: 194 | # c, r, s, b = c[1:l.item()+1, ...], r[1:l.item()+1, ...], s[1:l.item()+1, ...], b[1:l.item()+1, ...] 195 | # mu_, std_, mix_, b_ = mu_[:l.item(), ...], std_[:l.item(), ...], mix_[:l.item(), ...], b_[:l.item(), ...] 196 | # # preparing for mdn loss calc 197 | # mu_ = mu_.view(1, l.item(), args.n_mix, -1) 198 | # std_ = std_.view(1, l.item(), args.n_mix, -1) 199 | # if args.rational: 200 | # param_ = torch.cat([c, r, s], -1).view(1, l.item(), -1) 201 | # else: 202 | # param_ = torch.cat([c, s], -1).view(1, l.item(), -1) 203 | # mix_ = mix_.log().view(1, l.item(), args.n_mix) 204 | # gmml = gmm_loss(param_, mu_, std_, mix_, reduce=True) 205 | # stopbitloss = (-b*torch.log(b_)).mean() 206 | # loss.append( gmml + stopbitloss ) 207 | 208 | # loss = sum(loss) / len(loss) + KLD * args.wkl * ann 209 | 210 | # avg_loss = ((avg_loss * i) + loss.item()) / (i + 1) 211 | 212 | # print(f'[Testing: -/{e}/{args.epochs}] -> Loss: {avg_loss:.4f}') 213 | # writer.add_scalar('test-loss', avg_loss, global_step=e) 214 | torch.save(model.state_dict(), os.path.join(args.base, args.modelname)) 215 | 216 | savefile = os.path.join(args.base, 'logs', args.tag, str(e) + '.png') 217 | inference(qdtest.get_dataloader(args.batch_size), model, embedder, emblayers=args.emblayers, embhidden=args.embhidden, 218 | layers=args.layers, hidden=args.hidden, variational=False, bezier_degree=args.bezier_degree, bezier_degree_low=args.bezier_degree_low, 219 | n_mix=args.n_mix, nsamples=args.nsamples, rsamples=args.rsamples, savefile=savefile, device=device, invert_y=not args.raw) 220 | 221 | 222 | if __name__ == '__main__': 223 | import argparse 224 | parser = argparse.ArgumentParser() 225 | 226 | parser.add_argument('--root', type=str, required=True, help='quickdraw binary file') 227 | parser.add_argument('--base', type=str, required=False, default='.', help='base folder of operation (needed for condor)') 228 | parser.add_argument('--n_class', '-c', type=int, required=False, default=0, help='no. of classes') 229 | parser.add_argument('--iam', action='store_true', help='Use IAM dataset') 230 | parser.add_argument('--raw', action='store_true', help='Use raw QuickDraw data') 231 | parser.add_argument('--npz', action='store_true', help='Use .npz QuickDraw data') 232 | parser.add_argument('--max_sketches_each_cat', '-n', type=int, required=False, default=25000, help='Max no. of sketches each category') 233 | 234 | parser.add_argument('--embvariational', action='store_true', help='Impose prior on latent space (in embedder)') 235 | parser.add_argument('--embhidden', type=int, required=False, default=16, help='no. of hidden neurons (in embedder)') 236 | parser.add_argument('--emblayers', type=int, required=False, default=1, help='no of layers (in embedder)') 237 | parser.add_argument('--emblatent', type=int, required=False, default=256, help='dim of latent vector (in embedder)') 238 | parser.add_argument('--embmodel', type=str, required=True, help='path to the pre-trained embedder') 239 | parser.add_argument('-T', '--stochastic_t', action='store_true', help='Use stochastic t-values') 240 | parser.add_argument('-R', '--rational', action='store_true', help='Rational bezier curve ?') 241 | parser.add_argument('--concatz', action='store_true', help='concat z with all inputs in decoder') 242 | parser.add_argument('--hidden', type=int, required=False, default=256, help='no. of hidden neurons') 243 | parser.add_argument('-x', '--n_mix', type=int, required=False, default=3, help='no. of GMM mixtures') 244 | parser.add_argument('--layers', type=int, required=False, default=2, help='no of layers in encoder RNN') 245 | 246 | parser.add_argument('--bezier_degree', type=int, required=False, default=9, help='degree of the bezier') 247 | parser.add_argument('-y', '--bezier_degree_low', type=int, required=False, default=9, help='lowest degree of the bezier') 248 | parser.add_argument('-z', '--bezier_degree_high', type=int, required=False, default=9, help='highest degree of the bezier') 249 | 250 | parser.add_argument('-V', '--variational', action='store_true', help='Impose prior on latent space') 251 | parser.add_argument('--wkl', type=float, required=False, default=1.0, help='weight of the KL term') 252 | 253 | parser.add_argument('-b','--batch_size', type=int, required=False, default=128, help='batch size') 254 | parser.add_argument('--dropout', type=float, required=False, default=0.8, help='Dropout rate') 255 | parser.add_argument('--lr', type=float, required=False, default=1e-4, help='learning rate') 256 | parser.add_argument('-e', '--epochs', type=int, required=False, default=40, help='no of epochs') 257 | # parser.add_argument('--anneal_KLD', action='store_true', help='Increase annealing factor of KLD gradually') 258 | 259 | parser.add_argument('--tag', type=str, required=False, default='main', help='run identifier') 260 | parser.add_argument('--rendersketch', action='store_true', help='Render the sketches (debugging purpose)') 261 | parser.add_argument('-m', '--modelname', type=str, required=False, default='model', help='name of saved model') 262 | parser.add_argument('--npzfile', type=str, required=False, default='ctrlpt.npz', help='SketchRNN style .npz for control points') 263 | parser.add_argument('--producenpz', action='store_true', help='Produce npz') 264 | parser.add_argument('-i', '--interval', type=int, required=False, default=50, help='logging interval') 265 | parser.add_argument('--nsamples', type=int, required=False, default=6, help='no. of data samples for inference') 266 | parser.add_argument('--rsamples', type=int, required=False, default=5, help='no. of distribution samples for inference') 267 | args = parser.parse_args() 268 | 269 | main( args ) -------------------------------------------------------------------------------- /bezierae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import matplotlib.pyplot as plt 5 | import torch.nn.functional as F 6 | import torch.distributions as dist 7 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence 8 | 9 | from bezierloss import BezierLoss 10 | 11 | class RNNBezierAE(nn.Module): 12 | def __init__(self, n_input, n_hidden, n_layer, n_latent, bezier_degree_low, bezier_degree_high, 13 | dtype=torch.float32, bidirectional=True, dropout=0.8, rational=True): 14 | super().__init__() 15 | 16 | # Track parameters 17 | self.n_input, self.n_hidden, self.n_layer = n_input, n_hidden, n_layer 18 | self.n_latent = n_latent 19 | self.bezier_degree_low = bezier_degree_low 20 | self.bezier_degree_high = bezier_degree_high 21 | 22 | self.bezier_degree = list(range(bezier_degree_low, bezier_degree_high + 1)) 23 | self.n_latent_ctrl = [(z + 1 - 1) * 2 for z in self.bezier_degree] # The second '-1' is for Delta_P encoding 24 | self.n_latent_ratw = [z + 1 - 2 for z in self.bezier_degree] 25 | 26 | self.bidirectional = 2 if bidirectional else 1 27 | self.dtype = dtype 28 | self.dropout = dropout 29 | self.rational = rational 30 | 31 | # The t-network 32 | self.tcell = self.tcell = nn.LSTM(self.n_input, self.n_hidden, self.n_layer, 33 | bidirectional=bidirectional, dropout=self.dropout) 34 | 35 | self.t_logits = nn.ModuleList([torch.nn.Linear(self.bidirectional * self.n_hidden, 1) for _ in self.bezier_degree]) 36 | 37 | # ... 38 | n_hc = 2 * self.bidirectional * self.n_hidden 39 | self.hc_project = nn.Linear(n_hc, self.n_latent) 40 | 41 | self.ctrlpt_arms = nn.ModuleList([nn.Linear(self.n_latent, c) for c in self.n_latent_ctrl]) 42 | if self.rational: 43 | self.ratw_arms = nn.ModuleList([nn.Linear(self.n_latent, self.n_latent_ratw) for r in self.n_latent_ratw]) 44 | 45 | # Bezier mechanics 46 | self.bezierlosses = nn.ModuleList([BezierLoss(z, reg_weight_p=1e-3, reg_weight_r=None) for z in self.bezier_degree]) 47 | 48 | def constraint_t(self, ts, lens): 49 | ts = ts.squeeze(-1) 50 | csm = [] 51 | for t, l in zip(ts, lens): 52 | csm.append( torch.cumsum(torch.softmax(t[:l.item()], 0), 0) ) 53 | csm = pad_sequence(csm, batch_first=True, padding_value=0.) 54 | return csm 55 | 56 | def reparam(self, mu, logvar): 57 | std = torch.exp(0.5 * logvar) 58 | eps = torch.randn_like(std) 59 | return (mu + eps * std) 60 | 61 | def forward(self, x, h_initial, c_initial, inf_loss=False): 62 | out, (h_final, c_final) = self.tcell(x, (h_initial, c_initial)) 63 | hns, lens = pad_packed_sequence(out, batch_first=True) 64 | 65 | t_logits = [t_logit(hns) for t_logit in self.t_logits] 66 | ts = [self.constraint_t(t_logit, lens) for t_logit in t_logits] 67 | 68 | 69 | # latent space 70 | h_final = h_final.view(self.n_layer, self.bidirectional, -1, self.n_hidden) 71 | c_final = c_final.view(self.n_layer, self.bidirectional, -1, self.n_hidden) 72 | H = torch.cat([h_final[-1, 0], h_final[-1, 1]], 1) 73 | C = torch.cat([c_final[-1, 0], c_final[-1, 0]], 1) 74 | HC = torch.cat([H, C], 1) # concat all "states" of the LSTM 75 | 76 | hc_projection = F.relu(self.hc_project(HC)) 77 | latent_ctrlpt = [ctrlpt_arm(hc_projection) for ctrlpt_arm in self.ctrlpt_arms] 78 | 79 | # 'P's should be encoded as [P0=0, DelP1, DelP2, ..] 80 | # latent_ctrlpt = latent_ctrlpt.view(-1, self.n_latent_ctrl // 2, 2) 81 | latent_ctrlpt = [ctrlpt.view(-1, ctrlpt.shape[1] // 2, 2) for ctrlpt in latent_ctrlpt] 82 | latent_ctrlpt_return = latent_ctrlpt 83 | P0 = torch.zeros(latent_ctrlpt[0].shape[0], 1, 2, device=latent_ctrlpt[0].device) 84 | latent_ctrlpt = [torch.cat([P0, ctrlpt], 1) for ctrlpt in latent_ctrlpt] 85 | latent_ctrlpt = [torch.cumsum(ctrlpt, 1) for ctrlpt in latent_ctrlpt] 86 | # breakpoint() 87 | 88 | if self.rational: 89 | latent_ratw = [ratw_arm(hc_projection) for ratw_arm in self.ratw_arms] 90 | z_ = torch.ones((latent_ratw[0].shape[0], 1), device=latent_ratw[0].device) * 5. # sigmoid(5.) is close to 1 91 | latent_ratw_padded = [torch.cat([z_, ratw, z_], 1) for ratw in latent_ratw] 92 | 93 | if self.training: 94 | out, regu = [], [] 95 | if self.rational: 96 | latent_ratw_padded_sigm = [torch.sigmoid(r) for r in latent_ratw_padded] 97 | for loss, z_ts, z_latent_ctrlpt, z_latent_ratw in zip(self.bezierlosses, ts, latent_ctrlpt, latent_ratw_padded_sigm): 98 | z_out, z_regu = [], [] 99 | for t, p, r, l in zip(z_ts, z_latent_ctrlpt, z_latent_ratw, lens): 100 | z_out.append( loss(p, r, None, ts=t[:l]) ) 101 | z_regu.append( (loss._consecutive_dist(p)**2).mean() ) 102 | out.append(z_out) 103 | regu.append(z_regu) 104 | else: 105 | for loss, z_ts, z_latent_ctrlpt in zip(self.bezierlosses, ts, latent_ctrlpt): 106 | z_out, z_regu = [], [] 107 | for t, p, l in zip(z_ts, z_latent_ctrlpt, lens): 108 | z_out.append( loss(p, None, None, ts=t[:l]) ) 109 | z_regu.append( (loss._consecutive_dist(p)**2).mean() ) 110 | out.append(z_out) 111 | regu.append(z_regu) 112 | 113 | return out, sum([sum(z_regu)/len(z_regu) for z_regu in regu]) / len(self.bezier_degree) 114 | 115 | else: 116 | if self.rational: 117 | return latent_ctrlpt_return, latent_ratw 118 | else: 119 | if not inf_loss: 120 | return latent_ctrlpt_return 121 | else: 122 | out = [] 123 | XY, _ = pad_packed_sequence(x, batch_first=True) 124 | for loss, z_ts, z_latent_ctrlpt in zip(self.bezierlosses, ts, latent_ctrlpt): 125 | z_out = [] 126 | for t, p, x, l in zip(z_ts, z_latent_ctrlpt, XY, lens): 127 | x = x[:l,:] 128 | z_out.append( loss(p, None, x, ts=t[:l]) ) 129 | out.append(z_out) 130 | 131 | # choose the right degree based on some heuristics 132 | n_degrees = self.bezier_degree_high - self.bezier_degree_low + 1 133 | # i_degree_range = np.arange(self.bezier_degree_low, self.bezier_degree_high + 1) 134 | loss_degs = [] 135 | for i in range(lens.shape[0]): 136 | loss_deg = np.array([out[j][i].item() for j in range(n_degrees)]) 137 | # plt.plot(i_degree_range * (1./25.), loss_deg) 138 | # plt.show() 139 | # breakpoint() 140 | loss_degs.append(loss_deg) 141 | 142 | return latent_ctrlpt_return, loss_degs 143 | 144 | class RNNSketchAE(nn.Module): 145 | def __init__(self, n_inps, n_hidden, n_layer = 2, n_mixture = 3, dropout = 0.8, eps = 1e-8, rational = True, 146 | variational = False, concatz = False): 147 | super().__init__() 148 | 149 | # Track parameters 150 | self.n_ctrlpt, self.n_ratw, self.n_start = n_inps 151 | self.n_hidden = n_hidden 152 | self.n_layer = 2 153 | self.n_hc = 2 * 2 * self.n_hidden 154 | self.n_latent = self.n_hc // 2 155 | self.dropout = dropout 156 | self.n_params = self.n_ctrlpt + (self.n_ratw if rational else 0) + self.n_start 157 | self.n_mixture = n_mixture 158 | self.rational = rational 159 | self.variational = variational 160 | self.concatz = concatz 161 | 162 | self.eps = eps 163 | 164 | # Layer definition 165 | self.encoder = nn.LSTM(self.n_params, self.n_hidden, self.n_layer, bidirectional=True, batch_first=True, dropout=dropout) 166 | if not self.concatz: 167 | self.decoder = nn.LSTM(self.n_params, 2 * self.n_hidden, self.n_layer, bidirectional=False, batch_first=True, dropout=dropout) 168 | else: 169 | self.decoder = nn.LSTM(self.n_params + self.n_latent, 2 * self.n_hidden, self.n_layer, bidirectional=False, batch_first=True, dropout=dropout) 170 | 171 | # Other transformations 172 | self.hc_to_latent = nn.Linear(self.n_hc, self.n_latent) # encoder side 173 | if self.variational: 174 | self.hc_to_latent_logvar = nn.Linear(self.n_hc, self.n_latent) # encoder side 175 | self.latent_to_h0_1 = nn.Linear(self.n_latent, self.n_hidden * 2) # decoder side 176 | self.latent_to_c0_1 = nn.Linear(self.n_latent, self.n_hidden * 2) # decoder side 177 | self.latent_to_h0_2 = nn.Linear(self.n_latent, self.n_hidden * 2) # decoder side 178 | self.latent_to_c0_2 = nn.Linear(self.n_latent, self.n_hidden * 2) # decoder side 179 | self.tanh = nn.Tanh() 180 | 181 | self.param_mu_arm = nn.Linear(self.n_hidden * 2, self.n_params * self.n_mixture) 182 | self.param_std_arm = nn.Linear(self.n_hidden * 2, self.n_params * self.n_mixture) # put through exp() 183 | self.param_mix_arm = nn.Linear(self.n_hidden * 2, self.n_mixture) # put through softmax 184 | self.stopbit_arm = nn.Linear(self.n_hidden * 2, 1) 185 | 186 | def reparam(self, mu, logvar): 187 | std = torch.exp(0.5 * logvar) 188 | eps = torch.randn_like(std) 189 | return (mu + eps * std) 190 | 191 | def forward(self, initials, ctrlpt, ratw, start, inference=False): 192 | h_initial, c_initial = initials 193 | if self.rational: 194 | input = torch.cat([ctrlpt, ratw, start], -1) 195 | else: 196 | input = torch.cat([ctrlpt, start], -1) 197 | _, (hn, cn) = self.encoder(input, (h_initial, c_initial)) 198 | hn = hn.view(self.n_layer, 2, -1, self.n_hidden) 199 | cn = cn.view(self.n_layer, 2, -1, self.n_hidden) 200 | hn, cn = hn[-1,...], cn[-1,...] # only from the topmost layer 201 | 202 | hc = torch.cat([hn[0], hn[1], cn[0], cn[1]], -1) # concat all of 'em 203 | latent = self.hc_to_latent(hc) 204 | if self.variational: 205 | latent_mean = latent 206 | latent_logvar = self.hc_to_latent_logvar(hc) 207 | latent = self.reparam(latent, latent_logvar) 208 | 209 | KLD = -0.5 * torch.mean(1 + latent_logvar - latent.pow(2) - latent_logvar.exp()) 210 | #### encoder ends here #### 211 | 212 | h01, c01 = self.latent_to_h0_1(latent), self.latent_to_c0_1(latent) 213 | h02, c02 = self.latent_to_h0_2(latent), self.latent_to_c0_2(latent) 214 | h0 = self.tanh(torch.stack([h01, h02], 0)) 215 | c0 = self.tanh(torch.stack([c01, c02], 0)) 216 | 217 | if self.concatz: 218 | latent_c = latent.view(-1, 1, self.n_latent).repeat(1, input.shape[1], 1) 219 | input = torch.cat([input, latent_c], -1) 220 | state, _ = self.decoder(input, (h0, c0)) 221 | 222 | # out_ctrlpt = self.ctrlpt_arm(state) 223 | # out_ratw = self.ratw_arm(state) 224 | # out_start = self.start_arm(state) 225 | out_param_mu = self.param_mu_arm(state) 226 | out_param_std = torch.exp(self.param_std_arm(state)) 227 | out_param_mix = torch.softmax(self.param_mix_arm(state), -1) 228 | out_stopbit = torch.sigmoid(self.stopbit_arm(state)) 229 | 230 | if self.training: 231 | if not self.variational: 232 | return out_param_mu, out_param_std, out_param_mix, out_stopbit 233 | else: 234 | return out_param_mu, out_param_std, out_param_mix, out_stopbit, KLD 235 | else: 236 | 237 | if inference: 238 | L = input.shape[1] # just as a safety (see the for loop) 239 | input = torch.zeros(1, 1, self.n_params, device=input.device) 240 | stop = False 241 | 242 | out_ctrlpts, out_ratws, out_starts = [], [], [] 243 | for _ in range(L): 244 | if self.concatz: 245 | latent_c = latent.view(1, 1, self.n_latent) 246 | input = torch.cat([input, latent_c], -1) 247 | state, (h1, c1) = self.decoder(input, (h0, c0)) 248 | 249 | out_param_mu = self.param_mu_arm(state).squeeze() 250 | out_param_std = torch.exp(self.param_std_arm(state)).squeeze() 251 | out_param_mix = torch.softmax(self.param_mix_arm(state), -1).squeeze() 252 | out_stopbit = torch.sigmoid(self.stopbit_arm(state)).squeeze() 253 | 254 | # reshape to make the n_mix visible 255 | out_param_mu = out_param_mu.view(self.n_mixture, out_param_mu.shape[-1] // self.n_mixture) 256 | out_param_std = out_param_std.view(self.n_mixture, out_param_std.shape[-1] // self.n_mixture) 257 | 258 | mix_id = dist.Categorical(out_param_mix.squeeze()).sample() 259 | 260 | mu, std = out_param_mu[mix_id.item(), :], out_param_std[mix_id.item(), :] 261 | sample = dist.Normal(mu, std).sample() 262 | out_ctrlpts.append(sample[:self.n_ctrlpt]) 263 | if self.rational: 264 | out_ratws.append(sample[self.n_ctrlpt:self.n_ctrlpt+self.n_ratw]) 265 | out_starts.append(sample[self.n_ctrlpt+self.n_ratw:]) 266 | input = torch.cat([out_ctrlpts[-1], out_ratws[-1], out_starts[-1]], -1) 267 | else: 268 | out_starts.append(sample[self.n_ctrlpt:]) 269 | input = torch.cat([out_ctrlpts[-1], out_starts[-1]], -1) 270 | 271 | input = input.unsqueeze(0).unsqueeze(0) 272 | h0, c0 = h1, c1 273 | 274 | if out_stopbit.item() >= 0.99: 275 | break 276 | 277 | out_ctrlpts = torch.stack(out_ctrlpts, 0) 278 | if self.rational: 279 | out_ratws = torch.stack(out_ratws, 0) 280 | out_starts = torch.stack(out_starts, 0) 281 | 282 | if self.rational: 283 | return out_ctrlpts, out_ratws, out_starts 284 | else: 285 | return out_ctrlpts, out_starts 286 | 287 | if not self.variational: 288 | # as of now, teacher-frocing even in testing 289 | return out_param_mu, out_param_std, out_param_mix, out_stopbit 290 | else: 291 | return out_param_mu, out_param_std, out_param_mix, out_stopbit, KLD 292 | 293 | # def gmm_loss(mu, std, mix, n_mix, ctrlpt, ratw, start): 294 | # param = torch.cat([ctrlpt, ratw, start], -1) 295 | # mus = torch.split(mu, mu.shape[-1]//n_mix, -1) 296 | # stds = torch.split(std, std.shape[-1]//n_mix, -1) 297 | # mixs = torch.split(mix, mix.shape[-1]//n_mix, -1) 298 | # Ns = [dist.Normal(m, s) for m, s in zip(mus, stds)] 299 | # pdfs = [] 300 | # for N, pi in zip(Ns, mixs): 301 | # pdfs.append((N.log_prob(param).sum(-1).exp() + 1e-10) * pi.view(-1,)) 302 | # breakpoint() 303 | # return -sum(pdfs).log().mean() 304 | 305 | def gmm_loss(batch, mus, sigmas, logpi, reduce=True): # pylint: disable=too-many-arguments 306 | # TAKEN FROM: https://github.com/ctallec/world-models/blob/master/models/mdrnn.py 307 | ## NOT MY CODE 308 | 309 | """ Computes the gmm loss. 310 | Compute minus the log probability of batch under the GMM model described 311 | by mus, sigmas, pi. Precisely, with bs1, bs2, ... the sizes of the batch 312 | dimensions (several batch dimension are useful when you have both a batch 313 | axis and a time step axis), gs the number of mixtures and fs the number of 314 | features. 315 | :args batch: (bs1, bs2, *, fs) torch tensor 316 | :args mus: (bs1, bs2, *, gs, fs) torch tensor 317 | :args sigmas: (bs1, bs2, *, gs, fs) torch tensor 318 | :args logpi: (bs1, bs2, *, gs) torch tensor 319 | :args reduce: if not reduce, the mean in the following formula is ommited 320 | :returns: 321 | loss(batch) = - mean_{i1=0..bs1, i2=0..bs2, ...} log( 322 | sum_{k=1..gs} pi[i1, i2, ..., k] * N( 323 | batch[i1, i2, ..., :] | mus[i1, i2, ..., k, :], sigmas[i1, i2, ..., k, :])) 324 | NOTE: The loss is not reduced along the feature dimension (i.e. it should scale ~linearily 325 | with fs). 326 | """ 327 | batch = batch.unsqueeze(-2) 328 | normal_dist = dist.Normal(mus, sigmas) 329 | g_log_probs = normal_dist.log_prob(batch) 330 | g_log_probs = logpi + torch.sum(g_log_probs, dim=-1) 331 | max_log_probs = torch.max(g_log_probs, dim=-1, keepdim=True)[0] 332 | g_log_probs = g_log_probs - max_log_probs 333 | 334 | g_probs = torch.exp(g_log_probs) 335 | probs = torch.sum(g_probs, dim=-1) 336 | 337 | log_prob = max_log_probs.squeeze() + torch.log(probs) 338 | if reduce: 339 | return - torch.mean(log_prob) 340 | return - log_prob --------------------------------------------------------------------------------