├── .gitmodules ├── AIS.ipynb ├── README.md ├── ops.py ├── simple_example.ipynb └── theanorc_cpu /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "eval_gen"] 2 | path = eval_gen 3 | url = https://github.com/DmitryUlyanov/eval_gen 4 | -------------------------------------------------------------------------------- /AIS.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "os.environ['THEANORC'] = 'theanorc_cpu'\n", 13 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n", 14 | "\n", 15 | "import torch\n", 16 | "import torch.nn as nn\n", 17 | "import theano\n", 18 | "import numpy as np\n", 19 | "import theano \n", 20 | "import theano.tensor as T\n", 21 | "\n", 22 | "print (theano.__version__)" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": { 29 | "collapsed": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "from ops import pytorch_wrapper\n", 34 | "import sys\n", 35 | "sys.path.append(\"./eval_gen/\")\n", 36 | "torch.backends.cudnn.benchmark = True" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "# Load data" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": { 50 | "collapsed": true 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "import h5py\n", 55 | "f = h5py.File('/sdh/data_fuel/mnist.hdf5')\n", 56 | "\n", 57 | "num_test = 24\n", 58 | "num_samples= 16\n", 59 | "\n", 60 | "X = f['features'].value[60000:]\n", 61 | "\n", 62 | "permutation = np.random.RandomState(seed=2919).permutation(X.shape[0])\n", 63 | "X= X[permutation][:num_test]" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "# Load model" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": { 77 | "collapsed": true, 78 | "scrolled": true 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "dtype= torch.cuda.FloatTensor\n", 83 | "age_model = torch.load('netG_epoch_25.pth')\n", 84 | "age_model.eval()\n", 85 | "age_model.type(dtype)\n", 86 | "net = age_model\n", 87 | "\n", 88 | "class NetWrapper(nn.Module):\n", 89 | " '''\n", 90 | " The AIS code needs samples to be of shape (B x 32^2)\n", 91 | " But my network produces (B x 32 x 32)\n", 92 | " So I use this wrapper.\n", 93 | " '''\n", 94 | " def __init__(self, net):\n", 95 | " super(NetWrapper, self).__init__()\n", 96 | " self.net = net\n", 97 | " \n", 98 | " def forward(self, x):\n", 99 | " return self.net(None, x.view(-1,10,1,1)).view(-1,32*32) / 2. + 0.5" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "# Go" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "collapsed": true, 114 | "scrolled": true 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "def generator(z):\n", 119 | " '''\n", 120 | " This function should define theano computational graph\n", 121 | " for evaluating net(z), for `z` -- latent vector of shape [B x Z_dim]\n", 122 | " '''\n", 123 | " \n", 124 | " f = pytorch_wrapper(NetWrapper(net), dtype=dtype)\n", 125 | " out = f(z)\n", 126 | " \n", 127 | " return out\n", 128 | "\n", 129 | "\n", 130 | "from sampling import samplers_32 as samplers\n", 131 | "lld, pf, finalstate = samplers.run_ais(generator, \n", 132 | " X, \n", 133 | " num_samples,\n", 134 | " num_steps=10000, \n", 135 | " sigma=0.03, \n", 136 | " hdim=10, \n", 137 | " L=10, \n", 138 | " epsilon=0.01, \n", 139 | " data='continuous',\n", 140 | " prior=\"normal\")" 141 | ] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "Python 2", 147 | "language": "python", 148 | "name": "python2" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 2 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython2", 160 | "version": "2.7.13" 161 | } 162 | }, 163 | "nbformat": 4, 164 | "nbformat_minor": 2 165 | } 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Pytorch inside Theano 2 | 3 | (and pytorch wrapper for [AIS](https://github.com/tonywu95/eval_gen)) 4 | 5 | This repo shows a dirty hack, how to run Pytorch graphs inside any Theano graph. Moreover, both forward and backward passes are supported. So whenever you want to mix Pytorch and Theano you may use the wrapper from this repo. 6 | 7 | In particular, I wanted to use [this code](https://github.com/tonywu95/eval_gen) to evaluate a generative model using [1]. Their code is written in Theano and my [AGE](https://arxiv.org/abs/1704.02304) model was trained using Pytorch. 8 | 9 | [1] *On the Quantitative Analysis of Decoder-Based Generative Models*, Yuhuai Wu, Yuri Burda, Ruslan Salakhutdinov, Roger Grosse, ICLR 2016 10 | 11 | ## General usage 12 | As easy as it could be. 13 | ``` 14 | from ops import pytorch_wrapper 15 | f_theano = pytorch_wrapper(f_pytorch, dtype=dtype, debug=True) 16 | ``` 17 | And then use `f_theano` in your theano graphs. See `simple_example.ipynb`. 18 | ## As AIS wrapper on MNIST dataset 19 | 20 | 0. Train your pytoch model on MNIST dataset. 21 | 22 | 1. Clone this repo with `--recursive` flag: 23 | ``` 24 | git clone --recursive https://github.com/DmitryUlyanov/pytorch_in_theano 25 | ``` 26 | 2. See `AIS.ipynb` for an example how I used it for AGE model on MNIST dataset. You will need to replace `NetWrapper(net)` in `generator(z)` with your network. 27 | 28 | 29 | If you want to compare your generative model, here are the two likelihood scores I've computed for MNIST with `z_dim=10`: 30 | 31 | | method | score | 32 | |--------|-------| 33 | | [AGE](https://arxiv.org/abs/1704.02304) | 746 | 34 | | ALI | 721 | 35 | 36 | And the results from the paper: 705 for VAE and 328 for GAN. 37 | ### On other datasets 38 | To be true, I now do not remember why I had to [modify sampler](https://github.com/DmitryUlyanov/eval_gen/commit/2347d967ef5554719cb6c4fa1a12f0a7b7903939) in `eval_gen` code. But probably it is because of shapes mismatch errors, that I struggled to figure out for a long time. So, please, before blindly running the code examine sampler file. For sure you need to put [here](https://github.com/DmitryUlyanov/eval_gen/blob/master/sampling/samplers_32.py#L109) the right value and probably change something in several other places. 39 | 40 | ## Misc 41 | 42 | Tested with python 2, `theano=0.8.2.dev-901275534cbfe3fbbe290ce85d1abf8bb9a5b203`, `pytorch=0.2.0_4`. 43 | 44 | 45 | If you find this code helpful for your research, please cite this repo: 46 | 47 | ``` 48 | @misc{Ulyanov2017_ais_wrapper, 49 | author = {Ulyanov, Dmitry}, 50 | title = {Pytorch wrapper for AIS}, 51 | year = {2017}, 52 | publisher = {GitHub}, 53 | journal = {GitHub repository}, 54 | howpublished = {\url{https://github.com/DmitryUlyanov/pytorch-in-theano}}, 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import torch 3 | import numpy as np 4 | 5 | class bp(theano.Op): 6 | ''' 7 | Theano.Op for backward pass used for `fp` op. 8 | Do not use it explicitly in your graphs. 9 | ''' 10 | def __init__(self, net, debug, dtype): 11 | self.net= net 12 | 13 | self.output_ = None 14 | self.input_ = None 15 | self.input_np_ = None 16 | 17 | self.debug = debug 18 | self.dtype = dtype 19 | 20 | __props__ = () 21 | 22 | def make_node(self, x, y): 23 | x_ = theano.tensor.as_tensor_variable(x) 24 | y_ = theano.tensor.as_tensor_variable(y) 25 | return theano.gof.Apply(self, [x_, y_], [x_.type()]) 26 | 27 | def perform(self, node, inputs, output_storage): 28 | ''' 29 | Actual backward pass computations. 30 | We will do some kind of caching: 31 | Check if the input is the same as the stored one during forward pass 32 | If it is the same -- do only backward pass, if it is different do forward pass again here 33 | ''' 34 | 35 | input = inputs[0] 36 | grad_output = inputs[1] 37 | 38 | if self.debug: print('Backward pass:') 39 | 40 | # Caching 41 | if self.input_np_ is not None: 42 | if np.all(np.allclose(inputs[0], self.input_np_)): 43 | # assume np.all(np.allclose(output_var.data.cpu().numpy(), self.output_.data.cpu().numpy())) 44 | 45 | output_var = self.output_ 46 | input_var = self.input_ 47 | 48 | if self.debug: print('\t1)Forward in backward: cached') 49 | else: 50 | assert False, 'Buffer does not match input, IT\'s A BUG' 51 | else: 52 | 53 | input_var = torch.autograd.Variable(torch.from_numpy(input).type(self.dtype), requires_grad=True) 54 | output_var = self.net(input_var) 55 | 56 | if self.debug: print('\t1)Forward in backward: compute') 57 | 58 | 59 | if self.debug: print('\t2) Backward in backward') 60 | 61 | # Backward 62 | grad = torch.from_numpy(grad_output).type(self.dtype) 63 | output_var.backward(gradient = grad) 64 | 65 | # Put result in the right place 66 | output_storage[0][0] = input_var.grad.data.cpu().numpy().astype(inputs[0].dtype) 67 | 68 | def grad(self, inputs, output_grads): 69 | assert False, 'We should never get here' 70 | return [output_grads[0]] 71 | 72 | def __str__(self): 73 | return 'backward_pass' 74 | 75 | class pytorch_wrapper(theano.Op): 76 | ''' 77 | This is a theano.Op that can evaluate network from pytorch 78 | And get its gradient w.r.t. input 79 | ''' 80 | def __init__(self, net, debug=False, dtype=torch.FloatTensor): 81 | self.net = net.type(dtype) 82 | self.dtype = dtype 83 | 84 | self.bpop = bp(self.net, debug, dtype) 85 | self.debug = debug 86 | __props__ = () 87 | 88 | def make_node(self, x): 89 | x_ = theano.tensor.as_tensor_variable(x) 90 | return theano.gof.Apply(self, [x_], [x_.type()]) 91 | 92 | def perform(self, node, inputs, output_storage): 93 | ''' 94 | In this function we should compute output tensor 95 | Inputs are numpy array, so it's easy 96 | ''' 97 | if self.debug: print('Forward pass') 98 | 99 | # Wrap input into variable 100 | input = torch.autograd.Variable(torch.from_numpy(inputs[0]).type(self.dtype), requires_grad=True) 101 | out = self.net(input) 102 | out_np = out.data.cpu().numpy().astype(inputs[0].dtype) 103 | 104 | # Put output to the right place 105 | output_storage[0][0] = out_np 106 | 107 | 108 | self.bpop.output_ = out 109 | self.bpop.input_ = input 110 | self.bpop.input_np_ = inputs[0] 111 | 112 | def grad(self, inputs, output_grads): 113 | ''' 114 | And `grad` should operate TheanoOps only, not numpy arrays 115 | So the only workaround I've found is to define another TheanoOp for backward pass and call it 116 | ''' 117 | return [self.bpop(inputs[0], output_grads[0])] 118 | 119 | def __str__(self): 120 | return 'forward_pass' -------------------------------------------------------------------------------- /simple_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "os.environ['THEANORC'] = './theanorc_cpu'\n", 13 | "\n", 14 | "import torch\n", 15 | "import torch.nn as nn\n", 16 | "import theano\n", 17 | "import numpy as np\n", 18 | "import theano.tensor as T\n", 19 | "\n", 20 | "print (torch.__version__, theano.__version__)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": { 27 | "collapsed": true 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "from ops import pytorch_wrapper" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "# Simple example" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "collapsed": true 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "dtype = torch.FloatTensor\n", 50 | "\n", 51 | "# We use neural networks for testing, but you can use any function from pytorch\n", 52 | "net = nn.Sequential(nn.Conv2d(3,4,5)).type(dtype)\n", 53 | "x = np.random.rand(2,3,13,13).astype(np.float32)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "#### Get output and grad in pytorch" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": { 67 | "collapsed": true 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "# Forward\n", 72 | "input = torch.autograd.Variable(dtype(x), requires_grad=True)\n", 73 | "out_var = net(input).sum()\n", 74 | "\n", 75 | "# Backward\n", 76 | "out_var.backward()\n", 77 | "\n", 78 | "input_grad_pytorch = input.grad.data.numpy()\n", 79 | "out_pytorch = out_var.data.numpy()" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "#### Now try to get the same values using theano" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": { 93 | "collapsed": true, 94 | "scrolled": true 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "# Wrap forward pass\n", 99 | "f = pytorch_wrapper(net, dtype=dtype, debug=True)\n", 100 | "\n", 101 | "# Create theano graph\n", 102 | "xt = T.tensor4('x')\n", 103 | "yt = f(xt).sum()\n", 104 | "gy = T.grad(yt, xt)\n", 105 | "\n", 106 | "# Define function\n", 107 | "f_grad = theano.function([xt], gy, on_unused_input='warn')\n", 108 | "f_out = theano.function([xt], yt, on_unused_input='warn')" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": { 115 | "collapsed": true, 116 | "scrolled": true 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "out_theano = f_out(x)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": { 127 | "collapsed": true, 128 | "scrolled": true 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "input_grad_theano = f_grad(x)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "collapsed": true 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "assert np.isclose(out_theano, out_pytorch), 'Outputs do not match'\n", 144 | "assert np.all(np.isclose(input_grad_theano, input_grad_pytorch)), 'Grads do not match'" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "# Test 2" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "AIS but with dummy network" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": { 165 | "collapsed": true 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "import h5py\n", 170 | "f = h5py.File('/sdh/data_fuel/mnist.hdf5')\n", 171 | "\n", 172 | "num_test = 4\n", 173 | "num_samples= 2\n", 174 | "\n", 175 | "X = f['features'].value[60000:]\n", 176 | "\n", 177 | "permutation = np.random.RandomState(seed=2919).permutation(X.shape[0])\n", 178 | "X= X[permutation][:num_test]" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": { 185 | "collapsed": true, 186 | "scrolled": true 187 | }, 188 | "outputs": [], 189 | "source": [ 190 | "dtype = torch.cuda.FloatTensor\n", 191 | "net = nn.Sequential(nn.ConvTranspose2d(10,1,32))\n", 192 | "net(torch.autograd.Variable(torch.zeros(2,10,1,1))).size()" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": { 199 | "collapsed": true, 200 | "scrolled": true 201 | }, 202 | "outputs": [], 203 | "source": [ 204 | "class NetWrapper(nn.Module):\n", 205 | " '''\n", 206 | " The AIS code needs samples to be of shape (B x 32^2)\n", 207 | " But my network produces (B x 32 x 32)\n", 208 | " So we use this wrapper.\n", 209 | " '''\n", 210 | " def __init__(self, net):\n", 211 | " super(NetWrapper, self).__init__()\n", 212 | " self.net = net\n", 213 | " \n", 214 | " def forward(self, x):\n", 215 | " return self.net(x.view(-1,10,1,1)).view(-1,32*32) / 2. + 0.5\n", 216 | " \n", 217 | " \n", 218 | "def generator(z):\n", 219 | " '''\n", 220 | " This function should define theano computational graph\n", 221 | " for evaluating net(z), for `z` -- latent vector of shape [B x Z_dim]\n", 222 | " '''\n", 223 | " \n", 224 | " f = pytorch_wrapper(NetWrapper(net),dtype=dtype)\n", 225 | " out = f(z)\n", 226 | " \n", 227 | " return out\n", 228 | "\n", 229 | "import sys\n", 230 | "sys.path.append(\"./eval_gen/\")\n", 231 | "\n", 232 | "from sampling import samplers_32 as samplers\n", 233 | "lld, pf,finalstate = samplers.run_ais(generator, \n", 234 | " X, \n", 235 | " num_samples,\n", 236 | " num_steps=10000, \n", 237 | " sigma=0.03, \n", 238 | " hdim=10, \n", 239 | " L=10, \n", 240 | " epsilon=0.01, \n", 241 | " data='continuous',\n", 242 | " prior=\"normal\")" 243 | ] 244 | } 245 | ], 246 | "metadata": { 247 | "kernelspec": { 248 | "display_name": "Python 2", 249 | "language": "python", 250 | "name": "python2" 251 | }, 252 | "language_info": { 253 | "codemirror_mode": { 254 | "name": "ipython", 255 | "version": 2 256 | }, 257 | "file_extension": ".py", 258 | "mimetype": "text/x-python", 259 | "name": "python", 260 | "nbconvert_exporter": "python", 261 | "pygments_lexer": "ipython2", 262 | "version": "2.7.13" 263 | } 264 | }, 265 | "nbformat": 4, 266 | "nbformat_minor": 2 267 | } 268 | -------------------------------------------------------------------------------- /theanorc_cpu: -------------------------------------------------------------------------------- 1 | [global] 2 | device = cpu 3 | floatX = float32 4 | mode = FAST_RUN 5 | optimizer = None 6 | exception_verbosity = high 7 | 8 | [lib] 9 | cnmem = 0.8 10 | 11 | [dnn.conv] 12 | algo_fwd = time_once 13 | algo_bwd_data = time_once 14 | algo_bwd_filter = time_once 15 | --------------------------------------------------------------------------------