├── .gitignore ├── LICENSE ├── README.md ├── graphy ├── __init__.py ├── function.py ├── graphics.py ├── misc │ ├── __init__.py │ ├── data.py │ ├── logger.py │ └── optim.py ├── ndict.py ├── nodes │ ├── __init__.py │ ├── ar.py │ ├── conv.py │ └── rand.py └── png.py ├── models.py ├── tf_train.py ├── tf_utils ├── __init__.py ├── adamax.py ├── cifar10_data.py ├── common.py ├── data_utils.py ├── distributions.py ├── distributions_test.py ├── hparams.py ├── hparams_test.py └── layers.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 openai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Archive (code is provided as-is, no updates expected) 2 | 3 | # Improve Variational Inference with Inverse Autoregressive Flow 4 | 5 | Code for reproducing key results in the paper [Improving Variational Inference with Inverse Autoregressive Flow](http://arxiv.org/abs/1606.04934) by Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, and Max Welling. 6 | 7 | ## Prerequisites 8 | 9 | 1. Make sure that recent versions installed of: 10 | - Python (version 2.7 or higher) 11 | - Numpy (e.g. `pip install numpy`) 12 | - Theano (e.g. `pip install Theano`) 13 | 14 | 2. Set `floatX = float32` in the `[global]` section of Theano config (usually `~/.theanorc`). Alternatively you could prepend `THEANO_FLAGS=floatX=float32 ` to the python commands below. 15 | 16 | 3. Clone this repository, e.g.: 17 | ```sh 18 | git clone https://github.com/openai/iaf.git 19 | ``` 20 | 21 | 4. Download the [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) (get the *Python* version) and create an environment variable `CIFAR10_PATH` that points to the subdirectory with CIFAR-10 data. For example: 22 | ```sh 23 | export CIFAR10_PATH="$HOME/cifar-10" 24 | ``` 25 | 26 | ## Syntax of train.py 27 | 28 | Example: 29 | ```sh 30 | python train.py with problem=cifar10 n_z=32 n_h=64 depths=[2,2,2] margs.depth_ar=1 margs.posterior=down_iaf2_NL margs.kl_min=0.25 31 | ``` 32 | 33 | `problem` is the problem (dataset) to train on. I only tested `cifar10` for this release. 34 | 35 | `n_z` is the number of stochastic featuremaps in each layer. 36 | 37 | `n_h` is the number of deterministic featuremaps used throughout the model. 38 | 39 | `depths` is an array of integers that denotes the depths of the *levels* in the model. Each level is a sequence of layers. Each subsequent level operates over spatially smaller featuremaps. In case of CIFAR-10, the first level operates over 16x16 featuremaps, the second over 8x8 featuremaps, etc. 40 | 41 | Some possible choices for `margs.posterior` are: 42 | - `up_diag`: bottom-up factorized Gaussian 43 | - `up_iaf1_nl`: bottom-up IAF, mean-only perturbation 44 | - `up_iaf2_nl`: bottom-up IAF 45 | - `down_diag`: top-down factorized Gaussian 46 | - `down_iaf1_nl`: top-down IAF, mean-only perturbation 47 | - `down_iaf2_nl`: top-down IAF 48 | 49 | `margs.depth_ar` is the number of hidden layers within IAF, and can be any non-negative integer. 50 | 51 | `margs.kl_min`: the minimum information constraint. Should be a non-negative float (where 0 is no constraint). 52 | 53 | ## Results of Table 3 54 | 55 | (3.28 bits/dim) 56 | 57 | ```sh 58 | python train.py with problem=cifar10 n_h=160 depths=[10,10] margs.depth_ar=2 margs.posterior=down_iaf2_nl margs.prior=diag margs.kl_min=0.25 59 | ``` 60 | 61 | More instructions will follow. 62 | 63 | 64 | ## Multi-GPU TensorFlow implementation 65 | 66 | ### Prerequisites 67 | 68 | Make sure that recent versions installed of: 69 | - Python (version 2.7 or higher) 70 | - TensorFlow 71 | - tqdm 72 | 73 | `CIFAR10_PATH` environment variable should point to the dataset location. 74 | 75 | ### Syntax of tf_train.py 76 | 77 | Training script: 78 | ```sh 79 | python tf_train.py --logdir --hpconfig depth=1,num_blocks=20,kl_min=0.1,learning_rate=0.002,batch_size=32 --num_gpus 8 --mode train 80 | ``` 81 | 82 | It will run the training procedure on a given number of GPUs. Model checkpoints will be stored in `/train` directory along with TensorBoard summaries that are useful for monitoring and debugging issues. 83 | 84 | Evaluation script: 85 | ```sh 86 | python tf_train.py --logdir --hpconfig depth=1,num_blocks=20,kl_min=0.1,learning_rate=0.002,batch_size=32 --num_gpus 1 --mode eval_test 87 | ``` 88 | 89 | It will run the evaluation on the test set using a single GPU and will produce TensorBoard summary with the results and generated samples. 90 | 91 | To start TensorBoard: 92 | ```sh 93 | tensorboard --logdir 94 | ``` 95 | 96 | For the description of hyper-parameters, take a look at `get_default_hparams` function in `tf_train.py`. 97 | 98 | 99 | ### Loading from the checkpoint 100 | 101 | The best IAF model trained on CIFAR-10 reached 3.15 bits/dim when evaluated with a single sample. With 10,000 samples, the estimation of log likelihood is 3.111 bits/dim. 102 | The checkpoint is available at [link](https://drive.google.com/file/d/0B-pv8mYT4p0OOXFfWElyeUs0bUk/view?usp=sharing). 103 | Steps to use it: 104 | - download the file 105 | - create directory `/train/` and copy the checkpoint there 106 | - run the following command: 107 | ```sh 108 | python tf_train.py --logdir --hpconfig depth=1,num_blocks=20,kl_min=0.1,learning_rate=0.002,batch_size=32 --num_gpus 1 --mode eval_test 109 | ``` 110 | 111 | The script will run the evaluation on the test set and generate samples stored in TensorFlow events file that can be accessed using TensorBoard. -------------------------------------------------------------------------------- /graphy/__init__.py: -------------------------------------------------------------------------------- 1 | import ndict 2 | 3 | import numpy as np 4 | import theano 5 | import theano.tensor.shared_randomstreams 6 | import theano.compile 7 | import math, time, sys 8 | 9 | # Change recursion limit (for deep theano models) 10 | import sys 11 | sys.setrecursionlimit(10000) 12 | 13 | # some config 14 | floatX = theano.config.floatX # @UndefinedVariable 15 | print '[graphy] floatX = '+floatX 16 | 17 | rng = theano.tensor.shared_randomstreams.RandomStreams(0) 18 | rng_curand = rng 19 | if 'gpu' in theano.config.device: # @UndefinedVariable 20 | import theano.sandbox.cuda.rng_curand 21 | rng_curand = theano.sandbox.cuda.rng_curand.CURAND_RandomStreams(0) 22 | 23 | # Shared floating-point Theano variable from a numpy variable 24 | def sharedf(x, target=None, name=None, borrow=False, broadcastable=None): 25 | if target == None: 26 | return theano.shared(np.asarray(x, dtype=floatX), name=name, borrow=borrow, broadcastable=broadcastable) 27 | else: 28 | return theano.shared(np.asarray(x, dtype=floatX), target=target, name=name, borrow=borrow, broadcastable=broadcastable) 29 | 30 | # Shared random normal variable 31 | def sharedrandf(scale, size): 32 | return sharedf(np.random.normal(0, scale, size=size)) 33 | 34 | # Construct object from keyword arguments or dictionary 35 | class Struct: 36 | def __init__(self, **entries): 37 | self.__dict__.update(entries) 38 | def __repr__(self): # nice printing 39 | return '<%s>' % str('\n '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.iteritems())) 40 | 41 | # Import rest of the files 42 | from function import * 43 | 44 | import misc 45 | import misc.data 46 | import misc.optim 47 | import misc.logger 48 | 49 | import nodes 50 | import graphics 51 | -------------------------------------------------------------------------------- /graphy/function.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano 3 | import ndict 4 | import math, time, sys 5 | import graphy as G 6 | from collections import OrderedDict 7 | 8 | ''' 9 | NaN detection for theano functions 10 | from: http://deeplearning.net/software/theano/tutorial/debug_faq.html 11 | ''' 12 | def nan_detection_mode(): 13 | def detect_nan(i, node, fn): 14 | for output in fn.outputs: 15 | if np.isnan(output[0]).any(): 16 | print '*** NaN detected ***' 17 | theano.printing.debugprint(node) 18 | print 'Inputs : %s' % [input[0] for input in fn.inputs] 19 | print 'Outputs: %s' % [output[0] for output in fn.outputs] 20 | #break 21 | raise Exception() 22 | return theano.compile.MonitorMode(post_func=detect_nan) # @UndefinedVariable 23 | 24 | default_function_mode = 'FAST_RUN' 25 | #default_function_mode = nan_detection_mode() 26 | 27 | ''' 28 | Graphy function 29 | Wraps theano function, same API, except: 30 | - input x is a dict or sequence of dicts, so no worries about ordering 31 | (as with regular theano functions) 32 | - output y can be either a dict of Theano vars or a single Theano variable 33 | - Supports lazy compilation 34 | - Supports minibatches 35 | - Checks whether input keys match at compile- and runtime 36 | ''' 37 | def function(x, y, lazy=False, _debug=False, checknan='raise', **kwargs): 38 | # Default keyword arguments 39 | if not kwargs.has_key('on_unused_input'): 40 | kwargs['on_unused_input'] = 'warn' 41 | if not kwargs.has_key('mode'): 42 | kwargs['mode'] = default_function_mode 43 | # Order the input dict 44 | x = ndict.ordered(ndict.flatten(x)) 45 | # Check the output dict 46 | return_single_y = False 47 | if not isinstance(y, dict): 48 | return_single_y = True 49 | y = {str(y): y} 50 | y = ndict.ordered(y) 51 | # Lazily compiled function (saves a lot of time) 52 | f = [None] 53 | def _compile(verbose=True): 54 | t0 = time.time() 55 | print 'Compiling... ', 56 | #print '[graphy] Compiling function '+str(x.keys())+' => '+str(y.keys())+' ...' 57 | sys.stdout.flush() 58 | f[0] = theano.function(x.values(), y.values(), **kwargs) 59 | print "%.2f" % (time.time()-t0), 's' 60 | if not lazy: 61 | _compile() 62 | # The function to be called 63 | def func(data, n_batch=0, randomorder=True, data_global={}): 64 | data = ndict.ordered(ndict.flatten(data)) 65 | data_global = ndict.ordered(ndict.flatten(data_global)) 66 | # Check if keys of 'x' and 'inputs' match 67 | allkeys = (data.keys() + data_global.keys()) 68 | for i in range(len(data)): 69 | if x.keys()[i] not in allkeys: 70 | raise Exception('Non-matching keys:'+str(allkeys)+' vs. '+str(x.keys())) 71 | # Compile function if not already done 72 | if f[0] == None: 73 | _compile() 74 | if n_batch <= 0: 75 | # Get results 76 | _data = data.copy() 77 | _data.update(data_global) 78 | inputs_ordered = ndict.orderedvals((_data,)) 79 | _result = f[0](*inputs_ordered) 80 | # Put it in a dictionary with the corresponding keys 81 | result = {y.keys()[i]: _result[i] for i in range(len(y))} 82 | else: 83 | # Minibatch-based evaluation. 84 | # This assumes that input and output are tensors, and the first dimension iterates of datapoints 85 | n_tot = data.itervalues().next().shape[0] 86 | n_minibatches = int(math.ceil(1. * n_tot / n_batch)) 87 | 88 | n_tile = 1 89 | if n_batch > n_tot: 90 | assert n_batch%n_tot == 0 91 | n_tile = n_batch/n_tot 92 | 93 | indices = np.tile(np.arange(n_tot),n_tile) 94 | if randomorder: 95 | np.random.shuffle(indices) 96 | adict = dict(zip(np.tile(np.arange(n_tot),n_tile),indices)) 97 | indices_inverse = sorted(adict, key=adict.get) 98 | 99 | results = [] 100 | for i in range(n_minibatches): 101 | data_minibatch = ndict.getRowsFromIndices(data, indices[i*n_batch:(i+1)*n_batch]) 102 | data_minibatch.update(data_global) 103 | inputs_ordered = ndict.orderedvals((data_minibatch,)) 104 | results.append(f[0](*inputs_ordered)) 105 | if _debug: 106 | print 'Function debug', i, results[-1] 107 | if checknan == 'raise': 108 | if np.isnan(np.sum(results[-1])): 109 | print results[-1] 110 | raise Exception("NaN detected") 111 | result = {y.keys()[i]: np.concatenate([results[j][i] for j in range(n_minibatches)]) for i in range(len(y))} 112 | if randomorder: 113 | result = ndict.getRowsFromIndices(result, indices_inverse) 114 | 115 | result = OrderedDict(sorted(result.items())) 116 | 117 | # Return result 118 | #raise Exception() 119 | if return_single_y: 120 | return result[result.keys()[0]] 121 | return result 122 | # Return the func 123 | return G.Struct(__call__=func, f=f) 124 | 125 | # f: a function (as above) that returns ndict 126 | # function: a function that returns arguments for f 127 | # concat_axis: axis over which to concatenate the results 128 | def loop(f, f_data, n_batch, n_its, concat_axis=0): 129 | assert n_its >= 1 130 | results = [f(f_data(), n_batch=n_batch) for i in range(n_its)] 131 | result = {i: np.concatenate([results[j][i] for j in range(n_its)], axis=0) for i in results[0].keys()} 132 | return result 133 | -------------------------------------------------------------------------------- /graphy/graphics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy, scipy.misc 3 | from PIL import Image 4 | 5 | def save_image(x, path): 6 | from graphy import png 7 | x = x.swapaxes(0, 2).swapaxes(0,1).reshape((x.shape[1],-1)) 8 | png.from_array(x, 'RGB').save(path) 9 | 10 | def save_raster(x, path, rescale=False, width=None): 11 | save_image(to_raster(x, rescale, width), path) 12 | 13 | #def save_raster(x, path, rescale=False): 14 | # return Image.fromarray(to_raster(x, rescale).swapaxes(0, 2)).save(path, 'PNG') 15 | 16 | # Shape: (n_patches,3,rows,columns) 17 | # Or: 18 | def to_raster(x, rescale=False, width=None): 19 | #x = x.swapaxes(2, 3) 20 | if len(x.shape) == 3: 21 | x = x.reshape((x.shape[0],1,x.shape[1],x.shape[2])) 22 | if x.shape[1] == 1: 23 | x = np.repeat(x, 3, axis=1) 24 | if rescale: 25 | x = (x - x.min()) / (x.max() - x.min()) * 255. 26 | x = np.clip(x, 0, 255) 27 | assert len(x.shape) == 4 28 | assert x.shape[1] == 3 29 | n_patches = x.shape[0] 30 | if width is None: 31 | width = np.sqrt(n_patches) #result width 32 | assert width == int(width) 33 | height = n_patches/width #result height 34 | tile_height = x.shape[2] 35 | tile_width = x.shape[3] 36 | result = np.zeros((3,height*tile_height,width*tile_width), dtype='uint8') 37 | for i in range(n_patches): 38 | _x = (i%width)*tile_width 39 | y = np.floor(i/width)*tile_height 40 | result[:,y:y+tile_height,_x:_x+tile_width] = x[i] 41 | return result 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /graphy/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/iaf/ad33fe4872bf6e4b4f387e709a625376bb8b0d9d/graphy/misc/__init__.py -------------------------------------------------------------------------------- /graphy/misc/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import graphy as G 5 | 6 | basepath = os.environ['ML_DATA_PATH'] 7 | 8 | ''' Standard datasets 9 | The first dimension of tensors goes over datapoints. 10 | ''' 11 | 12 | def mnist(with_y=True): 13 | 14 | n_y = 10 15 | import scipy.io 16 | data = scipy.io.loadmat(basepath+'/mnist_roweis/mnist_all.mat') 17 | train_x = [data['train'+str(i)] for i in range(n_y)] 18 | train_y = [(i*np.ones((train_x[i].shape[0],))).astype(np.uint8) for i in range(n_y)] 19 | test_x = [data['test'+str(i)] for i in range(n_y)] 20 | test_y = [(i*np.ones((test_x[i].shape[0],))).astype(np.uint8) for i in range(n_y)] 21 | 22 | train = {'x':np.concatenate(train_x)} 23 | test = {'x':np.concatenate(test_x)} 24 | 25 | if with_y: 26 | train['y'] = np.concatenate(train_y) 27 | test['y'] = np.concatenate(test_y) 28 | 29 | G.ndict.shuffle(train) #important!! 30 | G.ndict.shuffle(test) #important!! 31 | return train, test 32 | 33 | ''' 34 | Binarized MNIST (by Hugo Larochelle) 35 | ''' 36 | def mnist_binarized(validset=False, flattened=True): 37 | path = basepath+'/mnist_binarized/' 38 | import h5py 39 | train = {'x':h5py.File(path+"binarized_mnist-train.h5")['data'][:].astype('uint8')*255} 40 | valid = {'x':h5py.File(path+"binarized_mnist-valid.h5")['data'][:].astype('uint8')*255} 41 | test = {'x':h5py.File(path+"binarized_mnist-test.h5")['data'][:].astype('uint8')*255} 42 | G.ndict.shuffle(train) 43 | G.ndict.shuffle(test) 44 | G.ndict.shuffle(valid) 45 | if not flattened: 46 | for data in [train,valid,test]: 47 | data['x'] = data['x'].reshape((-1,1,28,28)) 48 | if not validset: 49 | print "Full training set" 50 | train['x'] = np.concatenate((train['x'], valid['x'])) 51 | return train, test 52 | return train, valid, test 53 | 54 | # Converts integer labels to binarized labels (1-of-K coding) 55 | def binarize_labels(y, n_classes=10): 56 | new_y = np.zeros((y.shape[0], n_classes), dtype=G.floatX) 57 | for i in range(y.shape[0]): 58 | new_y[i, y[i]] = 1 59 | return new_y 60 | 61 | ''' 62 | Create semi-supervised sets of labeled and unlabeled data 63 | where there are equal number of labels from each class 64 | x: dict with dataset 65 | key_y: name (key) of label variable in x 66 | shuffle: whether to shuffle the input and output 67 | n_labeled: number of labeled instances 68 | ''' 69 | def create_semisupervised(x, key_y, n_labeled, shuffle=True): 70 | if shuffle: 71 | G.ndict.shuffle(x) 72 | n_classes = np.amax(x[key_y])+1 73 | if n_labeled%n_classes != 0: raise("Cannot create stratisfied semi-supervised set since n_labeled (wished number of labeled samples) not divisible by n_classes (number of classes)") 74 | n_labels_per_class = n_labeled/n_classes 75 | x_l = {j: [0]*n_classes for j in x} #labeled 76 | x_u = {j: [0]*n_classes for j in x} #unlabeld 77 | for i in range(n_classes): 78 | idx = x[key_y] == i 79 | for j in x: 80 | x_l[j][i] = x[j][idx][:n_labels_per_class] 81 | x_u[j][i] = x[j][idx][n_labels_per_class:] 82 | x_l = {i: np.concatenate(x_l[i]) for i in x} 83 | x_u = {i: np.concatenate(x_u[i]) for i in x} 84 | if shuffle: 85 | G.ndict.shuffle(x_l) 86 | G.ndict.shuffle(x_u) 87 | return x_l, x_u 88 | 89 | 90 | # from http://cs.nyu.edu/~roweis/data.html 91 | # returned pixels are uint8 92 | def cifar10(with_y=True, binarize_y=False): 93 | # Load the original images into numpy arrays 94 | def unpickle(file): 95 | import cPickle 96 | fo = open(file, 'rb') 97 | result = cPickle.load(fo) 98 | fo.close() 99 | return result 100 | path = os.environ['CIFAR10_PATH'] 101 | n_train = 5 102 | _train = [unpickle(path+'data_batch_'+str(i+1)) for i in range(n_train)] 103 | train = {'x':np.concatenate([_train[i]['data'] for i in range(n_train)])} 104 | _test = unpickle(path+'test_batch') 105 | test = {'x':_test['data']} 106 | 107 | train['x'] = train['x'].reshape((-1,3,32,32)) 108 | test['x'] = test['x'].reshape((-1,3,32,32)) 109 | 110 | if with_y: 111 | train['y'] = np.concatenate([_train[i]['labels'] for i in range(n_train)]) 112 | test['y'] = np.asarray(_test['labels']) 113 | if binarize_y: 114 | train['y'] = binarize_labels(train['y']) 115 | test['y'] = binarize_labels(test['y']) 116 | 117 | G.ndict.shuffle(train) 118 | G.ndict.shuffle(test) 119 | return train, test 120 | 121 | # SVHN data 122 | def svhn(with_y=True, with_extra=False, binarize_y=False): 123 | path = os.environ['ML_DATA_PATH']+'/svhn' 124 | import scipy.io 125 | train = scipy.io.loadmat(path+'/train_32x32.mat') 126 | train_x = train['X'].transpose((3,2,0,1)) 127 | if with_extra: 128 | assert not with_y 129 | extra_x = scipy.io.loadmat(path+'_extra/extra_32x32.mat')['X'].transpose((3,2,0,1)) 130 | train_x = np.concatenate((train_x,extra_x),axis=0) 131 | 132 | test = scipy.io.loadmat(path+'/test_32x32.mat') 133 | test_x = test['X'].transpose((3,2,0,1)) 134 | 135 | if with_y: 136 | train_y = train['y'].reshape((-1,)) - 1 137 | test_y = test['y'].reshape((-1,)) - 1 138 | if binarize_y: 139 | train['y'] = binarize_labels(train['y']) 140 | test['y'] = binarize_labels(test['y']) 141 | return {'x':train_x, 'y':train_y}, {'x':test_x, 'y':test_y} 142 | 143 | return {'x':train_x}, {'x':test_x} 144 | 145 | # SVHN data 146 | def lfw(with_y=True, pad=False): 147 | path = os.environ['ML_DATA_PATH']+'/lfw/' 148 | data = {'x':np.load(path+'lfw_62x47.npy').transpose((0,3,1,2))} 149 | if pad: 150 | padded = np.zeros((data['x'].shape[0],3,64,48),dtype='uint8') 151 | padded[:,:,:-2,:-1] = data['x'] 152 | data['x'] = padded 153 | 154 | if with_y: 155 | data['y'] = np.load(path+'lfw_labels.npy') 156 | return data 157 | -------------------------------------------------------------------------------- /graphy/misc/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | class Logger(object): 4 | def __init__(self, filename): 5 | self.terminal = sys.stdout 6 | self.filename = filename 7 | def write(self, message): 8 | self.terminal.write(message) 9 | with open(self.filename, "a") as log: 10 | log.write(message) 11 | def __getattr__(self, attr): 12 | return getattr(self.terminal, attr) 13 | -------------------------------------------------------------------------------- /graphy/misc/optim.py: -------------------------------------------------------------------------------- 1 | import graphy as G 2 | import theano 3 | import theano.tensor as T 4 | import math 5 | import numpy as np 6 | 7 | from collections import OrderedDict 8 | 9 | def SGD(w, objective, alpha=.1): 10 | print 'SGD', 'alpha:',alpha 11 | g = T.grad(objective.sum(), w, disconnected_inputs='warn') 12 | updates = OrderedDict() 13 | for i in range(len(g)): 14 | updates[w[i]] = w[i] + alpha * g[i] 15 | return updates 16 | 17 | # Adam 18 | def Adam(ws, objective, alpha=.0003, beta=.9, gamma=.999): 19 | print 'Adam', 'alpha:',alpha,'beta1:',beta,'gamma:',gamma 20 | 21 | new = OrderedDict() 22 | 23 | gs = G.ndict.T_grad(objective.sum(), ws, disconnected_inputs='warn') #warn/raise 24 | 25 | it = G.sharedf(0.) 26 | new[it] = it + 1. 27 | 28 | fix1 = 1-beta**(it+1.) 29 | fix2 = 1-gamma**(it+1.) # To make estimates unbiased 30 | lr_t = alpha * T.sqrt(fix2) / fix1 31 | 32 | ws_avg = [] 33 | for j in range(len(ws)): 34 | w_avg = {} 35 | for i in ws[j]: 36 | w = ws[j][i] 37 | g = gs[j][i] 38 | 39 | # Initial values 40 | shape = w.get_value().shape 41 | m = G.sharedf(np.zeros(shape)) 42 | v = G.sharedf(np.zeros(shape)) 43 | w_avg[i] = G.sharedf(np.zeros(shape)) 44 | 45 | # Updates 46 | new[m] = beta * m + (1-beta) * g 47 | new[v] = gamma * v + (1-gamma) * g**2 48 | new[w] = w + lr_t * new[m] / (T.sqrt(new[v]) + 1e-8) 49 | new[w_avg[i]] = gamma * new[w] + (1.-gamma) * w_avg[i] 50 | 51 | ws_avg += [w_avg] 52 | 53 | return new, ws_avg 54 | 55 | 56 | def AdaMax(w, objective, alpha=.01, beta1=.1, beta2=.001): 57 | print 'AdaMax', 'alpha:',alpha,'beta1:',beta1,'beta2:',beta2 58 | g = T.grad(objective.sum(), w, disconnected_inputs='warn') 59 | 60 | new = OrderedDict() 61 | 62 | for i in range(len(w)): 63 | #gi = T.switch(T.isnan(gi),T.zeros_like(gi),gi) #remove NaN's 64 | mom1 = G.sharedf(w[i].get_value() * 0.) 65 | _max = G.sharedf(w[i].get_value() * 0.) 66 | new[mom1] = (1-beta1) * mom1 + beta1 * g[i] 67 | new[_max] = T.maximum((1-beta2)*_max, abs(g[i]) + 1e-8) 68 | new[w[i]] = w[i] + alpha * new[mom1] / new[_max] 69 | 70 | return new 71 | 72 | # AdaMax that averages over multiple minibatches 73 | def AdaMax2(w, objective, alpha=.01, beta1=.1, beta2=.001, n_accum=2): 74 | print 'AdaMax2', 'alpha:',alpha,'beta1:',beta1,'beta2:',beta2, 'n_accum:', n_accum 75 | g = T.grad(objective.sum(), w, disconnected_inputs='warn') 76 | 77 | new = OrderedDict() 78 | 79 | from theano.ifelse import ifelse 80 | it = G.sharedf(0.) 81 | new[it] = it + 1 82 | reset = T.eq(T.mod(new[it],n_accum), 0) 83 | update = T.eq(T.mod(new[it],n_accum), n_accum-1) 84 | 85 | for i in range(len(w)): 86 | mom1 = G.sharedf(w[i].get_value() * 0.) 87 | _max = G.sharedf(w[i].get_value() * 0.) 88 | g_sum = G.sharedf(w[i].get_value() * 0.) 89 | 90 | #gi = T.switch(T.isnan(gi),T.zeros_like(gi),gi) #remove NaN's 91 | new[g_sum] = ifelse(reset, g[i], g_sum + g[i]) 92 | new[mom1] = ifelse(update, (1-beta1) * mom1 + beta1 * new[g_sum], mom1) 93 | new[_max] = ifelse(update, T.maximum((1-beta2)*_max, abs(new[g_sum]) + 1e-8), _max) 94 | new[w[i]] = ifelse(update, w[i] + alpha * new[mom1] / new[_max], w[i]) 95 | 96 | return new 97 | 98 | # AdaMax that keeps running average of parameter 99 | def AdaMaxAvg(ws, ws_avg, objective, alpha=.01, beta1=.1, beta2=.001, update_keys=None, disconnected_inputs='raise'): 100 | print 'AdaMax_Avg', 'alpha:',alpha,'beta1:',beta1,'beta2:',beta2 101 | 102 | gs = G.ndict.T_grad(objective.sum(), ws, disconnected_inputs=disconnected_inputs) #warn/raise 103 | 104 | if update_keys is None: 105 | update_keys = [ws[j].keys() for j in range(len(ws))] 106 | 107 | new = OrderedDict() 108 | for j in range(len(ws)): 109 | if ws_avg is not None: 110 | w_avg = ws_avg[j] 111 | for i in update_keys[j]: 112 | _w = ws[j][i] 113 | _g = gs[j][i] 114 | #_g = T.switch(T.isnan(_g),T.zeros_like(_g),_g) #remove NaN's 115 | mom1 = G.sharedf(_w.get_value() * 0.) 116 | _max = G.sharedf(_w.get_value() * 0. + 1e-8) 117 | 118 | new[mom1] = (1-beta1) * mom1 + beta1 * _g 119 | new[_max] = T.maximum((1-beta2)*_max, abs(_g) + 1e-8) 120 | new[_w] = _w + alpha * new[mom1] / new[_max] 121 | if ws_avg is not None: 122 | new[w_avg[i]] = beta2 * _w + (1.-beta2) * w_avg[i] 123 | return new 124 | 125 | # Eve that keeps running average of parameter 126 | def Eve(w, w_avg, f, alpha=.01, beta1=.1, beta2=.001, beta3=0.01, disconnected_inputs='raise'): 127 | print 'Eve', 'alpha:',alpha,'beta1:',beta1,'beta2:',beta2,'beta3:',beta3 128 | 129 | mom = {} 130 | _max = {} 131 | delta = {} 132 | w_prime = {} 133 | for i in w: 134 | mom[i] = G.sharedf(w[i].get_value() * 0.) 135 | _max[i] = G.sharedf(w[i].get_value() * 0. + 1e-8) 136 | delta[i] = G.sharedf(w[i].get_value() * 0.) 137 | w_prime[i] = w[i] + (1-beta1)/beta1 * delta[i] 138 | 139 | train_cost = f(w_prime).mean() 140 | g = G.ndict.T_grad(train_cost, w, disconnected_inputs=disconnected_inputs) #warn/raise 141 | 142 | new = OrderedDict() 143 | for i in w: 144 | new[mom[i]] = (1-beta1) * mom[i] + beta1 * g[i] 145 | new[_max[i]] = T.maximum((1-beta2)*_max[i], abs(g[i]) + 1e-8) 146 | new[delta[i]] = alpha * new[mom[i]] / new[_max[i]] 147 | new[w[i]] = w[i] + new[delta[i]] 148 | 149 | for i in w: 150 | new[w_avg[i]] = beta3 * w[i] + (1.-beta3) * w_avg[i] 151 | return train_cost, new 152 | 153 | # AdaMax that keeps running average of parameter 154 | # Accumulates gradient over n_accum minibatches 155 | def AdaMaxAvg2(ws, objective, alpha=.01, beta1=.1, beta2=.001, beta3=0.01, n_accum=1): 156 | if n_accum == 1: 157 | return AdaMaxAvg(ws, objective, alpha, beta1, beta2, beta3) 158 | print 'AdaMax_Avg2', 'alpha:',alpha,'beta1:',beta1,'beta2:',beta2,'beta3:',beta3,'n_accum:',n_accum 159 | 160 | gs = G.ndict.T_grad(objective.sum(), ws, disconnected_inputs='raise') 161 | 162 | new = OrderedDict() 163 | 164 | from theano.ifelse import ifelse 165 | it = G.sharedf(0.) 166 | new[it] = it + 1 167 | reset = T.eq(T.mod(it,n_accum), 0) 168 | update = T.eq(T.mod(it,n_accum), n_accum-1) 169 | 170 | ws_avg = [] 171 | for j in range(len(ws)): 172 | w_avg = {} 173 | for i in ws[j]: 174 | _w = ws[j][i] 175 | _g = gs[j][i] 176 | #_g = T.switch(T.isnan(_g),T.zeros_like(_g),_g) #remove NaN's 177 | mom1 = G.sharedf(_w.get_value() * 0.) 178 | _max = G.sharedf(_w.get_value() * 0.) 179 | w_avg[i] = G.sharedf(_w.get_value()) 180 | g_sum = G.sharedf(_w.get_value() * 0.) 181 | 182 | new[g_sum] = ifelse(reset, _g, g_sum + _g) 183 | new[mom1] = ifelse(update, (1-beta1) * mom1 + beta1 * new[g_sum], mom1) 184 | new[_max] = ifelse(update, T.maximum((1-beta2)*_max, abs(new[g_sum]) + 1e-8), _max) 185 | new[_w] = ifelse(update, _w + alpha * new[mom1] / new[_max], _w) 186 | new[w_avg[i]] = ifelse(update, beta3 * new[_w] + (1.-beta3) * w_avg[i], w_avg[i]) 187 | ws_avg += [w_avg] 188 | return new, ws_avg 189 | 190 | 191 | -------------------------------------------------------------------------------- /graphy/ndict.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import collections as C 4 | import numpy as np 5 | 6 | ''' 7 | Operations on dictionaries of numpy tensors and/or theano symbolic variables. 8 | Keys are strings. 9 | 10 | Functions prepended with 'np_' only work on Numpy arrays 11 | Functions appended with 'T_' only work on Theano symbolic variables 12 | ''' 13 | 14 | def size(d): 15 | result = 0 16 | for i in d: 17 | result += d[i].size 18 | return result 19 | 20 | #== Row operators 21 | 22 | def getRows(d, ifrom, ito): 23 | return {i: d[i][ifrom:ito] for i in d} 24 | 25 | def getRowsFromIndices(d, rowIndices): 26 | return {i: d[i][rowIndices] for i in d} 27 | 28 | # ds: multiple dictionaries 29 | # result: cols ifrom to ito from 'ds' 30 | def getRows_multiple(ds, ifrom, ito): 31 | return [{i: d[i][ifrom:ito] for i in d} for d in ds] 32 | 33 | #=== Shuffle along first dimension of dicts 34 | 35 | def shuffle(d): 36 | n_rows = d.itervalues().next().shape[0] 37 | idx = np.arange(n_rows) 38 | import time 39 | np.random.shuffle(idx) 40 | for i in d: 41 | t0 = time.time() 42 | d[i] = d[i][idx] 43 | #print i, time.time()-t0 44 | 45 | #== Clone operations 46 | 47 | def shallowClone(d): 48 | return {i: d[i] for i in d} 49 | 50 | def clone(d): 51 | result = {} 52 | for i in d: 53 | result[i] = d[i].copy() 54 | return result 55 | 56 | def cloneShared(d): 57 | result = {} 58 | for i in d: 59 | result[i] = theano.shared(d[i].get_value()) 60 | return result 61 | 62 | def np_cloneZeros(d): 63 | result = {} 64 | for i in d: 65 | result[i] = np.zeros(d[i].shape) 66 | return result 67 | 68 | def np_cloneOnesN(d): 69 | result = {} 70 | for i in d: 71 | result[i] = np.ones(d[i].shape) 72 | return result 73 | 74 | def T_cloneZeros(d): 75 | result = {} 76 | for i in d: 77 | result[i] = T.zeros_like(d[i]) 78 | return result 79 | 80 | def T_cloneOnes(d): 81 | result = {} 82 | for i in d: 83 | result[i] = T.ones_like(d[i]) 84 | return result 85 | 86 | def Tshared_cloneZeros(d): 87 | result = {} 88 | for i in d: 89 | result[i] = theano.shared(d[i].get_value() * 0.) 90 | return result 91 | 92 | #=== Shape operations 93 | 94 | # Get shapes of elements of d as a dict 95 | def getShapes(d): 96 | shapes = {} 97 | for i in d: 98 | shapes[i] = d[i].shape 99 | return shapes 100 | 101 | # Set shapes of elements of d 102 | def setShapes(d, shapes): 103 | result = {} 104 | for i in d: 105 | result[i] = d[i].reshape(shapes[i]) 106 | return result 107 | 108 | #=== Ordering operations 109 | def ordered(d): 110 | return C.OrderedDict(sorted(d.items())) 111 | 112 | # converts normal dicts to ordered dicts, ordered by keys 113 | def ordereddicts(ds): 114 | return [ordered(d) for d in ds] 115 | def orderedvals(ds): 116 | vals = [] 117 | for d in ds: 118 | vals += ordered(d).values() 119 | return vals 120 | 121 | 122 | #=== Type operations 123 | 124 | def astype(d, _type): 125 | return {i: d[i].astype(_type) for i in d} 126 | 127 | #def type(d): 128 | # return {i: d[i].type for i in d} 129 | 130 | #=== Get/set value 131 | 132 | def get_value(d): 133 | return {i: d[i].get_value() for i in d} 134 | 135 | def set_value(d, d2, complete=True): 136 | for i in d: 137 | if i not in d2: 138 | if complete: raise Exception() 139 | continue 140 | d[i].set_value(d2[i]) 141 | 142 | #=== Merging/combining of multiple dicts 143 | 144 | # Flatten sequence of dicts into one dict 145 | # Input can also be nested sequence of sequences 146 | # by default raises when keys overlap 147 | def flatten(ds, raiseOnDuplicateKeys=True): 148 | if isinstance(ds, dict): return ds 149 | assert (isinstance(ds, list) or isinstance(ds, tuple)) 150 | result = {} 151 | for d in ds: 152 | if (isinstance(d, list) or isinstance(d, tuple)): 153 | # recursion 154 | d = flatten(d, raiseOnDuplicateKeys) 155 | assert isinstance(d, dict) 156 | if raiseOnDuplicateKeys and any(i in d.keys() for i in result.keys()): 157 | print d.keys() 158 | print result.keys() 159 | raise Exception("Keys overlap overlap") 160 | result.update(d) 161 | return result 162 | 163 | #=== Gradients 164 | 165 | # Return gradients of scalar 'y' w.r.t. elements of d 166 | # 'd' is a dict, or list of dicts 167 | def T_grad(y, d, **kwargs): 168 | if type(d) is list: 169 | d = ordereddicts(d) 170 | vals = orderedvals(d) 171 | g = T.grad(y, vals, **kwargs) 172 | g_list = [] 173 | idx = 0 174 | for i in range(len(d)): 175 | g_list += [dict(zip(d[i].keys(), g[idx:idx+len(d[i].keys())]))] 176 | idx += len(d[i].keys()) 177 | return g_list 178 | else: 179 | d = ordered(d) 180 | keys = d.keys() 181 | grads = T.grad(y, d.values(), **kwargs) 182 | g = {keys[i]: grads[i] for i in range(len(grads))} 183 | return g 184 | 185 | #=== Printing 186 | 187 | def p(d): 188 | for i in d: print i+'\n', d[i] 189 | 190 | def np_pNorm(d): 191 | for i in d: print i, np.linalg.norm(d[i]) 192 | 193 | def norm(d): 194 | return {i: np.linalg.norm(d[i]) for i in d} 195 | 196 | def pShape(d): 197 | for i in d: print i, d[i].shape 198 | 199 | def np_hasNaN(d): 200 | result = False 201 | for i in d: result = result or np.isnan(d[i]).any() 202 | return result 203 | 204 | #=== Saving/loading 205 | 206 | # Save/Load ndict to compressed file 207 | # (a gzipped tar file, i.e. .tar.gz) 208 | # if addext=True, then '.ndict' will be appended to filename 209 | def np_savez(d, filename, addext=True): 210 | import tarfile, os 211 | if addext: 212 | filename = filename + '.ndict.tar.gz' 213 | fname1 = 'arrays.npz' 214 | fname2 = 'names.txt' 215 | _d = ordered(d) 216 | # Write values (arrays) 217 | np.savez(filename+'.'+fname1, *_d.values()) 218 | # Write keys (names of arrays) 219 | with open(filename+'.'+fname2, 'w') as thefile: 220 | for key in _d.keys(): thefile.write("%s\n" % key) 221 | # Write TAR file 222 | tar = tarfile.open(filename, "w:gz") 223 | for fname in [fname1, fname2]: 224 | tar.add(filename+'.'+fname, fname) 225 | os.remove(filename+'.'+fname) 226 | tar.close() 227 | 228 | # Loads ndict from file written with savez 229 | def np_loadz(filename): 230 | import tarfile 231 | with tarfile.open(filename, 'r:gz') as tar: 232 | members = tar.getmembers() 233 | arrays = np.load(tar.extractfile(members[0])) 234 | names = tar.extractfile(members[1]).readlines() 235 | result = {names[i][:-1]: arrays['arr_'+str(i)] for i in range(len(names))} 236 | return ordered(result) 237 | 238 | def shared(d, dtype=theano.config.floatX): # @UndefinedVariable 239 | result = {} 240 | for i in d: 241 | result[i] = theano.shared(np.asarray(d[i], dtype)) 242 | return result 243 | -------------------------------------------------------------------------------- /graphy/nodes/__init__.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import graphy as G 4 | import numpy as np 5 | 6 | import sys 7 | 8 | N = sys.modules[__name__] 9 | 10 | # hyperparams 11 | l2norm = True 12 | logscale = True 13 | logscale_scale = 3. 14 | init_stdev = .1 15 | maxweight = 3. 16 | 17 | # Func for initializing parameters with random orthogonal matrix 18 | def randorth(shape): 19 | from scipy.linalg import sqrtm, inv 20 | assert len(shape) == 2 21 | w = np.random.normal(0, size=shape) 22 | w = w.dot(inv(sqrtm(w.T.dot(w)))) 23 | return G.sharedf(w) 24 | 25 | # Softmax function 26 | def softmax(x): 27 | e_x = T.exp(x - x.max(axis=1, keepdims=True)) 28 | return e_x / e_x.sum(axis=1, keepdims=True) 29 | 30 | def to_one_hot(x, n_y): 31 | # TODO: Replace this with built-in Theano function in extra_ops 32 | assert type(n_y) == int 33 | return T.eye(n_y)[x] 34 | 35 | def dropout(x, p): 36 | if p > 0.: 37 | retain_p = 1-p 38 | x *= G.rng.binomial(x.shape, p=retain_p, dtype=G.floatX) / retain_p 39 | return x 40 | 41 | # dropout where variable is replaced by noise with same marginal as input 42 | def dropout_bernoulli(x, p): 43 | if p > 0: 44 | mask = G.rng.binomial(x.shape, p=p, dtype=G.floatX) 45 | p_noise = T.mean(x, axis=0, keepdims=True) 46 | noise = G.rng.binomial(x.shape, p=p_noise, dtype=G.floatX) 47 | x = (mask < .5) * x + (mask > .5) * noise 48 | return x 49 | 50 | # Linear layer 51 | def linear_l2(name, n_in, n_out, w): 52 | 53 | # L2 normalization of weights 54 | def l2normalize(_w): 55 | targetnorm=1. 56 | norm = T.sqrt((_w**2).sum(axis=0, keepdims=True)) 57 | return _w * (targetnorm / norm) 58 | def maxconstraint(_w): 59 | return _w * (maxweight / T.maximum(maxweight, abs(_w).max(axis=0, keepdims=True))) 60 | 61 | w[name+'_w'] = G.sharedf(0.05*np.random.randn(n_in,n_out)) 62 | 63 | if maxweight > 0: 64 | w[name+'_w'].set_value(maxconstraint(w[name+'_w']).tag.test_value) 65 | w[name+'_b'] = G.sharedf(np.zeros((n_out,))) 66 | if l2norm: 67 | if logscale: 68 | w[name+'_s'] = G.sharedf(np.zeros((n_out,))) 69 | else: 70 | w[name+'_s'] = G.sharedf(np.ones((n_out,))) 71 | else: 72 | print 'WARNING: constant rescale, these weights arent saved' 73 | constant_rescale = G.sharedf(np.zeros((n_out,))) 74 | 75 | 76 | def f(h, w): 77 | _w = w[name+'_w'] 78 | if l2norm: 79 | _w = l2normalize(_w) 80 | h = T.dot(h, _w) 81 | if l2norm: 82 | if logscale: 83 | h *= T.exp(logscale_scale*w[name+'_s']) 84 | else: 85 | h *= abs(w[name+'_s']) 86 | else: 87 | h *= T.exp(constant_rescale) 88 | h += w[name+'_b'] 89 | 90 | if '__init' in w: 91 | # Std 92 | std = (1./init_stdev) * h.std(axis=0) + 1e-8 93 | if name+'_s' in w: 94 | if logscale: 95 | w[name+'_s'].set_value(-T.log(std).tag.test_value/logscale_scale) 96 | else: 97 | w[name+'_s'].set_value((1./std).tag.test_value) 98 | else: 99 | constant_rescale.set_value(-T.log(std).tag.test_value) 100 | #w[name+'_w'].set_value((_w / std.dimshuffle('x',0)).tag.test_value) 101 | 102 | h /= std.dimshuffle('x',0) 103 | 104 | # Mean 105 | mean = h.mean(axis=0) 106 | w[name+'_b'].set_value(-mean.tag.test_value) 107 | h -= mean.dimshuffle('x',0) 108 | 109 | #print name, abs(w[name+'_w']).get_value().mean(), w[name+'_w'].get_value().std(), w[name+'_w'].get_value().max() 110 | 111 | #print name, abs(h).max().tag.test_value, abs(h).min().tag.test_value 112 | #h = T.printing.Print(name)(h) 113 | 114 | return h 115 | 116 | # Post updates: normalize weights to unit L2 norm 117 | def postup(updates, w): 118 | if l2norm and maxweight>0: 119 | updates[w[name+'_w']] = maxconstraint(updates[w[name+'_w']]) 120 | return updates 121 | 122 | return G.Struct(__call__=f, postup=postup, w=w) 123 | 124 | # Mean-only batchnorm, and a bias unit 125 | def batchnorm_meanonly(name, n_h, w={}): 126 | w[name+'_b'] = G.sharedf(np.zeros((n_h,))) 127 | def f(h, w): 128 | h -= h.mean(axis=(0,2,3), keepdims=True) 129 | h += w[name+'_b'].dimshuffle('x',0,'x','x') 130 | return h 131 | return G.Struct(__call__=f, w=w) 132 | 133 | 134 | ''' 135 | Nonlinear functions 136 | (including parameterized ones) 137 | ''' 138 | def nonlinearity(name, which, shape=None, w={}): 139 | 140 | if which == 'prelu': 141 | w[name] = G.sharedf(np.zeros(shape)) 142 | if which == 'pelu': 143 | w[name] = G.sharedf(np.zeros(shape)) 144 | if which == 'softplus2': 145 | w[name] = G.sharedf(np.zeros(shape)) 146 | if which == 'softplus_shiftscale': 147 | w[name+'_in_s'] = G.sharedf(np.zeros(shape)) 148 | w[name+'_in_b'] = G.sharedf(np.zeros(shape)) 149 | if which == 'linearsigmoid': 150 | w[name+'_a'] = G.sharedf(.5*np.ones(shape)) 151 | w[name+'_b'] = G.sharedf(.5*np.ones(shape)) 152 | if which == 'meanonlybatchnorm_softplus': 153 | assert type(shape) == int 154 | w[name+'_b'] = G.sharedf(np.zeros(shape)) 155 | if which == 'meanonlybatchnorm_relu': 156 | assert type(shape) == int 157 | w[name+'_b'] = G.sharedf(np.zeros(shape)) 158 | 159 | def f(h, w=None): 160 | if which == None or which == 'None': 161 | return h 162 | elif which == 'tanh': 163 | return T.tanh(h) 164 | elif which == 'softmax': 165 | return T.nnet.softmax(h) 166 | elif which == 'prelu': 167 | return w[name]*h*(h<0.) + h*(h>=0.) 168 | elif which == 'relu': 169 | return h*(h>=0.) 170 | elif which == 'shiftedrelu': 171 | return T.switch(h < -1., -1., h) 172 | elif which == 'leakyrelu': 173 | return 0.01 * h*(h<0.) + h*(h>=0.) 174 | elif which == 'elu': 175 | return T.switch(h < 0., T.exp(h)-1, h) 176 | elif which == 'softplus': 177 | return T.nnet.softplus(h) 178 | elif which == 'softplus_shiftscale': 179 | return T.nnet.softplus(T.exp(w[name+'_in_s']) * h + w[name+'_in_b']) 180 | elif which == 'softplus2': 181 | return T.nnet.softplus(h) - w[name] * T.nnet.softplus(-h) 182 | elif which == 'linearsigmoid': 183 | return w[name+'_a'] * h + w[name+'_b'] * T.nnet.sigmoid(h) 184 | elif which == 'meanonlybatchnorm_softplus': 185 | h -= h.mean(axis=(0,2,3), keepdims=True) 186 | h += w[name+'_b'].dimshuffle('x',0,'x','x') 187 | return T.nnet.softplus(h) 188 | elif which == 'meanonlybatchnorm_relu': 189 | h -= h.mean(axis=(0,2,3), keepdims=True) 190 | h += w[name+'_b'].dimshuffle('x',0,'x','x') 191 | return T.nnet.relu(h) 192 | else: 193 | raise Exception("Unrecognized nonlinearity: "+which) 194 | 195 | 196 | return G.Struct(__call__=f, w=w) 197 | 198 | 199 | # n_in is an int 200 | # n_h is a list of ints 201 | # n_out is an int or list of ints 202 | # nl_h: nonlinearity of hidden units 203 | # nl_out: nonlinearity of output 204 | def mlp_l2(name, n_in, n_h, n_out, nl_h, nl_out=None, nl_in=None, w={}): 205 | 206 | if not isinstance(n_out, list) and isinstance(n_out, int): 207 | n_out = [n_out] 208 | 209 | # parameters for input perturbation 210 | if nl_in != None: 211 | f_nl_in = N.nonlinearity(name+'_in_nl', nl_in, (n_in,), w) 212 | 213 | # parameters for hidden units 214 | nh = [n_in]+n_h 215 | linear_h = [] 216 | f_nl_h = [] 217 | for i in range(len(n_h)): 218 | s = name+'_'+str(i) 219 | linear_h.append(N.linear_l2(s, nh[i], nh[i+1], w)) 220 | f_nl_h.append(N.nonlinearity(s+'_nl', nl_h, (nh[i+1],), w)) 221 | 222 | # parameters for output 223 | f_nl_out = [] 224 | linear_out = [] 225 | for i in range(len(n_out)): 226 | s = name+'_out_'+str(i) 227 | linear_out.append(N.linear_l2(s, n_h[-1], n_out[i], w)) 228 | f_nl_out.append(N.nonlinearity(s+'nl', nl_out, (n_out[i],), w)) 229 | 230 | def f(h, w, return_hiddens=False): 231 | 232 | if nl_in != None: 233 | h = f_nl_in(h, w) 234 | 235 | hiddens = [] 236 | for i in range(len(n_h)): 237 | h = linear_h[i](h, w) 238 | h = f_nl_h[i](h, w) 239 | hiddens.append(h) 240 | 241 | out = [] 242 | for i in range(len(n_out)): 243 | _out = linear_out[i](h, w) 244 | _out = f_nl_out[i](_out, w) 245 | out.append(_out) 246 | 247 | if len(n_out) == 1: out = out[0] 248 | 249 | if return_hiddens: 250 | return hiddens, out 251 | 252 | return out 253 | 254 | def postup(updates, w): 255 | for l in linear_h: updates = l.postup(updates, w) 256 | for l in linear_out: updates = l.postup(updates, w) 257 | return updates 258 | 259 | return G.Struct(__call__=f, w=w, postup=postup) 260 | 261 | -------------------------------------------------------------------------------- /graphy/nodes/ar.py: -------------------------------------------------------------------------------- 1 | import theano.tensor as T 2 | import numpy as np 3 | import graphy as G 4 | import graphy.nodes as N 5 | import graphy.nodes.conv 6 | 7 | # hyperparams 8 | logscale = True #Really works better! 9 | logscale_scale = 3. 10 | init_stdev = .1 11 | maxweight = 0. 12 | bn = False #mean-only batchnorm 13 | do_constant_rescale = False 14 | 15 | # auto-regressive linear layer 16 | def linear(name, n_in, n_out, diagonalzeros, l2norm=True, w={}): 17 | assert n_in % n_out == 0 or n_out % n_in == 0 18 | 19 | mask = np.ones((n_in, n_out),dtype=G.floatX) 20 | if n_out >= n_in: 21 | k = n_out / n_in 22 | for i in range(n_in): 23 | mask[i+1:,i*k:(i+1)*k] = 0 24 | if diagonalzeros: 25 | mask[i:i+1,i*k:(i+1)*k] = 0 26 | else: 27 | k = n_in / n_out 28 | for i in range(n_out): 29 | mask[(i+1)*k:,i:i+1] = 0 30 | if diagonalzeros: 31 | mask[i*k:(i+1)*k:,i:i+1] = 0 32 | 33 | # L2 normalization of weights 34 | def l2normalize(_w, axis=0): 35 | if diagonalzeros: 36 | # to prevent NaN gradients 37 | # TODO: smarter solution (also see below) 38 | if n_out >= n_in: 39 | _w = T.set_subtensor(_w[:,:n_out/n_in], 0.) 40 | else: 41 | _w = T.set_subtensor(_w[:,:1], 0.) 42 | targetnorm = 1. 43 | norm = T.sqrt((_w**2).sum(axis=axis, keepdims=True)) 44 | norm += 1e-8 45 | new_w = _w * (targetnorm / norm) 46 | return new_w 47 | def maxconstraint(_w): 48 | return _w * (maxweight / T.maximum(maxweight, abs(_w).max(axis=0, keepdims=True))) 49 | 50 | w[name+'_w'] = G.sharedf(mask * 0.05 * np.random.randn(n_in, n_out)) 51 | if maxweight > 0: 52 | w[name+'_w'].set_value(maxconstraint(w[name+'_w']).tag.test_value) 53 | 54 | w[name+'_b'] = G.sharedf(np.zeros((n_out,))) 55 | if l2norm: 56 | if logscale: 57 | w[name+'_s'] = G.sharedf(np.zeros((n_out,))) 58 | else: 59 | w[name+'_s'] = G.sharedf(np.ones((n_out,))) 60 | elif do_constant_rescale: 61 | print 'WARNING: constant rescale, these weights arent saved' 62 | constant_rescale = G.sharedf(np.zeros((n_out,))) 63 | 64 | 65 | def f(h, w): 66 | _input = h 67 | _w = mask * w[name+'_w'] 68 | if l2norm: 69 | _w = l2normalize(_w) 70 | h = T.dot(h, _w) 71 | if l2norm: 72 | if logscale: 73 | h *= T.exp(logscale_scale*w[name+'_s']) 74 | else: 75 | h *= abs(w[name+'_s']) 76 | elif do_constant_rescale: 77 | h *= T.exp(constant_rescale) 78 | 79 | h += w[name+'_b'] 80 | 81 | if '__init' in w: 82 | # Std 83 | std = (1./init_stdev) * h.std(axis=0) 84 | std += (std <= 0) 85 | std += 1e-8 86 | if name+'_s' in w: 87 | if logscale: 88 | w[name+'_s'].set_value(-T.log(std).tag.test_value/logscale_scale) 89 | else: 90 | w[name+'_s'].set_value((1./std).tag.test_value) 91 | elif do_constant_rescale: 92 | constant_rescale.set_value(-T.log(std).tag.test_value) 93 | #w[name+'_w'].set_value((_w / std.dimshuffle('x',0)).tag.test_value) 94 | 95 | h /= std.dimshuffle('x',0) 96 | 97 | # Mean 98 | mean = h.mean(axis=0) 99 | w[name+'_b'].set_value(-mean.tag.test_value) 100 | h -= mean.dimshuffle('x',0) 101 | 102 | #print name, w[name+'_w'].get_value().mean(), w[name+'_w'].get_value().std(), w[name+'_w'].get_value().max() 103 | 104 | #print name, abs(h).max().tag.test_value, abs(h).min().tag.test_value 105 | #h = T.printing.Print(name)(h) 106 | 107 | return h 108 | 109 | # Post updates: normalize weights to unit L2 norm 110 | def postup(updates, w): 111 | updates[w[name+'_w']] = mask * updates[w[name+'_w']] 112 | if l2norm and maxweight>0.: 113 | updates[w[name+'_w']] = maxconstraint(updates[w[name+'_w']]) 114 | return updates 115 | 116 | return G.Struct(__call__=f, postup=postup, w=w) 117 | 118 | # Auto-Regressive MLP with l2 normalization 119 | # n_in is an int 120 | # n_h is a list of ints 121 | # n_out is an int or list of ints 122 | # nl_h: nonlinearity of hidden units 123 | def mlp(name, n_in, n_context, n_h, n_out, nl, w={}): 124 | 125 | if not isinstance(n_out, list) and isinstance(n_out, int): 126 | n_out = [n_out] 127 | 128 | if n_context > 0: 129 | # parameters for context input 130 | linear_context = N.linear_l2(name+'_context', n_context, n_h[0], w) 131 | 132 | # parameters for hidden units 133 | nh = [n_in]+n_h 134 | linear_h = [] 135 | f_nl_h = [] 136 | for i in range(len(n_h)): 137 | s = name+'_'+str(i) 138 | linear_h.append(linear(s, nh[i], nh[i+1], False, True, w)) 139 | f_nl_h.append(N.nonlinearity(s+'_nl', nl, (nh[i+1],), w)) 140 | 141 | # parameters for output 142 | linear_out = [] 143 | for i in range(len(n_out)): 144 | s = name+'_out_'+str(i) 145 | linear_out.append(linear(s, n_h[-1], n_out[i], True, True, w)) 146 | 147 | def f(h, h_context, w, return_hiddens=False): 148 | # h_context can be None if n_context == 0 149 | 150 | hiddens = [] 151 | for i in range(len(n_h)): 152 | h = linear_h[i](h, w) 153 | if i == 0 and n_context > 0: 154 | h += linear_context(h_context, w) 155 | h = f_nl_h[i](h, w) 156 | hiddens.append(h) 157 | 158 | out = [] 159 | for i in range(len(n_out)): 160 | _out = linear_out[i](h, w) 161 | out.append(_out) 162 | 163 | if len(n_out) == 1: out = out[0] 164 | 165 | if return_hiddens: 166 | return hiddens, out 167 | 168 | return out 169 | 170 | def postup(updates, w): 171 | if n_context > 0: 172 | updates = linear_context.postup(updates, w) 173 | for l in linear_h: updates = l.postup(updates, w) 174 | for l in linear_out: updates = l.postup(updates, w) 175 | return updates 176 | 177 | return G.Struct(__call__=f, w=w, postup=postup) 178 | 179 | 180 | def msconv2d(name, n_scales, n_in, n_out, size_kernel=(3,3), zerodiagonal=True, flipmask=False, pad_channel=True, border_mode='valid', w={}): 181 | convs = [conv2d(name+"_s"+str(i), n_in, n_out, size_kernel, zerodiagonal, flipmask, pad_channel, border_mode, w) for i in range(n_scales)] 182 | def f(h, w): 183 | results = [] 184 | for i in range(n_scales-1): 185 | results.append(convs[i](h, w)) 186 | h = N.conv.downsample2d_nearest_neighbour(h, scale=2) 187 | result = convs[-1](h, w) 188 | for i in range(n_scales-1): 189 | result = N.conv.upsample2d_nearest_neighbour(result) 190 | result += results[-1-i] 191 | return result 192 | 193 | def postup(updates, w): 194 | for conv in convs: 195 | updates = conv.postup(updates, w) 196 | return updates 197 | 198 | return G.Struct(__call__=f, w=w, postup=postup) 199 | 200 | def conv2d(name, n_in, n_out, size_kernel=(3,3), zerodiagonal=True, flipmask=False, pad_channel=True, border_mode='valid', zeroinit=False, l2norm=True, w={}): 201 | 202 | do_scale = False 203 | if zeroinit: 204 | l2norm = False 205 | do_scale = True 206 | 207 | if not pad_channel: 208 | border_mode = 'same' 209 | print 'No pad_channel, changing border_mode to same' 210 | 211 | #if 'whitener' not in name: 212 | # pad_channel = False 213 | # border_mode = 'same' 214 | 215 | if '[sharedw]' in name and '[/sharedw]' in name: 216 | name_w = name 217 | pre, b = name.split("[sharedw]") 218 | c, post = b.split("[/sharedw]") 219 | name_w = pre+"[s]"+post 220 | name = pre+c+post # Don't share the bias and scales 221 | #name = name_w # Also share the bias and scales 222 | else: 223 | name_w = name 224 | 225 | assert border_mode in ['valid','full','same'] 226 | 227 | _n_in = n_in 228 | 229 | if pad_channel: 230 | if size_kernel[0] > 1 or size_kernel[1] > 1: 231 | assert size_kernel[0] == size_kernel[1] 232 | assert border_mode == 'valid' 233 | _n_in += 1 234 | else: 235 | pad_channel = False 236 | 237 | if border_mode == 'same': 238 | assert size_kernel[0]%2 == 1 239 | border_mode = ((size_kernel[0]-1)/2,(size_kernel[1]-1)/2) 240 | 241 | if True: 242 | # Build autoregressive mask 243 | l = (size_kernel[0]-1)/2 244 | m = (size_kernel[1]-1)/2 245 | mask = np.ones((n_out, _n_in, size_kernel[0], size_kernel[1]),dtype=G.floatX) 246 | mask[:,:,:l,:] = 0 247 | mask[:,:,l,:m] = 0 248 | 249 | if n_out >= n_in: 250 | assert n_out%n_in == 0 251 | k = n_out / n_in 252 | for i in range(n_in): 253 | mask[i*k:(i+1)*k,i+1:,l,m] = 0 254 | if zerodiagonal: 255 | mask[i*k:(i+1)*k,i:i+1,l,m] = 0 256 | else: 257 | assert n_in%n_out == 0 258 | k = n_in / n_out 259 | for i in range(n_out): 260 | mask[i:i+1,(i+1)*k:,l,m] = 0 261 | if zerodiagonal: 262 | mask[i:i+1,i*k:(i+1)*k:,l,m] = 0 263 | if flipmask: 264 | mask = mask[::-1,::-1,::-1,::-1] 265 | 266 | 267 | def l2normalize(kerns): 268 | if zerodiagonal: 269 | # to prevent NaN gradients 270 | # TODO: smarter solution (also see below) 271 | l = (size_kernel[0]-1)/2 272 | m = (size_kernel[1]-1)/2 273 | if n_out >= n_in: 274 | kerns = T.set_subtensor(kerns[:n_out/n_in,:,l,m], 0.) 275 | else: 276 | kerns = T.set_subtensor(kerns[:1,:,l,m], 0.) 277 | 278 | targetnorm = 1. 279 | norm = T.sqrt((kerns**2).sum(axis=(1,2,3), keepdims=True)) 280 | norm += 1e-8 281 | return kerns * (targetnorm / norm) 282 | def maxconstraint(kerns): 283 | return kerns * (maxweight / T.maximum(maxweight, abs(kerns).max(axis=(1,2,3), keepdims=True))) 284 | 285 | if zeroinit: 286 | w[name_w+'_w'] = G.sharedf(np.zeros((n_out, _n_in, size_kernel[0], size_kernel[1]))) 287 | else: 288 | w[name_w+'_w'] = G.sharedf(mask * 0.05*np.random.randn(n_out, _n_in, size_kernel[0], size_kernel[1])) 289 | if maxweight > 0: 290 | w[name_w+'_w'].set_value(maxconstraint(w[name_w+'_w']).tag.test_value) 291 | 292 | w[name+'_b'] = G.sharedf(np.zeros((n_out,))) 293 | 294 | if l2norm or do_scale: 295 | if logscale: 296 | w[name+'_s'] = G.sharedf(np.zeros((n_out,))) 297 | else: 298 | w[name+'_s'] = G.sharedf(np.ones((n_out,))) 299 | elif do_constant_rescale: 300 | print 'WARNING: constant rescale, these weights arent saved' 301 | constant_rescale = G.sharedf(np.ones((n_out,))) 302 | 303 | 304 | def f(h, w): 305 | input_shape = h.tag.test_value.shape[1:] 306 | 307 | _input = h 308 | 309 | if pad_channel: 310 | h = N.conv.pad2dwithchannel(h, size_kernel) 311 | 312 | kerns = mask * w[name_w+'_w'] 313 | if l2norm: 314 | kerns = l2normalize(kerns) 315 | if l2norm or do_scale: 316 | if logscale: 317 | kerns *= T.exp(logscale_scale*w[name+'_s']).dimshuffle(0,'x','x','x') 318 | else: 319 | kerns *= w[name+'_s'].dimshuffle(0,'x','x','x') 320 | elif do_constant_rescale: 321 | kerns *= constant_rescale.dimshuffle(0,'x','x','x') 322 | 323 | h = N.conv.dnn_conv(h, kerns, border_mode=border_mode) 324 | 325 | # Center 326 | if bn: # mean-only batch norm 327 | h -= h.mean(axis=(0,2,3), keepdims=True) 328 | 329 | h += w[name+'_b'].dimshuffle('x',0,'x','x') 330 | 331 | if '__init' in w and not zeroinit: 332 | 333 | # Std 334 | data_std = h.std(axis=(0,2,3)) 335 | num_zeros = (data_std.tag.test_value == 0).sum() 336 | if num_zeros > 0: 337 | print "Warning: Stdev=0 for "+str(num_zeros)+" features in "+name+". Skipping data-dependent init." 338 | else: 339 | if name+'_s' in w: 340 | if logscale: 341 | w[name+'_s'].set_value(-T.log(data_std).tag.test_value/logscale_scale) 342 | else: 343 | w[name+'_s'].set_value((1./data_std).tag.test_value) 344 | elif do_constant_rescale: 345 | constant_rescale.set_value((1./data_std).tag.test_value) 346 | #w[name+'_w'].set_value((kerns / std.dimshuffle(0,'x','x','x')).tag.test_value) 347 | 348 | h /= data_std.dimshuffle('x',0,'x','x') 349 | 350 | # Mean 351 | mean = h.mean(axis=(0,2,3)) 352 | w[name+'_b'].set_value(-mean.tag.test_value) 353 | h -= mean.dimshuffle('x',0,'x','x') 354 | 355 | #print name, w[name+'_w'].get_value().mean(), w[name+'_w'].get_value().std(), w[name+'_w'].get_value().max() 356 | 357 | if not '__init' in w: 358 | output_shape = h.tag.test_value.shape[1:] 359 | print 'ar.conv2d', name, input_shape, output_shape, size_kernel, zerodiagonal, flipmask, pad_channel, border_mode, zeroinit, l2norm 360 | 361 | #print name, abs(h).max().tag.test_value, abs(h).min().tag.test_value 362 | #h = T.printing.Print(name)(h) 363 | 364 | return h 365 | 366 | # Normalize weights to _norm L2 norm 367 | # TODO: check whether only_upper_bounds here really helps 368 | # (the effect is a higher learning rate in the beginning of training) 369 | def postup(updates, w): 370 | updates[w[name_w+'_w']] = mask * updates[w[name_w+'_w']] 371 | if l2norm and maxweight>0.: 372 | updates[w[name_w+'_w']] = maxconstraint(updates[w[name_w+'_w']]) 373 | return updates 374 | 375 | return G.Struct(__call__=f, w=w, postup=postup) 376 | 377 | # Auto-Regressive convnet with l2 normalization 378 | def multiconv2d(name, n_in, n_h, n_out, size_kernel, flipmask, nl='relu', w={}): 379 | 380 | if not isinstance(n_out, list) and isinstance(n_out, int): 381 | n_out = [n_out] 382 | 383 | # parameters for hidden units 384 | sizes = [n_in]+n_h 385 | conv_h = [] 386 | f_nl_h = [] 387 | for i in range(len(n_h)): 388 | conv_h.append(conv2d(name+'_'+str(i), sizes[i], sizes[i+1], size_kernel, False, flipmask, w=w)) 389 | f_nl_h.append(N.nonlinearity(name+'_'+str(i)+'_nl', nl, sizes[i+1], w=w)) 390 | 391 | # parameters for output 392 | conv_out = [] 393 | for i in range(len(n_out)): 394 | conv_out.append(conv2d(name+'_out_'+str(i), sizes[-1], n_out[i], size_kernel, True, flipmask, w=w)) 395 | 396 | def f(h, context, w, return_hiddens=False): 397 | # h_context can be None if n_context == 0 398 | 399 | hiddens = [] 400 | for i in range(len(n_h)): 401 | h = conv_h[i](h, w) # + context 402 | if i == 0: h += context 403 | h = f_nl_h[i](h, w) 404 | hiddens.append(h) 405 | 406 | out = [] 407 | for i in range(len(n_out)): 408 | _out = conv_out[i](h, w) 409 | out.append(_out) 410 | 411 | if len(n_out) == 1: out = out[0] 412 | 413 | if return_hiddens: 414 | return hiddens, out 415 | 416 | return out 417 | 418 | def postup(updates, w): 419 | for l in conv_h: updates = l.postup(updates, w) 420 | for l in conv_out: updates = l.postup(updates, w) 421 | return updates 422 | 423 | return G.Struct(__call__=f, w=w, postup=postup) 424 | 425 | 426 | 427 | # ResNet V3 layer 428 | def resnet_layer_a(name, n_feats, nl='elu', w={}): 429 | 430 | f_nl1 = N.nonlinearity(name+"_nl1", nl) 431 | f_nl2 = N.nonlinearity(name+"_nl2", nl) 432 | 433 | # either no change in shape, or subsampling 434 | conv1 = conv2d(name+'_conv1', n_feats, n_feats, (3,3), zerodiagonal=False, w=w) 435 | conv2 = conv2d(name+'_conv2', n_feats, n_feats, (3,3), zerodiagonal=False, w=w) 436 | 437 | def f(_input, w): 438 | h = f_nl1(_input) 439 | h = f_nl2(conv1(h, w)) 440 | h = conv2(h, w) 441 | return _input + .1 * h 442 | 443 | def postup(updates, w): 444 | updates = conv1.postup(updates, w) 445 | updates = conv2.postup(updates, w) 446 | return updates 447 | 448 | return G.Struct(__call__=f, w=w, postup=postup) 449 | 450 | # ResNet V3 layer 451 | def resnet_layer_b(name, n_feats, factor=4, nl='elu', w={}): 452 | 453 | f_nl1 = N.nonlinearity(name+"_nl1", nl) 454 | f_nl2 = N.nonlinearity(name+"_nl2", nl) 455 | f_nl3 = N.nonlinearity(name+"_nl3", nl) 456 | 457 | # either no change in shape, or subsampling 458 | conv1 = conv2d(name+'_conv1', n_feats, n_feats/factor, (1,1), zerodiagonal=False, w=w) 459 | conv2 = conv2d(name+'_conv2', n_feats/factor, n_feats/factor, (3,3), zerodiagonal=False, w=w) 460 | conv3 = conv2d(name+'_conv3', n_feats/factor, n_feats, (1,1), zerodiagonal=False, w=w) 461 | 462 | def f(_input, w): 463 | h = f_nl1(_input) 464 | h = f_nl2(conv1(h, w)) 465 | h = f_nl3(conv2(h, w)) 466 | h = conv3(h, w) 467 | return _input + .1 * h 468 | 469 | def postup(updates, w): 470 | updates = conv1.postup(updates, w) 471 | updates = conv2.postup(updates, w) 472 | updates = conv3.postup(updates, w) 473 | return updates 474 | 475 | return G.Struct(__call__=f, w=w, postup=postup) 476 | 477 | # Auto-Regressive convnet with l2 normalization 478 | def resnet(name, depth, n_in, n_h, n_out, size_kernel=(3,3), flipmask=False, nl='elu', layertype='a', factor=4, weightsharing=False, w={}): 479 | 480 | if not isinstance(n_out, list) and isinstance(n_out, int): 481 | n_out = [n_out] 482 | 483 | conv_input = conv2d(name+'_input', n_in, n_h, size_kernel, False, flipmask, w=w) 484 | 485 | # parameters for hidden units 486 | resnet = [] 487 | for i in range(depth): 488 | _name = name+'_'+str(i) 489 | if weightsharing: 490 | _name = name+'[sharedw]_'+str(i)+'[/sharedw]' 491 | if layertype == 'a': 492 | resnet.append(resnet_layer_a(_name, n_h, nl, w)) 493 | elif layertype == 'b': 494 | resnet.append(resnet_layer_b(_name, n_h, factor, nl, w)) 495 | else: raise Exception() 496 | 497 | # parameters for output 498 | conv_out = [conv2d(name+'_out_'+str(i), n_h, n_out[i], size_kernel, True, flipmask, w=w) for i in range(len(n_out))] 499 | 500 | def f(h, h_context, w, return_hiddens=False): 501 | 502 | h = conv_input(h, w) 503 | if h_context != None: 504 | h += h_context 505 | 506 | hiddens = [] 507 | for i in range(len(resnet)): 508 | h = resnet[i](h, w) 509 | hiddens.append(h) 510 | 511 | out = [] 512 | for i in range(len(n_out)): 513 | _out = conv_out[i](h, w) 514 | out.append(_out) 515 | 516 | if len(n_out) == 1: out = out[0] 517 | 518 | if return_hiddens: 519 | return hiddens, out 520 | 521 | return out 522 | 523 | def postup(updates, w): 524 | for l in resnet: updates = l.postup(updates, w) 525 | for l in conv_out: updates = l.postup(updates, w) 526 | return updates 527 | 528 | return G.Struct(__call__=f, w=w, postup=postup) 529 | 530 | -------------------------------------------------------------------------------- /graphy/nodes/conv.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | Convolutional functions 4 | ''' 5 | import numpy as np 6 | import theano 7 | import theano.tensor as T 8 | if 'gpu' in theano.config.device: # @UndefinedVariable 9 | from theano.sandbox.cuda.dnn import dnn_conv 10 | from theano.sandbox.cuda.dnn import dnn_pool 11 | elif 'cuda' in theano.config.device: # @UndefinedVariable 12 | from theano.sandbox.gpuarray.dnn import dnn_conv 13 | from theano.sandbox.gpuarray.dnn import dnn_pool 14 | else: raise Exception() 15 | import graphy as G 16 | import graphy.nodes as N 17 | 18 | # hyperparams 19 | logscale = True #Really works better! 20 | bias_logscale = False 21 | logscale_scale = 3. 22 | init_stdev = .1 23 | maxweight = 0. 24 | bn = False #mean-only batchnorm 25 | 26 | # General de-pooling inspired by Jascha Sohl-Dickstein's code 27 | # Divides n_features by factor**2, multiplies width/height factor 28 | def depool2d_split(x, factor=2): 29 | assert factor >= 1 30 | if factor == 1: return x 31 | #assert x.shape[1] >= 4 and x.shape[1]%4 == 0 32 | x = x.reshape((x.shape[0], x.shape[1]/factor**2, factor, factor, x.shape[2], x.shape[3])) 33 | x = x.dimshuffle(0, 1, 4, 2, 5, 3) 34 | x = x.reshape((x.shape[0], x.shape[1], x.shape[2]*x.shape[3], x.shape[4]*x.shape[5])) 35 | return x 36 | 37 | # General nearest-neighbour downsampling inspired by Jascha Sohl-Dickstein's code 38 | def downsample2d_nearest_neighbour(x, scale=2): 39 | x = x.reshape((x.shape[0], x.shape[1], x.shape[2]/scale, scale, x.shape[3]/scale, scale)) 40 | x = T.mean(x, axis=5) 41 | x = T.mean(x, axis=3) 42 | return x 43 | 44 | # 2X nearest-neighbour upsampling, also inspired by Jascha Sohl-Dickstein's code 45 | def upsample2d_nearest_neighbour(x): 46 | shape = x.shape 47 | x = x.reshape((shape[0], shape[1], shape[2], 1, shape[3], 1)) 48 | x = T.concatenate((x, x), axis=5) 49 | x = T.concatenate((x, x), axis=3) 50 | x = x.reshape((shape[0], shape[1], shape[2]*2, shape[3]*2)) 51 | return x 52 | 53 | # 2X nearest-neighbour upsampling, also inspired by Jascha Sohl-Dickstein's code 54 | def upsample2d_perforated(x): 55 | shape = x.shape 56 | x = x.reshape((shape[0], shape[1], shape[2], 1, shape[3], 1)) 57 | y = T.zeros((shape[0], shape[1], shape[2], 2, shape[3], 2),dtype=G.floatX) 58 | x = T.set_subtensor(y[:,:,:,0:1,:,0:1], x) 59 | x = x.reshape((shape[0], shape[1], shape[2]*2, shape[3]*2)) 60 | return x 61 | 62 | # Pad input 63 | def pad2d(x, n_padding): 64 | result_shape = (x.shape[0],x.shape[1],x.shape[2]+2*n_padding,x.shape[3]+2*n_padding) 65 | result = T.zeros(result_shape, dtype=G.floatX) 66 | result = T.set_subtensor(result[:,:,n_padding:-n_padding,n_padding:-n_padding], x) 67 | return result 68 | 69 | 70 | # Pad input, add extra channel 71 | def pad2dwithchannel(x, size_kernel): 72 | assert size_kernel[0]>1 or size_kernel[1]>1 73 | assert size_kernel[0]%2 == 1 74 | assert size_kernel[1]%2 == 1 75 | a = (size_kernel[0]-1)/2 76 | b = (size_kernel[1]-1)/2 77 | if True: 78 | n_channels = x.shape[1] 79 | result_shape = (x.shape[0],x.shape[1]+1,x.shape[2]+2*a,x.shape[3]+2*b) 80 | result = T.zeros(result_shape, dtype=G.floatX) 81 | result = T.set_subtensor(result[:,n_channels,:,:], 1.) 82 | result = T.set_subtensor(result[:,n_channels,a:-a,b:-b], 0.) 83 | result = T.set_subtensor(result[:,:n_channels,a:-a,b:-b], x) 84 | else: 85 | # new code, requires that the minibatch size 'x.tag.test_value.shape[0]' is the same during execution 86 | # I thought this would be more memory-efficient, but seems not the case in practice 87 | print 'new code, requires that the minibatch size "x.tag.test_value.shape[0]" is the same during execution' 88 | x_shape = x.tag.test_value.shape 89 | n_channels = x_shape[1] 90 | result_shape = (x_shape[0],x_shape[1]+1,x_shape[2]+2*a,x_shape[3]+2*b) 91 | result = np.zeros(result_shape,dtype=G.floatX) 92 | result[:,n_channels,:,:] = 1. 93 | result[:,n_channels,a:-a,b:-b] = 0. 94 | result = T.constant(result) 95 | result = T.set_subtensor(result[:,:n_channels,a:-a,b:-b], x) 96 | return result 97 | 98 | 99 | # Multi-scale conv 100 | def msconv2d(name, n_scales, n_in, n_out, size_kernel=(3,3), pad_channel=True, border_mode='valid', downsample=1, upsample=1, w={}): 101 | convs = [conv2d(name+"_s"+str(i), n_in, n_out, size_kernel, pad_channel, border_mode, downsample, upsample, w) for i in range(n_scales)] 102 | def f(h, w): 103 | results = [] 104 | for i in range(n_scales-1): 105 | results.append(convs[i](h, w)) 106 | h = N.conv.downsample2d_nearest_neighbour(h, scale=2) 107 | result = convs[-1](h, w) 108 | for i in range(n_scales-1): 109 | result = N.conv.upsample2d_nearest_neighbour(result) 110 | result += results[-1-i] 111 | return result 112 | 113 | def postup(updates, w): 114 | for conv in convs: 115 | updates = conv.postup(updates, w) 116 | return updates 117 | 118 | return G.Struct(__call__=f, w=w, postup=postup) 119 | 120 | # 2D conv with input bias 121 | # size_kernel = (n_rows, n_cols) 122 | def conv2d(name, n_in, n_out, size_kernel=(3,3), pad_channel=True, border_mode='valid', downsample=1, upsample=1, datainit=True, zeroinit=False, l2norm=True, w={}): 123 | 124 | # TODO FIX: blows up parameters if all inputs are 0 125 | 126 | if not pad_channel: 127 | border_mode = 'same' 128 | print 'No pad_channel, changing border_mode to same' 129 | 130 | if '[sharedw]' in name and '[/sharedw]' in name: 131 | name_w = name 132 | pre, b = name.split("[sharedw]") 133 | number, post = b.split("[/sharedw]") 134 | name_w = pre+"[s]"+post 135 | name = pre+number+post # Don't share the bias and scales 136 | #name = name_w # Also share the bias and scales 137 | else: 138 | name_w = name 139 | 140 | if type(downsample) == int: 141 | downsample = (downsample,downsample) 142 | assert type(downsample) == tuple 143 | assert border_mode in ['valid','full','same'] 144 | 145 | _n_in = n_in 146 | _n_out = n_out 147 | if upsample > 1: 148 | _n_out = n_out * upsample**2 149 | 150 | if pad_channel: 151 | if size_kernel[0] > 1 or size_kernel[1] > 1: 152 | assert size_kernel[0] == size_kernel[1] 153 | assert border_mode == 'valid' 154 | _n_in += 1 155 | else: 156 | pad_channel = False 157 | 158 | if border_mode == 'same': 159 | assert size_kernel[0]%2 == 1 160 | border_mode = ((size_kernel[0]-1)/2,(size_kernel[1]-1)/2) 161 | 162 | def l2normalize(kerns): 163 | norm = T.sqrt((kerns**2).sum(axis=(1,2,3), keepdims=True)) 164 | return kerns / norm 165 | def maxconstraint(kerns): 166 | return kerns * (maxweight / T.maximum(maxweight, abs(kerns).max(axis=(1,2,3), keepdims=True))) 167 | 168 | if zeroinit: 169 | w[name_w+'_w'] = G.sharedf(np.zeros((_n_out, _n_in, size_kernel[0], size_kernel[1]))) 170 | datainit = False 171 | else: 172 | w[name_w+'_w'] = G.sharedf(0.05*np.random.randn(_n_out, _n_in, size_kernel[0], size_kernel[1])) 173 | if maxweight > 0: 174 | w[name_w+'_w'].set_value(maxconstraint(w[name_w+'_w']).tag.test_value) 175 | 176 | w[name+'_b'] = G.sharedf(np.zeros((_n_out,))) 177 | if bias_logscale: 178 | w[name+'_bs'] = G.sharedf(0.) 179 | 180 | if l2norm: 181 | if logscale: 182 | w[name+'_s'] = G.sharedf(np.zeros((_n_out,))) 183 | else: 184 | w[name+'_s'] = G.sharedf(np.ones((_n_out,))) 185 | elif do_constant_rescale: 186 | print 'WARNING: constant rescale, these weights arent saved' 187 | constant_rescale = G.sharedf(np.ones((_n_out,))) 188 | 189 | 190 | def f(h, w): 191 | 192 | input_shape = h.tag.test_value.shape[1:] 193 | 194 | _input = h 195 | 196 | if pad_channel: 197 | h = pad2dwithchannel(h, size_kernel) 198 | 199 | kerns = w[name_w+'_w'] 200 | #if name == '1_down_conv1': 201 | # kerns = T.printing.Print('kerns 1')(kerns) 202 | if l2norm: 203 | kerns = l2normalize(kerns) 204 | if logscale: 205 | kerns *= T.exp(logscale_scale*w[name+'_s']).dimshuffle(0,'x','x','x') 206 | else: 207 | kerns *= w[name+'_s'].dimshuffle(0,'x','x','x') 208 | elif do_constant_rescale: 209 | kerns *= constant_rescale.dimshuffle(0,'x','x','x') 210 | 211 | #if name == '1_down_conv1': 212 | # kerns = T.printing.Print('kerns 2')(kerns) 213 | 214 | h = dnn_conv(h, kerns, border_mode=border_mode, subsample=downsample) 215 | 216 | # Mean-only batch norm 217 | if bn: 218 | h -= h.mean(axis=(0,2,3), keepdims=True) 219 | 220 | _b = w[name+'_b'].dimshuffle('x',0,'x','x') 221 | if bias_logscale: 222 | _b *= T.exp(logscale_scale * w[name+'_bs']) 223 | h += _b 224 | 225 | if '__init' in w and datainit: 226 | 227 | # Std 228 | data_std = h.std(axis=(0,2,3)) 229 | num_zeros = (data_std.tag.test_value == 0).sum() 230 | if num_zeros > 0: 231 | print "Warning: Stdev=0 for "+str(num_zeros)+" features in "+name+". Skipping data-dependent init." 232 | else: 233 | 234 | std = (1./init_stdev) * data_std 235 | std += 1e-7 236 | 237 | if name+'_s' in w: 238 | if logscale: 239 | w[name+'_s'].set_value(-T.log(std).tag.test_value/logscale_scale) 240 | else: 241 | w[name+'_s'].set_value((1./std).tag.test_value) 242 | elif do_constant_rescale: 243 | constant_rescale.set_value((1./std).tag.test_value) 244 | 245 | h /= std.dimshuffle('x',0,'x','x') 246 | 247 | # Mean 248 | mean = h.mean(axis=(0,2,3)) 249 | w[name+'_b'].set_value(-mean.tag.test_value) 250 | h -= mean.dimshuffle('x',0,'x','x') 251 | 252 | #print name, w[name+'_w'].get_value().mean(), w[name+'_w'].get_value().std(), w[name+'_w'].get_value().max() 253 | 254 | if upsample>1: 255 | h = depool2d_split(h, factor=upsample) 256 | 257 | if not '__init' in w: 258 | output_shape = h.tag.test_value.shape[1:] 259 | print 'conv2d', name, input_shape, output_shape, size_kernel, pad_channel, border_mode, downsample, upsample 260 | 261 | #print name, abs(h).max().tag.test_value, abs(h).min().tag.test_value 262 | #h = T.printing.Print(name)(h) 263 | 264 | return h 265 | 266 | # Normalize weights to _norm L2 norm 267 | # TODO: check whether only_upper_bounds here really helps 268 | # (the effect is a higher learning rate in the beginning of training) 269 | def postup(updates, w): 270 | if l2norm and maxweight>0.: 271 | updates[w[name_w+'_w']] = maxconstraint(updates[w[name_w+'_w']]) 272 | return updates 273 | 274 | return G.Struct(__call__=f, w=w, postup=postup) 275 | 276 | # ResNet layer 277 | def resnetv1_layer(name, n_in, n_out, size_kernel=(3,3), downsample=1, upsample=1, nl='relu', w={}): 278 | #print 'resnet_layer', name, shape_in, shape_out, size_kernel, downsample, upsample 279 | 280 | f_nl = N.nonlinearity(name+"_nl", nl) 281 | 282 | border_mode = 'valid' 283 | 284 | if upsample == 1: 285 | # either no change in shape, or subsampling 286 | conv1 = conv2d(name+'_conv1', n_in, n_out, size_kernel, True, border_mode, downsample, upsample, w=w) 287 | conv2 = conv2d(name+'_conv2', n_out, n_out, size_kernel, True, border_mode, downsample=1, upsample=1, w=w) 288 | conv3 = None 289 | if downsample>1 or upsample>1 or n_out != n_in: 290 | conv3 = conv2d(name+'_conv3', n_in, n_out, (downsample, downsample), None, 'valid', downsample, upsample, w=w) 291 | else: 292 | # upsampling 293 | assert downsample == 1 294 | conv1 = conv2d(name+'_conv1', n_in, n_in, size_kernel, True, border_mode, downsample=1, upsample=1, w=w) 295 | conv2 = conv2d(name+'_conv2', n_in, n_out, size_kernel, True, border_mode, downsample, upsample, w=w) 296 | conv3 = None 297 | if downsample>1 or upsample>1 or n_out != n_in: 298 | conv3 = conv2d(name+'_conv3', n_in, n_out, (downsample, downsample), None, 'valid', downsample, upsample, w=w) 299 | 300 | def f(_input, w): 301 | hidden = f_nl(conv1(_input, w)) 302 | _output = .1 * conv2(hidden, w) 303 | if conv3 != None: 304 | return T.nnet.relu(conv3(_input, w) + _output) 305 | return T.nnet.relu(_input + _output) 306 | 307 | def postup(updates, w): 308 | updates = conv1.postup(updates, w) 309 | updates = conv2.postup(updates, w) 310 | if conv3 != None: 311 | updates = conv3.postup(updates, w) 312 | return updates 313 | 314 | return G.Struct(__call__=f, w=w, postup=postup) 315 | 316 | # ResNet v1 with n_layers layers 317 | # Support sub/upsampling 318 | # In case of subsampling, first layer does subsampling (like in the ResNet paper) 319 | # In case of upsampling, the last layer does the upsampling (to make the net symmetrical) 320 | def resnetv1(name, n_layers, n_in, n_out, size_kernel=(3,3), downsample=1, upsample=1, nl='relu', w={}): 321 | layers = [] 322 | for i in range(n_layers): 323 | _n_in = n_in 324 | _n_out = n_out 325 | _downsample = downsample 326 | _upsample = upsample 327 | if _downsample > 1 and i > 0: 328 | _downsample = 1 329 | _n_in = n_out 330 | if _upsample > 1 and i < n_layers-1: 331 | _upsample = 1 332 | _n_out = n_in 333 | 334 | layer = resnetv1_layer(name+'_'+str(i), _n_in, _n_out, size_kernel, _downsample, _upsample, nl, w) 335 | layers.append(layer) 336 | 337 | def f(h, w): 338 | for i in range(n_layers): 339 | h = layers[i](h, w) 340 | return h 341 | 342 | def postup(updates, w): 343 | for i in range(n_layers): 344 | updates = layers[i].postup(updates, w) 345 | return updates 346 | 347 | return G.Struct(__call__=f, w=w, postup=postup) 348 | 349 | 350 | 351 | 352 | # ResNet V2 layer 353 | def resnetv2_layer_a(name, n_feats, nl='relu', w={}): 354 | 355 | f_nl = N.nonlinearity(name+"_nl", nl) 356 | 357 | # either no change in shape, or subsampling 358 | conv1 = conv2d(name+'_conv1', n_feats, n_feats, (3,3), w=w) 359 | conv2 = conv2d(name+'_conv2', n_feats, n_feats, (3,3), w=w) 360 | 361 | def f(_input, w): 362 | h = _input 363 | h = f_nl(conv1(h, w)) 364 | h = conv2(h, w) 365 | return T.nnet.relu(_input + .1 * h) 366 | 367 | def postup(updates, w): 368 | updates = conv1.postup(updates, w) 369 | updates = conv2.postup(updates, w) 370 | return updates 371 | 372 | return G.Struct(__call__=f, w=w, postup=postup) 373 | 374 | # ResNet V2 layer 375 | def resnetv2_layer_b(name, n_feats, factor=4, nl='relu', w={}): 376 | 377 | f_nl = N.nonlinearity(name+"_nl", nl) 378 | 379 | # either no change in shape, or subsampling 380 | conv1 = conv2d(name+'_conv1', n_feats, n_feats/factor, (1,1), w=w) 381 | conv2 = conv2d(name+'_conv2', n_feats/factor, n_feats/factor, (3,3), w=w) 382 | conv3 = conv2d(name+'_conv3', n_feats/factor, n_feats, (1,1), w=w) 383 | 384 | def f(_input, w): 385 | h = _input 386 | h = f_nl(conv1(h, w)) 387 | h = f_nl(conv2(h, w)) 388 | h = conv3(h, w) 389 | return T.nnet.relu(_input + .1 * h) 390 | 391 | def postup(updates, w): 392 | updates = conv1.postup(updates, w) 393 | updates = conv2.postup(updates, w) 394 | updates = conv3.postup(updates, w) 395 | return updates 396 | 397 | return G.Struct(__call__=f, w=w, postup=postup) 398 | 399 | 400 | # ResNet V2 with n_layers layers 401 | # V2: no sub/upsampling, not changing nr of features, bottleneck layer, fixed kernel size (1x1 and 3x3) 402 | def resnetv2(name, n_layers, n_feats, layertype='a', factor=4, nl='relu', w={}): 403 | 404 | layers = [] 405 | for i in range(n_layers): 406 | if layertype == 'a': 407 | layers.append(resnetv2_layer_a(name+'_'+str(i), n_feats, nl, w)) 408 | if layertype == 'b': 409 | layers.append(resnetv2_layer_b(name+'_'+str(i), n_feats, factor, nl, w)) 410 | 411 | def f(h, w): 412 | for i in range(n_layers): 413 | h = layers[i](h, w) 414 | return h 415 | 416 | def postup(updates, w): 417 | for i in range(n_layers): 418 | updates = layers[i].postup(updates, w) 419 | return updates 420 | 421 | return G.Struct(__call__=f, w=w, postup=postup) 422 | 423 | # ResNet V3 layer 424 | def resnetv3_layer_a(name, n_feats, nl='softplus', alpha=.1, w={}): 425 | 426 | f_nl1 = N.nonlinearity(name+"_nl1", nl) 427 | f_nl2 = N.nonlinearity(name+"_nl2", nl) 428 | 429 | # either no change in shape, or subsampling 430 | conv1 = conv2d(name+'_conv1', n_feats, n_feats, (3,3), w=w) 431 | conv2 = conv2d(name+'_conv2', n_feats, n_feats, (3,3), w=w) 432 | 433 | def f(_input, w): 434 | h = f_nl1(_input) 435 | h = f_nl2(conv1(h, w)) 436 | h = conv2(h, w) 437 | return _input + alpha * h 438 | 439 | def postup(updates, w): 440 | updates = conv1.postup(updates, w) 441 | updates = conv2.postup(updates, w) 442 | return updates 443 | 444 | return G.Struct(__call__=f, w=w, postup=postup) 445 | 446 | # ResNet V3 layer 447 | def resnetv3_layer_b(name, n_feats, factor=4, nl='softplus', alpha=.1, w={}): 448 | 449 | f_nl1 = N.nonlinearity(name+"_nl1", nl) 450 | f_nl2 = N.nonlinearity(name+"_nl2", nl) 451 | f_nl3 = N.nonlinearity(name+"_nl3", nl) 452 | 453 | # either no change in shape, or subsampling 454 | conv1 = conv2d(name+'_conv1', n_feats, n_feats/factor, (1,1), w=w) 455 | conv2 = conv2d(name+'_conv2', n_feats/factor, n_feats/factor, (3,3), w=w) 456 | conv3 = conv2d(name+'_conv3', n_feats/factor, n_feats, (1,1), w=w) 457 | 458 | def f(_input, w): 459 | h = f_nl1(_input) 460 | h = f_nl2(conv1(h, w)) 461 | h = f_nl3(conv2(h, w)) 462 | h = conv3(h, w) 463 | return _input + alpha * h 464 | 465 | def postup(updates, w): 466 | updates = conv1.postup(updates, w) 467 | updates = conv2.postup(updates, w) 468 | updates = conv3.postup(updates, w) 469 | return updates 470 | 471 | return G.Struct(__call__=f, w=w, postup=postup) 472 | 473 | # ResNet V3 with n_layers layers 474 | # V3: like V2, but nonlinearity applied in more logical manner: as first element of inner functions 475 | def resnetv3(name, n_layers, n_feats, nl='softplus', layertype='a', factor=4, w={}): 476 | 477 | layers = [] 478 | for i in range(n_layers): 479 | if layertype == 'a': 480 | layers.append(resnetv3_layer_a(name+'_'+str(i), n_feats, nl, .1/n_layers, w)) 481 | if layertype == 'b': 482 | layers.append(resnetv3_layer_b(name+'_'+str(i), n_feats, factor, nl, .1/n_layers, w)) 483 | 484 | def f(h, w): 485 | for i in range(n_layers): 486 | h = layers[i](h, w) 487 | return h 488 | 489 | def postup(updates, w): 490 | for i in range(n_layers): 491 | updates = layers[i].postup(updates, w) 492 | return updates 493 | 494 | return G.Struct(__call__=f, w=w, postup=postup) 495 | -------------------------------------------------------------------------------- /graphy/nodes/rand.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import graphy as G 4 | import math 5 | import numpy as np 6 | import collections 7 | 8 | def RandomVariable(sample, logp, entr, **params): 9 | return G.Struct(sample=sample, logp=logp, entr=entr, **params) 10 | 11 | # TODO: turn these random variables functions into constructors 12 | 13 | ''' 14 | Bernoulli variable 15 | ''' 16 | def bernoulli(p, sample=None): 17 | if sample is None: 18 | sample = G.rng.binomial(p=p, dtype=G.floatX) 19 | logp = - T.nnet.binary_crossentropy(p, sample).flatten(2).sum(axis=1) 20 | entr = - (p * T.log(p) + (1-p) * T.log(1-p)).flatten(2).sum(axis=1) 21 | return RandomVariable(sample, logp, entr, p=p) 22 | 23 | ''' 24 | Categorical variable 25 | p: matrix with probabilities (each row should sum to one) 26 | ''' 27 | def categorical(p, sample=None): 28 | if sample is None: 29 | sample = G.rng.multinomial(pvals=p, dtype='int32').argmax(axis=1) 30 | logp = - T.nnet.categorical_crossentropy(p, sample.flatten()) 31 | entr = - (p * T.log(p)).sum(axis=1) 32 | return G.Struct(**{'sample':sample,'logp':logp,'entr':entr,'p':p}) 33 | return RandomVariable(sample, logp, entr, p=p) 34 | 35 | ''' 36 | 4D Categorical variable 37 | ulogp4d: 4D tensor with unnormalized log-probabilities at 2nd dimension 38 | 1st dimension goes over datapoints 39 | 2nd dimension has size n_vars*n_categories 40 | 3rd and 4th dimensions are the spatial dimensions 41 | 42 | sample: 4D tensor of integer values (each from 0 to n_categories-1) 43 | ''' 44 | def categorical4d(ulogp4d, n_vars=3, n_classes=256, sample=None): 45 | shape4d = ulogp4d.shape 46 | ulogp_2d = ulogp4d.reshape((shape4d[0],n_vars,n_classes,shape4d[2],shape4d[3])) 47 | ulogp_2d = ulogp_2d.dimshuffle(0,1,3,4,2) 48 | ulogp_2d = ulogp_2d.reshape((shape4d[0]*n_vars*shape4d[2]*shape4d[3],n_classes)) 49 | p_2d = T.nnet.softmax(ulogp_2d) 50 | if sample is None: 51 | sample_1d = G.rng.multinomial(pvals=p_2d, dtype='int32').argmax(axis=1) 52 | sample = sample_1d.reshape((shape4d[0],n_vars,shape4d[2],shape4d[3])) 53 | logp = - T.nnet.categorical_crossentropy(p_2d, sample.flatten()) 54 | logp = logp.reshape((shape4d[0],n_vars*shape4d[2]*shape4d[3])).sum(axis=1) 55 | entr = - (p_2d * T.log(p_2d)).sum(axis=1) 56 | return RandomVariable(sample, logp, entr, ulogp4d=ulogp4d) 57 | 58 | 59 | ''' 60 | Uniform random variable 61 | [a, b]: domain 62 | b > a 63 | ''' 64 | def uniform(a, b, sample=None): 65 | logp = None 66 | if sample is None: 67 | sample = G.rng_curand.uniform(size=a.shape, low=a, high=b, dtype=G.floatX) 68 | # Warning: logp incorrect of y is outside of scope 69 | logp = -T.log(b-a).flatten(2).sum(axis=1) 70 | entr = T.log(b-a).flatten(2).sum(axis=1) 71 | return RandomVariable(sample, logp, entr, a=a, b=b) 72 | 73 | ''' 74 | Diagonal Gaussian variable 75 | mean: mean 76 | logvar: log-variance 77 | ''' 78 | def gaussian_diag(mean, logvar, sample=None): 79 | eps = None 80 | if sample is None: 81 | eps = G.rng_curand.normal(size=mean.shape) 82 | sample = mean + T.exp(.5*logvar) * eps 83 | logps = -.5 * (T.log(2*math.pi) + logvar + (sample - mean)**2 / T.exp(logvar)) 84 | logp = logps.flatten(2).sum(axis=1) 85 | entr = (.5 * (T.log(2 * math.pi) + 1 + logvar)).flatten(2).sum(axis=1) 86 | kl = lambda p_mean, p_logvar: (.5 * (p_logvar - logvar) + (T.exp(logvar) + (mean-p_mean)**2)/(2*T.exp(p_logvar)) - .5).flatten(2).sum(axis=1) 87 | return RandomVariable(sample, logp, entr, mean=mean, logvar=logvar, kl=kl, logps=logps, eps=eps) 88 | 89 | 90 | ''' 91 | Full-covariance Gaussian using a cholesky factor 92 | mean (2D tensor): mean 93 | logvar (2D tensor): log-variance 94 | chol (3D tensor): cholesky factor minus the diagonal (upper triangular, zeros on diagonal) 95 | ''' 96 | def gaussian_chol(mean, logvar, chol, sample=None): 97 | if sample != None: 98 | raise Exception('Not implemented') 99 | diag = gaussian_diag(mean, logvar) 100 | mask = T.shape_padleft(T.triu(T.ones_like(chol[0]), 1)) 101 | sample = diag.sample + T.batched_dot(diag.sample, chol * mask) 102 | return RandomVariable(sample, diag.logp, diag.entr, mean=mean, logvar=logvar) 103 | 104 | ''' 105 | Diagonal Gaussian variables 106 | n_batch: batchsize 107 | y: output 108 | n_y: fixed parameter indicating dimensionality of the Gaussian 109 | ''' 110 | def gaussian_spherical(shape=None, sample=None): 111 | if sample is None: 112 | sample = G.rng_curand.normal(shape) 113 | if shape is None: 114 | assert sample != None 115 | shape = sample.shape 116 | logp = -.5 * (T.log(2*math.pi) + sample**2).flatten(2).sum(axis=1) 117 | entr = (1.*T.prod(shape[1:]).astype(G.floatX)) * T.ones((shape[0],), dtype=G.floatX) * G.sharedf(.5 * (np.log(2.*math.pi)+1.)) 118 | return RandomVariable(sample, logp, entr, shape=shape) 119 | 120 | ''' 121 | Diagonal Laplace variable 122 | mean: mean 123 | scale: scale 124 | ''' 125 | def laplace_diag(mean, logscale, sample=None): 126 | scale = .5*T.exp(logscale) 127 | if sample is None: 128 | u = G.rng_curand.uniform(size=mean.shape) - .5 129 | sample = mean - scale * T.sgn(u) * T.log(1-2*abs(u)) 130 | logp = (- logscale - abs(sample-mean) / scale).flatten(2).sum(axis=1) 131 | entr = (1 + logscale).flatten(2).sum(axis=1) 132 | return RandomVariable(sample, logp, entr, mean=mean, scale=scale) 133 | 134 | 135 | ''' 136 | Logistic random variable 137 | ''' 138 | def logistic(mean, logscale, sample=None): 139 | scale = T.exp(logscale) 140 | if sample is None: 141 | u = G.rng_curand.uniform(size=mean.shape) 142 | _y = T.log(-u/(u-1)) #inverse CDF of the logistic 143 | sample = mean + scale * _y 144 | else: 145 | _y = -(sample-mean)/scale 146 | _logp = -_y - logscale - 2*T.nnet.softplus(-_y) 147 | logp = _logp.flatten(2).sum(axis=1) 148 | entr = logscale.flatten(2) 149 | entr = entr.sum(axis=1) + 2. * entr.shape[1] 150 | return RandomVariable(sample, logp, entr, mean=mean, logscale=logscale, _logp=_logp) 151 | 152 | ''' 153 | Rectified Logistic random variable 154 | ''' 155 | def rectlogistic(mean, logscale, sample=None): 156 | if sample is None: 157 | sample = T.maximum(0, logistic(mean, logscale).sample) 158 | mass0 = 1./(1+T.exp(-mean/T.exp(logscale))) 159 | logp = ((sample<=0) * mass0 + (sample>0) * logistic(mean, logscale, sample)._logp).flatten(2).sum(axis=1) 160 | entr = "Not implemented" 161 | return RandomVariable(sample, logp, entr, mean=mean, logscale=logscale) 162 | 163 | 164 | ''' 165 | Discretized Logistic variable 166 | mean: mean 167 | logscale: logscale 168 | ''' 169 | def discretized_logistic(mean, logscale, binsize, sample=None): 170 | scale = T.exp(logscale) 171 | if sample is None: 172 | u = G.rng_curand.uniform(size=mean.shape) 173 | _y = T.log(-u/(u-1)) #inverse CDF of the logistic 174 | sample = mean + scale * _y #sample from the actual logistic 175 | sample = T.floor(sample/binsize)*binsize #discretize the sample 176 | _sample = (T.floor(sample/binsize)*binsize - mean)/scale 177 | logps = T.log( T.nnet.sigmoid(_sample + binsize/scale) - T.nnet.sigmoid(_sample) + 1e-7) 178 | logp = logps.flatten(2).sum(axis=1) 179 | #raise Exception() 180 | entr = logscale.flatten(2) 181 | entr = entr.sum(axis=1) + 2. * entr.shape[1].astype(G.floatX) 182 | return RandomVariable(sample, logp, entr, mean=mean, logscale=logscale, logps=logps) 183 | 184 | ''' 185 | Discretized Gaussian variable 186 | mean: mean 187 | logscale: logscale 188 | ''' 189 | def discretized_gaussian(mean, logvar, binsize, sample=None): 190 | scale = T.exp(.5*logvar) 191 | if sample is None: 192 | _y = G.rng_curand.normal(size=mean.shape) 193 | sample = mean + scale * _y #sample from the actual logistic 194 | sample = T.floor(sample/binsize)*binsize #discretize the sample 195 | _sample = (T.floor(sample/binsize)*binsize - mean)/scale 196 | def _erf(x): 197 | return T.erf(x/T.sqrt(2.)) 198 | logp = T.log( _erf(_sample + binsize/scale) - _erf(_sample) + 1e-7) + T.log(.5) 199 | logp = logp.flatten(2).sum(axis=1) 200 | #raise Exception() 201 | entr = (.5 * (T.log(2 * math.pi) + 1 + logvar)).flatten(2).sum(axis=1) 202 | return RandomVariable(sample, logp, entr, mean=mean, logvar=logvar) 203 | 204 | 205 | ''' 206 | Discretized Laplace variable 207 | mean: mean 208 | scale: scale 209 | ''' 210 | def discretized_laplace(mean, logscale, binsize, sample=None): 211 | scale = .5*T.exp(logscale) 212 | if sample is None: 213 | u = G.rng_curand.uniform(size=mean.shape) - .5 214 | sample = mean - scale * T.sgn(u) * T.log(1-2*abs(u)) 215 | sample = T.floor(sample/binsize)*binsize #discretize the sample 216 | 217 | d = .5*binsize 218 | def cdf(x): 219 | z = x-mean 220 | return .5 + .5 * T.sgn(z) * (1.-T.exp(-abs(z)/scale)) 221 | def logmass1(x): 222 | # General method for probability mass, but numerically unstable for large |x-mean|/scale 223 | return T.log(cdf(x+d) - cdf(x-d) + 1e-7) 224 | def logmass2(x): 225 | # Only valid for |x-mean| >= d 226 | return -abs(x-mean)/scale + T.log(T.exp(d/scale)-T.exp(-d/scale)) - np.log(2.).astype(G.floatX) 227 | def logmass_stable(x): 228 | switch = (abs(x-mean) < d) 229 | return switch * logmass1(x) + (1-switch) * logmass2(x) 230 | 231 | logp = logmass_stable(sample).flatten(2).sum(axis=1) 232 | entr = None #(1 + logscale).flatten(2).sum(axis=1) 233 | return RandomVariable(sample, logp, entr, mean=mean, scale=scale) 234 | 235 | ''' 236 | Laplace 237 | NOT CONVERTED YET 238 | ''' 239 | def zero_centered_laplace(name, w={}): 240 | w[name+'_logscale'] = G.sharedf(0.) 241 | def logp(v, w): 242 | return -abs(v).sum()/T.exp(w[name+'_logscale']) - v.size.astype(G.floatX) * (T.log(2.) + w[name+'_logscale']) 243 | postup = lambda updates, w:updates 244 | return G.Struct(logp=logp, postup=postup, w=w) 245 | 246 | ''' 247 | Diagonal Gaussian variable 248 | mean: mean 249 | logvar: log-variance 250 | NOT CONVERTED YET 251 | ''' 252 | def zero_centered_gaussian(name, w={}): 253 | w[name+'_logvar'] = G.sharedf(0.) 254 | def logp(v, w): 255 | logvar = w[name+'_logvar']*10 256 | return v.size.astype(G.floatX) * -.5 * (T.log(2.*math.pi) + logvar) - .5 * (v**2).sum() / T.exp(logvar) 257 | postup = lambda updates, w:updates 258 | return G.Struct(logp=logp, postup=postup, w=w) 259 | 260 | 261 | ''' 262 | Gaussian Scale Mixture 263 | NOT CONVERTED YET 264 | ''' 265 | def gsm(name, k, w={}, logvar_minmax=16): 266 | w[name+'_weight'] = G.sharedf(np.zeros((k,))) 267 | w[name+'_logvar'] = G.sharedf(np.random.randn(k)*.1) 268 | def logp(v, w): 269 | mixtureweights = T.exp(w[name+'_weight']) 270 | mixtureweights /= mixtureweights.sum() 271 | logvar = logvar_minmax*w[name+'_logvar'] 272 | var = T.exp(logvar) 273 | if k == 0: 274 | return 0. 275 | if k == 1: 276 | return -.5*(v**2).sum()/var[0] - v.size.astype(G.floatX) * (.5*T.log(2.*math.pi) + logvar[0]) 277 | p = 0. 278 | for i in range(k): 279 | p += mixtureweights[i] * T.exp(-.5*v**2/var[i]) / T.sqrt(2.*math.pi*var[i]) 280 | logp = T.log(p).sum() 281 | return logp 282 | 283 | def postup(updates, w): 284 | updates[w[name+'_logvar']] = T.clip(updates[w[name+'_logvar']], -1., 1.) 285 | return updates 286 | 287 | return G.Struct(logp=logp, postup=postup, w=w) 288 | 289 | 290 | 291 | 292 | 293 | 294 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import graphy as G 2 | import graphy.nodes as N 3 | import graphy.nodes.rand 4 | import graphy.nodes.conv 5 | import graphy.nodes.ar 6 | import numpy as np 7 | import theano 8 | import theano.tensor as T 9 | from collections import OrderedDict 10 | from pyexpat import model 11 | 12 | floatX = theano.config.floatX # @UndefinedVariable 13 | 14 | # CVAE ResNet layer of deterministic and stochastic units 15 | def cvae_layer(name, prior, posterior, n_h1, n_h2, n_z, depth_ar, downsample, nl, kernel, weightsharing, downsample_type, w): 16 | 17 | if False: 18 | # New such that we can recognize variational params later 19 | name_q = name+'_q_' 20 | name_p = name+'_p_' 21 | else: 22 | name_q = name 23 | name_p = name 24 | 25 | n_conv_up1 = n_h2+2*n_z 26 | n_conv_up2 = n_h2+n_z 27 | 28 | n_conv_down_posterior = 0 29 | n_conv_down_prior = n_h2+2*n_z 30 | 31 | # Prior 32 | prior_conv1 = None 33 | 34 | if prior in ['diag','diag2']: 35 | n_conv_down_prior = n_h2+2*n_z 36 | elif prior == 'made': 37 | prior_conv1 = N.ar.multiconv2d(name_p+'_prior_conv1', n_z, depth_ar*[n_h2], [n_z,n_z], kernel, False, nl=nl, w=w) 38 | n_conv_down_prior = n_h2+n_h2 39 | elif prior == 'bernoulli': 40 | n_conv_down_prior = n_h2+n_z 41 | prior_conv1 = N.conv.conv2d(name_p+'_prior_conv1', n_z, n_z, kernel, w=w) 42 | else: 43 | raise Exception("Unknown prior") 44 | 45 | # Posterior 46 | posterior_conv1 = None 47 | posterior_conv2 = None 48 | posterior_conv3 = None 49 | posterior_conv4 = None 50 | 51 | if posterior == 'up_diag': 52 | pass 53 | elif posterior == 'up_iaf1': 54 | posterior_conv1 = N.ar.conv2d(name_q+'_posterior_conv1', n_z, n_z, kernel, w=w) 55 | elif posterior == 'up_iaf2': 56 | posterior_conv1 = N.ar.conv2d(name_q+'_posterior_conv1', n_z, 2*n_z, kernel, w=w) 57 | 58 | elif posterior == 'up_iaf1_nl': 59 | n_conv_up1 = n_h2+2*n_z+n_h2 60 | posterior_conv1 = N.ar.multiconv2d(name_q+'_posterior_conv1', n_z, depth_ar*[n_h2], n_z, kernel, False, nl=nl, w=w) 61 | elif posterior == 'up_iaf2_nl': 62 | n_conv_up1 = n_h2+2*n_z+n_h2 63 | posterior_conv1 = N.ar.multiconv2d(name_q+'_posterior_conv1', n_z, depth_ar*[n_h2], [n_z,n_z], kernel, False, nl=nl, w=w) 64 | 65 | # elif posterior == 'down_diag': 66 | # n_conv_down1 = n_h2+4*n_z 67 | elif posterior == 'down_diag': 68 | n_conv_up2 = n_h2 69 | n_conv_down_posterior = 2*n_z 70 | elif posterior == 'down_bernoulli': 71 | n_conv_up2 = n_h2 72 | n_conv_down_posterior = n_z 73 | elif posterior == 'down_tim': 74 | pass 75 | elif posterior == 'down_iaf1': 76 | n_conv_up2 = n_h2 77 | n_conv_down_posterior = 2*n_z 78 | posterior_conv1 = N.ar.conv2d(name_q+'_posterior_conv1', n_z, n_z, kernel, w=w) 79 | elif posterior == 'down_iaf2': 80 | n_conv_up2 = n_h2 81 | n_conv_down_posterior = 2*n_z 82 | posterior_conv1 = N.ar.conv2d(name_q+'_posterior_conv1', n_z, 2*n_z, kernel, w=w) 83 | elif posterior == 'down_iaf1_nl': 84 | n_conv_up1 = n_h2+2*n_z+n_h2 85 | n_conv_up2 = n_h2 86 | n_conv_down_posterior = 2*n_z+n_h2 87 | posterior_conv1 = N.ar.multiconv2d(name_q+'_posterior_conv1', n_z, depth_ar*[n_h2], n_z, kernel, False, nl=nl, w=w) 88 | elif posterior == 'down_iaf2_nl': 89 | n_conv_up1 = n_h2+2*n_z+n_h2 90 | n_conv_up2 = n_h2 91 | n_conv_down_posterior = 2*n_z+n_h2 92 | posterior_conv1 = N.ar.multiconv2d(name_q+'_posterior_conv1', n_z, depth_ar*[n_h2], [n_z,n_z], kernel, False, nl=nl, w=w) 93 | elif posterior == 'down_iaf2_nl2': 94 | n_conv_up1 = n_h2+2*n_z+n_h2 95 | n_conv_up2 = n_h2 96 | n_conv_down_posterior = 2*n_z+n_h2 97 | posterior_conv1 = N.ar.multiconv2d(name_q+'_posterior_conv1', n_z, depth_ar*[n_h2], [n_z,n_z], kernel, False, nl=nl, w=w) 98 | posterior_conv2 = N.ar.multiconv2d(name_q+'_posterior_conv2', n_z, depth_ar*[n_h2], [n_z,n_z], kernel, True, nl=nl, w=w) 99 | elif posterior == 'down_iaf1_deep': 100 | n_conv_up1 = n_h2+2*n_z+n_h2 101 | n_conv_up2 = n_h2 102 | n_conv_down_posterior = 2*n_z+n_h2 103 | posterior_conv1 = N.ar.resnet(name_q+'_deepiaf', depth_ar, n_z, n_h2, n_z, kernel, False, nl=nl, weightsharing=weightsharing, w=w) 104 | elif posterior == 'down_iaf2_deep': 105 | n_conv_up1 = n_h2+2*n_z+n_h2 106 | n_conv_up2 = n_h2 107 | n_conv_down_posterior = 2*n_z+n_h2 108 | posterior_conv1 = N.ar.resnet(name_q+'_deepiaf', depth_ar, n_z, n_h2, [n_z,n_z], kernel, False, nl=nl, weightsharing=weightsharing, w=w) 109 | 110 | #elif posterior == 'iaf_deep1': 111 | # extra1 = N.ar.resnet(name+'_posterior_2', depth_iaf, n_z, 2*n_h, n_h, n_z, (3,3), False, nl=nl, w=w) 112 | #elif posterior == 'iaf_deep2': 113 | # extra1 = N.ar.resnet(name+'_posterior_2', depth_iaf, n_z, 2*n_h, n_h, [n_z,n_z], (3,3), False, nl=nl, w=w) 114 | else: 115 | raise Exception("Unknown posterior "+posterior) 116 | 117 | ds = 1 118 | if downsample: 119 | ds = 2 120 | if downsample_type == 'conv': 121 | up_conv3 = N.conv.conv2d(name_q+'_up_conv3', n_h1, n_h1, kernel, downsample=ds, w=w) 122 | down_conv3 = N.conv.conv2d(name_q+'_down_conv3', n_h1, n_h1, kernel, upsample=ds, w=w) 123 | 124 | up_nl1 = N.nonlinearity(name_q+"_up_nl1", nl) 125 | up_conv1 = N.conv.conv2d(name_q+'_up_conv1_'+str(ds), n_h1, n_conv_up1, kernel, downsample=ds, w=w) 126 | up_nl2 = N.nonlinearity(name_q+"_nl_up2", nl) 127 | up_conv2 = N.conv.conv2d(name_q+'_up_conv2', n_conv_up2, n_h1, kernel, w=w) 128 | 129 | down_nl1 = N.nonlinearity(name_p+"_down_nl1", nl) 130 | down_conv1 = N.conv.conv2d(name_p+'_down_conv1', n_h1, n_conv_down_prior+n_conv_down_posterior, kernel, w=w) 131 | down_nl2 = N.nonlinearity(name_p+"_down_nl2", nl) 132 | down_conv2 = N.conv.conv2d(name_p+'_down_conv2_'+str(ds), n_h2+n_z, n_h1, kernel, upsample=ds, w=w) 133 | 134 | up_output = [None] 135 | qz = [None] 136 | up_context = [None] 137 | 138 | def up(input, w): 139 | 140 | h = up_conv1(up_nl1(input, w), w) 141 | h_det = h[:,:n_h2,:,:] 142 | qz_mean = h[:,n_h2:n_h2+n_z,:,:] 143 | qz_logsd = h[:,n_h2+n_z:n_h2+2*n_z,:,:] 144 | qz[0] = N.rand.gaussian_diag(qz_mean, 2*qz_logsd) 145 | if posterior == 'up_diag': 146 | h = T.concatenate([h_det,qz[0].sample],axis=1) 147 | elif posterior == 'up_iaf1': 148 | arw_mean = posterior_conv1(qz[0].sample, w) 149 | arw_mean *= .1 150 | qz[0].sample = (qz[0].sample - arw_mean) 151 | h = T.concatenate([h_det,qz[0].sample],axis=1) 152 | elif posterior == 'up_iaf2': 153 | arw_mean_logsd = posterior_conv1(qz[0].sample, w) 154 | arw_mean = arw_mean_logsd[:,::2,:,:] 155 | arw_logsd = arw_mean_logsd[:,1::2,:,:] 156 | arw_mean *= .1 157 | arw_logsd *= .1 158 | qz[0].sample = (qz[0].sample - arw_mean) / T.exp(arw_logsd) 159 | qz[0].logps += arw_logsd 160 | qz[0].logp += arw_logsd.flatten(2).sum(axis=1) 161 | h = T.concatenate([h_det,qz[0].sample],axis=1) 162 | elif posterior == 'up_iaf1_nl': 163 | context = h[:,n_h2+2*n_z:n_h2+2*n_z+n_h2] 164 | arw_mean = posterior_conv1(qz[0].sample, context, w) 165 | arw_mean *= .1 166 | qz[0].sample = (qz[0].sample - arw_mean) 167 | h = T.concatenate([h_det,qz[0].sample],axis=1) 168 | elif posterior == 'up_iaf2_nl': 169 | context = h[:,n_h2+2*n_z:n_h2+2*n_z+n_h2] 170 | arw_mean, arw_logsd = posterior_conv1(qz[0].sample, context, w) 171 | arw_mean *= .1 172 | arw_logsd *= .1 173 | qz[0].sample = (qz[0].sample - arw_mean) / T.exp(arw_logsd) 174 | qz[0].logps += arw_logsd 175 | qz[0].logp += arw_logsd.flatten(2).sum(axis=1) 176 | h = T.concatenate([h_det,qz[0].sample],axis=1) 177 | elif posterior == 'down_tim': 178 | h = T.concatenate([h_det,qz[0].mean],axis=1) 179 | elif posterior in ['down_iaf1_nl','down_iaf2_nl','down_iaf2_nl2','down_iaf1_deep','down_iaf2_deep']: 180 | up_context[0] = h[:,n_h2+2*n_z:n_h2+2*n_z+n_h2] 181 | h = h_det 182 | elif posterior in ['down_diag','down_iaf1','down_iaf2','down_bernoulli']: 183 | h = h_det 184 | else: 185 | raise Exception() 186 | if downsample: 187 | if downsample_type == 'nn': 188 | input = N.conv.downsample2d_nearest_neighbour(input, 2) 189 | elif downsample_type == 'conv': 190 | input = up_conv3(input, w) 191 | output = input + .1 * up_conv2(up_nl2(h, w), w) 192 | up_output[0] = output 193 | 194 | return output 195 | 196 | def bernoulli_p(h): 197 | #p = T.clip(.5+.5*h, 1e-7, 1. - 1e-7) 198 | p = 1e-7 + (1-2e-7)*T.nnet.sigmoid(h) 199 | return p 200 | 201 | def down_q(input, train, w): 202 | 203 | #if name == '1': 204 | # print input.tag.test_value 205 | 206 | # prior 207 | h = down_nl1(input, w) 208 | #h = T.printing.Print('h1'+name)(h) 209 | h = down_conv1(h, w) 210 | #h = T.printing.Print('h2'+name)(h) 211 | 212 | logqs = 0 213 | 214 | # posterior 215 | if posterior in ['up_diag','up_iaf1','up_iaf2','up_iaf1_nl','up_iaf2_nl']: 216 | z = qz[0].sample 217 | logqs = qz[0].logps 218 | elif posterior == 'down_diag': 219 | rz_mean = h[:,n_conv_down_prior:n_conv_down_prior+n_z,:,:] 220 | rz_logsd = h[:,n_conv_down_prior+n_z:n_conv_down_prior+2*n_z,:,:] 221 | _qz = N.rand.gaussian_diag(qz[0].mean + rz_mean, qz[0].logvar + 2*rz_logsd) 222 | z = _qz.sample 223 | logqs = _qz.logps 224 | elif posterior == 'down_tim': 225 | assert prior == 'diag' 226 | pz_mean = h[:,n_h2:n_h2+n_z,:,:] 227 | pz_logsd = h[:,n_h2+n_z:n_h2+2*n_z,:,:] 228 | 229 | qz_prec = 1./T.exp(qz[0].logvar) 230 | pz_prec = 1./T.exp(2*pz_logsd) 231 | rz_prec = qz_prec + pz_prec 232 | rz_mean = (pz_prec/rz_prec) * pz_mean + (qz_prec/rz_prec) * qz[0].mean 233 | _qz = N.rand.gaussian_diag(rz_mean, -T.log(rz_prec)) 234 | z = _qz.sample 235 | logqs = _qz.logps 236 | elif posterior == 'down_iaf1': 237 | rz_mean = h[:,n_conv_down_prior:n_conv_down_prior+n_z,:,:] 238 | rz_logsd = h[:,n_conv_down_prior+n_z:n_conv_down_prior+2*n_z,:,:] 239 | _qz = N.rand.gaussian_diag(qz[0].mean + rz_mean, qz[0].logvar + 2*rz_logsd) 240 | z = _qz.sample 241 | logqs = _qz.logps 242 | # ARW transform 243 | arw_mean = posterior_conv1(z, w) 244 | arw_mean *= .1 245 | z = (z - arw_mean) 246 | elif posterior == 'down_iaf2': 247 | rz_mean = h[:,n_conv_down_prior:n_conv_down_prior+n_z,:,:] 248 | rz_logsd = h[:,n_conv_down_prior+n_z:n_conv_down_prior+2*n_z,:,:] 249 | _qz = N.rand.gaussian_diag(qz[0].mean + rz_mean, qz[0].logvar + 2*rz_logsd) 250 | z = _qz.sample 251 | logqs = _qz.logps 252 | # ARW transform 253 | arw_mean_logsd = posterior_conv1(z, w) 254 | arw_mean = arw_mean_logsd[:,::2,:,:] 255 | arw_logsd = arw_mean_logsd[:,1::2,:,:] 256 | arw_mean *= .1 257 | arw_logsd *= .1 258 | z = (z - arw_mean) / T.exp(arw_logsd) 259 | logqs += arw_logsd 260 | elif posterior in ['down_iaf1_nl','down_iaf1_deep']: 261 | rz_mean = h[:,n_conv_down_prior:n_conv_down_prior+n_z,:,:] 262 | rz_logsd = h[:,n_conv_down_prior+n_z:n_conv_down_prior+2*n_z,:,:] 263 | _qz = N.rand.gaussian_diag(qz[0].mean + rz_mean, qz[0].logvar + 2*rz_logsd) 264 | z = _qz.sample 265 | logqs = _qz.logps 266 | # ARW transform 267 | down_context = h[:,n_conv_down_prior+2*n_z:n_conv_down_prior+2*n_z+n_h2,:,:] 268 | context = up_context[0] + down_context 269 | arw_mean = posterior_conv1(z, context, w) 270 | arw_mean *= .1 271 | z = (z - arw_mean) 272 | elif posterior in ['down_iaf2_nl','down_iaf2_nl2','down_iaf2_deep']: 273 | rz_mean = h[:,n_conv_down_prior:n_conv_down_prior+n_z,:,:] 274 | rz_logsd = h[:,n_conv_down_prior+n_z:n_conv_down_prior+2*n_z,:,:] 275 | _qz = N.rand.gaussian_diag(qz[0].mean + rz_mean, qz[0].logvar + 2*rz_logsd) 276 | z = _qz.sample 277 | logqs = _qz.logps 278 | # ARW transform 279 | down_context = h[:,n_conv_down_prior+2*n_z:n_conv_down_prior+2*n_z+n_h2,:,:] 280 | context = up_context[0] + down_context 281 | arw_mean, arw_logsd = posterior_conv1(z, context, w) 282 | arw_mean *= .1 283 | arw_logsd *= .1 284 | z = (z - arw_mean) / T.exp(arw_logsd) 285 | logqs += arw_logsd 286 | if posterior == 'down_iaf2_nl2': 287 | arw_mean, arw_logsd = posterior_conv2(z, context, w) 288 | arw_mean *= .1 289 | arw_logsd *= .1 290 | z = (z - arw_mean) / T.exp(arw_logsd) 291 | logqs += arw_logsd 292 | 293 | 294 | # Prior 295 | if prior == 'diag': 296 | pz_mean = h[:,n_h2:n_h2+n_z,:,:] 297 | pz_logsd = h[:,n_h2+n_z:n_h2+2*n_z,:,:] 298 | logps = N.rand.gaussian_diag(pz_mean, 2*pz_logsd, z).logps 299 | elif prior == 'diag2': 300 | logps = N.rand.gaussian_diag(0*z, 0*z, z).logps 301 | pz_mean = h[:,n_h2:n_h2+n_z,:,:] 302 | pz_logsd = h[:,n_h2+n_z:n_h2+2*n_z,:,:] 303 | z = pz_mean + z * T.exp(pz_logsd) 304 | elif prior == 'made': 305 | made_context = h[:,n_h2:2*n_h2,:,:] 306 | made_mean, made_logsd = prior_conv1(z, made_context, w) 307 | made_mean *= .1 308 | made_logsd *= .1 309 | logps = N.rand.gaussian_diag(made_mean, 2*made_logsd, z).logps 310 | elif prior == 'bernoulli': 311 | assert posterior == 'down_bernoulli' 312 | pz_p = bernoulli_p(h[:,n_h2:n_h2+n_z,:,:]) 313 | logps = z01 * T.log(pz_p) + (1.-z01) * T.log(1.-pz_p) 314 | else: 315 | raise Exception() 316 | 317 | h_det = h[:,:n_h2,:,:] 318 | h = T.concatenate([h_det, z], axis=1) 319 | if downsample: 320 | if downsample_type == 'nn': 321 | input = N.conv.upsample2d_nearest_neighbour(input) 322 | elif downsample_type == 'conv': 323 | input = down_conv3(input, w) 324 | 325 | output = input + .1 * down_conv2(down_nl2(h, w), w) 326 | 327 | 328 | return output, logqs - logps 329 | 330 | def down_p(input, eps, w): 331 | # prior 332 | h = down_conv1(down_nl1(input, w), w) 333 | h_det = h[:,:n_h2,:,:] 334 | if prior in ['diag','diag2']: 335 | mean_prior = h[:,n_h2:n_h2+n_z,:,:] 336 | logsd_prior = h[:,n_h2+n_z:n_h2+2*n_z,:,:] 337 | z = mean_prior + eps * T.exp(logsd_prior) 338 | elif prior == 'made': 339 | print "TODO: SAMPLES FROM MADE PRIOR" 340 | z = eps 341 | elif prior == 'bernoulli': 342 | assert posterior == 'down_bernoulli' 343 | pz_p = bernoulli_p(h[:,n_h2:n_h2+n_z,:,:]) 344 | if False: 345 | z = N.rand.bernoulli(pz_p).sample 346 | else: 347 | print "Alert: Sampling using Gaussian approximation" 348 | z = pz_p + T.sqrt(pz_p * (1-pz_p)) * eps 349 | z = prior_conv1(2*z-1, w) 350 | 351 | h = T.concatenate([h_det, z], axis=1) 352 | if downsample: 353 | if downsample_type == 'nn': 354 | input = N.conv.upsample2d_nearest_neighbour(input) 355 | elif downsample_type == 'conv': 356 | input = down_conv3(input, w) 357 | 358 | output = input + .1 * down_conv2(down_nl2(h, w), w) 359 | return output 360 | 361 | def postup(updates, w): 362 | modules = [up_conv1,up_conv2,down_conv1,down_conv2] 363 | if downsample and downsample_type == 'conv': 364 | modules += [up_conv3,down_conv3] 365 | if prior_conv1 != None: 366 | modules.append(prior_conv1) 367 | if posterior_conv1 != None: 368 | modules.append(posterior_conv1) 369 | if posterior_conv2 != None: 370 | modules.append(posterior_conv2) 371 | if posterior_conv3 != None: 372 | modules.append(posterior_conv3) 373 | if posterior_conv3 != None: 374 | modules.append(posterior_conv4) 375 | for m in modules: 376 | updates = m.postup(updates, w) 377 | return updates 378 | 379 | return G.Struct(up=up, down_q=down_q, down_p=down_p, postup=postup, w=w) 380 | 381 | # Conv VAE 382 | # - Hybrid deterministic/stochastic ResNet block per layer 383 | 384 | def cvae1(shape_x, depths, depth_ar, n_h1, n_h2, n_z, prior='diag', posterior='down_diag', px='logistic', nl='softplus', kernel_x=(5,5), kernel_h=(3,3), kl_min=0, optim='adamax', alpha=0.002, beta1=0.1, beta2=0.001, weightsharing=None, pad_x = 0, data_init=None, downsample_type='nn'): 385 | _locals = locals() 386 | _locals.pop('data_init') 387 | print 'CVAE1 with ', _locals 388 | #assert posterior in ['diag1','diag2','iaf_linear','iaf_nonlinear'] 389 | assert px in ['logistic','bernoulli'] 390 | w = {} # model params 391 | if pad_x > 0: 392 | shape_x[1] += 2*pad_x 393 | shape_x[2] += 2*pad_x 394 | 395 | # Input whitening 396 | if px == 'logistic': 397 | w['logsd_x'] = G.sharedf(0.) 398 | 399 | # encoder 400 | x_enc = N.conv.conv2d('x_enc', shape_x[0], n_h1, kernel_x, downsample=2, w=w) 401 | x_dec = N.conv.conv2d('x_dec', n_h1, shape_x[0], kernel_x, upsample=2, w=w) 402 | x_dec_nl = N.nonlinearity('x_dec_nl', nl, n_h1, w) 403 | 404 | layers = [] 405 | for i in range(len(depths)): 406 | layers.append([]) 407 | for j in range(depths[i]): 408 | downsample = (i > 0 and j == 0) 409 | if weightsharing is None or not weightsharing: 410 | name = str(i)+'_'+str(j) 411 | elif weightsharing == 'all': 412 | name = '[sharedw]'+str(i)+'_'+str(j)+'[/sharedw]' 413 | elif weightsharing == 'acrosslevels': 414 | name = '[sharedw]'+str(i)+'[/sharedw]'+'_'+str(j) 415 | elif weightsharing == 'withinlevel': 416 | name = '[sharedw]'+str(i)+'[/sharedw]'+'_'+str(j) 417 | else: 418 | raise Exception() 419 | layers[i].append(cvae_layer(name, prior, posterior, n_h1, n_h2, n_z, depth_ar, downsample, nl, kernel_h, False, downsample_type, w)) 420 | 421 | # top-level value 422 | w['h_top'] = G.sharedf(np.zeros((n_h1,))) 423 | 424 | # Initialize variables 425 | x = T.tensor4('x', dtype='uint8') 426 | x.tag.test_value = data_init['x'] 427 | n_batch_test = data_init['x'].shape[0] 428 | _x = T.clip((x + .5) / 256., 0, 1) 429 | #_x = T.clip(x / 255., 0, 1) 430 | 431 | if pad_x > 0: 432 | _x = N.conv.pad2d(_x, pad_x) 433 | 434 | # Objective function 435 | def f_encode_decode(w, train=True): 436 | 437 | results = {} 438 | 439 | h = x_enc(_x - .5, w) 440 | 441 | obj_kl = G.sharedf(0.) 442 | 443 | # bottom-up encoders 444 | for i in range(len(depths)): 445 | for j in range(depths[i]): 446 | h = layers[i][j].up(h, w) 447 | 448 | # top-level activations 449 | h = T.tile(w['h_top'].dimshuffle('x',0,'x','x'), (_x.shape[0],1,shape_x[1]/2**len(depths), shape_x[2]/2**len(depths))) 450 | 451 | # top-down priors, posteriors and decoders 452 | for i in list(reversed(range(len(depths)))): 453 | for j in list(reversed(range(depths[i]))): 454 | h, kl = layers[i][j].down_q(h, train, w) 455 | kl_sum = kl.sum(axis=(1,2,3)) 456 | results['cost_z'+str(i).zfill(3)+'_'+str(j).zfill(3)] = kl_sum 457 | # Constraint: Minimum number of bits per featuremap, averaged across minibatch 458 | if kl_min > 0: 459 | if True: 460 | kl = kl.sum(axis=(2,3)).mean(axis=0,dtype=G.floatX) 461 | obj_kl += T.maximum(np.asarray(kl_min,G.floatX), kl).sum(dtype=G.floatX) 462 | else: 463 | kl = T.maximum(np.asarray(kl_min,G.floatX), kl.sum(axis=(2,3))).sum(axis=1,dtype=G.floatX) 464 | obj_kl += kl 465 | else: 466 | obj_kl += kl_sum 467 | 468 | output = .1 * x_dec(x_dec_nl(h, w), w) 469 | 470 | # empirical distribution 471 | if px == 'logistic': 472 | mean_x = T.clip(output+.5, 0+1/512., 1-1/512.) 473 | logsd_x = 0*mean_x + w['logsd_x'] 474 | obj_logpx = N.rand.discretized_logistic(mean_x, logsd_x, 1/256., _x).logp 475 | #obj_z = T.printing.Print('obj_z')(obj_z) 476 | obj = obj_logpx - obj_kl 477 | # Compute the bits per pixel 478 | obj *= (1./np.prod(shape_x) * 1./np.log(2.)).astype('float32') 479 | 480 | #if not '__init' in w: 481 | # raise Exception() 482 | 483 | elif px == 'bernoulli': 484 | prob_x = T.nnet.sigmoid(output) 485 | prob_x = T.maximum(T.minimum(prob_x, 1-1e-7), 1e-7) 486 | #prob_x = T.printing.Print('prob_x')(prob_x) 487 | obj_logpx = N.rand.bernoulli(prob_x, _x).logp 488 | 489 | #obj_logqz = T.printing.Print('obj_logqz')(obj_logqz) 490 | #obj_logpz = T.printing.Print('obj_logpz')(obj_logpz) 491 | #obj_logpx = T.printing.Print('obj_logpx')(obj_logpx) 492 | obj = obj_logpx - obj_kl 493 | #obj = T.printing.Print('obj')(obj) 494 | 495 | results['cost_x'] = -obj_logpx 496 | results['cost'] = -obj 497 | return results 498 | 499 | # Turns Gaussian noise 'eps' into a sample 500 | def f_decoder(eps, w): 501 | 502 | # top-level activations 503 | h = T.tile(w['h_top'].dimshuffle('x',0,'x','x'), (eps['eps_0_0'].shape[0],1,shape_x[1]/2**len(depths), shape_x[2]/2**len(depths))) 504 | 505 | # top-down priors, posteriors and decoders 506 | for i in list(reversed(range(len(depths)))): 507 | for j in list(reversed(range(depths[i]))): 508 | h = layers[i][j].down_p(h, eps['eps_'+str(i)+'_'+str(j)], w) 509 | 510 | output = .1 * x_dec(x_dec_nl(h, w), w) 511 | 512 | if px == 'logistic': 513 | mean_x = T.clip(output+.5, 0+1/512., 1-1/512.) 514 | elif px == 'bernoulli': 515 | mean_x = T.nnet.sigmoid(output) 516 | 517 | image = (256.*mean_x).astype('uint8') 518 | if pad_x > 0: 519 | image = image[:,:,pad_x:-pad_x,pad_x:-pad_x] 520 | 521 | return image 522 | 523 | def f_eps(n_batch, w): 524 | eps = {} 525 | for i in range(len(depths)): 526 | for j in range(depths[i]): 527 | eps['eps_'+str(i)+'_'+str(j)] = G.rng_curand.normal((n_batch,n_z,shape_x[1]/2**(i+1),shape_x[2]/2**(i+1)),dtype=floatX) 528 | return eps 529 | 530 | def postup(updates, w): 531 | nodes = [x_enc,x_dec] 532 | for n in nodes: 533 | updates = n.postup(updates, w) 534 | for i in range(len(depths)): 535 | for j in range(depths[i]): 536 | updates = layers[i][j].postup(updates, w) 537 | 538 | return updates 539 | 540 | # Compile init function 541 | if data_init != None: 542 | w['__init'] = OrderedDict() 543 | f_encode_decode(w) 544 | w.pop('__init') 545 | #for i in w: print i, abs(w[i].get_value()).min(), abs(w[i].get_value()).max(), abs(w[i].get_value()).mean() 546 | 547 | # Compile training function 548 | 549 | #todo: replace postup with below 550 | #w['_updates'] = updates 551 | #f_cost(w) 552 | #updates = w.pop('_updates') 553 | 554 | 555 | w_avg = {i: G.sharedf(w[i].get_value()) for i in w} 556 | 557 | def lazy(f): 558 | def newf(*args, **kws): 559 | if not hasattr(f, 'cache'): 560 | f.cache = f() 561 | return f.cache(*args, **kws) 562 | return newf 563 | 564 | @lazy 565 | def f_train(): 566 | if optim == 'adamax': 567 | train_cost = f_encode_decode(w)['cost'] 568 | updates = G.misc.optim.AdaMaxAvg([w],[w_avg], train_cost, alpha=-alpha, beta1=beta1, beta2=beta2, disconnected_inputs='ignore') 569 | elif optim == 'eve': 570 | f = lambda w: f_encode_decode(w)['cost'] 571 | train_cost, updates = G.misc.optim.Eve(w, w_avg, f, alpha=-alpha, beta1=beta1, beta2=beta2, disconnected_inputs='ignore') 572 | updates = postup(updates, w) 573 | return G.function({'x':x}, train_cost, updates=updates, lazy=lazy) 574 | 575 | @lazy 576 | def f_train_q(): 577 | keys_q = [] 578 | for i in w: 579 | if '_q_' in i: keys_q.append(i) 580 | train_cost = f_encode_decode(w)['cost'] 581 | updates = G.misc.optim.AdaMaxAvg([w],None, train_cost, alpha=-alpha, beta1=beta1, beta2=beta2, update_keys=keys_q, disconnected_inputs='ignore') 582 | updates = postup(updates, w) 583 | return G.function({'x':x}, train_cost, updates=updates, lazy=lazy) 584 | 585 | # Compile evaluation function 586 | @lazy 587 | def f_eval(): 588 | results = f_encode_decode(w_avg, False) 589 | return G.function({'x':x}, results) 590 | 591 | # Compile epsilon generating function 592 | @lazy 593 | def f_eps_(): 594 | n_batch = T.lscalar() 595 | n_batch.tag.test_value = 16 596 | eps = f_eps(n_batch, w) 597 | return G.function({'n_batch':n_batch}, eps, lazy=lazy) 598 | 599 | # Compile sampling function 600 | @lazy 601 | def f_decode(): 602 | eps = {} 603 | for i in range(len(depths)): 604 | for j in range(depths[i]): 605 | eps['eps_'+str(i)+'_'+str(j)] = T.tensor4('eps'+str(i)) 606 | eps['eps_'+str(i)+'_'+str(j)].tag.test_value = np.random.randn(n_batch_test,n_z,shape_x[1]/2**(i+1),shape_x[2]/2**(i+1)).astype(floatX) 607 | image = f_decoder(eps, w_avg) 608 | return G.function(eps, image, lazy=lazy) 609 | 610 | return G.Struct(train=f_train, eval=f_eval, decode=f_decode, eps=f_eps_, w=w, w_avg=w_avg) 611 | 612 | # Fully-connected VAE 613 | # - Hybrid deterministic/stochastic ResNet block per layer 614 | 615 | def fcvae(shape_x, depth_model, depth_ar, n_h1, n_h2, n_z, posterior, px='logistic', nl='softplus', alpha=0.002, beta1=0.1, beta2=0.001, share_w=False, data_init=None): 616 | _locals = locals() 617 | _locals.pop('data_init') 618 | print 'CVAE9 with ', _locals 619 | #assert posterior in ['diag1','diag2','iaf_linear','iaf_nonlinear'] 620 | assert px in ['logistic','bernoulli'] 621 | w = {} # model params 622 | 623 | kernel_h = (1,1) 624 | n_x = shape_x[0]*shape_x[1]*shape_x[2] 625 | 626 | # Input whitening 627 | if px == 'logistic': 628 | w['logsd_x'] = G.sharedf(0.) 629 | 630 | # encoder 631 | x_enc = N.conv.conv2d('x_enc', n_x, n_h1, (1,1), w=w) 632 | x_dec = N.conv.conv2d('x_dec', n_h1, n_x, (1,1), w=w) 633 | x_dec_nl = N.nonlinearity('x_dec_nl', nl, n_h1, w) 634 | 635 | layers = [] 636 | for i in range(depth_model): 637 | name = str(i) 638 | if share_w: 639 | name = '[sharedw]'+str(i)+'[/sharedw]' 640 | layers.append(cvae_layer(name, posterior, n_h1, n_h2, n_z, depth_ar, False, nl, kernel_h, share_w, w)) 641 | 642 | # top-level value 643 | #w['h_top'] = G.sharedf(np.zeros((n_h1,))) 644 | w['h_top'] = G.sharedf(np.random.normal(0,0.01,size=(n_h1,))) 645 | 646 | # Initialize variables 647 | x = T.tensor4('x') 648 | x.tag.test_value = data_init['x'] 649 | n_batch_test = data_init['x'].shape[0] 650 | _x = T.clip(x / 255., 0, 1) 651 | 652 | # Objective function 653 | def f_cost(w, train=True): 654 | 655 | results = {} 656 | 657 | h = x_enc(_x.reshape((-1,n_x,1,1)) - .5, w) 658 | 659 | obj_logpz = 0 660 | obj_logqz = 0 661 | 662 | # bottom-up encoders 663 | for i in range(depth_model): 664 | h = layers[i].up(h, w) 665 | 666 | # top-level activations 667 | h = T.tile(w['h_top'].dimshuffle('x',0,'x','x'), (_x.shape[0],1,1,1)) 668 | 669 | # top-down priors, posteriors and decoders 670 | for i in list(reversed(range(depth_model))): 671 | h, _obj_logqz, _obj_logpz = layers[i].down_q(h, train, w) 672 | obj_logqz += _obj_logqz 673 | obj_logpz += _obj_logpz 674 | results['cost_z'+str(i).zfill(3)] = _obj_logqz - _obj_logpz 675 | 676 | output = .1 * x_dec(x_dec_nl(h, w), w).reshape((-1,shape_x[0],shape_x[1],shape_x[2])) 677 | 678 | # empirical distribution 679 | if px == 'logistic': 680 | mean_x = T.clip(output, -.5, .5) 681 | logsd_x = 0*mean_x + w['logsd_x'] 682 | obj_logpx = N.rand.discretized_logistic(mean_x, logsd_x, 1/255., _x - .5).logp 683 | 684 | obj = obj_logpz - obj_logqz + obj_logpx 685 | # Compute the bits per pixel 686 | obj *= (1./np.prod(shape_x) * 1./np.log(2.)).astype('float32') 687 | 688 | elif px == 'bernoulli': 689 | prob_x = T.nnet.sigmoid(output) 690 | prob_x = T.minimum(prob_x, 1-1e-7) 691 | prob_x = T.maximum(prob_x, 1e-7) 692 | #prob_x = T.printing.Print('prob_x')(prob_x) 693 | obj_logpx = N.rand.bernoulli(prob_x, _x).logp 694 | 695 | #obj_logqz = T.printing.Print('obj_logqz')(obj_logqz) 696 | #obj_logpz = T.printing.Print('obj_logpz')(obj_logpz) 697 | #obj_logpx = T.printing.Print('obj_logpx')(obj_logpx) 698 | obj = obj_logpz - obj_logqz + obj_logpx 699 | #obj = T.printing.Print('obj')(obj) 700 | 701 | results['cost_x'] = -obj_logpx 702 | results['cost'] = -obj 703 | return results 704 | 705 | #print 'obj_logpz', obj_logpz.tag.test_value 706 | #print 'obj_logqz', obj_logqz.tag.test_value 707 | #print 'obj_logpx', obj_x.tag.test_value 708 | #obj_logpz = T.printing.Print('obj_logpz')(obj_logpz) 709 | #obj_logqz = T.printing.Print('obj_logqz')(obj_logqz) 710 | #obj_x = T.printing.Print('obj_logpx')(obj_x) 711 | 712 | 713 | 714 | 715 | # Turns Gaussian noise 'eps' into a sample 716 | def f_decoder(eps, w): 717 | 718 | # top-level activations 719 | h = T.tile(w['h_top'].dimshuffle('x',0,'x','x'), (eps['eps_0'].shape[0],1,1,1)) 720 | 721 | # top-down priors, posteriors and decoders 722 | for i in list(reversed(range(depth_model))): 723 | h = layers[i].down_p(h, eps['eps_'+str(i)], w) 724 | 725 | output = .1 * x_dec(x_dec_nl(h, w), w).reshape((-1,shape_x[0],shape_x[1],shape_x[2])) 726 | if px == 'logistic': 727 | mean_x = T.clip(output[:,:,:,:] + .5, 0, 1) 728 | elif px == 'bernoulli': 729 | mean_x = T.nnet.sigmoid(output) 730 | image = (255.*T.clip(mean_x, 0, 1)).astype('uint8') 731 | return image 732 | 733 | def f_eps(n_batch, w): 734 | eps = {} 735 | for i in range(depth_model): 736 | eps['eps_'+str(i)] = G.rng_curand.normal((n_batch,n_z,1,1),dtype=floatX) 737 | return eps 738 | 739 | def postup(updates, w): 740 | nodes = [x_enc,x_dec] 741 | for n in nodes: 742 | updates = n.postup(updates, w) 743 | for i in range(depth_model): 744 | updates = layers[i].postup(updates, w) 745 | 746 | return updates 747 | 748 | # Compile init function 749 | if data_init != None: 750 | w['__init'] = OrderedDict() 751 | f_cost(w) 752 | w.pop('__init') 753 | #for i in w: print i, abs(w[i].get_value()).min(), abs(w[i].get_value()).max(), abs(w[i].get_value()).mean() 754 | 755 | # Compile training function 756 | results = f_cost(w) 757 | updates, (w_avg,) = G.misc.optim.AdaMaxAvg([w], results['cost'], alpha=-alpha, beta1=beta1, beta2=beta2, disconnected_inputs='ignore') 758 | #todo: replace postup with below 759 | #w['_updates'] = updates 760 | #f_cost(w) 761 | #updates = w.pop('_updates') 762 | 763 | updates = postup(updates, w) 764 | f_train = G.function({'x':x}, results['cost'], updates=updates) 765 | 766 | # Compile evaluation function 767 | results = f_cost(w_avg, False) 768 | f_eval = G.function({'x':x}, results) 769 | 770 | # Compile epsilon generating function 771 | n_batch = T.lscalar() 772 | n_batch.tag.test_value = 16 773 | eps = f_eps(n_batch, w) 774 | f_eps = G.function({'n_batch':n_batch}, eps) 775 | 776 | # Compile sampling function 777 | eps = {} 778 | for i in range(depth_model): 779 | eps['eps_'+str(i)] = T.tensor4('eps'+str(i)) 780 | eps['eps_'+str(i)].tag.test_value = np.random.randn(n_batch_test,n_z,1,1).astype(floatX) 781 | image = f_decoder(eps, w_avg) 782 | f_decode = G.function(eps, image) 783 | 784 | return G.Struct(train=f_train, eval=f_eval, decode=f_decode, eps=f_eps, w=w, w_avg=w_avg) 785 | 786 | -------------------------------------------------------------------------------- /tf_train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import tensorflow as tf 4 | from tensorflow.contrib.framework.python.ops import arg_scope 5 | from tf_utils.adamax import AdamaxOptimizer 6 | from tf_utils.hparams import HParams 7 | from tf_utils.common import img_stretch, img_tile 8 | from tf_utils.common import assign_to_gpu, split, CheckpointLoader, average_grads, NotBuggySupervisor 9 | from tf_utils.layers import conv2d, deconv2d, ar_multiconv2d, resize_nearest_neighbor 10 | from tf_utils.distributions import DiagonalGaussian, discretized_logistic, compute_lowerbound, repeat 11 | from tf_utils.data_utils import get_inputs, get_images 12 | import tqdm 13 | 14 | # settings 15 | flags = tf.flags 16 | flags.DEFINE_string("logdir", "/tmp/vae", "Logging directory.") 17 | flags.DEFINE_string("hpconfig", "", "Overrides default hyper-parameters.") 18 | flags.DEFINE_string("mode", "train", "Whether to run 'train' or 'eval' model.") 19 | flags.DEFINE_integer("num_gpus", 8, "Number of GPUs used.") 20 | FLAGS = flags.FLAGS 21 | 22 | 23 | class IAFLayer(object): 24 | def __init__(self, hps, mode, downsample): 25 | self.hps = hps 26 | self.mode = mode 27 | self.downsample = downsample 28 | 29 | def up(self, input, **_): 30 | hps = self.hps 31 | h_size = hps.h_size 32 | z_size = hps.z_size 33 | stride = [2, 2] if self.downsample else [1, 1] 34 | 35 | with arg_scope([conv2d]): 36 | x = tf.nn.elu(input) 37 | x = conv2d("up_conv1", x, 2 * z_size + 2 * h_size, stride=stride) 38 | self.qz_mean, self.qz_logsd, self.up_context, h = split(x, 1, [z_size, z_size, h_size, h_size]) 39 | 40 | h = tf.nn.elu(h) 41 | h = conv2d("up_conv3", h, h_size) 42 | if self.downsample: 43 | input = resize_nearest_neighbor(input, 0.5) 44 | return input + 0.1 * h 45 | 46 | def down(self, input): 47 | hps = self.hps 48 | h_size = hps.h_size 49 | z_size = hps.z_size 50 | 51 | with arg_scope([conv2d, ar_multiconv2d]): 52 | x = tf.nn.elu(input) 53 | x = conv2d("down_conv1", x, 4 * z_size + h_size * 2) 54 | pz_mean, pz_logsd, rz_mean, rz_logsd, down_context, h_det = split(x, 1, [z_size] * 4 + [h_size] * 2) 55 | 56 | prior = DiagonalGaussian(pz_mean, 2 * pz_logsd) 57 | posterior = DiagonalGaussian(rz_mean + self.qz_mean, 2 * (rz_logsd + self.qz_logsd)) 58 | context = self.up_context + down_context 59 | 60 | if self.mode in ["init", "sample"]: 61 | z = prior.sample 62 | else: 63 | z = posterior.sample 64 | 65 | if self.mode == "sample": 66 | kl_cost = kl_obj = tf.zeros([hps.batch_size * hps.k]) 67 | else: 68 | logqs = posterior.logps(z) 69 | x = ar_multiconv2d("ar_multiconv2d", z, context, [h_size, h_size], [z_size, z_size]) 70 | arw_mean, arw_logsd = x[0] * 0.1, x[1] * 0.1 71 | z = (z - arw_mean) / tf.exp(arw_logsd) 72 | logqs += arw_logsd 73 | logps = prior.logps(z) 74 | 75 | kl_cost = logqs - logps 76 | 77 | if hps.kl_min > 0: 78 | # [0, 1, 2, 3] -> [0, 1] -> [1] / (b * k) 79 | kl_ave = tf.reduce_mean(tf.reduce_sum(kl_cost, [2, 3]), [0], keep_dims=True) 80 | kl_ave = tf.maximum(kl_ave, hps.kl_min) 81 | kl_ave = tf.tile(kl_ave, [hps.batch_size * hps.k, 1]) 82 | kl_obj = tf.reduce_sum(kl_ave, [1]) 83 | else: 84 | kl_obj = tf.reduce_sum(kl_cost, [1, 2, 3]) 85 | kl_cost = tf.reduce_sum(kl_cost, [1, 2, 3]) 86 | 87 | h = tf.concat(1, [z, h_det]) 88 | h = tf.nn.elu(h) 89 | if self.downsample: 90 | input = resize_nearest_neighbor(input, 2) 91 | h = deconv2d("down_deconv2", h, h_size) 92 | else: 93 | h = conv2d("down_conv2", h, h_size) 94 | output = input + 0.1 * h 95 | return output, kl_obj, kl_cost 96 | 97 | 98 | def get_default_hparams(): 99 | return HParams( 100 | batch_size=16, # Batch size on one GPU. 101 | eval_batch_size=100, # Batch size for evaluation. 102 | num_gpus=8, # Number of GPUs (effectively increases batch size). 103 | learning_rate=0.01, # Learning rate. 104 | z_size=32, # Size of z variables. 105 | h_size=160, # Size of resnet block. 106 | kl_min=0.25, # Number of "free bits/nats". 107 | depth=2, # Number of downsampling blocks. 108 | num_blocks=2, # Number of resnet blocks for each downsampling layer. 109 | k=1, # Number of samples for IS objective. 110 | dataset="cifar10", # Dataset name. 111 | image_size=32, # Image size. 112 | ) 113 | 114 | 115 | class CVAE1(object): 116 | def __init__(self, hps, mode, x=None): 117 | self.hps = hps 118 | self.mode = mode 119 | input_shape = [hps.batch_size * hps.num_gpus, 3, hps.image_size, hps.image_size] 120 | self.x = tf.placeholder(tf.uint8, shape=input_shape) if x is None else x 121 | self.m_trunc = [] 122 | self.dec_log_stdv = tf.get_variable("dec_log_stdv", initializer=tf.constant(0.0)) 123 | 124 | losses = [] 125 | grads = [] 126 | xs = tf.split(0, hps.num_gpus, self.x) 127 | opt = AdamaxOptimizer(hps.learning_rate) 128 | 129 | num_pixels = 3 * hps.image_size * hps.image_size 130 | for i in range(hps.num_gpus): 131 | with tf.device(assign_to_gpu(i)): 132 | m, obj, loss = self._forward(xs[i], i) 133 | losses += [loss] 134 | self.m_trunc += [m] 135 | 136 | # obj /= (np.log(2.) * num_pixels * hps.batch_size) 137 | if mode == "train": 138 | grads += [opt.compute_gradients(obj)] 139 | 140 | self.global_step = tf.get_variable("global_step", [], tf.int32, initializer=tf.zeros_initializer, 141 | trainable=False) 142 | self.bits_per_dim = tf.add_n(losses) / (np.log(2.) * num_pixels * hps.batch_size * hps.num_gpus) 143 | 144 | if mode == "train": 145 | # add gradients together and get training updates 146 | grad = average_grads(grads) 147 | self.train_op = opt.apply_gradients(grad, global_step=self.global_step) 148 | tf.scalar_summary("model/bits_per_dim", self.bits_per_dim) 149 | tf.scalar_summary("model/dec_log_stdv", self.dec_log_stdv) 150 | self.summary_op = tf.merge_all_summaries() 151 | else: 152 | self.train_op = tf.no_op() 153 | 154 | if mode in ["train", "eval"]: 155 | with tf.name_scope(None): # This is needed due to EMA implementation silliness. 156 | # keep track of moving average 157 | ema = tf.train.ExponentialMovingAverage(decay=0.999) 158 | self.train_op = tf.group(*[self.train_op, ema.apply(tf.trainable_variables())]) 159 | self.avg_dict = ema.variables_to_restore() 160 | 161 | def _forward(self, x, gpu): 162 | hps = self.hps 163 | 164 | x = tf.to_float(x) 165 | x = tf.clip_by_value((x + 0.5) / 256.0, 0.0, 1.0) - 0.5 166 | 167 | # Input images are repeated k times on the input. 168 | # This is used for Importance Sampling loss (k is number of samples). 169 | data_size = hps.batch_size * hps.k 170 | x = repeat(x, hps.k) 171 | 172 | orig_x = x 173 | h_size = hps.h_size 174 | 175 | with arg_scope([conv2d, deconv2d], init=(self.mode == "init")): 176 | layers = [] 177 | for i in range(hps.depth): 178 | layers.append([]) 179 | for j in range(hps.num_blocks): 180 | downsample = (i > 0) and (j == 0) 181 | layers[-1].append(IAFLayer(hps, self.mode, downsample)) 182 | 183 | h = conv2d("x_enc", x, h_size, [5, 5], [2, 2]) # -> [16, 16] 184 | for i, layer in enumerate(layers): 185 | for j, sub_layer in enumerate(layer): 186 | with tf.variable_scope("IAF_%d_%d" % (i, j)): 187 | h = sub_layer.up(h) 188 | 189 | # top->down 190 | self.h_top = h_top = tf.get_variable("h_top", [h_size], initializer=tf.zeros_initializer) 191 | h_top = tf.reshape(h_top, [1, -1, 1, 1]) 192 | h = tf.tile(h_top, [data_size, 1, hps.image_size / 2 ** len(layers), hps.image_size / 2 ** len(layers)]) 193 | kl_cost = kl_obj = 0.0 194 | 195 | for i, layer in reversed(list(enumerate(layers))): 196 | for j, sub_layer in reversed(list(enumerate(layer))): 197 | with tf.variable_scope("IAF_%d_%d" % (i, j)): 198 | h, cur_obj, cur_cost = sub_layer.down(h) 199 | kl_obj += cur_obj 200 | kl_cost += cur_cost 201 | 202 | if self.mode == "train" and gpu == hps.num_gpus - 1: 203 | tf.scalar_summary("model/kl_obj_%02d_%02d" % (i, j), tf.reduce_mean(cur_obj)) 204 | tf.scalar_summary("model/kl_cost_%02d_%02d" % (i, j), tf.reduce_mean(cur_cost)) 205 | 206 | x = tf.nn.elu(h) 207 | x = deconv2d("x_dec", x, 3, [5, 5]) 208 | x = tf.clip_by_value(x, -0.5 + 1 / 512., 0.5 - 1 / 512.) 209 | 210 | log_pxz = discretized_logistic(x, self.dec_log_stdv, sample=orig_x) 211 | obj = tf.reduce_sum(kl_obj - log_pxz) 212 | 213 | if self.mode == "train" and gpu == hps.num_gpus - 1: 214 | tf.scalar_summary("model/log_pxz", -tf.reduce_mean(log_pxz)) 215 | tf.scalar_summary("model/kl_obj", tf.reduce_mean(kl_obj)) 216 | tf.scalar_summary("model/kl_cost", tf.reduce_mean(kl_cost)) 217 | 218 | loss = tf.reduce_sum(compute_lowerbound(log_pxz, kl_cost, hps.k)) 219 | return x, obj, loss 220 | 221 | 222 | def run(hps): 223 | with tf.variable_scope("model") as vs: 224 | x = get_inputs(hps.dataset, "train", hps.batch_size * FLAGS.num_gpus, hps.image_size) 225 | 226 | hps.num_gpus = 1 227 | init_x = x[:hps.batch_size, :, :, :] 228 | init_model = CVAE1(hps, "init", init_x) 229 | 230 | vs.reuse_variables() 231 | hps.num_gpus = FLAGS.num_gpus 232 | model = CVAE1(hps, "train", x) 233 | 234 | saver = tf.train.Saver() 235 | 236 | total_size = 0 237 | for v in tf.trainable_variables(): 238 | total_size += np.prod([int(s) for s in v.get_shape()]) 239 | print("Num trainable variables: %d" % total_size) 240 | 241 | init_op = tf.initialize_all_variables() 242 | 243 | def init_fn(ses): 244 | print("Initializing parameters.") 245 | # XXX(rafal): TensorFlow bug?? Default initializer should handle things well.. 246 | ses.run(init_model.h_top.initializer) 247 | ses.run(init_op) 248 | print("Initialized!") 249 | 250 | sv = NotBuggySupervisor(is_chief=True, 251 | logdir=FLAGS.logdir + "/train", 252 | summary_op=None, # Automatic summaries don"t work with placeholders. 253 | saver=saver, 254 | global_step=model.global_step, 255 | save_summaries_secs=30, 256 | save_model_secs=0, 257 | init_op=None, 258 | init_fn=init_fn) 259 | 260 | print("starting training") 261 | local_step = 0 262 | begin = time.time() 263 | 264 | config = tf.ConfigProto(allow_soft_placement=True) 265 | with sv.managed_session(config=config) as sess: 266 | print("Running first iteration!") 267 | while not sv.should_stop(): 268 | fetches = [model.bits_per_dim, model.global_step, model.dec_log_stdv, model.train_op] 269 | 270 | should_compute_summary = (local_step % 20 == 19) 271 | if should_compute_summary: 272 | fetches += [model.summary_op] 273 | 274 | fetched = sess.run(fetches) 275 | 276 | if should_compute_summary: 277 | sv.summary_computed(sess, fetched[-1]) 278 | 279 | if local_step < 10 or should_compute_summary: 280 | print("Iteration %d, time = %.2fs, train bits_per_dim = %.4f, dec_log_stdv = %.4f" % ( 281 | fetched[1], time.time() - begin, fetched[0], fetched[2])) 282 | begin = time.time() 283 | if np.isnan(fetched[0]): 284 | print("NAN detected!") 285 | break 286 | if local_step % 100 == 0: 287 | saver.save(sess, sv.save_path, global_step=sv.global_step, write_meta_graph=False) 288 | 289 | local_step += 1 290 | sv.stop() 291 | 292 | 293 | def run_eval(hps, mode): 294 | hps.num_gpus = 1 295 | hps.batch_size = hps.eval_batch_size 296 | 297 | with tf.variable_scope("model") as vs: 298 | model = CVAE1(hps, "eval") 299 | vs.reuse_variables() 300 | sample_model = CVAE1(hps, "sample") 301 | 302 | saver = tf.train.Saver(model.avg_dict) 303 | # Use only 4 threads for the evaluation. 304 | config = tf.ConfigProto(allow_soft_placement=True, 305 | intra_op_parallelism_threads=4, 306 | inter_op_parallelism_threads=4) 307 | sess = tf.Session(config=config) 308 | sw = tf.train.SummaryWriter(FLAGS.logdir + "/" + FLAGS.mode, sess.graph) 309 | ckpt_loader = CheckpointLoader(saver, model.global_step, FLAGS.logdir + "/train") 310 | 311 | with sess.as_default(): 312 | dataset = get_images(hps.dataset, mode[5:], hps.image_size) 313 | assert dataset.n % hps.batch_size == 0 314 | epoch_size = int(dataset.n / hps.batch_size) 315 | 316 | while ckpt_loader.load_checkpoint(): 317 | global_step = ckpt_loader.last_global_step 318 | 319 | dataset.shuffle() 320 | summary = tf.Summary() 321 | all_bits_per_dim = [] 322 | for _ in tqdm.trange(epoch_size): 323 | all_bits_per_dim += [sess.run(model.bits_per_dim, {model.x: dataset.next_batch(hps.batch_size)})] 324 | 325 | average_bits = float(np.mean(all_bits_per_dim)) 326 | print("Step: %d Score: %.3f" % (global_step, average_bits)) 327 | summary.value.add(tag='eval_bits_per_dim', simple_value=average_bits) 328 | 329 | if hps.k == 1: 330 | # show reconstructions from the model 331 | total_samples = 36 332 | num_examples = 0 333 | imgs_inputs = np.zeros([total_samples / 2, hps.image_size, hps.image_size, 3], np.float32) 334 | imgs_recs = np.zeros([total_samples / 2, hps.image_size, hps.image_size, 3], np.float32) 335 | while num_examples < total_samples / 2: 336 | batch = dataset.next_batch(hps.batch_size) 337 | sample_x = sess.run(model.m_trunc[0], {model.x: batch}) 338 | batch_bhwc = np.transpose(batch, (0, 2, 3, 1)) 339 | img_bhwc = np.transpose(sample_x, (0, 2, 3, 1)) 340 | 341 | if num_examples + hps.batch_size > total_samples / 2: 342 | cur_examples = total_samples / 2 - num_examples 343 | else: 344 | cur_examples = hps.batch_size 345 | 346 | imgs_inputs[num_examples:num_examples + cur_examples, ...] = img_stretch(batch_bhwc[:cur_examples, ...]) 347 | imgs_recs[num_examples:num_examples + cur_examples, ...] = img_stretch(img_bhwc[:cur_examples, ...]) 348 | num_examples += cur_examples 349 | 350 | imgs_to_plot = np.zeros([total_samples, hps.image_size, hps.image_size, 3], np.float32) 351 | imgs_to_plot[::2, ...] = imgs_inputs 352 | imgs_to_plot[1::2, ...] = imgs_recs 353 | imgs = img_tile(imgs_to_plot, aspect_ratio=1.0, border=0).astype(np.float32) 354 | imgs = np.expand_dims(imgs, 0) 355 | im_summary = tf.image_summary("reconstructions", imgs, 1) 356 | summary.MergeFromString(sess.run(im_summary)) 357 | 358 | # generate samples from the model 359 | num_examples = 0 360 | imgs_to_plot = np.zeros([total_samples, hps.image_size, hps.image_size, 3], np.float32) 361 | while num_examples < total_samples: 362 | sample_x = sess.run(sample_model.m_trunc[0]) 363 | img_bhwc = img_stretch(np.transpose(sample_x, (0, 2, 3, 1))) 364 | 365 | if num_examples + hps.batch_size > total_samples: 366 | cur_examples = total_samples - num_examples 367 | else: 368 | cur_examples = hps.batch_size 369 | 370 | imgs_to_plot[num_examples:num_examples+cur_examples, ...] = img_stretch(img_bhwc[:cur_examples, ...]) 371 | num_examples += cur_examples 372 | 373 | imgs = img_tile(imgs_to_plot, aspect_ratio=1.0, border=0).astype(np.float32) 374 | imgs = np.expand_dims(imgs, 0) 375 | im_summary = tf.image_summary("samples", imgs, 1) 376 | summary.MergeFromString(sess.run(im_summary)) 377 | 378 | sw.add_summary(summary, global_step) 379 | sw.flush() 380 | 381 | 382 | def main(_): 383 | hps = get_default_hparams().parse(FLAGS.hpconfig) 384 | 385 | if FLAGS.mode == "train": 386 | run(hps) 387 | else: 388 | run_eval(hps, FLAGS.mode) 389 | 390 | 391 | if __name__ == "__main__": 392 | tf.app.run() 393 | -------------------------------------------------------------------------------- /tf_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/iaf/ad33fe4872bf6e4b4f387e709a625376bb8b0d9d/tf_utils/__init__.py -------------------------------------------------------------------------------- /tf_utils/adamax.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.ops import control_flow_ops 2 | from tensorflow.python.ops import math_ops 3 | from tensorflow.python.ops import state_ops 4 | from tensorflow.python.framework import ops 5 | from tensorflow.python.training import optimizer 6 | import tensorflow as tf 7 | 8 | 9 | class AdamaxOptimizer(optimizer.Optimizer): 10 | """Optimizer that implements the Adamax algorithm. 11 | 12 | See [Kingma et. al., 2014](http://arxiv.org/abs/1412.6980) 13 | ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). 14 | 15 | @@__init__ 16 | """ 17 | 18 | def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, use_locking=False, name="Adamax"): 19 | super(AdamaxOptimizer, self).__init__(use_locking, name) 20 | self._lr = learning_rate 21 | self._beta1 = beta1 22 | self._beta2 = beta2 23 | 24 | # Tensor versions of the constructor arguments, created in _prepare(). 25 | self._lr_t = None 26 | self._beta1_t = None 27 | self._beta2_t = None 28 | 29 | def _prepare(self): 30 | self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate") 31 | self._beta1_t = ops.convert_to_tensor(self._beta1, name="beta1") 32 | self._beta2_t = ops.convert_to_tensor(self._beta2, name="beta2") 33 | 34 | def _create_slots(self, var_list): 35 | # Create slots for the first and second moments. 36 | for v in var_list: 37 | self._zeros_slot(v, "m", self._name) 38 | self._zeros_slot(v, "v", self._name) 39 | 40 | def _apply_dense(self, grad, var): 41 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 42 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 43 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 44 | if var.dtype.base_dtype == tf.float16: 45 | eps = 1e-7 # Can't use 1e-8 due to underflow -- not sure if it makes a big difference. 46 | else: 47 | eps = 1e-8 48 | 49 | v = self.get_slot(var, "v") 50 | v_t = v.assign(beta1_t * v + (1. - beta1_t) * grad) 51 | m = self.get_slot(var, "m") 52 | m_t = m.assign(tf.maximum(beta2_t * m + eps, tf.abs(grad))) 53 | g_t = v_t / m_t 54 | 55 | var_update = state_ops.assign_sub(var, lr_t * g_t) 56 | return control_flow_ops.group(*[var_update, m_t, v_t]) 57 | 58 | def _apply_sparse(self, grad, var): 59 | raise NotImplementedError("Sparse gradient updates are not supported.") 60 | -------------------------------------------------------------------------------- /tf_utils/cifar10_data.py: -------------------------------------------------------------------------------- 1 | import cPickle 2 | import os 3 | import sys 4 | import tarfile 5 | from six.moves import urllib 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | 10 | def maybe_download_and_extract(data_dir, url='http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'): 11 | if not os.path.exists(os.path.join(data_dir, 'cifar-10-batches-py')): 12 | if not os.path.exists(data_dir): 13 | os.makedirs(data_dir) 14 | filename = url.split('/')[-1] 15 | filepath = os.path.join(data_dir, filename) 16 | if not os.path.exists(filepath): 17 | def _progress(count, block_size, total_size): 18 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, 19 | float(count * block_size) / float(total_size) * 100.0)) 20 | sys.stdout.flush() 21 | filepath, _ = urllib.request.urlretrieve(url, filepath, _progress) 22 | print() 23 | statinfo = os.stat(filepath) 24 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 25 | tarfile.open(filepath, 'r:gz').extractall(data_dir) 26 | 27 | 28 | def unpickle(file): 29 | fo = open(file, 'rb') 30 | d = cPickle.load(fo) 31 | fo.close() 32 | return {'x': np.cast[np.uint8](d['data'].reshape((10000, 3, 32, 32))), 33 | 'y': np.array(d['labels']).astype(np.uint8)} 34 | 35 | 36 | def load(data_dir, subset='train'): 37 | maybe_download_and_extract(data_dir) 38 | if subset == 'train': 39 | train_data = [unpickle(os.path.join(data_dir,'cifar-10-batches-py/data_batch_' + str(i))) for i in range(1,6)] 40 | trainx = np.concatenate([d['x'] for d in train_data],axis=0) 41 | trainy = np.concatenate([d['y'] for d in train_data],axis=0) 42 | return trainx, trainy 43 | elif subset=='test': 44 | test_data = unpickle(os.path.join(data_dir,'cifar-10-batches-py/test_batch')) 45 | testx = test_data['x'] 46 | testy = test_data['y'] 47 | return testx, testy 48 | else: 49 | raise NotImplementedError('subset should be either train or test') 50 | 51 | 52 | def read_cifar10(filename_queue): 53 | """Reads and parses examples from CIFAR10 data files. 54 | Recommendation: if you want N-way read parallelism, call this function 55 | N times. This will give you N independent Readers reading different 56 | files & positions within those files, which will give better mixing of 57 | examples. 58 | Args: 59 | filename_queue: A queue of strings with the filenames to read from. 60 | Returns: 61 | An object representing a single example, with the following fields: 62 | height: number of rows in the result (32) 63 | width: number of columns in the result (32) 64 | depth: number of color channels in the result (3) 65 | key: a scalar string Tensor describing the filename & record number 66 | for this example. 67 | label: an int32 Tensor with the label in the range 0..9. 68 | uint8image: a [height, width, depth] uint8 Tensor with the image data 69 | """ 70 | 71 | class CIFAR10Record(object): 72 | pass 73 | 74 | result = CIFAR10Record() 75 | 76 | # Dimensions of the images in the CIFAR-10 dataset. 77 | # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the 78 | # input format. 79 | label_bytes = 1 # 2 for CIFAR-100 80 | result.height = 32 81 | result.width = 32 82 | result.depth = 3 83 | image_bytes = result.height * result.width * result.depth 84 | # Every record consists of a label followed by the image, with a 85 | # fixed number of bytes for each. 86 | record_bytes = label_bytes + image_bytes 87 | 88 | # Read a record, getting filenames from the filename_queue. No 89 | # header or footer in the CIFAR-10 format, so we leave header_bytes 90 | # and footer_bytes at their default of 0. 91 | reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) 92 | result.key, value = reader.read(filename_queue) 93 | 94 | # Convert from a string to a vector of uint8 that is record_bytes long. 95 | record_bytes = tf.decode_raw(value, tf.uint8) 96 | 97 | # The first bytes represent the label, which we convert from uint8->int32. 98 | result.label = tf.cast( 99 | tf.slice(record_bytes, [0], [label_bytes]), tf.int32) 100 | 101 | # The remaining bytes after the label represent the image, which we reshape 102 | # from [depth * height * width] to [depth, height, width]. 103 | result.uint8image = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), 104 | [result.depth, result.height, result.width]) 105 | return result 106 | 107 | 108 | def cifar_preloaded(images, batch_size): 109 | with tf.device("/cpu:0"): 110 | input_images = tf.constant(images) 111 | image = tf.train.slice_input_producer([input_images]) 112 | return tf.train.shuffle_batch(image, batch_size, 20000 + 3 * batch_size, 20000, 5) 113 | 114 | 115 | def cifar_inputs(data_dir, batch_size): 116 | with tf.device("/cpu:0"): 117 | filenames = [os.path.join(data_dir, "data_batch_%d.bin" % i) for i in range(1, 6)] 118 | filename_queue = tf.train.string_input_producer(filenames) 119 | image_list = [read_cifar10(filename_queue).uint8image for _ in range(5)] 120 | images = tf.train.shuffle_batch_join( 121 | [image_list], 122 | batch_size=batch_size, 123 | capacity=20000 + 3 * batch_size, 124 | min_after_dequeue=20000)[0] 125 | return images 126 | -------------------------------------------------------------------------------- /tf_utils/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"): 8 | def _assign(op): 9 | node_def = op if isinstance(op, tf.NodeDef) else op.node_def 10 | if node_def.op == "Variable": 11 | return ps_dev 12 | else: 13 | return "/gpu:%d" % gpu 14 | return _assign 15 | 16 | 17 | def find_trainable_variables(key): 18 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, ".*{}.*".format(key)) 19 | 20 | 21 | def split(x, split_dim, split_sizes): 22 | n = len(list(x.get_shape())) 23 | dim_size = np.sum(split_sizes) 24 | assert int(x.get_shape()[split_dim]) == dim_size 25 | ids = np.cumsum([0] + split_sizes) 26 | ids[-1] = -1 27 | begin_ids = ids[:-1] 28 | 29 | ret = [] 30 | for i in range(len(split_sizes)): 31 | cur_begin = np.zeros([n], dtype=np.int32) 32 | cur_begin[split_dim] = begin_ids[i] 33 | cur_end = np.zeros([n], dtype=np.int32) - 1 34 | cur_end[split_dim] = split_sizes[i] 35 | ret += [tf.slice(x, cur_begin, cur_end)] 36 | return ret 37 | 38 | 39 | def load_from_checkpoint(saver, logdir): 40 | sess = tf.get_default_session() 41 | ckpt = tf.train.get_checkpoint_state(logdir) 42 | if ckpt and ckpt.model_checkpoint_path: 43 | if os.path.isabs(ckpt.model_checkpoint_path): 44 | # Restores from checkpoint with absolute path. 45 | saver.restore(sess, ckpt.model_checkpoint_path) 46 | else: 47 | # Restores from checkpoint with relative path. 48 | saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path)) 49 | return True 50 | return False 51 | 52 | 53 | class CheckpointLoader(object): 54 | def __init__(self, saver, global_step, logdir): 55 | self.saver = saver 56 | self.global_step_tensor = global_step 57 | self.logdir = logdir 58 | # TODO(rafal): make it restart-proof? 59 | self.last_global_step = 0 60 | 61 | def load_checkpoint(self): 62 | while True: 63 | if load_from_checkpoint(self.saver, self.logdir): 64 | global_step = int(self.global_step_tensor.eval()) 65 | if global_step <= self.last_global_step: 66 | print("Waiting for a new checkpoint...") 67 | time.sleep(60) 68 | continue 69 | print("Succesfully loaded model at step=%s." % global_step) 70 | else: 71 | print("No checkpoint file found. Waiting...") 72 | time.sleep(60) 73 | continue 74 | self.last_global_step = global_step 75 | return True 76 | 77 | 78 | def average_grads(tower_grads): 79 | def average_dense(grad_and_vars): 80 | if len(grad_and_vars) == 1: 81 | return grad_and_vars[0][0] 82 | 83 | grad = grad_and_vars[0][0] 84 | for g, _ in grad_and_vars[1:]: 85 | grad += g 86 | return grad / len(grad_and_vars) 87 | 88 | def average_sparse(grad_and_vars): 89 | if len(grad_and_vars) == 1: 90 | return grad_and_vars[0][0] 91 | 92 | indices = [] 93 | values = [] 94 | for g, _ in grad_and_vars: 95 | indices += [g.indices] 96 | values += [g.values] 97 | indices = tf.concat(0, indices) 98 | values = tf.concat(0, values) 99 | return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape) 100 | 101 | average_grads = [] 102 | for grad_and_vars in zip(*tower_grads): 103 | if grad_and_vars[0][0] is None: 104 | grad = None 105 | elif isinstance(grad_and_vars[0][0], tf.IndexedSlices): 106 | grad = average_sparse(grad_and_vars) 107 | else: 108 | grad = average_dense(grad_and_vars) 109 | # Keep in mind that the Variables are redundant because they are shared 110 | # across towers. So .. we will just return the first tower's pointer to 111 | # the Variable. 112 | v = grad_and_vars[0][1] 113 | grad_and_var = (grad, v) 114 | average_grads.append(grad_and_var) 115 | return average_grads 116 | 117 | 118 | def img_stretch(img): 119 | img = img.astype(np.float32) 120 | img -= np.min(img) 121 | img /= np.max(img) + 1e-12 122 | return img 123 | 124 | 125 | def img_tile(imgs, aspect_ratio=1.0, tile_shape=None, border=1, 126 | border_color=0): 127 | """ Tile images in a grid. 128 | If tile_shape is provided only as many images as specified in tile_shape 129 | will be included in the output. 130 | """ 131 | 132 | # Prepare images 133 | imgs = np.asarray(imgs) 134 | if imgs.ndim != 3 and imgs.ndim != 4: 135 | raise ValueError('imgs has wrong number of dimensions.') 136 | n_imgs = imgs.shape[0] 137 | 138 | # Grid shape 139 | img_shape = np.array(imgs.shape[1:3]) 140 | if tile_shape is None: 141 | img_aspect_ratio = img_shape[1] / float(img_shape[0]) 142 | aspect_ratio *= img_aspect_ratio 143 | tile_height = int(np.ceil(np.sqrt(n_imgs * aspect_ratio))) 144 | tile_width = int(np.ceil(np.sqrt(n_imgs / aspect_ratio))) 145 | grid_shape = np.array((tile_height, tile_width)) 146 | else: 147 | assert len(tile_shape) == 2 148 | grid_shape = np.array(tile_shape) 149 | 150 | # Tile image shape 151 | tile_img_shape = np.array(imgs.shape[1:]) 152 | tile_img_shape[:2] = (img_shape[:2] + border) * grid_shape[:2] - border 153 | 154 | # Assemble tile image 155 | tile_img = np.empty(tile_img_shape) 156 | tile_img[:] = border_color 157 | for i in range(grid_shape[0]): 158 | for j in range(grid_shape[1]): 159 | img_idx = j + i * grid_shape[1] 160 | if img_idx >= n_imgs: 161 | # No more images - stop filling out the grid. 162 | break 163 | img = imgs[img_idx] 164 | yoff = (img_shape[0] + border) * i 165 | xoff = (img_shape[1] + border) * j 166 | tile_img[yoff:yoff + img_shape[0], xoff:xoff + img_shape[1], ...] = img 167 | 168 | return tile_img 169 | 170 | 171 | # Fixes supervisor to start queue runners before initializing the model. 172 | # TODO(rafal): Send a patch to main tensorflow repo. 173 | class NotBuggySupervisor(tf.train.Supervisor): 174 | def prepare_or_wait_for_session(self, master="", config=None, 175 | wait_for_checkpoint=False, 176 | max_wait_secs=7200, 177 | start_standard_services=True): 178 | """Make sure the model is ready to be used. 179 | 180 | Create a session on 'master', recovering or initializing the model as 181 | needed, or wait for a session to be ready. If running as the chief 182 | and `start_standard_service` is set to True, also call the session 183 | manager to start the standard services. 184 | 185 | Args: 186 | master: name of the TensorFlow master to use. See the `tf.Session` 187 | constructor for how this is interpreted. 188 | config: Optional ConfigProto proto used to configure the session, 189 | which is passed as-is to create the session. 190 | wait_for_checkpoint: Whether we should wait for the availability of a 191 | checkpoint before creating Session. Defaults to False. 192 | max_wait_secs: Maximum time to wait for the session to become available. 193 | start_standard_services: Whether to start the standard services and the 194 | queue runners. 195 | 196 | Returns: 197 | A Session object that can be used to drive the model. 198 | """ 199 | # For users who recreate the session with prepare_or_wait_for_session(), we 200 | # need to clear the coordinator's stop_event so that threads managed by the 201 | # coordinator can run. 202 | self._coord.clear_stop() 203 | 204 | if self._is_chief: 205 | sess, initialized = self._session_manager.recover_session( 206 | master, self.saver, checkpoint_dir=self._logdir, 207 | wait_for_checkpoint=wait_for_checkpoint, 208 | max_wait_secs=max_wait_secs, config=config) 209 | 210 | if start_standard_services: 211 | print("Starting queue runners") 212 | self.start_queue_runners(sess) 213 | 214 | if not initialized: 215 | if not self.init_op and not self._init_fn: 216 | raise RuntimeError("Model is not initialized and no init_op or " 217 | "init_fn was given") 218 | if self.init_op: 219 | sess.run(self.init_op, feed_dict=self._init_feed_dict) 220 | if self._init_fn: 221 | self._init_fn(sess) 222 | not_ready = self._session_manager._model_not_ready(sess) 223 | if not_ready: 224 | raise RuntimeError("Init operations did not make model ready. " 225 | "Init op: %s, init fn: %s, error: %s" 226 | % (self.init_op.name, self._init_fn, not_ready)) 227 | 228 | self._write_graph() 229 | if start_standard_services: 230 | self.start_standard_services(sess) 231 | else: 232 | sess = self._session_manager.wait_for_session(master, 233 | config=config, 234 | max_wait_secs=max_wait_secs) 235 | if start_standard_services: 236 | self.start_queue_runners(sess) 237 | return sess 238 | -------------------------------------------------------------------------------- /tf_utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from cifar10_data import load, cifar_preloaded 3 | import numpy as np 4 | 5 | 6 | def get_inputs(dataset, mode, batch_size, image_size): 7 | if dataset == "cifar10": 8 | assert image_size == 32 9 | images = get_images(dataset, mode, image_size).x 10 | return cifar_preloaded(images, batch_size) 11 | # return cifar_inputs("/home/rafal/data/cifar-10-batches-bin", batch_size) 12 | 13 | 14 | class Dataset(object): 15 | def __init__(self, x, deterministic=False): 16 | self.x = x 17 | self.n = x.shape[0] 18 | self._deterministic = deterministic 19 | self._next_id = 0 20 | self.shuffle() 21 | 22 | def shuffle(self): 23 | if not self._deterministic: 24 | perm = np.arange(self.n) 25 | np.random.shuffle(perm) 26 | self.x = self.x[perm] 27 | self._next_id = 0 28 | 29 | def next_batch(self, batch_size): 30 | if self._next_id + batch_size > self.n: 31 | self.shuffle() 32 | 33 | cur_id = self._next_id 34 | self._next_id += batch_size 35 | 36 | return self.x[cur_id:cur_id+batch_size] 37 | 38 | 39 | class CIFAR(object): 40 | def __init__(self, data_dir, deterministic=False): 41 | self.train = Dataset(load(data_dir, "train")[0], deterministic) 42 | self.test = Dataset(load(data_dir, "test")[0], deterministic) 43 | 44 | 45 | def get_images(dataset, mode, image_size, deterministic=False): 46 | if dataset == "cifar10": 47 | path = os.environ['CIFAR10_PATH'] 48 | cifar = CIFAR(path, deterministic) 49 | if mode == "train": 50 | return cifar.train 51 | if mode == "test": 52 | return cifar.test 53 | -------------------------------------------------------------------------------- /tf_utils/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def gaussian_diag_logps(mean, logvar, sample=None): 6 | if sample is None: 7 | noise = tf.random_normal(tf.shape(mean)) 8 | sample = mean + tf.exp(0.5 * logvar) * noise 9 | 10 | return -0.5 * (np.log(2 * np.pi) + logvar + tf.square(sample - mean) / tf.exp(logvar)) 11 | 12 | 13 | class DiagonalGaussian(object): 14 | 15 | def __init__(self, mean, logvar, sample=None): 16 | self.mean = mean 17 | self.logvar = logvar 18 | 19 | if sample is None: 20 | noise = tf.random_normal(tf.shape(mean)) 21 | sample = mean + tf.exp(0.5 * logvar) * noise 22 | self.sample = sample 23 | 24 | def logps(self, sample): 25 | return gaussian_diag_logps(self.mean, self.logvar, sample) 26 | 27 | 28 | def discretized_logistic(mean, logscale, binsize=1 / 256.0, sample=None): 29 | scale = tf.exp(logscale) 30 | sample = (tf.floor(sample / binsize) * binsize - mean) / scale 31 | logp = tf.log(tf.sigmoid(sample + binsize / scale) - tf.sigmoid(sample) + 1e-7) 32 | return tf.reduce_sum(logp, [1, 2, 3]) 33 | 34 | 35 | def logsumexp(x): 36 | x_max = tf.reduce_max(x, [1], keep_dims=True) 37 | return tf.reshape(x_max, [-1]) + tf.log(tf.reduce_sum(tf.exp(x - x_max), [1])) 38 | 39 | 40 | def repeat(x, n): 41 | if n == 1: 42 | return x 43 | 44 | shape = map(int, x.get_shape().as_list()) 45 | shape[0] *= n 46 | idx = tf.range(tf.shape(x)[0]) 47 | idx = tf.reshape(idx, [-1, 1]) 48 | idx = tf.tile(idx, [1, n]) 49 | idx = tf.reshape(idx, [-1]) 50 | x = tf.gather(x, idx) 51 | x.set_shape(shape) 52 | return x 53 | 54 | 55 | def compute_lowerbound(log_pxz, sum_kl_costs, k=1): 56 | if k == 1: 57 | return sum_kl_costs - log_pxz 58 | 59 | # log 1/k \sum p(x | z) * p(z) / q(z | x) = -log(k) + logsumexp(log p(x|z) + log p(z) - log q(z|x)) 60 | log_pxz = tf.reshape(log_pxz, [-1, k]) 61 | sum_kl_costs = tf.reshape(sum_kl_costs, [-1, k]) 62 | return - (- tf.log(float(k)) + logsumexp(log_pxz - sum_kl_costs)) 63 | -------------------------------------------------------------------------------- /tf_utils/distributions_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from distributions import logsumexp, compute_lowerbound, repeat 4 | 5 | 6 | class DistributionsTestCase(tf.test.test_util.TensorFlowTestCase): 7 | def test_logsumexp(self): 8 | a = np.arange(10) 9 | res = np.log(np.sum(np.exp(a))) 10 | 11 | with self.test_session(): 12 | res_tf = logsumexp(a.astype(np.float32).reshape([1, -1])).eval() 13 | self.assertEqual(res, res_tf) 14 | 15 | def test_lowerbound(self): 16 | a = np.log(np.array([0.3, 0.3, 0.3, 0.3], np.float32).reshape([1, -1])) 17 | b = np.log(np.array([0.1, 0.5, 0.9, 0.6], np.float32).reshape([1, -1])) 18 | 19 | res = - (- np.log(4) + np.log(np.sum(np.exp(a - b)))) 20 | with self.test_session(): 21 | res_tf = tf.reduce_sum(compute_lowerbound(a, b, 4)).eval() 22 | self.assertAlmostEqual(res, res_tf, places=4) 23 | 24 | def test_lowerbound2(self): 25 | a = np.log(np.array([0.3, 0.3, 0.3, 0.3], np.float32).reshape([-1, 1])) 26 | b = np.log(np.array([0.1, 0.5, 0.9, 0.6], np.float32).reshape([-1, 1])) 27 | 28 | res = (b - a).sum() 29 | with self.test_session(): 30 | res_tf = tf.reduce_sum(compute_lowerbound(a, b, 1)).eval() 31 | self.assertAlmostEqual(res, res_tf, places=4) 32 | 33 | def test_repeat(self): 34 | a = np.random.randn(10, 5, 2) 35 | repeated_a = np.repeat(a, 2, axis=0) 36 | with self.test_session(): 37 | repeated_a_tf = repeat(a, 2).eval() 38 | self.assertAllClose(repeated_a, repeated_a_tf) 39 | -------------------------------------------------------------------------------- /tf_utils/hparams.py: -------------------------------------------------------------------------------- 1 | class HParams(object): 2 | 3 | def __init__(self, **kwargs): 4 | self._items = {} 5 | for k, v in kwargs.items(): 6 | self._set(k, v) 7 | 8 | def _set(self, k, v): 9 | self._items[k] = v 10 | setattr(self, k, v) 11 | 12 | def parse(self, str_value): 13 | hps = HParams(**self._items) 14 | for entry in str_value.strip().split(","): 15 | entry = entry.strip() 16 | if not entry: 17 | continue 18 | key, sep, value = entry.partition("=") 19 | if not sep: 20 | raise ValueError("Unable to parse: %s" % entry) 21 | default_value = hps._items[key] 22 | if isinstance(default_value, bool): 23 | hps._set(key, value.lower() == "true") 24 | elif isinstance(default_value, int): 25 | hps._set(key, int(value)) 26 | elif isinstance(default_value, float): 27 | hps._set(key, float(value)) 28 | else: 29 | hps._set(key, value) 30 | return hps 31 | -------------------------------------------------------------------------------- /tf_utils/hparams_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from hparams import HParams 3 | 4 | 5 | class HParamsTestCase(unittest.TestCase): 6 | def test_basic(self): 7 | hps = HParams(int_value=13, float_value=17.5, bool_value=True, str_value="test") 8 | self.assertEqual(hps.int_value, 13) 9 | self.assertEqual(hps.float_value, 17.5) 10 | self.assertEqual(hps.bool_value, True) 11 | self.assertEqual(hps.str_value, "test") 12 | 13 | def test_parse(self): 14 | hps = HParams(int_value=13, float_value=17.5, bool_value=True, str_value="test") 15 | self.assertEqual(hps.parse("int_value=10").int_value, 10) 16 | self.assertEqual(hps.parse("float_value=10").float_value, 10) 17 | self.assertEqual(hps.parse("float_value=10.3").float_value, 10.3) 18 | self.assertEqual(hps.parse("bool_value=true").bool_value, True) 19 | self.assertEqual(hps.parse("bool_value=True").bool_value, True) 20 | self.assertEqual(hps.parse("bool_value=false").bool_value, False) 21 | self.assertEqual(hps.parse("str_value=value").str_value, "value") 22 | 23 | if __name__ == '__main__': 24 | unittest.main() 25 | -------------------------------------------------------------------------------- /tf_utils/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.contrib.framework.python.ops import arg_scope, add_arg_scope 4 | 5 | 6 | @add_arg_scope 7 | def linear(name, x, num_units, init_scale=1., init=False): 8 | with tf.variable_scope(name): 9 | if init: 10 | # data based initialization of parameters 11 | v = tf.get_variable("V", [int(x.get_shape()[1]), num_units], tf.float32, 12 | tf.random_normal_initializer(0, 0.05)) 13 | v_norm = tf.nn.l2_normalize(v.initialized_value(), [0]) 14 | x_init = tf.matmul(x, v_norm) 15 | m_init, v_init = tf.nn.moments(x_init, [0]) 16 | scale_init = init_scale / tf.sqrt(v_init + 1e-10) 17 | _ = tf.get_variable("g", initializer=scale_init) 18 | _ = tf.get_variable("b", initializer=-m_init * scale_init) 19 | return tf.reshape(scale_init, [1, num_units]) * (x_init - tf.reshape(m_init, [1, num_units])) 20 | else: 21 | v = tf.get_variable("V", [int(x.get_shape()[1]), num_units]) 22 | g = tf.get_variable("g", [num_units]) 23 | b = tf.get_variable("b", [num_units]) 24 | 25 | # use weight normalization (Salimans & Kingma, 2016) 26 | x = tf.matmul(x, v) 27 | scaler = g / tf.sqrt(tf.reduce_sum(tf.square(v), [0])) 28 | return tf.reshape(scaler, [1, num_units]) * x + tf.reshape(b, [1, num_units]) 29 | 30 | 31 | @add_arg_scope 32 | def conv2d(name, x, num_filters, filter_size=(3, 3), stride=(1, 1), pad="SAME", init_scale=0.1, init=False, 33 | mask=None, dtype=tf.float32, **_): 34 | stride_shape = [1, 1, stride[0], stride[1]] 35 | filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[1]), num_filters] 36 | 37 | with tf.variable_scope(name): 38 | if init: 39 | # data based initialization of parameters 40 | v = tf.get_variable("V", filter_shape, dtype, tf.random_normal_initializer(0, 0.05, dtype=dtype)) 41 | v = v.initialized_value() 42 | if mask is not None: # Used for auto-regressive convolutions. 43 | v = mask * v 44 | 45 | v_norm = tf.nn.l2_normalize(v, [0, 1, 2]) 46 | x_init = tf.nn.conv2d(x, v_norm, stride_shape, pad, data_format="NCHW") 47 | m_init, v_init = tf.nn.moments(x_init, [0, 2, 3]) 48 | scale_init = init_scale / tf.sqrt(v_init + 1e-10) 49 | _ = tf.get_variable("g", initializer=tf.log(scale_init) / 3.0) 50 | _ = tf.get_variable("b", initializer=-m_init * scale_init) 51 | return tf.reshape(scale_init, [1, -1, 1, 1]) * (x_init - tf.reshape(m_init, [1, -1, 1, 1])) 52 | else: 53 | v = tf.get_variable("V", filter_shape) 54 | g = tf.get_variable("g", [num_filters]) 55 | b = tf.get_variable("b", [num_filters]) 56 | if mask is not None: 57 | v = mask * v 58 | 59 | # use weight normalization (Salimans & Kingma, 2016) 60 | w = tf.reshape(tf.exp(g), [1, 1, 1, num_filters]) * tf.nn.l2_normalize(v, [0, 1, 2]) 61 | 62 | # calculate convolutional layer output 63 | b = tf.reshape(b, [1, -1, 1, 1]) 64 | return tf.nn.conv2d(x, w, stride_shape, pad, data_format="NCHW") + b 65 | 66 | 67 | def my_deconv2d(x, filters, strides): 68 | input_shape = x.get_shape() 69 | output_shape = [int(input_shape[0]), int(filters.get_shape()[2]), 70 | int(input_shape[2] * strides[2]), int(input_shape[2] * strides[3])] 71 | x = tf.transpose(x, (0, 2, 3, 1)) # go to NHWC data layout 72 | output_shape = [output_shape[0], output_shape[2], output_shape[3], output_shape[1]] 73 | strides = [strides[0], strides[2], strides[3], strides[1]] 74 | x = tf.nn.conv2d_transpose(x, filters, output_shape=output_shape, strides=strides, padding="SAME") 75 | x = tf.transpose(x, (0, 3, 1, 2)) # back to NCHW 76 | return x 77 | 78 | 79 | @add_arg_scope 80 | def deconv2d(name, x, num_filters, filter_size=(3, 3), stride=(2, 2), pad="SAME", init_scale=0.1, init=False, 81 | mask=None, dtype=tf.float32, **_): 82 | stride_shape = [1, 1, stride[0], stride[1]] 83 | filter_shape = [filter_size[0], filter_size[1], num_filters, int(x.get_shape()[1])] 84 | 85 | with tf.variable_scope(name): 86 | if init: 87 | # data based initialization of parameters 88 | v = tf.get_variable("V", filter_shape, dtype, tf.random_normal_initializer(0, 0.05, dtype=dtype)) 89 | v = v.initialized_value() 90 | if mask is not None: # Used for auto-regressive convolutions. 91 | v = mask * v 92 | 93 | v_norm = tf.nn.l2_normalize(v, [0, 1, 2]) 94 | x_init = my_deconv2d(x, v_norm, stride_shape) 95 | m_init, v_init = tf.nn.moments(x_init, [0, 2, 3]) 96 | scale_init = init_scale / tf.sqrt(v_init + 1e-10) 97 | _ = tf.get_variable("g", initializer=tf.log(scale_init) / 3.0) 98 | _ = tf.get_variable("b", initializer=-m_init*scale_init) 99 | return tf.reshape(scale_init, [1, -1, 1, 1]) * (x_init - tf.reshape(m_init, [1, -1, 1, 1])) 100 | else: 101 | v = tf.get_variable("V", filter_shape) 102 | g = tf.get_variable("g", [num_filters]) 103 | b = tf.get_variable("b", [num_filters]) 104 | if mask is not None: 105 | v = mask * v 106 | 107 | # use weight normalization (Salimans & Kingma, 2016) 108 | w = tf.reshape(tf.exp(g), [1, 1, num_filters, 1]) * tf.nn.l2_normalize(v, [0, 1, 2]) 109 | 110 | # calculate convolutional layer output 111 | b = tf.reshape(b, [1, -1, 1, 1]) 112 | return my_deconv2d(x, w, stride_shape) + b 113 | 114 | 115 | def get_linear_ar_mask(n_in, n_out, zerodiagonal=False): 116 | assert n_in % n_out == 0 or n_out % n_in == 0, "%d - %d" % (n_in, n_out) 117 | 118 | mask = np.ones([n_in, n_out], dtype=np.float32) 119 | if n_out >= n_in: 120 | k = n_out / n_in 121 | for i in range(n_in): 122 | mask[i + 1:, i * k:(i + 1) * k] = 0 123 | if zerodiagonal: 124 | mask[i:i + 1, i * k:(i + 1) * k] = 0 125 | else: 126 | k = n_in / n_out 127 | for i in range(n_out): 128 | mask[(i + 1) * k:, i:i + 1] = 0 129 | if zerodiagonal: 130 | mask[i * k:(i + 1) * k:, i:i + 1] = 0 131 | return mask 132 | 133 | 134 | def get_conv_ar_mask(h, w, n_in, n_out, zerodiagonal=False): 135 | l = (h - 1) / 2 136 | m = (w - 1) / 2 137 | mask = np.ones([h, w, n_in, n_out], dtype=np.float32) 138 | mask[:l, :, :, :] = 0 139 | mask[l, :m, :, :] = 0 140 | mask[l, m, :, :] = get_linear_ar_mask(n_in, n_out, zerodiagonal) 141 | return mask 142 | 143 | 144 | @add_arg_scope 145 | def ar_conv2d(name, x, num_filters, filter_size=(3, 3), stride=(1, 1), pad="SAME", init_scale=1., 146 | zerodiagonal=True, **_): 147 | h = filter_size[0] 148 | w = filter_size[1] 149 | n_in = int(x.get_shape()[1]) 150 | n_out = num_filters 151 | 152 | mask = tf.constant(get_conv_ar_mask(h, w, n_in, n_out, zerodiagonal)) 153 | with arg_scope([conv2d]): 154 | return conv2d(name, x, num_filters, filter_size, stride, pad, init_scale, mask=mask) 155 | 156 | 157 | # Auto-Regressive convnet with l2 normalization 158 | @add_arg_scope 159 | def ar_multiconv2d(name, x, context, n_h, n_out, nl=tf.nn.elu, **_): 160 | with tf.variable_scope(name), arg_scope([ar_conv2d]): 161 | for i, size in enumerate(n_h): 162 | x = ar_conv2d("layer_%d" % i, x, size, zerodiagonal=False) 163 | if i == 0: 164 | x += context 165 | x = nl(x) 166 | return [ar_conv2d("layer_out_%d" % i, x, size, zerodiagonal=True) for i, size in enumerate(n_out)] 167 | 168 | 169 | def resize_nearest_neighbor(x, scale): 170 | input_shape = map(int, x.get_shape().as_list()) 171 | size = [int(input_shape[2] * scale), int(input_shape[3] * scale)] 172 | x = tf.transpose(x, (0, 2, 3, 1)) # go to NHWC data layout 173 | x = tf.image.resize_nearest_neighbor(x, size) 174 | x = tf.transpose(x, (0, 3, 1, 2)) # back to NCHW 175 | return x 176 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import graphy as G 2 | import numpy as np 3 | import time, sys, os 4 | from sacred import Experiment 5 | from __builtin__ import False 6 | 7 | ex = Experiment('Deep VAE') 8 | 9 | @ex.config 10 | def config(): 11 | 12 | # optimization: 13 | n_reporting = 10 #epochs between reporting 14 | px = 'logistic' 15 | pad_x = 0 16 | 17 | # datatype 18 | problem = 'cifar10' 19 | n_batch = 16 # Minibatch size 20 | if problem == 'mnist': 21 | shape_x = (1,28,28) 22 | px = 'bernoulli' 23 | pad_x = 2 24 | n_h = 64 25 | n_z = 32 26 | if problem == 'cifar10': 27 | shape_x = (3,32,32) 28 | n_h = 160 29 | n_z = 32 30 | if problem == 'svhn': 31 | shape_x = (3,32,32) 32 | n_reporting = 1 33 | n_h = 160 34 | n_z = 32 35 | if problem == 'lfw': 36 | shape_x = (3,64,48) 37 | n_h = 160 38 | n_z = 32 39 | 40 | n_h1 = n_h 41 | n_h2 = n_h 42 | 43 | # dataset 44 | n_train = 0 45 | 46 | # model 47 | model_type = 'cvae1' 48 | 49 | if model_type == 'cvae1': 50 | depths = [2,2] 51 | 52 | margs = { 53 | 'shape_x': shape_x, 54 | 'depths': depths, 55 | 'n_h1': n_h1, 56 | 'n_h2': n_h2, 57 | 'n_z': n_z, 58 | 'prior': 'diag', 59 | 'posterior': 'down_diag', 60 | 'px': px, 61 | 'nl': 'elu', 62 | 'kernel_x': (5,5), 63 | 'kernel_h': (3,3), 64 | 'kl_min': 0.25, 65 | 'optim': 'adamax', 66 | 'alpha': 0.002, 67 | 'beta1': 0.1, 68 | 'pad_x': pad_x, 69 | 'weightsharing': False, 70 | 'depth_ar': 1, 71 | 'downsample_type': 'nn' 72 | } 73 | 74 | if model_type == 'simplecvae1': 75 | depths = [2,2,2] 76 | widths = [32,64,128] 77 | 78 | margs = { 79 | 'shape_x': shape_x, 80 | 'depths': depths, 81 | 'widths': widths, 82 | 'n_z': n_z, 83 | 'prior': 'diag', 84 | 'posterior': 'down_diag', 85 | 'px': px, 86 | 'nl': 'elu', 87 | 'kernel_x': (5,5), 88 | 'kernel_h': (3,3), 89 | 'kl_min': 0.25, 90 | 'optim': 'adamax', 91 | 'alpha': 0.002, 92 | 'beta1': 0.1, 93 | 'pad_x': pad_x, 94 | 'weightsharing': False 95 | } 96 | 97 | # model loading/saving 98 | save_model = True 99 | load_model_path = None 100 | load_model_complete = True # Whether loaded parameters are complete 101 | 102 | # Estimate the marginal likelihood 103 | est_marglik = 0. 104 | est_marglik_data = 'valid' 105 | 106 | def init_logs(): 107 | global logpath, logdir 108 | # Create log directory 109 | logdir = str(time.time()) 110 | logpath = os.environ['ML_LOG_PATH']+'/'+logdir+'/' 111 | print 'Logpath: '+logpath 112 | os.makedirs(logpath) 113 | # Log stdout messages to file 114 | sys.stdout = G.misc.logger.Logger(logpath+"log.txt") 115 | # Clone local source to logdir 116 | os.system("rsync -au --include '*/' --include '*.py' --exclude '*' . "+logpath+"source") 117 | with open(logpath+"source/run.sh", 'w') as f: 118 | f.write("python "+" ".join(sys.argv)+"\n") 119 | os.chmod(logpath+"source/run.sh", 0700) 120 | 121 | @ex.capture 122 | def construct_model(data_init, model_type, margs, load_model_path, load_model_complete, n_batch): 123 | import models 124 | margs['data_init'] = data_init 125 | if model_type == 'fcvae1': 126 | model = models.fcvae(**margs) 127 | if model_type == 'cvae1': 128 | model = models.cvae1(**margs) 129 | if model_type == 'simplecvae1': 130 | import simplemodel 131 | model = simplemodel.simplecvae1(**margs) 132 | 133 | if load_model_path != None: 134 | print 'Loading existing model at '+load_model_path 135 | _w = G.ndict.np_loadz(load_model_path+'/weights.ndict.tar.gz') 136 | 137 | G.ndict.set_value(model.w, _w, load_model_complete) 138 | G.ndict.set_value(model.w_avg, _w, load_model_complete) 139 | 140 | return model 141 | 142 | @ex.capture 143 | def get_data(problem, n_train, n_batch): 144 | 145 | if problem == 'cifar10': 146 | # Load data 147 | data_train, data_valid = G.misc.data.cifar10(False) 148 | if problem == 'svhn': 149 | # Load data 150 | data_train, data_valid = G.misc.data.svhn(False, True) 151 | elif problem == 'mnist': 152 | # Load data 153 | validset = False 154 | if validset: 155 | data_train, data_valid, data_test = G.misc.data.mnist_binarized(validset, False) 156 | else: 157 | data_train, data_valid = G.misc.data.mnist_binarized(validset, False) 158 | data_train['x'] = data_train['x'].reshape((-1,1,28,28)) 159 | data_valid['x'] = data_valid['x'].reshape((-1,1,28,28)) 160 | elif problem == 'lfw': 161 | data_train = G.misc.data.lfw(False,True) 162 | data_valid = G.ndict.getRows(data_train, 0, 1000) 163 | 164 | 165 | data_init = {'x':data_train['x'][:n_batch]} 166 | 167 | if n_train > 0: 168 | data_train = G.ndict.getRows(data_train, 0, n_train) 169 | data_valid = G.ndict.getRows(data_valid, 0, n_train) 170 | 171 | return data_train, data_valid, data_init 172 | 173 | @ex.automain 174 | def train(shape_x, problem, n_batch, n_train, n_reporting, save_model, est_marglik, est_marglik_data, margs): 175 | 176 | global logpath 177 | 178 | # Initialize logs 179 | init_logs() 180 | 181 | # Get data 182 | data_train, data_valid, data_init = get_data() 183 | 184 | # Construct model 185 | model = construct_model(data_init) 186 | 187 | # Estimate the marginal likelihood 188 | if est_marglik > 0: 189 | if est_marglik_data == 'valid': 190 | data = data_valid 191 | elif est_marglik_data == 'train': 192 | data = data_train 193 | # Correction since model's actual cost is divided by this factor 194 | correctionfactor = - (np.prod(shape_x) * np.log(2.)) 195 | obj_test = [] 196 | for i in range(est_marglik): 197 | cost = model.eval(data, n_batch=n_batch, randomorder=False)['cost'] * correctionfactor 198 | obj_test.append(cost) 199 | _obj = np.vstack(obj_test) 200 | _max = np.max(_obj, axis=0) 201 | _est = np.log(np.exp(_obj - _max).mean(axis=0)) + _max 202 | if i%1 == 0: 203 | print 'Estimate of logp(x) after', i+1, 'samples:', _est.mean() / correctionfactor 204 | raise Exception() 205 | sys.exit() 206 | 207 | # Report 208 | cost_best = [None] 209 | eps_fixed = model.eps({'n_batch':100}) 210 | def report(epoch, dt, cost): 211 | if np.isnan(cost): 212 | raise Exception('NaN detected!!') 213 | 214 | results_valid = model.eval(data_valid, n_batch=n_batch) 215 | for i in results_valid: results_valid[i] = results_valid[i].mean() 216 | 217 | _w = G.ndict.get_value(model.w_avg) 218 | G.ndict.np_savez(_w, logpath+'weights') 219 | 220 | if cost_best[0] is None or results_valid['cost'] < cost_best[0]: 221 | cost_best[0] = results_valid['cost'] 222 | if save_model: 223 | G.ndict.np_savez(_w, logpath+'weights_best') 224 | 225 | if True: 226 | # Write all results to file 227 | with open(logpath+"results.txt", "a") as log: 228 | if epoch == 0: 229 | log.write("Epoch "+" ".join(map(str, results_valid.keys())) + "\n") 230 | log.write(str(epoch)+" "+" ".join(map(str, results_valid.values())) + "\n") 231 | 232 | if True: 233 | eps = model.eps({'n_batch':100}) 234 | image = model.decode(eps) 235 | G.graphics.save_raster(image, logpath+'sample_'+str(epoch)+'.png') 236 | image = model.decode(eps_fixed) 237 | G.graphics.save_raster(image, logpath+'sample_fixed1_'+str(epoch)+'.png') 238 | 239 | #eps_fixed_copy = G.ndict.clone(eps_fixed) 240 | #for i in range(len(eps_fixed)): 241 | # eps_fixed_copy[''] 242 | 243 | if epoch == 0: 244 | print 'logdir:', 't:', 'Epoch:', 'Train cost:', 'Valid cost:', 'Best:', 'log(stdev) of p(x|z):' 245 | 246 | logsd_x = 0. 247 | if 'logsd_x' in model.w_avg: 248 | logsd_x = model.w_avg['logsd_x'].get_value() 249 | print logdir, '%.2f'%dt, epoch, '%.5f'%cost, '%.5f'%results_valid['cost'], '%.5f'%cost_best[0], logsd_x 250 | 251 | print 'Training' 252 | 253 | for epoch in xrange(1000000): 254 | t0 = time.time() 255 | 256 | result = model.train(data_train, n_batch=n_batch) 257 | 258 | if epoch <= 10 or epoch%n_reporting == 0: 259 | report(epoch, time.time()-t0, cost=np.mean(result)) 260 | 261 | 262 | --------------------------------------------------------------------------------