├── exp ├── __init__.py ├── dog12.py ├── face48.py ├── face1.py ├── face12.py └── face24.py ├── rpglib ├── __init__.py ├── utils.py ├── disc0.py ├── real.py ├── gen.py ├── genx.py └── disc.py ├── .gitignore ├── imgs └── interp.jpg ├── filts ├── K12_dm12.npz ├── K24_dm12.npz ├── K48_dm12.npz ├── README.md └── rproj.py ├── fixbn.py ├── sample.py ├── interp.py ├── baseline_train.py ├── train.py └── README.md /exp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rpglib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | models 4 | data -------------------------------------------------------------------------------- /imgs/interp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayanc/rpgan/HEAD/imgs/interp.jpg -------------------------------------------------------------------------------- /filts/K12_dm12.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayanc/rpgan/HEAD/filts/K12_dm12.npz -------------------------------------------------------------------------------- /filts/K24_dm12.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayanc/rpgan/HEAD/filts/K24_dm12.npz -------------------------------------------------------------------------------- /filts/K48_dm12.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayanc/rpgan/HEAD/filts/K48_dm12.npz -------------------------------------------------------------------------------- /exp/dog12.py: -------------------------------------------------------------------------------- 1 | # Generator Parameters 2 | ksz=4 3 | zlen = 200 # Dimensionality of z 4 | f1 = 2048 # Features in first layer of Gen output 5 | 6 | # Discrimnator Parameters 7 | df = 64 # No. of hidden features (at first layer of D) 8 | 9 | # Training set 10 | imsz = 128 11 | bsz = 64 12 | lfile='data/dogs.txt' 13 | crop=True 14 | 15 | wts_dir='models/dog12' 16 | SAVEFREQ=1e3 17 | MAXITER=1e5 18 | -------------------------------------------------------------------------------- /exp/face48.py: -------------------------------------------------------------------------------- 1 | # Generator Parameters 2 | ksz=4 3 | zlen = 100 # Dimensionality of z 4 | f1 = 1024 # Features in first layer of Gen output 5 | # Discrimnator Parameters 6 | df = 128 # No. of hidden features (at first layer of D) 7 | 8 | # Training set 9 | imsz = 64 10 | bsz = 64 11 | lfile='data/faces.txt' 12 | crop=False 13 | 14 | # Learning parameters 15 | wts_dir='models/face48' 16 | SAVEFREQ=1e3 17 | MAXITER=1e5 18 | -------------------------------------------------------------------------------- /exp/face1.py: -------------------------------------------------------------------------------- 1 | # Generator Parameters 2 | ksz=4 3 | zlen = 100 # Dimensionality of z 4 | f1 = 1024 # Features in first layer of Gen output 5 | 6 | # Discrimnator Parameters 7 | df = 64 # No. of hidden features (at first layer of D) 8 | 9 | # Training set 10 | imsz = 64 11 | bsz = 64 12 | lfile='data/faces.txt' 13 | crop=False 14 | 15 | # Learning parameters 16 | wts_dir='models/face1' 17 | SAVEFREQ=1e3 18 | MAXITER=1e5 19 | -------------------------------------------------------------------------------- /exp/face12.py: -------------------------------------------------------------------------------- 1 | # Generator Parameters 2 | ksz=4 3 | zlen = 100 # Dimensionality of z 4 | f1 = 1024 # Features in first layer of Gen output 5 | 6 | # Discrimnator Parameters 7 | df = 128 # No. of hidden features (at first layer of D) 8 | 9 | # Training set 10 | imsz = 64 11 | bsz = 64 12 | lfile='data/faces.txt' 13 | crop=False 14 | 15 | # Learning parameters 16 | wts_dir='models/face12' 17 | SAVEFREQ=1e3 18 | MAXITER=1e5 19 | -------------------------------------------------------------------------------- /exp/face24.py: -------------------------------------------------------------------------------- 1 | # Generator Parameters 2 | ksz=4 3 | zlen = 100 # Dimensionality of z 4 | f1 = 1024 # Features in first layer of Gen output 5 | 6 | # Discrimnator Parameters 7 | df = 128 # No. of hidden features (at first layer of D) 8 | 9 | # Training set 10 | imsz = 64 11 | bsz = 64 12 | lfile='data/faces.txt' 13 | crop=False 14 | 15 | # Learning parameters 16 | wts_dir='models/face24' 17 | SAVEFREQ=1e3 18 | MAXITER=1e5 19 | -------------------------------------------------------------------------------- /filts/README.md: -------------------------------------------------------------------------------- 1 | # Generating Random Projection Filters 2 | 3 | This directory contains the different filter weight files used for training the models in the paper. The filter files are named `Knn_dm12.npz`, where nn denotes the number of projections / discriminators. Note that all filters downsample spatially by 2 and go from a 3 channel to a single channel image, through convolutions with 8x8 filters. 4 | 5 | These filters were generated with the `rproj.py` script provided here as: 6 | ```bash 7 | $ ./rproj.py K12_dm12.npz 8,2,1,12 8 | $ ./rproj.py K24_dm12.npz 8,2,1,24 9 | $ ./rproj.py K48_dm12.npz 8,2,1,48 10 | ``` 11 | The second parameter specifies kernel size, stride, number of output channels, and number of projections/discriminators. 12 | 13 | To use these weights during training, place them as `filts.npz` in the weights directory for an experiment. 14 | -------------------------------------------------------------------------------- /filts/rproj.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-- Ayan Chakrabarti 3 | 4 | from __future__ import print_function 5 | import numpy as np 6 | import sys 7 | 8 | if len(sys.argv) < 3: 9 | sys.exit("USAGE: outfile.npz [ksize,stride,numf,numk]_repeated") 10 | 11 | ofile = sys.argv[1] 12 | 13 | wts = {} 14 | nfilt = 0 15 | 16 | for i in range(len(sys.argv)-2): 17 | args = sys.argv[i+2].split(',') 18 | ksize = int(args[0]) 19 | stride = int(args[1]) 20 | numf = int(args[2]) 21 | numk = int(args[3]) 22 | 23 | for j in range(numk): 24 | pfi = np.random.normal(size=(ksize,ksize,3,numf)) 25 | pfi = np.float32(pfi) 26 | pfi = pfi / np.sqrt(np.sum(pfi*pfi,axis=(0,1,2))) 27 | wts['p%d' % nfilt] = pfi 28 | wts['s%d' % nfilt] = stride 29 | nfilt = nfilt + 1 30 | 31 | wts['nfilt'] = nfilt 32 | np.savez(ofile,**wts) 33 | print("Wrote %d discriminator projections to %s." % (nfilt,ofile)) 34 | -------------------------------------------------------------------------------- /rpglib/utils.py: -------------------------------------------------------------------------------- 1 | #-- Ayan Chakrabarti 2 | import re 3 | import os 4 | from glob import glob 5 | import numpy as np 6 | 7 | # Manage checkpoint files, read off iteration number from filename 8 | # Use clean() to keep latest, and modulo n iters, delete rest 9 | class ckpter: 10 | def __init__(self,wcard): 11 | self.wcard = wcard 12 | self.load() 13 | 14 | def load(self): 15 | lst = glob(self.wcard) 16 | if len(lst) > 0: 17 | lst=[(l,int(re.match('.*/.*_(\d+)',l).group(1))) 18 | for l in lst] 19 | self.lst=sorted(lst,key=lambda x: x[1]) 20 | 21 | self.iter = self.lst[-1][1] 22 | self.latest = self.lst[-1][0] 23 | else: 24 | self.lst=[] 25 | self.iter=0 26 | self.latest=None 27 | 28 | def clean(self,every=0,last=1): 29 | self.load() 30 | old = self.lst[:-last] 31 | for j in old: 32 | if every == 0 or j[1] % every != 0: 33 | os.remove(j[0]) 34 | 35 | ## Read weights 36 | def netload(net,fname,sess): 37 | wts = np.load(fname) 38 | for k in wts.keys(): 39 | wvar = net.weights[k] 40 | wk = wts[k].reshape(wvar.get_shape()) 41 | sess.run(wvar.assign(wk)) 42 | 43 | # Save weights to an npz file 44 | def netsave(net,fname,sess): 45 | wts = {} 46 | for k in net.weights.keys(): 47 | wts[k] = net.weights[k].eval(sess) 48 | np.savez(fname,**wts) 49 | -------------------------------------------------------------------------------- /fixbn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-- Ayan Chakrabarti 3 | 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import tensorflow as tf 9 | import numpy as np 10 | from skimage.io import imsave 11 | 12 | from rpglib import utils as ut 13 | from rpglib import genx as gen 14 | 15 | 16 | if len(sys.argv) < 2: 17 | sys.exit("USAGE: fixbn.py exp [iteration]") 18 | 19 | 20 | from importlib import import_module 21 | p = import_module("exp." + sys.argv[1]) 22 | p.bsz = 2048 # Run in CPU mode 23 | 24 | ######################################################################### 25 | if len(sys.argv) == 2: 26 | gsave = ut.ckpter(p.wts_dir + '/iter_*.gmodel.npz') 27 | mfile = gsave.latest 28 | if mfile is None: 29 | sys.exit("Could not find anything in " + p.wts_dir) 30 | niter = gsave.iter 31 | else: 32 | mfile = p.wts_dir + '/iter_' + sys.argv[2] + '.gmodel.npz' 33 | niter = int(sys.argv[2]) 34 | 35 | ofile = p.wts_dir + '/iter_' + ('%d' % niter) + '.bgmodel.npz' 36 | ######################################################################### 37 | 38 | # Set up Generator 39 | Z = tf.random_uniform([p.bsz,1,1,p.zlen],-1.0,1.0) 40 | G = gen.Gnet(p,Z,True) 41 | 42 | ######################################################################### 43 | # Start TF session (respecting OMP_NUM_THREADS) 44 | nthr = os.getenv('OMP_NUM_THREADS') 45 | if nthr is None: 46 | sess = tf.Session() 47 | else: 48 | sess = tf.Session(config=tf.ConfigProto( 49 | intra_op_parallelism_threads=int(nthr))) 50 | sess.run(tf.initialize_all_variables()) 51 | 52 | ######################################################################### 53 | 54 | print("Restoring G from " + mfile ) 55 | ut.netload(G,mfile,sess) 56 | print("Done!") 57 | 58 | ######################################################################### 59 | 60 | print("Running forward pass.") 61 | _=sess.run(G.bnops) 62 | print("Saving to %s."%ofile) 63 | ut.netsave(G,ofile,sess) 64 | print("Done!\n") 65 | -------------------------------------------------------------------------------- /rpglib/disc0.py: -------------------------------------------------------------------------------- 1 | #-- Ayan Chakrabarti 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | class Dnet: 6 | 7 | def __init__(self,params): 8 | 9 | self.weights = {} 10 | 11 | self.bsz = params.bsz 12 | self.f = params.df 13 | imsz = params.imsz 14 | 15 | self.numd = 1 16 | 17 | self.numc = 3 18 | self.numl = int(np.log2(imsz)) 19 | 20 | 21 | for j in range(self.numl): 22 | if j == 0: 23 | f = self.numc 24 | else: 25 | f = self.f*(2**(j-1)) 26 | 27 | if j < self.numl-1: 28 | ksz = 4 29 | f1 = self.f*(2**j) 30 | else: 31 | ksz = 2 32 | f1 = 1 33 | 34 | sq = np.sqrt(3.0 / np.float32(ksz*ksz*f)) 35 | w = tf.Variable(tf.random_uniform([ksz,ksz,f,f1],\ 36 | minval=-sq,maxval=sq,dtype=tf.float32)) 37 | b = tf.Variable(tf.constant(0,shape=[f1],dtype=tf.float32)) 38 | 39 | self.weights['c%d_w'%j] = w 40 | self.weights['c%d_b'%j] = b 41 | 42 | 43 | def dloss(self,im,floss): 44 | out = im 45 | 46 | for j in range(self.numl): 47 | w = self.weights['c%d_w'%j] 48 | b = self.weights['c%d_b'%j] 49 | 50 | if j < self.numl-1: 51 | strides = [1,2,2,1] 52 | out = tf.pad(out,[[0,0],[1,1],[1,1],[0,0]]) 53 | else: 54 | strides = [1,1,1,1] 55 | 56 | out = tf.nn.conv2d(out,w,strides,'VALID') 57 | out = out + b 58 | if j < self.numl-1: 59 | out = tf.maximum(0.2*out,out) 60 | 61 | if floss: 62 | loss2 = tf.reduce_mean(tf.nn.softplus(out)) 63 | else: 64 | loss2 = None 65 | loss = tf.reduce_mean(tf.nn.softplus(-out)) 66 | 67 | return loss, loss2 68 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-- Ayan Chakrabarti 3 | 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import tensorflow as tf 9 | import numpy as np 10 | from skimage.io import imsave 11 | 12 | from rpglib import utils as ut 13 | from rpglib import gen 14 | 15 | ######################################################################### 16 | if len(sys.argv) < 3: 17 | sys.exit("USAGE: sample.py exp[,seed] out.jpg [iteration]") 18 | 19 | arg1 = sys.argv[1].split(",") 20 | ename = arg1[0] 21 | if len(arg1) == 1: 22 | seed = 0 23 | else: 24 | seed = int(arg1[1]) 25 | 26 | 27 | from importlib import import_module 28 | p = import_module("exp." + ename) 29 | p.bsz = 150 30 | layout = [10,15] 31 | 32 | 33 | fname = sys.argv[2] 34 | 35 | if len(sys.argv) == 3: 36 | gsave = ut.ckpter(p.wts_dir + '/iter_*.gmodel.npz') 37 | mfile = gsave.latest 38 | else: 39 | mfile = p.wts_dir + '/iter_' + sys.argv[3] + '.gmodel.npz' 40 | 41 | ######################################################################### 42 | 43 | # Initialize loader, generator, discriminator 44 | 45 | Z = tf.placeholder(shape=[p.bsz,1,1,p.zlen],dtype=tf.float32) 46 | 47 | G = gen.Gnet(p,Z) 48 | img = G.out 49 | 50 | ######################################################################### 51 | # Start TF session (respecting OMP_NUM_THREADS) 52 | nthr = os.getenv('OMP_NUM_THREADS') 53 | if nthr is None: 54 | sess = tf.Session() 55 | else: 56 | sess = tf.Session(config=tf.ConfigProto( 57 | intra_op_parallelism_threads=int(nthr))) 58 | sess.run(tf.initialize_all_variables()) 59 | 60 | ######################################################################### 61 | 62 | print("Restoring G from " + mfile ) 63 | ut.netload(G,mfile,sess) 64 | print("Done!") 65 | 66 | ######################################################################### 67 | 68 | print("Generating " + fname) 69 | zval = np.float32(np.random.RandomState(seed).rand(p.bsz,1,1,p.zlen)*2.0-1.0) 70 | imval = sess.run(G.out,feed_dict={Z: zval}) 71 | imval = np.uint8( (imval*0.5+0.5)*255.0) 72 | 73 | imval = imval.reshape(layout + [p.imsz,p.imsz,3]) 74 | imval = imval.transpose([0,2,1,3,4]).copy() 75 | imval = imval.reshape([layout[0]*p.imsz,layout[1]*p.imsz,3]) 76 | imsave(fname, imval) 77 | -------------------------------------------------------------------------------- /rpglib/real.py: -------------------------------------------------------------------------------- 1 | # Ayan Chakrabarti 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | class Real: 6 | def graph(self,bsz,imsz,crop=False): 7 | self.names = [] 8 | # Create placeholders 9 | for i in range(bsz): 10 | self.names.append(tf.placeholder(tf.string)) 11 | 12 | batch = [] 13 | for i in range(bsz): 14 | # Load image 15 | img = tf.read_file(self.names[i]) 16 | code = tf.decode_raw(img,tf.uint8)[0] 17 | img = tf.cond(tf.equal(code,137), 18 | lambda: tf.image.decode_png(img,channels=3), 19 | lambda: tf.image.decode_jpeg(img,channels=3)) 20 | 21 | if crop: 22 | in_s = tf.to_float(tf.shape(img)[:2]) 23 | min_s = tf.minimum(in_s[0],in_s[1]) 24 | new_s = tf.to_int32((float(imsz+1)/min_s)*in_s) 25 | img = tf.image.resize_images(img,new_s[0],new_s[1]) 26 | img = tf.random_crop(img,[imsz,imsz,3]) 27 | 28 | batch.append(tf.expand_dims(img,0)) 29 | 30 | batch = tf.to_float(tf.concat(0,batch))*(2.0/255.0) - 1.0 31 | 32 | # Fetching logic 33 | self.batch = tf.Variable(tf.zeros([bsz,imsz,imsz,3],dtype=tf.float32),trainable=False) 34 | self.fetchOp = tf.assign(self.batch,batch).op 35 | 36 | def fdict(self): 37 | fd = {} 38 | 39 | for i in range(len(self.names)): 40 | idx = self.idx[self.niter % self.ndata] 41 | self.niter = self.niter + 1 42 | if self.niter % self.ndata == 0: 43 | self.idx = np.int32(self.rand.permutation(self.ndata)) 44 | 45 | fd[self.names[i]] = self.files[idx] 46 | return fd 47 | 48 | def __init__(self,lfile,bsz,imsz,niter,crop=False): 49 | 50 | # Setup fetch graph 51 | self.graph(bsz,imsz,crop) 52 | 53 | # Load file list 54 | self.files = [] 55 | for line in open(lfile).readlines(): 56 | self.files.append(line.strip()) 57 | self.ndata = len(self.files) 58 | 59 | # Setup shuffling 60 | self.niter = niter*bsz 61 | self.rand = np.random.RandomState(0) 62 | idx = self.rand.permutation(self.ndata) 63 | for i in range(niter // self.ndata): 64 | idx = self.rand.permutation(ndata) 65 | self.idx = np.int32(idx) 66 | -------------------------------------------------------------------------------- /rpglib/gen.py: -------------------------------------------------------------------------------- 1 | #-- Ayan Chakrabarti 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | class Gnet: 6 | 7 | def __init__(self,params,z): 8 | 9 | self.weights = {} 10 | 11 | f = params.zlen 12 | f1 = params.f1 13 | ksz = params.ksz 14 | bsz = params.bsz 15 | 16 | sz = ksz-ksz%2 17 | lnum = 1 18 | 19 | 20 | ####### First block is FC 21 | 22 | # Initialize 23 | sq = np.sqrt(3.0 / np.float32(f)) 24 | w = tf.Variable(tf.random_uniform([1,1,f,f1*sz*sz],\ 25 | minval=-sq,maxval=sq,dtype=tf.float32)) 26 | b = tf.Variable(tf.constant(0,shape=[f1],dtype=tf.float32)) 27 | self.weights['c%d_w'%lnum] = w 28 | self.weights['c%d_b'%lnum] = b 29 | 30 | out = tf.nn.conv2d(z,w,[1,1,1,1],'VALID') 31 | out = tf.reshape(out,[-1,sz,sz,f1]) 32 | 33 | om,ov = tf.nn.moments(out,[0,1,2]) 34 | out = tf.nn.batch_normalization(out,om,ov,None,None,1e-3) 35 | out = out + b 36 | #out = tf.maximum(0.2*out,out) 37 | out = tf.nn.relu(out) 38 | 39 | lnum = lnum+1 40 | sz = sz*2 41 | f = f1 42 | f1 = f1//2 43 | 44 | # Subsequent blocks are deconv 45 | while sz <= params.imsz: 46 | # Initialize 47 | sq = np.sqrt(3.0 / np.float32(ksz*ksz*f)) 48 | w = tf.Variable(tf.random_uniform([ksz,ksz,f1,f],\ 49 | minval=-sq,maxval=sq,dtype=tf.float32)) 50 | b = tf.Variable(tf.constant(0,shape=[f1],dtype=tf.float32)) 51 | 52 | self.weights['c%d_w'%lnum] = w 53 | self.weights['c%d_b'%lnum] = b 54 | lnum = lnum+1 55 | 56 | out = tf.nn.conv2d_transpose(out,w,[bsz,sz,sz,f1],[1,2,2,1],'SAME') 57 | if sz < params.imsz: 58 | om,ov = tf.nn.moments(out,[0,1,2]) 59 | out = tf.nn.batch_normalization(out,om,ov,None,None,1e-3) 60 | 61 | out = out + b 62 | 63 | if sz == params.imsz: 64 | out = tf.nn.tanh(out) 65 | else: 66 | #out = tf.maximum(0.2*out,out) 67 | out = tf.nn.relu(out) 68 | 69 | sz = sz*2 70 | f = f1 71 | if sz == params.imsz: 72 | f1 = 3 73 | else: 74 | f1 = f1 //2 75 | 76 | self.out = out 77 | 78 | 79 | -------------------------------------------------------------------------------- /interp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-- Ayan Chakrabarti 3 | 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import tensorflow as tf 9 | import numpy as np 10 | from skimage.io import imsave 11 | 12 | from rpglib import utils as ut 13 | from rpglib import genx as gen 14 | 15 | OLEN=150 16 | N=11 17 | 18 | ######################################################################### 19 | if len(sys.argv) < 4: 20 | sys.exit("USAGE: interp.py exp[,seed[,iteration]] out.png lid,rid lid,rid lid,rid ") 21 | 22 | arg1 = sys.argv[1].split(",") 23 | ename = arg1[0] 24 | if len(arg1) == 1: 25 | seed = 0 26 | else: 27 | seed = int(arg1[1]) 28 | if len(arg1) < 3: 29 | niter = None 30 | else: 31 | niter = arg1[2] 32 | 33 | from importlib import import_module 34 | p = import_module("exp." + ename) 35 | 36 | fname = sys.argv[2] 37 | 38 | npair=len(sys.argv)-3 39 | p.bsz = N*npair 40 | layout = [npair,N] 41 | 42 | lid = [] 43 | rid = [] 44 | for i in range(npair): 45 | ri,li = [int(x)-1 for x in sys.argv[i+3].split(',')] 46 | lid.append(li) 47 | rid.append(ri) 48 | 49 | zval = np.float32(np.random.RandomState(seed).rand(OLEN,1,1,p.zlen)*2.0-1.0) 50 | zleft = zval[lid,...] 51 | zright = zval[rid,...] 52 | sm = np.float32(np.linspace(0.0,1.0,N).reshape([1,N,1,1])) 53 | zval = zleft*sm + zright*(1.0-sm) 54 | zval = zval.reshape([p.bsz,1,1,p.zlen]) 55 | 56 | if niter is None: 57 | gsave = ut.ckpter(p.wts_dir + '/iter_*.bgmodel.npz') 58 | mfile = gsave.latest 59 | else: 60 | mfile = p.wts_dir + '/iter_' + niter + '.bgmodel.npz' 61 | 62 | ######################################################################### 63 | 64 | # Initialize loader, generator, discriminator 65 | 66 | Z = tf.placeholder(shape=[p.bsz,1,1,p.zlen],dtype=tf.float32) 67 | 68 | G = gen.Gnet(p,Z) 69 | img = G.out 70 | 71 | ######################################################################### 72 | # Start TF session (respecting OMP_NUM_THREADS) 73 | nthr = os.getenv('OMP_NUM_THREADS') 74 | if nthr is None: 75 | sess = tf.Session() 76 | else: 77 | sess = tf.Session(config=tf.ConfigProto( 78 | intra_op_parallelism_threads=int(nthr))) 79 | sess.run(tf.initialize_all_variables()) 80 | 81 | ######################################################################### 82 | 83 | print("Restoring G from " + mfile ) 84 | ut.netload(G,mfile,sess) 85 | print("Done!") 86 | 87 | ######################################################################### 88 | 89 | print("Generating " + fname) 90 | imval = sess.run(G.out,feed_dict={Z: zval}) 91 | imval = np.uint8( (imval*0.5+0.5)*255.0) 92 | 93 | imval = imval.reshape(layout + [p.imsz,p.imsz,3]) 94 | imval = imval.transpose([0,2,1,3,4]).copy() 95 | imval = imval.reshape([layout[0]*p.imsz,layout[1]*p.imsz,3]) 96 | imsave(fname, imval) 97 | -------------------------------------------------------------------------------- /rpglib/genx.py: -------------------------------------------------------------------------------- 1 | #-- Ayan Chakrabarti 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | class Gnet: 6 | 7 | def __init__(self,params,z,setbn=False): 8 | self.params = params 9 | self.weights = {} 10 | self.bnops = [] 11 | 12 | params = self.params 13 | 14 | f = params.zlen 15 | f1 = params.f1 16 | ksz = params.ksz 17 | bsz = params.bsz 18 | 19 | sz = ksz-ksz%2 20 | lnum = 1 21 | 22 | ####### First block is FC 23 | 24 | # Initialize 25 | sq = np.sqrt(3.0 / np.float32(f)) 26 | w = tf.Variable(tf.random_uniform([1,1,f,f1*sz*sz],\ 27 | minval=-sq,maxval=sq,dtype=tf.float32)) 28 | b = tf.Variable(tf.constant(0,shape=[f1],dtype=tf.float32)) 29 | self.weights['c%d_w'%lnum] = w 30 | self.weights['c%d_b'%lnum] = b 31 | 32 | out = tf.nn.conv2d(z,w,[1,1,1,1],'VALID') 33 | out = tf.reshape(out,[-1,sz,sz,f1]) 34 | 35 | om_ = tf.Variable(tf.zeros([f1])) 36 | ov_ = tf.Variable(tf.zeros([f1])) 37 | self.weights['bnm_c%d_w'%lnum] = om_ 38 | self.weights['bnv_c%d_w'%lnum] = ov_ 39 | 40 | if setbn: 41 | om,ov = tf.nn.moments(out,[0,1,2]) 42 | self.bnops.append(tf.assign(om_,om).op) 43 | self.bnops.append(tf.assign(ov_,ov).op) 44 | else: 45 | om = om_ 46 | ov = ov_ 47 | 48 | out = tf.nn.batch_normalization(out,om,ov,None,None,1e-3) 49 | out = out + b 50 | #out = tf.maximum(0.2*out,out) 51 | out = tf.nn.relu(out) 52 | 53 | lnum = lnum+1 54 | sz = sz*2 55 | f = f1 56 | f1 = f1//2 57 | 58 | # Subsequent blocks are deconv 59 | while sz <= params.imsz: 60 | # Initialize 61 | sq = np.sqrt(3.0 / np.float32(ksz*ksz*f)) 62 | w = tf.Variable(tf.random_uniform([ksz,ksz,f1,f],\ 63 | minval=-sq,maxval=sq,dtype=tf.float32)) 64 | b = tf.Variable(tf.constant(0,shape=[f1],dtype=tf.float32)) 65 | self.weights['c%d_w'%lnum] = w 66 | self.weights['c%d_b'%lnum] = b 67 | 68 | out = tf.nn.conv2d_transpose(out,w,[bsz,sz,sz,f1],[1,2,2,1],'SAME') 69 | if sz < params.imsz: 70 | om_ = tf.Variable(tf.zeros([f1])) 71 | ov_ = tf.Variable(tf.zeros([f1])) 72 | self.weights['bnm_c%d_w'%lnum] = om_ 73 | self.weights['bnv_c%d_w'%lnum] = ov_ 74 | 75 | if setbn: 76 | om,ov = tf.nn.moments(out,[0,1,2]) 77 | self.bnops.append(tf.assign(om_,om).op) 78 | self.bnops.append(tf.assign(ov_,ov).op) 79 | else: 80 | om = om_ 81 | ov = ov_ 82 | 83 | out = tf.nn.batch_normalization(out,om,ov,None,None,1e-3) 84 | 85 | out = out + b 86 | 87 | if sz == params.imsz: 88 | out = tf.nn.tanh(out) 89 | else: 90 | #out = tf.maximum(0.2*out,out) 91 | out = tf.nn.relu(out) 92 | 93 | lnum = lnum+1 94 | sz = sz*2 95 | f = f1 96 | if sz == params.imsz: 97 | f1 = 3 98 | else: 99 | f1 = f1 //2 100 | 101 | self.out = out 102 | 103 | -------------------------------------------------------------------------------- /rpglib/disc.py: -------------------------------------------------------------------------------- 1 | #-- Ayan Chakrabarti 2 | import numpy as np 3 | import tensorflow as tf 4 | import sys 5 | 6 | class Dnet: 7 | 8 | def __init__(self,params): 9 | 10 | self.weights = {} 11 | 12 | self.bsz = params.bsz 13 | self.f = params.df 14 | imsz = params.imsz 15 | 16 | f = np.load(params.wts_dir + '/filts.npz') 17 | self.numd = int(f['nfilt']) 18 | 19 | # Expect all filters to be same stride/numc 20 | fi = f['p0'] 21 | self.numc = fi.shape[-1] 22 | stride = int(f['s0']) 23 | self.stride = [1,stride,stride,1] 24 | wsz = imsz // stride 25 | self.numl = int(np.log2(wsz)) 26 | 27 | f0 = fi.shape[0] 28 | pad = (f0-stride)//2 29 | if pad == 0: 30 | self.pad = None 31 | else: 32 | self.pad = [[0,0],[pad,pad],[pad,pad],[0,0]] 33 | 34 | self.v0 = [] 35 | self.vk = [] 36 | self.sOps = [] 37 | self.filt = tf.Variable(tf.zeros(fi.shape,dtype=tf.float32)) 38 | 39 | for i in range(self.numd): 40 | fi = f['p%d' % i] 41 | if int(f['s%d' % i]) != stride: 42 | sys.exit("Expect all filters to be same stride.") 43 | if fi.shape[-1] != self.numc: 44 | sys.exit("Expect all filters to have same number of channels.") 45 | if fi.shape[0] != f0: 46 | sys.exit("Expect all filters to be same size.") 47 | 48 | self.sOps.append([self.filt.assign(tf.constant(fi))]) 49 | self.vk.append([]) 50 | 51 | 52 | for j in range(self.numl): 53 | if j == 0: 54 | f = self.numc 55 | else: 56 | f = self.f*(2**(j-1)) 57 | 58 | if j < self.numl-1: 59 | ksz = 4 60 | f1 = self.f*(2**j) 61 | else: 62 | ksz = 2 63 | f1 = 1 64 | 65 | sq = np.sqrt(3.0 / np.float32(ksz*ksz*f)) 66 | 67 | 68 | w0 = tf.Variable(tf.random_uniform([ksz,ksz,f,f1],\ 69 | minval=-sq,maxval=sq,dtype=tf.float32)) 70 | self.v0.append(w0) 71 | b0 = tf.Variable(tf.constant(0,shape=[f1],dtype=tf.float32)) 72 | self.v0.append(b0) 73 | 74 | for i in range(self.numd): 75 | w = tf.Variable(tf.random_uniform([ksz,ksz,f,f1],\ 76 | minval=-sq,maxval=sq,dtype=tf.float32)) 77 | self.vk[i].append(w) 78 | self.sOps[i].append(w0.assign(tf.identity(w))) 79 | 80 | b = tf.Variable(tf.constant(0,shape=[f1],dtype=tf.float32)) 81 | self.vk[i].append(b) 82 | self.sOps[i].append(b0.assign(tf.identity(b))) 83 | 84 | self.weights['c%d_%d_w'%(i,j)] = w 85 | self.weights['c%d_%d_b'%(i,j)] = b 86 | 87 | 88 | def dloss(self,im,floss): 89 | 90 | loss2=None 91 | if self.pad is None: 92 | out = im 93 | else: 94 | out = tf.pad(im,self.pad) 95 | 96 | out = tf.nn.conv2d(out,self.filt,self.stride,'VALID') 97 | 98 | idx = 0 99 | for j in range(self.numl): 100 | 101 | if j < self.numl-1: 102 | strides = [1,2,2,1] 103 | out = tf.pad(out,[[0,0],[1,1],[1,1],[0,0]]) 104 | else: 105 | strides = [1,1,1,1] 106 | 107 | out = tf.nn.conv2d(out,self.v0[idx],strides,'VALID') 108 | out = out + self.v0[idx+1] 109 | idx = idx + 2 110 | if j < self.numl-1: 111 | out = tf.maximum(0.2*out,out) 112 | 113 | if floss: 114 | loss2 = tf.reduce_mean(tf.nn.softplus(out)) 115 | loss = tf.reduce_mean(tf.nn.softplus(-out)) 116 | 117 | return loss,loss2 118 | -------------------------------------------------------------------------------- /baseline_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-- Ayan Chakrabarti 3 | 4 | import sys 5 | import os 6 | import time 7 | import tensorflow as tf 8 | import numpy as np 9 | 10 | from rpglib import utils as ut 11 | from rpglib import real 12 | 13 | from rpglib import gen 14 | from rpglib import disc0 as disc 15 | 16 | if len(sys.argv) < 2: 17 | sys.exit("USAGE: baseline_train.py exp") 18 | 19 | from importlib import import_module 20 | p = import_module("exp." + sys.argv[1]) 21 | 22 | def mprint(s): 23 | sys.stdout.write(time.strftime("%Y-%m-%d %H:%M:%S ") + s + "\n") 24 | sys.stdout.flush() 25 | 26 | 27 | ######################################################################### 28 | 29 | # Check for saved weights & find iter 30 | dsave = ut.ckpter(p.wts_dir + '/iter_*.dmodel.npz') 31 | gsave = ut.ckpter(p.wts_dir + '/iter_*.gmodel.npz') 32 | 33 | niter = gsave.iter 34 | 35 | ######################################################################### 36 | 37 | # Initialize loader, generator, discriminator 38 | 39 | imgs = real.Real(p.lfile,p.bsz,p.imsz,niter,p.crop) 40 | Z = tf.random_uniform([p.bsz,1,1,p.zlen],-1.0,1.0) 41 | 42 | G = gen.Gnet(p,Z) 43 | D = disc.Dnet(p) 44 | 45 | or2r,_ = D.dloss(imgs.batch,False) 46 | of2r,of2f = D.dloss(G.out,True) 47 | dloss = (or2r+of2f) / 2.0 48 | gloss = of2r 49 | 50 | ######################################################################### 51 | 52 | # Set up optimizer steps 53 | 54 | # For D 55 | opt = tf.train.AdamOptimizer(2e-4,0.5) 56 | dstep = opt.minimize(dloss,var_list=[D.weights[k] for k in D.weights.keys()]) 57 | 58 | # For G 59 | opt = tf.train.AdamOptimizer(2e-4,0.5) 60 | gstep = opt.minimize(gloss,var_list=[G.weights[k] for k in G.weights.keys()]) 61 | 62 | 63 | ######################################################################### 64 | # Start TF session (respecting OMP_NUM_THREADS) 65 | nthr = os.getenv('OMP_NUM_THREADS') 66 | if nthr is None: 67 | sess = tf.Session() 68 | else: 69 | sess = tf.Session(config=tf.ConfigProto( 70 | intra_op_parallelism_threads=int(nthr))) 71 | sess.run(tf.initialize_all_variables()) 72 | 73 | ######################################################################### 74 | 75 | # Load saved weights if any 76 | if dsave.latest is not None: 77 | mprint("Restoring D from " + dsave.latest ) 78 | ut.netload(D,dsave.latest,sess) 79 | mprint("Done!") 80 | 81 | if gsave.latest is not None: 82 | mprint("Restoring G from " + gsave.latest ) 83 | ut.netload(G,gsave.latest,sess) 84 | mprint("Done!") 85 | 86 | ######################################################################### 87 | 88 | # Main Training loop 89 | 90 | stop=False 91 | mprint("Starting from Iteration %d" % niter) 92 | try: 93 | while niter < p.MAXITER and not stop: 94 | 95 | # Run gstep and fetch images 96 | f2rv = sess.run([gloss,gstep,imgs.fetchOp],feed_dict=imgs.fdict()) 97 | glv = f2rv[0] 98 | # Run dstep 99 | dlv,_ = sess.run([dloss,dstep]) 100 | mprint("[%09d] Adam Loss: G=%.6f,D=%.6f" 101 | % (niter,glv,dlv)) 102 | 103 | niter=niter+1 104 | 105 | ## Save model weights if needed 106 | if p.SAVEFREQ > 0 and niter % p.SAVEFREQ == 0: 107 | dname = p.wts_dir + "/iter_%d.dmodel.npz" % niter 108 | gname = p.wts_dir + "/iter_%d.gmodel.npz" % niter 109 | 110 | ut.netsave(G,gname,sess) 111 | gsave.clean(every=p.SAVEFREQ,last=1) 112 | mprint("Saved G weights to " + gname ) 113 | 114 | 115 | except KeyboardInterrupt: # Catch ctrl+c/SIGINT 116 | mprint("Stopped!") 117 | stop = True 118 | pass 119 | 120 | # Save last 121 | if gsave.iter < niter: 122 | dname = p.wts_dir + "/iter_%d.dmodel.npz" % niter 123 | gname = p.wts_dir + "/iter_%d.gmodel.npz" % niter 124 | 125 | ut.netsave(D,dname,sess) 126 | dsave.clean(every=p.SAVEFREQ,last=1) 127 | mprint("Saved D weights to " + dname ) 128 | 129 | ut.netsave(G,gname,sess) 130 | gsave.clean(every=p.SAVEFREQ,last=1) 131 | mprint("Saved G weights to " + gname ) 132 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-- Ayan Chakrabarti 3 | 4 | import sys 5 | import os 6 | import time 7 | import tensorflow as tf 8 | import numpy as np 9 | 10 | if len(sys.argv) < 2: 11 | sys.exit("USAGE: train.py exp") 12 | 13 | 14 | from importlib import import_module 15 | p = import_module("exp." + sys.argv[1]) 16 | 17 | from rpglib import utils as ut 18 | from rpglib import real 19 | from rpglib import gen 20 | from rpglib import disc 21 | 22 | def mprint(s): 23 | sys.stdout.write(time.strftime("%Y-%m-%d %H:%M:%S ") + s + "\n") 24 | sys.stdout.flush() 25 | 26 | 27 | ######################################################################### 28 | 29 | # Check for saved weights & find iter 30 | dsave = ut.ckpter(p.wts_dir + '/iter_*.dmodel.npz') 31 | gsave = ut.ckpter(p.wts_dir + '/iter_*.gmodel.npz') 32 | 33 | niter = gsave.iter 34 | 35 | ######################################################################### 36 | 37 | # Initialize loader, generator, discriminator 38 | 39 | # Real 40 | imgs = real.Real(p.lfile,p.bsz,p.imsz,niter,p.crop) 41 | 42 | # Noise 43 | Z = tf.Variable(tf.random_uniform([p.bsz,1,1,p.zlen],-1.0,1.0)) 44 | G = gen.Gnet(p,Z) 45 | Gz = tf.Variable(tf.zeros([p.bsz,p.imsz,p.imsz,3],dtype=tf.float32)) 46 | gfwd = Gz.assign(G.out) 47 | 48 | # Discriminator 49 | D = disc.Dnet(p) 50 | 51 | or2r,_ = D.dloss(imgs.batch,False) 52 | of2r,of2f = D.dloss(Gz,True) 53 | dloss = (or2r+of2f) / 2.0 54 | gloss = of2r / float(D.numd) 55 | 56 | ######################################################################### 57 | 58 | # Set up optimizer steps 59 | 60 | # For D 61 | opt0 = tf.train.GradientDescentOptimizer(1.0) 62 | gv = opt0.compute_gradients(dloss,D.v0) 63 | dsteps = [] 64 | for i in range(D.numd): 65 | opt = tf.train.AdamOptimizer(2e-4,0.5) 66 | gvi = [(gv[j][0],D.vk[i][j]) for j in range(len(gv))] 67 | dsteps.append(opt.apply_gradients(gvi)) 68 | 69 | # For G 70 | GzGrad = tf.Variable(tf.zeros([p.bsz,p.imsz,p.imsz,3],dtype=tf.float32)) 71 | gstep0 = GzGrad.initializer 72 | 73 | opt0 = tf.train.GradientDescentOptimizer(1.0) 74 | gv = opt0.compute_gradients(gloss,[Gz]) 75 | gstepi = GzGrad.assign_add(gv[0][0]) 76 | 77 | opt = tf.train.AdamOptimizer(2e-4,0.5) 78 | gstepF = opt.minimize(tf.reduce_sum(GzGrad*G.out),\ 79 | var_list=[G.weights[k] for k in G.weights.keys()]) 80 | 81 | 82 | ######################################################################### 83 | # Start TF session (respecting OMP_NUM_THREADS) 84 | nthr = os.getenv('OMP_NUM_THREADS') 85 | if nthr is None: 86 | sess = tf.Session() 87 | else: 88 | sess = tf.Session(config=tf.ConfigProto( 89 | intra_op_parallelism_threads=int(nthr))) 90 | sess.run(tf.initialize_all_variables()) 91 | 92 | ######################################################################### 93 | 94 | # Load saved weights if any 95 | if dsave.latest is not None: 96 | mprint("Restoring D from " + dsave.latest ) 97 | ut.netload(D,dsave.latest,sess) 98 | mprint("Done!") 99 | 100 | if gsave.latest is not None: 101 | mprint("Restoring G from " + gsave.latest ) 102 | ut.netload(G,gsave.latest,sess) 103 | mprint("Done!") 104 | 105 | ######################################################################### 106 | 107 | # Main Training loop 108 | 109 | stop=False 110 | mprint("Starting from Iteration %d" % niter) 111 | try: 112 | while niter < p.MAXITER and not stop: 113 | 114 | # Run GStep 115 | sess.run(Z.initializer) 116 | sess.run([gfwd,gstep0]) 117 | 118 | gl = 0. 119 | gli = [] 120 | for i in range(D.numd): 121 | sess.run(D.sOps[i]) 122 | glv,_ = sess.run([gloss,gstepi]) 123 | gl = gl + glv 124 | gli.append(glv*float(D.numd)) 125 | sess.run(gstepF) 126 | 127 | # Run DStep 128 | sess.run(Z.initializer) 129 | sess.run([gfwd,imgs.fetchOp],feed_dict=imgs.fdict()) 130 | dl = 0. 131 | for i in range(D.numd): 132 | sess.run(D.sOps[i]) 133 | dlv,_ = sess.run([dloss,dsteps[i]]) 134 | dl = dl+dlv 135 | dl = dl/float(D.numd) 136 | 137 | mprint("[%09d] Adam Loss: G=%.6f,D=%.6f" 138 | % (niter,gl,dl)) 139 | 140 | # Display all outputs 141 | ostr = '[%09d]* ' % niter 142 | for j in range(D.numd): 143 | ostr = ostr + ("L%02d=%.3f," % (j,gli[j])) 144 | mprint(ostr[:-1]) 145 | 146 | niter=niter+1 147 | 148 | ## Save model weights if needed 149 | if p.SAVEFREQ > 0 and niter % p.SAVEFREQ == 0: 150 | gname = p.wts_dir + "/iter_%d.gmodel.npz" % niter 151 | 152 | ut.netsave(G,gname,sess) 153 | gsave.clean(every=p.SAVEFREQ,last=1) 154 | mprint("Saved G weights to " + gname ) 155 | 156 | 157 | except KeyboardInterrupt: # Catch ctrl+c/SIGINT 158 | mprint("Stopped!") 159 | stop = True 160 | pass 161 | 162 | # Save last 163 | dname = p.wts_dir + "/iter_%d.dmodel.npz" % niter 164 | gname = p.wts_dir + "/iter_%d.gmodel.npz" % niter 165 | 166 | if gsave.iter < niter: 167 | ut.netsave(G,gname,sess) 168 | gsave.clean(every=p.SAVEFREQ,last=1) 169 | mprint("Saved G weights to " + gname ) 170 | 171 | ut.netsave(D,dname,sess) 172 | dsave.clean(every=p.SAVEFREQ,last=1) 173 | mprint("Saved D weights to " + dname ) 174 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RP-GAN: Stable GAN Training with Random Projections 2 | ![Interpolated images from our GAN](imgs/interp.jpg) 3 | 4 | This repository contains a reference implementation of the algorithm described in the paper: 5 | 6 | > Behnam Neyshabur, Srinadh Bhojanapalli, and Ayan Chakrabarti, "[**Stabilizing GAN Training with Multiple Random Projections**](http://arxiv.org/abs/1705.07831)," arXiv:1705.07831 [cs.LG], 2017. 7 | 8 | Pre-trained generator models are not included in the repository due to their size, but are available as binary downloads as part of the [release](https://github.com/ayanc/rpgan/releases). This code and data is being released for research use. If you use the code in research that results in a publication, we request that you kindly cite the above paper. Please direct any questions to . 9 | 10 | ##### Requirements 11 | 12 | The code uses the [tensorflow](https://www.tensorflow.org/) library, and has been tested with versions `0.9` and `0.11` with both Python2 and Python3. You will need a modern GPU for training in a reasonable amount of time, but the sampling code should work on a CPU. 13 | 14 | ## Sampling with Trained Models 15 | 16 | We first describe usage of scripts for sampling from trained models. You can use these scripts for models you train yourself, or use the provided pre-trained models. 17 | 18 | ##### Pre-trained Models 19 | We provide a number of pre-trained models in the [release](https://github.com/ayanc/rpgan/releases), corresponding to the experiments in the paper. The parameters of each model (both for training and sampling) are described in `.py` files the `exp/` directory. `face1.py` describes a face image model trained in the traditional setting with a single discriminator, while `faceNN.py` are models trained with multiple discriminators each acting on one of *NN* random low-dimensional projections. `face48.py` describes the main face model used in our experiments, while `dog12.py` is the model trained with 12 discriminators on the Imagenet-Canines set. After downloading the trained model archive files, unzip them in the repository root directory. This should create files in sub-directories of `models/`. 20 | 21 | ##### Generating Samples 22 | Use `sample.py` to generate samples using any of trained models as: 23 | ```bash 24 | $ ./sample.py expName[,seed] out.png [iteration] 25 | ``` 26 | where `expName` is the name of the experiment file (without the `.py` extension), and `out.png` is the file to save the generated samples to. The script accepts optional parameters: `seed` (default 0) specifies the random seed used to generate the noise vectors provided to the generator, and `iteration` (default: max iteration available as saved file) specifies which model file to use in case multiple snapshots are available. E.g., 27 | 28 | ```bash 29 | $ ./sample.py face48 out.png # Sample from the face48 experiment, using 30 | # seed 0, and the latest model file. 31 | $ ./sample.py face48,100 out.png # Sample from the face48 experiment, using 32 | # seed 100, and the latest model file. 33 | $ ./sample.py face1 out.png # Sample from the single discriminator face 34 | # experiment, and the latest model file. 35 | $ ./sample.py face1 out.png 40000 # Sample from the single discriminator face 36 | # experiment, and the 40k iterations model. 37 | ``` 38 | 39 | ##### Interpolating in Latent Space 40 | 41 | We also provide a script to produce interpolated images like the ones at the top of this page. However, before you can use this script, you need to create a version of the model file that contains the population mean-variance statistics of the activations to be used in batch-norm la(`sample.py` above uses batch norm statistics which is fine since it is working with a large batch of noise vectors. However, for interpolation, you will typically be working with smaller, more correlated, batches, and therefore should use batch statistics). 42 | 43 | To create this version of the model file, use the provided script `fixbn.py` as: 44 | ```bash 45 | $ CUDA_VISIBLE_DEVICES= ./fixbn.py expName [iteration] 46 | ``` 47 | This will create a second version of the model weights file (with extension `.bgmodel.npz` instead of `.gmodel.npz`) that also stores the batch statistics. Like for `sample.py`, you can provide a second optional argument to specify a specific model snapshot corresponding to an iteration number. 48 | 49 | Note that we call the script with `CUDA_VISIBLE_DEVICES=` to force **tensorflow** to use the CPU instead of the GPU. This is because we compute these stats over a relatively large batch which typically doesn't fit in GPU memory (and since it's only one forward pass, running time isn't really an issue). 50 | 51 | You only need to call `fixbn.py` once, and after that, you can use the script `interp.py` to create interpolated samples. The script will generate multiple rows of images, each producing samples from noise vectors interpolated between a pair from left-to-right. The script lets you specify these pairs of noise vectors as IDs: 52 | ```bash 53 | $ ./interp.py expName[,seed[,iteration]] out.png lid,rid lid,rid .... 54 | ``` 55 | The first parameter now has two optional comma-separated arguments beyond the model name for seed and iteration. After this and the output file name, it agrees an arbitrary number of pairs of left-right image IDs, for each row of desired images in the output. These IDs correspond to the number of the image, in reading order, in the output generated by `sample.py` (with the same seed). For example, to create the images at the top of the page, use: 56 | ```bash 57 | $ ./interp.py face48 out.png 137,65 146,150 15,138 54,72 38,123 36,93 58 | ``` 59 | 60 | ## Training 61 | 62 | To train your own model, you will need to create a new model file (say `myown.py`) in the `exp/` directory. See the existing model files for reference. Here is an explanation of some of the key parameters: 63 | 64 | - `wts_dir`: Directory in which to store model weights. This directory must already exist. 65 | - `imsz`: Resolution / Size of the images (will be square color images of size `imsz x imsz`). 66 | - `lfile`: Path to a list file for the images you want to train on, where each line of the file contains a path to an image. 67 | - `crop`: Boolean (`True` or `False`). Indicates whether the images are already the correct resolution, or need to be cropped. If `True`, these images will first be resized so that the smaller side matches `imsz`, and then a random crop along the other dimension will be used for training. 68 | 69 | Before you begin training, you will need to create a file called `filts.npz` which defines the convolutional filters for the random projections. See the `filts/` directory for the filters used for the pre-trained models, as well as instructions on a script for creating your own. On 70 | 71 | Once you have created the model file and prepared the directory, you can begin training by using the `train.py` script as: 72 | ```bash 73 | $ ./train.py myown 74 | ``` 75 | where the first parameter is the name of your model file. 76 | 77 | We also provide a script for traditional training---`baseline_train.py`---with a single discriminator acting on the original image. It is used in the same way, except it doesn't require a `filts.npz` file in the weights directory. 78 | 79 | *** 80 | 81 | ### Acknowledgments 82 | 83 | This work was supported by the National Science Foundation under award no. [IIS-1820693](http://www.nsf.gov/awardsearch/showAward?AWD_ID=1820693). Any opinions, findings, and conclusions or recommendations expressed in this material are those of the authors, and do not necessarily reflect the views of the National Science Foundation. 84 | --------------------------------------------------------------------------------