├── .gitignore ├── .gitignore.swp ├── README.md ├── adversarial_autoencoder_denoising.ipynb ├── cnn_denoising.ipynb ├── dataprocessing.py ├── helper.py ├── install.sh ├── lib ├── __init__.py ├── activations.py ├── config.py ├── costs.py ├── cv2_utils.py ├── data_utils.py ├── inits.py ├── metrics.py ├── ops.py ├── python3.6 │ ├── __future__.py │ ├── _bootlocale.py │ ├── _collections_abc.py │ ├── _dummy_thread.py │ ├── _weakrefset.py │ ├── abc.py │ ├── base64.py │ ├── bisect.py │ ├── codecs.py │ ├── collections │ ├── config-3.6m-x86_64-linux-gnu │ ├── copy.py │ ├── copyreg.py │ ├── distutils │ │ ├── __init__.py │ │ └── distutils.cfg │ ├── encodings │ ├── enum.py │ ├── fnmatch.py │ ├── functools.py │ ├── genericpath.py │ ├── hashlib.py │ ├── heapq.py │ ├── hmac.py │ ├── imp.py │ ├── importlib │ ├── io.py │ ├── keyword.py │ ├── lib-dynload │ ├── linecache.py │ ├── locale.py │ ├── no-global-site-packages.txt │ ├── ntpath.py │ ├── operator.py │ ├── orig-prefix.txt │ ├── os.py │ ├── posixpath.py │ ├── random.py │ ├── re.py │ ├── reprlib.py │ ├── rlcompleter.py │ ├── shutil.py │ ├── site.py │ ├── sre_compile.py │ ├── sre_constants.py │ ├── sre_parse.py │ ├── stat.py │ ├── struct.py │ ├── tarfile.py │ ├── tempfile.py │ ├── token.py │ ├── tokenize.py │ ├── types.py │ ├── warnings.py │ └── weakref.py ├── rng.py ├── theano_utils.py ├── updates.py └── vis.py ├── nmf_denoising.ipynb ├── pca_denoising.ipynb └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | #images 10 | faces.hdf5 11 | 12 | bin/ 13 | include/ 14 | 15 | .ipynb_checkpoints 16 | 17 | # Distribution / packaging 18 | .Python 19 | env/ 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *,cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # PyBuilder 66 | target/ 67 | -------------------------------------------------------------------------------- /.gitignore.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heartyguy/ml-image-denoising/2c2d2e454714c890cc53f18c45d8733a6b9aeae8/.gitignore.swp -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ml-image-denoising 2 | Image denoising using PCA, NMF, SVD, Spectral decomposition, CNN and state of the art generative adversarial denoising autoencoder 3 | 4 | ## How to run 5 | require pip, virtualenv and git 6 | 7 | ```./install.sh``` 8 | 9 | You should download the CelebA dataset from [website](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) (you're looking for a file called img_align_celeba.zip). Unzip into the parent directory /.. then run 10 | 11 | ``` python dataprocessing.py ``` 12 | 13 | This will crop the images to the right size and store them in HDF5 format. 14 | 15 | Currently, we only process 5 batches * 1024 images 16 | 17 | Next run the dcgan notbook. 18 | 19 | ``` jupyter notebook ``` 20 | -------------------------------------------------------------------------------- /adversarial_autoencoder_denoising.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 3, 13 | "metadata": { 14 | "collapsed": false 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "import sys\n", 19 | "#sys.path.append('..')\n", 20 | "import os\n", 21 | "import json\n", 22 | "from time import time\n", 23 | "import numpy as np\n", 24 | "from tqdm import tqdm\n", 25 | "\n", 26 | "import theano\n", 27 | "import theano.tensor as T\n", 28 | "from theano.sandbox.cuda.dnn import dnn_conv\n", 29 | "\n", 30 | "from PIL import Image" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "N.B. The code from the following imports is lifted from the original [dcgan project](https://github.com/Newmu/dcgan_code)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": { 44 | "collapsed": false 45 | }, 46 | "outputs": [ 47 | { 48 | "ename": "ModuleNotFoundError", 49 | "evalue": "No module named 'theano_utils'", 50 | "output_type": "error", 51 | "traceback": [ 52 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 53 | "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", 54 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mactivations\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mupdates\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0minits\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrng\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpy_rng\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp_rng\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mops\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mbatchnorm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconv_cond_concat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdeconv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdropout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml2normalize\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 55 | "\u001b[0;32m/home/tian/classwinter17/stat442/finalproject/ml-image-denoising/lib/updates.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtheano_utils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mshared0s\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfloatX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msharedX\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mops\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0ml2norm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 56 | "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'theano_utils'" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "from lib import activations\n", 62 | "from lib import updates\n", 63 | "from lib import inits\n", 64 | "from lib.rng import py_rng, np_rng\n", 65 | "from lib.ops import batchnorm, conv_cond_concat, deconv, dropout, l2normalize\n", 66 | "from lib.metrics import nnc_score, nnd_score\n", 67 | "from lib.theano_utils import floatX, sharedX\n", 68 | "from lib.data_utils import OneHot, shuffle, iter_data, center_crop, patch" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": { 75 | "collapsed": true 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "from fuel.datasets.hdf5 import H5PYDataset\n", 80 | "from fuel.schemes import ShuffledScheme, SequentialScheme\n", 81 | "from fuel.streams import DataStream" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "# Data Stuff" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": { 95 | "collapsed": false 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "import h5py\n", 100 | "try:\n", 101 | " hf[\"target\"].shape\n", 102 | "except:\n", 103 | " hf = h5py.File('faces.hdf5','r+')\n", 104 | "num_samples = hf[\"input\"].shape[0]\n", 105 | "\n", 106 | "print \"number of samples in dataset : %i\" %num_samples" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "collapsed": false 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "split_dict = {\n", 118 | " 'train': {'input': (2000, num_samples), 'target': (2000, num_samples)},\n", 119 | " 'test': {'input': (0, 1000), 'target': (0, 1000)},\n", 120 | " 'val': {'input': (1000, 2000), 'target': (1000, 2000)}\n", 121 | "}\n", 122 | "\n", 123 | "hf.attrs['split'] = H5PYDataset.create_split_array(split_dict)\n", 124 | "train_set = H5PYDataset('faces.hdf5', which_sets=('train',))\n", 125 | "test_set = H5PYDataset('faces.hdf5', which_sets=('test',))\n", 126 | "val_set = H5PYDataset('faces.hdf5', which_sets=('val',))\n", 127 | "\n", 128 | "#batch_size = 128\n", 129 | "batch_size = 12\n", 130 | "#TODO : use shuffledscheme instead? Seems slower, might have screwed up the chunksize in the HDF5 files?\n", 131 | "\n", 132 | "tr_scheme = SequentialScheme(examples=train_set.num_examples, batch_size=batch_size)\n", 133 | "tr_stream = DataStream(train_set, iteration_scheme=tr_scheme)\n", 134 | "\n", 135 | "val_scheme = SequentialScheme(examples=val_set.num_examples, batch_size=batch_size)\n", 136 | "val_stream = DataStream(val_set, iteration_scheme=val_scheme)\n", 137 | "\n", 138 | "test_scheme = SequentialScheme(examples=test_set.num_examples, batch_size=batch_size)\n", 139 | "test_stream = DataStream(test_set, iteration_scheme=test_scheme)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "## Check data looks sensible" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": { 153 | "collapsed": false 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "for x_train, x_target in tr_stream.get_epoch_iterator():\n", 158 | " break\n", 159 | "print \"EXAMPLE TARGET IMAGE:\"\n", 160 | "\n", 161 | "Image.fromarray(x_target[3].astype(np.uint8))" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": { 168 | "collapsed": false 169 | }, 170 | "outputs": [], 171 | "source": [ 172 | "print \"EXAMPLE INPUT IMAGE:\"\n", 173 | "Image.fromarray(x_train[3].astype(np.uint8))" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "# Setup Neural Network" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": { 187 | "collapsed": true 188 | }, 189 | "outputs": [], 190 | "source": [ 191 | "def target_transform(X):\n", 192 | " return floatX(X).transpose(0, 3, 1, 2)/127.5 - 1.\n", 193 | "\n", 194 | "def input_transform(X):\n", 195 | " return target_transform(X)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": { 202 | "collapsed": false 203 | }, 204 | "outputs": [], 205 | "source": [ 206 | "l2 = 1e-5 # l2 weight decay\n", 207 | "nvis = 196 # # of samples to visualize during training\n", 208 | "b1 = 0.5 # momentum term of adam\n", 209 | "nc = 3 # # of channels in image\n", 210 | "#nbatch = 128 # # of examples in batch\n", 211 | "nbatch = 12 # # of examples in batch\n", 212 | "npx = 64 # # of pixels width/height of images\n", 213 | "\n", 214 | "nx = npx*npx*nc # # of dimensions in X\n", 215 | "niter = 1000 # # of iter at starting learning rate\n", 216 | "niter_decay = 30 # # of iter to linearly decay learning rate to zero\n", 217 | "lr = 0.0002 # initial learning rate for adam\n", 218 | "ntrain = 25000 # # of examples to train on\n", 219 | "\n", 220 | "relu = activations.Rectify()\n", 221 | "sigmoid = activations.Sigmoid()\n", 222 | "lrelu = activations.LeakyRectify()\n", 223 | "tanh = activations.Tanh()\n", 224 | "bce = T.nnet.binary_crossentropy\n", 225 | "\n", 226 | "def mse(x,y):\n", 227 | " return T.sum(T.pow(x-y,2), axis = 1)\n", 228 | "\n", 229 | "gifn = inits.Normal(scale=0.02)\n", 230 | "difn = inits.Normal(scale=0.02)\n", 231 | "sigma_ifn = inits.Normal(loc = -100., scale=0.02)\n", 232 | "gain_ifn = inits.Normal(loc=1., scale=0.02)\n", 233 | "bias_ifn = inits.Constant(c=0.)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "metadata": {}, 239 | "source": [ 240 | "\n", 241 | "The following methods are to help adjust the sizes of the convolutional layers in the generator and discriminator, which is very fiddly to do otherwise. The (overloaded) method make_conv_set can be used to create both the conv \n", 242 | "and deconv sets of layers. Note that the 'size' of the images is the size of the shortest side (32 in the input set, 128 in the target set). Only use powers of 2 here." 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": { 249 | "collapsed": false 250 | }, 251 | "outputs": [], 252 | "source": [ 253 | "def make_conv_layer(X, input_size, output_size, input_filters, \n", 254 | " output_filters, name, index,\n", 255 | " weights = None, filter_sz = 5):\n", 256 | " \n", 257 | " is_deconv = output_size >= input_size\n", 258 | "\n", 259 | " w_size = (input_filters, output_filters, filter_sz, filter_sz) \\\n", 260 | " if is_deconv else (output_filters, input_filters, filter_sz, filter_sz)\n", 261 | " \n", 262 | " if weights is None:\n", 263 | " w = gifn(w_size, '%sw%i' %(name, index))\n", 264 | " g = gain_ifn((output_filters), '%sg%i' %(name, index))\n", 265 | " b = bias_ifn((output_filters), '%sb%i' %(name, index))\n", 266 | " else:\n", 267 | " w,g,b = weights\n", 268 | " \n", 269 | " \n", 270 | " conv_method = deconv if is_deconv else dnn_conv\n", 271 | " activation = relu if is_deconv else lrelu\n", 272 | " \n", 273 | " sub = output_size / input_size if is_deconv else input_size / output_size\n", 274 | " \n", 275 | " if filter_sz == 3:\n", 276 | " bm = 1\n", 277 | " else:\n", 278 | " bm = 2\n", 279 | " \n", 280 | " layer = activation(batchnorm(conv_method(X, w, subsample=(sub, sub), border_mode=(bm, bm)), g=g, b=b))\n", 281 | " \n", 282 | " return layer, [w,g,b]\n", 283 | "\n", 284 | "def make_conv_set(input, layer_sizes, num_filters, name, weights = None, filter_szs = None):\n", 285 | " assert(len(layer_sizes) == len(num_filters))\n", 286 | " \n", 287 | " vars_ = []\n", 288 | " layers_ = []\n", 289 | " current_layer = input\n", 290 | " \n", 291 | " for i in range(len(layer_sizes) - 1):\n", 292 | " input_size = layer_sizes[i]\n", 293 | " output_size = layer_sizes[i + 1]\n", 294 | " input_filters = num_filters[i]\n", 295 | " output_filters = num_filters[i + 1]\n", 296 | " \n", 297 | " if weights is not None:\n", 298 | " this_wts = weights[i * 3 : i * 3 + 3]\n", 299 | " else:\n", 300 | " this_wts = None\n", 301 | " \n", 302 | " if filter_szs != None:\n", 303 | " filter_sz = filter_szs[i]\n", 304 | " else:\n", 305 | " filter_sz = 5\n", 306 | " \n", 307 | " layer, new_vars = make_conv_layer(current_layer, input_size, output_size, \n", 308 | " input_filters, output_filters, name, i, \n", 309 | " weights = this_wts, filter_sz = filter_sz)\n", 310 | " \n", 311 | " vars_ += new_vars\n", 312 | " layers_ += [layer]\n", 313 | " current_layer = layer\n", 314 | " \n", 315 | " return current_layer, vars_, layers_" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": { 322 | "collapsed": false 323 | }, 324 | "outputs": [], 325 | "source": [ 326 | "import pickle\n", 327 | "#Use code below if you want use a saved model\n", 328 | "'''\n", 329 | "[e_params, g_params, d_params] = pickle.load( open( \"models/autoencoder_100epoch/faces_dcgan_denoising_64epoch_100encoding.pkl\", \"rb\" ) )\n", 330 | "gwx = g_params[-1]\n", 331 | "dwy = d_params[-1]\n", 332 | "\n", 333 | "# inputs\n", 334 | "X = T.tensor4()\n", 335 | "\n", 336 | "## encode layer\n", 337 | "e_layer_sizes = [128, 64, 32, 16, 8]\n", 338 | "e_filter_sizes = [3, 256, 256, 512, 1024]\n", 339 | "\n", 340 | "eX, e_params, e_layers = make_conv_set(X, e_layer_sizes, e_filter_sizes, \"e\", weights=e_params)\n", 341 | "\n", 342 | "## generative layer\n", 343 | "g_layer_sizes = [8, 16, 32, 64, 128]\n", 344 | "g_num_filters = [1024, 512, 256, 256, 128]\n", 345 | "\n", 346 | "\n", 347 | "g_out, g_params, g_layers = make_conv_set(eX, g_layer_sizes, g_num_filters, \"g\", weights=g_params)\n", 348 | "g_params += [gwx]\n", 349 | "gX = tanh(deconv(g_out, gwx, subsample=(1, 1), border_mode=(2, 2)))\n", 350 | "\n", 351 | "\n", 352 | "## discrim layer(s)\n", 353 | "\n", 354 | "df1 = 128\n", 355 | "d_layer_sizes = [128, 64, 32, 16, 8]\n", 356 | "d_filter_sizes = [3, df1, 2 * df1, 4 * df1, 8 * df1]\n", 357 | "\n", 358 | "def discrim(input, name, weights=None):\n", 359 | " d_out, disc_params, d_layers = make_conv_set(input, d_layer_sizes, d_filter_sizes, name, weights = weights)\n", 360 | " d_flat = T.flatten(d_out, 2)\n", 361 | " \n", 362 | " disc_params += [dwy]\n", 363 | " y = sigmoid(T.dot(d_flat, dwy))\n", 364 | " \n", 365 | " return y, disc_params, d_layers\n", 366 | "\n", 367 | "# target outputs\n", 368 | "target = T.tensor4()\n", 369 | "\n", 370 | "p_real, d_params, d_layers = discrim(target, \"d\", weights=d_params)\n", 371 | "#we need to make sure the p_gen params are the same as the p_real params\n", 372 | "p_gen , d_params2, d_layers = discrim(gX, \"d\", weights=d_params)\n", 373 | "\n", 374 | "'''\n", 375 | "\n", 376 | "\n", 377 | "#Use code below if you are training a model from scratch\n", 378 | "\n", 379 | "# inputs\n", 380 | "X = T.tensor4()\n", 381 | "\n", 382 | "## encode layer\n", 383 | "e_layer_sizes = [128, 64, 32, 16, 8]\n", 384 | "e_filter_sizes = [3, 256, 256, 512, 1024]\n", 385 | "\n", 386 | "eX, e_params, e_layers = make_conv_set(X, e_layer_sizes, e_filter_sizes, \"e\")\n", 387 | "\n", 388 | "## generative layer\n", 389 | "g_layer_sizes = [8, 16, 32, 64, 128]\n", 390 | "g_num_filters = [1024, 512, 256, 256, 128]\n", 391 | "\n", 392 | "\n", 393 | "g_out, g_params, g_layers = make_conv_set(eX, g_layer_sizes, g_num_filters, \"g\")\n", 394 | "gwx = gifn((128, nc, 5, 5), 'gwx')\n", 395 | "g_params += [gwx]\n", 396 | "gX = tanh(deconv(g_out, gwx, subsample=(1, 1), border_mode=(2, 2)))\n", 397 | "\n", 398 | "\n", 399 | "## discrim layer(s)\n", 400 | "\n", 401 | "df1 = 128\n", 402 | "d_layer_sizes = [128, 64, 32, 16, 8]\n", 403 | "d_filter_sizes = [3, df1, 2 * df1, 4 * df1, 8 * df1]\n", 404 | "\n", 405 | "dwy = difn((df1 * 8 * 10 * 8, 1), 'dwy')\n", 406 | "\n", 407 | "def discrim(input, name, weights=None):\n", 408 | " d_out, disc_params, d_layers = make_conv_set(input, d_layer_sizes, d_filter_sizes, name, weights = weights)\n", 409 | " d_flat = T.flatten(d_out, 2)\n", 410 | " \n", 411 | " disc_params += [dwy]\n", 412 | " y = sigmoid(T.dot(d_flat, dwy))\n", 413 | " \n", 414 | " return y, disc_params, d_layers\n", 415 | "\n", 416 | "# target outputs\n", 417 | "target = T.tensor4()\n", 418 | "\n", 419 | "p_real, d_params, d_layers = discrim(target, \"d\")\n", 420 | "#we need to make sure the p_gen params are the same as the p_real params\n", 421 | "p_gen , d_params2, d_layers = discrim(gX, \"d\")" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "metadata": { 428 | "collapsed": false 429 | }, 430 | "outputs": [], 431 | "source": [ 432 | "# test everything working so far (errors are most likely size mismatches)\n", 433 | "f = theano.function([X], p_gen)\n", 434 | "f(input_transform(x_train)).shape" 435 | ] 436 | }, 437 | { 438 | "cell_type": "markdown", 439 | "metadata": {}, 440 | "source": [ 441 | "Next we set up the various cost functions we need" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "metadata": { 448 | "collapsed": false 449 | }, 450 | "outputs": [], 451 | "source": [ 452 | "from theano.tensor.signal.downsample import max_pool_2d\n", 453 | "\n", 454 | "## GAN costs\n", 455 | "d_cost_real = bce(p_real, T.ones(p_real.shape)).mean()\n", 456 | "d_cost_gen = bce(p_gen, T.zeros(p_gen.shape)).mean()\n", 457 | "g_cost_d = bce(p_gen, T.ones(p_gen.shape)).mean()\n", 458 | "\n", 459 | "## MSE encoding cost is done on an (averaged) downscaling of the image\n", 460 | "target_pool = max_pool_2d(target, (4,4), mode=\"average_exc_pad\",ignore_border=True)\n", 461 | "target_flat = T.flatten(target_pool, 2)\n", 462 | "gX_pool = max_pool_2d(gX, (4,4), mode=\"average_exc_pad\",ignore_border=True)\n", 463 | "gX_flat = T.flatten(gX_pool,2)\n", 464 | "enc_cost = mse(gX_flat, target_flat).mean() \n", 465 | "\n", 466 | "## MSE encoding without max pooling\n", 467 | "'''\n", 468 | "target_flat = T.flatten(target, 2)\n", 469 | "gX_flat = T.flatten(gX,2)\n", 470 | "enc_cost = mse(gX_flat, target_flat).mean() \n", 471 | "'''\n", 472 | "\n", 473 | "## generator cost is a linear combination of the discrim cost plus the MSE enocding cost\n", 474 | "d_cost = d_cost_real + d_cost_gen\n", 475 | "\n", 476 | "#To change the weight of MSE, change the denominator. ex. enc_cost/5 weights MSE much less that enc_cost/1\n", 477 | "g_cost = g_cost_d + enc_cost / 1 \n", 478 | "\n", 479 | "## N.B. e_cost and e_updates will only try and minimise MSE loss on the autoencoder (for debugging)\n", 480 | "e_cost = enc_cost\n", 481 | "\n", 482 | "cost = [g_cost_d, d_cost_real, enc_cost]\n", 483 | "\n", 484 | "elrt = sharedX(0.002)\n", 485 | "lrt = sharedX(lr)\n", 486 | "d_updater = updates.Adam(lr=lrt, b1=b1, regularizer=updates.Regularizer(l2=l2))\n", 487 | "g_updater = updates.Adam(lr=lrt, b1=b1, regularizer=updates.Regularizer(l2=l2))\n", 488 | "e_updater = updates.Adam(lr=elrt, b1=b1, regularizer=updates.Regularizer(l2=l2))\n", 489 | "\n", 490 | "d_updates = d_updater(d_params, d_cost)\n", 491 | "g_updates = g_updater(e_params + g_params, g_cost)\n", 492 | "e_updates = e_updater(e_params, e_cost)" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "metadata": { 499 | "collapsed": false 500 | }, 501 | "outputs": [], 502 | "source": [ 503 | "print 'COMPILING'\n", 504 | "t = time()\n", 505 | "_train_g = theano.function([X, target], cost, updates=g_updates)\n", 506 | "_train_d = theano.function([X, target], cost, updates=d_updates)\n", 507 | "_train_e = theano.function([X, target], cost, updates=e_updates)\n", 508 | "_get_cost = theano.function([X, target], cost)\n", 509 | "print '%.2f seconds to compile theano functions'%(time()-t)" 510 | ] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "metadata": {}, 515 | "source": [ 516 | "# Training code\n", 517 | "\n", 518 | "Code for generating the images every 100 batches or so." 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": null, 524 | "metadata": { 525 | "collapsed": false 526 | }, 527 | "outputs": [], 528 | "source": [ 529 | "img_dir = \"gen_images/\"\n", 530 | "\n", 531 | "if not os.path.exists(img_dir):\n", 532 | " os.makedirs(img_dir)\n", 533 | "\n", 534 | "ae_encode = theano.function([X, target], [gX, target])\n", 535 | "\n", 536 | "def inverse(X):\n", 537 | " X_pred = (X.transpose(0, 2, 3, 1) + 1) * 127.5\n", 538 | " X_pred = np.rint(X_pred).astype(int)\n", 539 | " X_pred = np.clip(X_pred, a_min = 0, a_max = 255)\n", 540 | " return X_pred.astype('uint8')\n", 541 | "\n", 542 | "\n", 543 | "def save_sample_pictures():\n", 544 | " for te_train, te_target in test_stream.get_epoch_iterator():\n", 545 | " break\n", 546 | " te_out, te_ta = ae_encode(input_transform(te_train), target_transform(te_target))\n", 547 | " te_reshape = inverse(te_out)\n", 548 | " te_target_reshape = inverse(te_ta)\n", 549 | "\n", 550 | " new_size = (128 * 6, 160 * 12)\n", 551 | " new_im = Image.new('RGB', new_size)\n", 552 | " r = np.random.choice(12, 24, replace=True).reshape(2,12)\n", 553 | " for i in range(2):\n", 554 | " for j in range(12):\n", 555 | " index = r[i][j]\n", 556 | " \n", 557 | " target_im = Image.fromarray(te_target_reshape[index])\n", 558 | " train_im = Image.fromarray(te_train[index].astype(np.uint8))\n", 559 | " im = Image.fromarray(te_reshape[index])\n", 560 | " \n", 561 | " new_im.paste(target_im, (128 * i * 3, 160 * j))\n", 562 | " new_im.paste(train_im, (128 * (i * 3 + 1), 160 * j))\n", 563 | " new_im.paste(im, (128 * (i * 3 + 2), 160 * j))\n", 564 | " img_loc = \"gen_images/%i.png\" %int(time()) \n", 565 | " print \"saving images to %s\" %img_loc\n", 566 | " new_im.save(img_loc)\n", 567 | "\n", 568 | "#saves output for all testing images. This may take a couple of minutes to run.\n", 569 | "def save_all_pictures():\n", 570 | " counter = 0\n", 571 | " for te_train, te_target in test_stream.get_epoch_iterator():\n", 572 | " te_out, te_ta = ae_encode(input_transform(te_train), target_transform(te_target))\n", 573 | " te_reshape = inverse(te_out)\n", 574 | " te_target_reshape = inverse(te_ta)\n", 575 | "\n", 576 | " new_size = (128 * 3, 160 * 12)\n", 577 | " new_im = Image.new('RGB', new_size)\n", 578 | " r = [range(12),range(12)]\n", 579 | " for i in range(1):\n", 580 | " for j in range(12):\n", 581 | " index = r[i][j]\n", 582 | " try:\n", 583 | " target_im = Image.fromarray(te_target_reshape[index])\n", 584 | " train_im = Image.fromarray(te_train[index].astype(np.uint8))\n", 585 | " im = Image.fromarray(te_reshape[index])\n", 586 | "\n", 587 | " new_im.paste(target_im, (128 * i * 3, 160 * j))\n", 588 | " new_im.paste(train_im, (128 * (i * 3 + 1), 160 * j))\n", 589 | " new_im.paste(im, (128 * (i * 3 + 2), 160 * j))\n", 590 | " except:\n", 591 | " print \"Eror with training image\"\n", 592 | " img_loc = \"gen_images/test_result_%i.png\" %counter \n", 593 | " print \"saving images to %s\" %img_loc\n", 594 | " new_im.save(img_loc)\n", 595 | " counter+=1\n", 596 | "\n", 597 | "#save_all_pictures() \n", 598 | "save_sample_pictures()" 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": null, 604 | "metadata": { 605 | "collapsed": false 606 | }, 607 | "outputs": [], 608 | "source": [ 609 | "def mn(l):\n", 610 | " if sum(l) == 0:\n", 611 | " return 0\n", 612 | " return sum(l) / len(l)\n", 613 | "\n", 614 | "## TODO : nicer way of coding these means?\n", 615 | "\n", 616 | "def get_test_errors():\n", 617 | " print \"getting test error\"\n", 618 | " g_costs = []\n", 619 | " d_costs = []\n", 620 | " e_costs = []\n", 621 | " k_costs = []\n", 622 | " for i in range(20):\n", 623 | " try:\n", 624 | " x_train, x_target = te_iterator.next()\n", 625 | " except:\n", 626 | " te_iterator = val_stream.get_epoch_iterator()\n", 627 | " x_train, x_target = te_iterator.next()\n", 628 | " x = input_transform(x_train)\n", 629 | " t = target_transform(x_target)\n", 630 | " cost = _get_cost(x,t)\n", 631 | " g_cost, d_cost, enc_cost = cost\n", 632 | " g_costs.append(g_cost)\n", 633 | " d_costs.append(d_cost)\n", 634 | " e_costs.append(enc_cost)\n", 635 | " \n", 636 | " s= \" ,\".join([\"test errors :\", str(mn(g_costs)), str(mn(d_costs)), str(mn(e_costs))])\n", 637 | " return s\n" 638 | ] 639 | }, 640 | { 641 | "cell_type": "markdown", 642 | "metadata": {}, 643 | "source": [ 644 | "# Train Model\n", 645 | "\n", 646 | "Finally, we come to the actual training of the model. This code can be keyboard interrupted, and the weights will be stored in memory, allowing us to stop, adjust and restart the training (this is how I got the model to train). For advice on training see the blog post at (#TODO)" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": null, 652 | "metadata": { 653 | "collapsed": true 654 | }, 655 | "outputs": [], 656 | "source": [ 657 | "iterator = tr_stream.get_epoch_iterator()\n", 658 | "\n", 659 | "# you may wish to reset the learning rate to something of your choosing if you feel it is too high/low\n", 660 | "lrt = sharedX(lr)" 661 | ] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": null, 666 | "metadata": { 667 | "collapsed": false, 668 | "scrolled": false 669 | }, 670 | "outputs": [], 671 | "source": [ 672 | "from time import time\n", 673 | "\n", 674 | "n_updates = 0\n", 675 | "t = time()\n", 676 | "\n", 677 | "n_epochs = 200\n", 678 | "\n", 679 | "print \"STARTING\"\n", 680 | "\n", 681 | "\n", 682 | "\n", 683 | "for epoch in range(n_epochs):\n", 684 | " \n", 685 | " \n", 686 | " tm = time()\n", 687 | "\n", 688 | " g_costs = []\n", 689 | " d_costs = []\n", 690 | " e_costs = []\n", 691 | " \n", 692 | " ## TODO : produces pretty ugly output, redo this?\n", 693 | " for i in tqdm(range(num_samples/128)):\n", 694 | " \n", 695 | " try:\n", 696 | " x_train, x_target = iterator.next()\n", 697 | " except:\n", 698 | " iterator = tr_stream.get_epoch_iterator()\n", 699 | " x_train, x_target = iterator.next()\n", 700 | " x = input_transform(x_train)\n", 701 | " t = target_transform(x_target)\n", 702 | "\n", 703 | " \n", 704 | " ## optional - change the criteria for how often we train the generator or discriminator\n", 705 | " if n_updates % 2 == 1:\n", 706 | " cost = _train_g(x,t) \n", 707 | " else:\n", 708 | " cost = _train_d(x,t)\n", 709 | " \n", 710 | " # optional - only train the generator on MSE cost. If you want to only train the autoencoder, uncomment the\n", 711 | " # and comment the cost updates above\n", 712 | " #cost = _train_e(x,t)\n", 713 | " g_cost, d_cost, enc_cost = cost\n", 714 | " g_costs.append(g_cost)\n", 715 | " d_costs.append(d_cost)\n", 716 | " e_costs.append(enc_cost)\n", 717 | "\n", 718 | " if n_updates % 100 == 0:\n", 719 | " s= \" ,\".join([\"training errors :\", str(mn(g_costs)), str(mn(d_costs)), str(mn(e_costs))])\n", 720 | " g_costs = []\n", 721 | " d_costs = []\n", 722 | " e_costs = []\n", 723 | " print get_test_errors()\n", 724 | " print s\n", 725 | " sys.stdout.flush()\n", 726 | " save_sample_pictures()\n", 727 | " n_updates += 1 \n", 728 | "\n", 729 | " print \"epoch %i of %i took %.2f seconds\" %(epoch, n_epochs, time() - tm)\n", 730 | " \n", 731 | " ## optional - reduce the learning rate as you go\n", 732 | " #lrt.set_value(floatX(lrt.get_value() * 0.95))\n", 733 | " #print lrt.get_value()\n", 734 | " \n", 735 | " \n", 736 | " sys.stdout.flush()\n", 737 | " \n", 738 | " \n", 739 | " " 740 | ] 741 | }, 742 | { 743 | "cell_type": "markdown", 744 | "metadata": {}, 745 | "source": [ 746 | "# Save weights if wanted\n", 747 | "You can reuse them by using the weights in the make_conv_set method #TODO - actually try this!" 748 | ] 749 | }, 750 | { 751 | "cell_type": "code", 752 | "execution_count": null, 753 | "metadata": { 754 | "collapsed": false 755 | }, 756 | "outputs": [], 757 | "source": [ 758 | "import pickle\n", 759 | "\n", 760 | "all_params = [e_params, g_params, d_params]\n", 761 | "\n", 762 | "pickle.dump(all_params, open(\"faces_dcgan_denoising_165epoch_1encoding.pkl\", 'w'))" 763 | ] 764 | }, 765 | { 766 | "cell_type": "code", 767 | "execution_count": null, 768 | "metadata": { 769 | "collapsed": true 770 | }, 771 | "outputs": [], 772 | "source": [] 773 | } 774 | ], 775 | "metadata": { 776 | "kernelspec": { 777 | "display_name": "Python 3", 778 | "language": "python", 779 | "name": "python3" 780 | }, 781 | "language_info": { 782 | "codemirror_mode": { 783 | "name": "ipython", 784 | "version": 3 785 | }, 786 | "file_extension": ".py", 787 | "mimetype": "text/x-python", 788 | "name": "python", 789 | "nbconvert_exporter": "python", 790 | "pygments_lexer": "ipython3", 791 | "version": "3.6.0" 792 | } 793 | }, 794 | "nbformat": 4, 795 | "nbformat_minor": 0 796 | } 797 | -------------------------------------------------------------------------------- /cnn_denoising.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [] 11 | } 12 | ], 13 | "metadata": { 14 | "kernelspec": { 15 | "display_name": "Python 3", 16 | "language": "python", 17 | "name": "python3" 18 | }, 19 | "language_info": { 20 | "codemirror_mode": { 21 | "name": "ipython", 22 | "version": 3 23 | }, 24 | "file_extension": ".py", 25 | "mimetype": "text/x-python", 26 | "name": "python", 27 | "nbconvert_exporter": "python", 28 | "pygments_lexer": "ipython3", 29 | "version": "3.6.0" 30 | } 31 | }, 32 | "nbformat": 4, 33 | "nbformat_minor": 2 34 | } 35 | -------------------------------------------------------------------------------- /dataprocessing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from PIL import Image 4 | from os import listdir 5 | from os.path import isfile, join 6 | import numpy as np 7 | import pickle 8 | from time import time 9 | import sys 10 | import h5py 11 | import random 12 | from tqdm import tqdm 13 | 14 | 15 | image_dir = '../img_align_celeba/' 16 | try: 17 | image_locs = [join(image_dir, f) for f in listdir(image_dir) if isfile(join(image_dir, f))] 18 | except: 19 | print("expected aligned images directory, see README") 20 | 21 | print("first image at " + str(image_locs[0])) 22 | 23 | total_imgs = len(image_locs) 24 | print("found %i images in directory" %total_imgs) 25 | 26 | 27 | def process_image(im): 28 | if im.mode != "RGB": 29 | im = im.convert("RGB") 30 | new_size = [int(i/1.3) for i in im.size] 31 | im.thumbnail(new_size, Image.ANTIALIAS) 32 | target = np.array(im)[3:-3,4:-4,:] 33 | im = Image.fromarray(target) 34 | new_size = [i/4 for i in im.size] 35 | im.thumbnail(new_size, Image.ANTIALIAS) 36 | input = np.array(im) 37 | return input, target 38 | 39 | 40 | 41 | 42 | def proc_loc(loc): 43 | try: 44 | i = Image.open(loc) 45 | #print("open image " + str(i)); 46 | input, target = process_image(i) 47 | return (input, target) 48 | except KeyboardInterrupt: 49 | raise 50 | #except: 51 | # return None 52 | 53 | 54 | try: 55 | hf = h5py.File('faces.hdf5','r+') 56 | except: 57 | hf = h5py.File('faces.hdf5','w') 58 | 59 | 60 | try: 61 | dset_t = hf.create_dataset("target", (1,160,128,3), 62 | maxshape= (1e6,160,128,3), chunks = (1,160,128,3), compression = "gzip") 63 | except: 64 | dset_t = hf['target'] 65 | 66 | try: 67 | dset_i = hf.create_dataset("input", (1, 40, 32, 3), 68 | maxshape= (1e6, 40, 32, 3), chunks = (1, 40, 32, 3), compression = "gzip") 69 | except: 70 | dset_i = hf['input'] 71 | 72 | 73 | batch_size = 1024 74 | #num_iter = total_imgs / 1024 75 | num_iter = 5 76 | 77 | insert_point = 0 78 | 79 | 80 | for i in tqdm(range(num_iter)): 81 | sys.stdout.flush() 82 | 83 | X_in = [] 84 | X_ta = [] 85 | 86 | a = time() 87 | locs = image_locs[i * batch_size : (i + 1) * batch_size] 88 | 89 | proc = [proc_loc(loc) for loc in locs] 90 | 91 | for pair in proc: 92 | if pair is not None: 93 | input, target = pair 94 | X_in.append(input) 95 | X_ta.append(target) 96 | 97 | X_in = np.array(X_in) 98 | X_ta = np.array(X_ta) 99 | 100 | dset_i.resize((insert_point + len(X_in),40, 32, 3)) 101 | dset_t.resize((insert_point + len(X_in),160,128,3)) 102 | 103 | dset_i[insert_point:insert_point + len(X_in)] = X_in 104 | dset_t[insert_point:insert_point + len(X_in)] = X_ta 105 | 106 | insert_point += len(X_in) 107 | 108 | hf.close() 109 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from PIL import Image 4 | import numpy 5 | import scipy 6 | 7 | 8 | def cmp_images_psnr(truth_array, noisy_array, length): 9 | psnr_val = 0.0 10 | for i in range(length): 11 | psnr_val += psnr(truth_array[i], noisy_array[i], 255.0) 12 | return psnr_val/length 13 | 14 | 15 | 16 | def psnr(dataset1, dataset2, maximumDataValue, ignore=None): 17 | # Make sure that the provided data sets are numpy ndarrays, if not 18 | if type(dataset1).__module__ != numpy.__name__: 19 | d1 = numpy.asarray(dataset1).flatten() 20 | else: 21 | d1 = dataset1.flatten() 22 | 23 | if type(dataset2).__module__ != numpy.__name__: 24 | d2 = numpy.asarray(dataset2).flatten() 25 | else: 26 | d2 = dataset2.flatten() 27 | 28 | # Make sure that the provided data sets are the same size 29 | if d1.size != d2.size: 30 | raise ValueError('Provided datasets must have the same size/shape') 31 | 32 | # Check if the provided data sets are identical, and if so, return an 33 | # infinite peak-signal-to-noise ratio 34 | if numpy.array_equal(d1, d2): 35 | return float('inf') 36 | 37 | # If specified, remove the values to ignore from the analysis and compute 38 | # the element-wise difference between the data sets 39 | if ignore is not None: 40 | index = numpy.intersect1d(numpy.where(d1 != ignore)[0], 41 | numpy.where(d2 != ignore)[0]) 42 | error = d1[index].astype(numpy.float64) - d2[index].astype(numpy.float64) 43 | else: 44 | error = d1.astype(numpy.float64)-d2.astype(numpy.float64) 45 | 46 | # Compute the mean-squared error 47 | meanSquaredError = numpy.sum(error**2) / error.size 48 | 49 | # Return the peak-signal-to-noise ratio 50 | return 10.0 * numpy.log10(maximumDataValue**2 / meanSquaredError) 51 | 52 | 53 | def image_interpolate(image_np, pil_sample_opt): 54 | return numpy.array(Image.fromarray(image_np.astype(numpy.uint8)).resize((128,160), resample=pil_sample_opt)) 55 | 56 | def bulk_image_interpolate(image_np_array, pil_sample_opt): 57 | ret_array = [] 58 | for image_np in image_np_array: 59 | ret_array.append( image_interpolate(image_np, pil_sample_opt)) 60 | return ret_array -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | virtualenv . 2 | 3 | # install blocks and fuels they are not available on pip. 4 | pip install git+git://github.com/mila-udem/blocks.git@stable \ 5 | -r https://raw.githubusercontent.com/mila-udem/blocks/stable/requirements.txt 6 | 7 | #pip install git+git://github.com/mila-udem/fuel.git 8 | 9 | pip install -r requirements.txt 10 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heartyguy/ml-image-denoising/2c2d2e454714c890cc53f18c45d8733a6b9aeae8/lib/__init__.py -------------------------------------------------------------------------------- /lib/activations.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | 4 | class Softmax(object): 5 | 6 | def __init__(self): 7 | pass 8 | 9 | def __call__(self, x): 10 | e_x = T.exp(x - x.max(axis=1).dimshuffle(0, 'x')) 11 | return e_x / e_x.sum(axis=1).dimshuffle(0, 'x') 12 | 13 | class ConvSoftmax(object): 14 | 15 | def __init__(self): 16 | pass 17 | 18 | def __call__(self, x): 19 | e_x = T.exp(x - x.max(axis=1, keepdims=True)) 20 | return e_x / e_x.sum(axis=1, keepdims=True) 21 | 22 | class Maxout(object): 23 | 24 | def __init__(self, n_pool=2): 25 | self.n_pool = n_pool 26 | 27 | def __call__(self, x): 28 | if x.ndim == 2: 29 | x = T.max([x[:, n::self.n_pool] for n in range(self.n_pool)], axis=0) 30 | elif x.ndim == 4: 31 | x = T.max([x[:, n::self.n_pool, :, :] for n in range(self.n_pool)], axis=0) 32 | else: 33 | raise NotImplementedError 34 | return x 35 | 36 | class Rectify(object): 37 | 38 | def __init__(self): 39 | pass 40 | 41 | def __call__(self, x): 42 | return (x + abs(x)) / 2.0 43 | 44 | class ClippedRectify(object): 45 | 46 | def __init__(self, clip=10.): 47 | self.clip = clip 48 | 49 | def __call__(self, x): 50 | return T.clip((x + abs(x)) / 2.0, 0., self.clip) 51 | 52 | class LeakyRectify(object): 53 | 54 | def __init__(self, leak=0.2): 55 | self.leak = leak 56 | 57 | def __call__(self, x): 58 | f1 = 0.5 * (1 + self.leak) 59 | f2 = 0.5 * (1 - self.leak) 60 | return f1 * x + f2 * abs(x) 61 | 62 | class Prelu(object): 63 | 64 | def __init__(self): 65 | pass 66 | 67 | def __call__(self, x, leak): 68 | if x.ndim == 4: 69 | leak = leak.dimshuffle('x', 0, 'x', 'x') 70 | f1 = 0.5 * (1 + leak) 71 | f2 = 0.5 * (1 - leak) 72 | return f1 * x + f2 * abs(x) 73 | 74 | class Tanh(object): 75 | 76 | def __init__(self): 77 | pass 78 | 79 | def __call__(self, x): 80 | return T.tanh(x) 81 | 82 | class Sigmoid(object): 83 | 84 | def __init__(self): 85 | pass 86 | 87 | def __call__(self, x): 88 | return T.nnet.sigmoid(x) 89 | 90 | class Linear(object): 91 | 92 | def __init__(self): 93 | pass 94 | 95 | def __call__(self, x): 96 | return x 97 | 98 | class HardSigmoid(object): 99 | 100 | def __init__(self): 101 | pass 102 | 103 | def __call__(self, X): 104 | return T.clip(X + 0.5, 0., 1.) 105 | 106 | class TRec(object): 107 | 108 | def __init__(self, t=1): 109 | self.t = t 110 | 111 | def __call__(self, X): 112 | return X*(X > self.t) 113 | 114 | class HardTanh(object): 115 | 116 | def __init__(self): 117 | pass 118 | 119 | def __call__(self, X): 120 | return T.clip(X, -1., 1.) -------------------------------------------------------------------------------- /lib/config.py: -------------------------------------------------------------------------------- 1 | data_dir = '/home/mike/Documents/convolutional_variational_autoencoder/dcgan/mnist/' 2 | -------------------------------------------------------------------------------- /lib/costs.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | 4 | def CategoricalCrossEntropy(y_true, y_pred): 5 | return T.nnet.categorical_crossentropy(y_pred, y_true).mean() 6 | 7 | def BinaryCrossEntropy(y_true, y_pred): 8 | return T.nnet.binary_crossentropy(y_pred, y_true).mean() 9 | 10 | def MeanSquaredError(y_true, y_pred): 11 | return T.sqr(y_pred - y_true).mean() 12 | 13 | def MeanAbsoluteError(y_true, y_pred): 14 | return T.abs_(y_pred - y_true).mean() 15 | 16 | def SquaredHinge(y_true, y_pred): 17 | return T.sqr(T.maximum(1. - y_true * y_pred, 0.)).mean() 18 | 19 | def Hinge(y_true, y_pred): 20 | return T.maximum(1. - y_true * y_pred, 0.).mean() 21 | 22 | cce = CCE = CategoricalCrossEntropy 23 | bce = BCE = BinaryCrossEntropy 24 | mse = MSE = MeanSquaredError 25 | mae = MAE = MeanAbsoluteError 26 | -------------------------------------------------------------------------------- /lib/cv2_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | def min_resize(x, size, interpolation=cv2.INTER_LINEAR): 4 | """ 5 | Resize an image so that it is size along the minimum spatial dimension. 6 | """ 7 | w, h = map(float, x.shape[:2]) 8 | if min([w, h]) != size: 9 | if w <= h: 10 | x = cv2.resize(x, (int(round((h/w)*size)), int(size)), interpolation=interpolation) 11 | else: 12 | x = cv2.resize(x, (int(size), int(round((w/h)*size))), interpolation=interpolation) 13 | return x -------------------------------------------------------------------------------- /lib/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import utils as skutils 3 | 4 | from rng import np_rng, py_rng 5 | 6 | def center_crop(x, ph, pw=None): 7 | if pw is None: 8 | pw = ph 9 | h, w = x.shape[:2] 10 | j = int(round((h - ph)/2.)) 11 | i = int(round((w - pw)/2.)) 12 | return x[j:j+ph, i:i+pw] 13 | 14 | def patch(x, ph, pw=None): 15 | if pw is None: 16 | pw = ph 17 | h, w = x.shape[:2] 18 | j = py_rng.randint(0, h-ph) 19 | i = py_rng.randint(0, w-pw) 20 | x = x[j:j+ph, i:i+pw] 21 | return x 22 | 23 | def list_shuffle(*data): 24 | idxs = np_rng.permutation(np.arange(len(data[0]))) 25 | if len(data) == 1: 26 | return [data[0][idx] for idx in idxs] 27 | else: 28 | return [[d[idx] for idx in idxs] for d in data] 29 | 30 | def shuffle(*arrays, **options): 31 | if isinstance(arrays[0][0], basestring): 32 | return list_shuffle(*arrays) 33 | else: 34 | return skutils.shuffle(*arrays, random_state=np_rng) 35 | 36 | def OneHot(X, n=None, negative_class=0.): 37 | X = np.asarray(X).flatten() 38 | if n is None: 39 | n = np.max(X) + 1 40 | Xoh = np.ones((len(X), n)) * negative_class 41 | Xoh[np.arange(len(X)), X] = 1. 42 | return Xoh 43 | 44 | def iter_data(*data, **kwargs): 45 | size = kwargs.get('size', 128) 46 | try: 47 | n = len(data[0]) 48 | except: 49 | n = data[0].shape[0] 50 | batches = n / size 51 | if n % size != 0: 52 | batches += 1 53 | 54 | for b in range(batches): 55 | start = b * size 56 | end = (b + 1) * size 57 | if end > n: 58 | end = n 59 | if len(data) == 1: 60 | yield data[0][start:end] 61 | else: 62 | yield tuple([d[start:end] for d in data]) -------------------------------------------------------------------------------- /lib/inits.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import theano 5 | import theano.tensor as T 6 | 7 | from theano_utils import sharedX, floatX, intX 8 | from rng import np_rng 9 | 10 | class Uniform(object): 11 | def __init__(self, scale=0.05): 12 | self.scale = 0.05 13 | 14 | def __call__(self, shape, name=None): 15 | return sharedX(np_rng.uniform(low=-self.scale, high=self.scale, size=shape), name=name) 16 | 17 | class Normal(object): 18 | def __init__(self, loc=0., scale=0.05): 19 | self.scale = scale 20 | self.loc = loc 21 | 22 | def __call__(self, shape, name=None): 23 | return sharedX(np_rng.normal(loc=self.loc, scale=self.scale, size=shape), name=name) 24 | 25 | class Orthogonal(object): 26 | """ benanne lasagne ortho init (faster than qr approach)""" 27 | def __init__(self, scale=1.1): 28 | self.scale = scale 29 | 30 | def __call__(self, shape, name=None): 31 | print 'called orthogonal init with shape', shape 32 | flat_shape = (shape[0], np.prod(shape[1:])) 33 | a = np_rng.normal(0.0, 1.0, flat_shape) 34 | u, _, v = np.linalg.svd(a, full_matrices=False) 35 | q = u if u.shape == flat_shape else v # pick the one with the correct shape 36 | q = q.reshape(shape) 37 | return sharedX(self.scale * q[:shape[0], :shape[1]], name=name) 38 | 39 | class Frob(object): 40 | 41 | def __init__(self): 42 | pass 43 | 44 | def __call__(self, shape, name=None): 45 | r = np_rng.normal(loc=0, scale=0.01, size=shape) 46 | r = r/np.sqrt(np.sum(r**2))*np.sqrt(shape[1]) 47 | return sharedX(r, name=name) 48 | 49 | class Constant(object): 50 | 51 | def __init__(self, c=0.): 52 | self.c = c 53 | 54 | def __call__(self, shape, name=None): 55 | return sharedX(np.ones(shape) * self.c, name=name) 56 | 57 | class ConvIdentity(object): 58 | 59 | def __init__(self, scale=1.): 60 | self.scale = scale 61 | 62 | def __call__(self, shape, name=None): 63 | w = np.zeros(shape) 64 | ycenter = shape[2]//2 65 | xcenter = shape[3]//2 66 | 67 | if shape[0] == shape[1]: 68 | o_idxs = np.arange(shape[0]) 69 | i_idxs = np.arange(shape[1]) 70 | elif shape[1] < shape[0]: 71 | o_idxs = np.arange(shape[0]) 72 | i_idxs = np.random.permutation(np.tile(np.arange(shape[1]), shape[0]/shape[1]+1))[:shape[0]] 73 | w[o_idxs, i_idxs, ycenter, xcenter] = self.scale 74 | return sharedX(w, name=name) 75 | 76 | class Identity(object): 77 | 78 | def __init__(self, scale=0.25): 79 | self.scale = scale 80 | 81 | def __call__(self, shape, name=None): 82 | if shape[0] != shape[1]: 83 | w = np.zeros(shape) 84 | o_idxs = np.arange(shape[0]) 85 | i_idxs = np.random.permutation(np.tile(np.arange(shape[1]), shape[0]/shape[1]+1))[:shape[0]] 86 | w[o_idxs, i_idxs] = self.scale 87 | else: 88 | w = np.identity(shape[0]) * self.scale 89 | return sharedX(w, name=name) 90 | 91 | class ReluInit(object): 92 | 93 | def __init__(self): 94 | pass 95 | 96 | def __call__(self, shape, name=None): 97 | if len(shape) == 2: 98 | scale = np.sqrt(2./shape[0]) 99 | elif len(shape) == 4: 100 | scale = np.sqrt(2./np.prod(shape[1:])) 101 | else: 102 | raise NotImplementedError 103 | return sharedX(np_rng.normal(size=shape, scale=scale), name=name) -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import theano 4 | import theano.tensor as T 5 | import gc 6 | import time 7 | 8 | from theano_utils import floatX 9 | from ops import euclidean, cosine 10 | 11 | from sklearn import metrics 12 | from sklearn.linear_model import LogisticRegression as LR 13 | 14 | def cv_reg_lr(trX, trY, vaX, vaY, Cs=[0.01, 0.05, 0.1, 0.5, 1., 5., 10., 50., 100.]): 15 | tr_accs = [] 16 | va_accs = [] 17 | models = [] 18 | for C in Cs: 19 | model = LR(C=C) 20 | model.fit(trX, trY) 21 | tr_pred = model.predict(trX) 22 | va_pred = model.predict(vaX) 23 | tr_acc = metrics.accuracy_score(trY, tr_pred) 24 | va_acc = metrics.accuracy_score(vaY, va_pred) 25 | print '%.4f %.4f %.4f'%(C, tr_acc, va_acc) 26 | tr_accs.append(tr_acc) 27 | va_accs.append(va_acc) 28 | models.append(model) 29 | best = np.argmax(va_accs) 30 | print 'best model C: %.4f tr_acc: %.4f va_acc: %.4f'%(Cs[best], tr_accs[best], va_accs[best]) 31 | return models[best] 32 | 33 | def gpu_nnc_predict(trX, trY, teX, metric='cosine', batch_size=4096): 34 | if metric == 'cosine': 35 | metric_fn = cosine_dist 36 | else: 37 | metric_fn = euclid_dist 38 | idxs = [] 39 | for i in range(0, len(teX), batch_size): 40 | mb_dists = [] 41 | mb_idxs = [] 42 | for j in range(0, len(trX), batch_size): 43 | dist = metric_fn(floatX(teX[i:i+batch_size]), floatX(trX[j:j+batch_size])) 44 | if metric == 'cosine': 45 | mb_dists.append(np.max(dist, axis=1)) 46 | mb_idxs.append(j+np.argmax(dist, axis=1)) 47 | else: 48 | mb_dists.append(np.min(dist, axis=1)) 49 | mb_idxs.append(j+np.argmin(dist, axis=1)) 50 | mb_idxs = np.asarray(mb_idxs) 51 | mb_dists = np.asarray(mb_dists) 52 | if metric == 'cosine': 53 | i = mb_idxs[np.argmax(mb_dists, axis=0), np.arange(mb_idxs.shape[1])] 54 | else: 55 | i = mb_idxs[np.argmin(mb_dists, axis=0), np.arange(mb_idxs.shape[1])] 56 | idxs.append(i) 57 | idxs = np.concatenate(idxs, axis=0) 58 | nearest = trY[idxs] 59 | return nearest 60 | 61 | def gpu_nnd_score(trX, teX, metric='cosine', batch_size=4096): 62 | if metric == 'cosine': 63 | metric_fn = cosine_dist 64 | else: 65 | metric_fn = euclid_dist 66 | dists = [] 67 | for i in range(0, len(teX), batch_size): 68 | mb_dists = [] 69 | for j in range(0, len(trX), batch_size): 70 | dist = metric_fn(floatX(teX[i:i+batch_size]), floatX(trX[j:j+batch_size])) 71 | if metric == 'cosine': 72 | mb_dists.append(np.max(dist, axis=1)) 73 | else: 74 | mb_dists.append(np.min(dist, axis=1)) 75 | mb_dists = np.asarray(mb_dists) 76 | if metric == 'cosine': 77 | d = np.max(mb_dists, axis=0) 78 | else: 79 | d = np.min(mb_dists, axis=0) 80 | dists.append(d) 81 | dists = np.concatenate(dists, axis=0) 82 | return float(np.mean(dists)) 83 | 84 | A = T.matrix() 85 | B = T.matrix() 86 | 87 | ed = euclidean(A, B) 88 | cd = cosine(A, B) 89 | 90 | cosine_dist = theano.function([A, B], cd) 91 | euclid_dist = theano.function([A, B], ed) 92 | 93 | def nnc_score(trX, trY, teX, teY, metric='euclidean'): 94 | pred = gpu_nnc_predict(trX, trY, teX, metric=metric) 95 | acc = metrics.accuracy_score(teY, pred) 96 | return acc*100. 97 | 98 | def nnd_score(trX, teX, metric='euclidean'): 99 | return gpu_nnd_score(trX, teX, metric=metric) 100 | -------------------------------------------------------------------------------- /lib/ops.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable, 4 | host_from_gpu, 5 | gpu_contiguous, HostFromGpu, 6 | gpu_alloc_empty) 7 | from theano.sandbox.cuda.dnn import GpuDnnConvDesc, GpuDnnConv, GpuDnnConvGradI, dnn_conv, dnn_pool 8 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 9 | 10 | from rng import t_rng 11 | 12 | t_rng = RandomStreams() 13 | 14 | def l2normalize(x, axis=1, e=1e-8, keepdims=True): 15 | return x/l2norm(x, axis=axis, e=e, keepdims=keepdims) 16 | 17 | def l2norm(x, axis=1, e=1e-8, keepdims=True): 18 | return T.sqrt(T.sum(T.sqr(x), axis=axis, keepdims=keepdims) + e) 19 | 20 | def cosine(x, y): 21 | d = T.dot(x, y.T) 22 | d /= l2norm(x).dimshuffle(0, 'x') 23 | d /= l2norm(y).dimshuffle('x', 0) 24 | return d 25 | 26 | def euclidean(x, y, e=1e-8): 27 | xx = T.sqr(T.sqrt((x*x).sum(axis=1) + e)) 28 | yy = T.sqr(T.sqrt((y*y).sum(axis=1) + e)) 29 | dist = T.dot(x, y.T) 30 | dist *= -2 31 | dist += xx.dimshuffle(0, 'x') 32 | dist += yy.dimshuffle('x', 0) 33 | dist = T.sqrt(dist) 34 | return dist 35 | 36 | def dropout(X, p=0.): 37 | """ 38 | dropout using activation scaling to avoid test time weight rescaling 39 | """ 40 | if p > 0: 41 | retain_prob = 1 - p 42 | X *= t_rng.binomial(X.shape, p=retain_prob, dtype=theano.config.floatX) 43 | X /= retain_prob 44 | return X 45 | 46 | def conv_cond_concat(x, y): 47 | """ 48 | concatenate conditioning vector on feature map axis 49 | """ 50 | return T.concatenate([x, y*T.ones((x.shape[0], y.shape[1], x.shape[2], x.shape[3]))], axis=1) 51 | 52 | def batchnorm(X, g=None, b=None, u=None, s=None, a=1., e=1e-8): 53 | """ 54 | batchnorm with support for not using scale and shift parameters 55 | as well as inference values (u and s) and partial batchnorm (via a) 56 | will detect and use convolutional or fully connected version 57 | """ 58 | if X.ndim == 4: 59 | if u is not None and s is not None: 60 | b_u = u.dimshuffle('x', 0, 'x', 'x') 61 | b_s = s.dimshuffle('x', 0, 'x', 'x') 62 | else: 63 | b_u = T.mean(X, axis=[0, 2, 3]).dimshuffle('x', 0, 'x', 'x') 64 | b_s = T.mean(T.sqr(X - b_u), axis=[0, 2, 3]).dimshuffle('x', 0, 'x', 'x') 65 | if a != 1: 66 | b_u = (1. - a)*0. + a*b_u 67 | b_s = (1. - a)*1. + a*b_s 68 | X = (X - b_u) / T.sqrt(b_s + e) 69 | if g is not None and b is not None: 70 | X = X*g.dimshuffle('x', 0, 'x', 'x') + b.dimshuffle('x', 0, 'x', 'x') 71 | elif X.ndim == 2: 72 | if u is None and s is None: 73 | u = T.mean(X, axis=0) 74 | s = T.mean(T.sqr(X - u), axis=0) 75 | if a != 1: 76 | u = (1. - a)*0. + a*u 77 | s = (1. - a)*1. + a*s 78 | X = (X - u) / T.sqrt(s + e) 79 | if g is not None and b is not None: 80 | X = X*g + b 81 | else: 82 | raise NotImplementedError 83 | return X 84 | 85 | def deconv(X, w, subsample=(1, 1), border_mode=(0, 0), conv_mode='conv'): 86 | """ 87 | sets up dummy convolutional forward pass and uses its grad as deconv 88 | currently only tested/working with same padding 89 | """ 90 | img = gpu_contiguous(X) 91 | kerns = gpu_contiguous(w) 92 | desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, 93 | conv_mode=conv_mode)(gpu_alloc_empty(img.shape[0], kerns.shape[1], img.shape[2]*subsample[0], img.shape[3]*subsample[1]).shape, kerns.shape) 94 | out = gpu_alloc_empty(img.shape[0], kerns.shape[1], img.shape[2]*subsample[0], img.shape[3]*subsample[1]) 95 | d_img = GpuDnnConvGradI()(kerns, img, out, desc) 96 | return d_img -------------------------------------------------------------------------------- /lib/python3.6/__future__.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/__future__.py -------------------------------------------------------------------------------- /lib/python3.6/_bootlocale.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/_bootlocale.py -------------------------------------------------------------------------------- /lib/python3.6/_collections_abc.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/_collections_abc.py -------------------------------------------------------------------------------- /lib/python3.6/_dummy_thread.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/_dummy_thread.py -------------------------------------------------------------------------------- /lib/python3.6/_weakrefset.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/_weakrefset.py -------------------------------------------------------------------------------- /lib/python3.6/abc.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/abc.py -------------------------------------------------------------------------------- /lib/python3.6/base64.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/base64.py -------------------------------------------------------------------------------- /lib/python3.6/bisect.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/bisect.py -------------------------------------------------------------------------------- /lib/python3.6/codecs.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/codecs.py -------------------------------------------------------------------------------- /lib/python3.6/collections: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/collections -------------------------------------------------------------------------------- /lib/python3.6/config-3.6m-x86_64-linux-gnu: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/config-3.6m-x86_64-linux-gnu -------------------------------------------------------------------------------- /lib/python3.6/copy.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/copy.py -------------------------------------------------------------------------------- /lib/python3.6/copyreg.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/copyreg.py -------------------------------------------------------------------------------- /lib/python3.6/distutils/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | import imp 5 | import opcode # opcode is not a virtualenv module, so we can use it to find the stdlib 6 | # Important! To work on pypy, this must be a module that resides in the 7 | # lib-python/modified-x.y.z directory 8 | 9 | dirname = os.path.dirname 10 | 11 | distutils_path = os.path.join(os.path.dirname(opcode.__file__), 'distutils') 12 | if os.path.normpath(distutils_path) == os.path.dirname(os.path.normpath(__file__)): 13 | warnings.warn( 14 | "The virtualenv distutils package at %s appears to be in the same location as the system distutils?") 15 | else: 16 | __path__.insert(0, distutils_path) 17 | real_distutils = imp.load_module("_virtualenv_distutils", None, distutils_path, ('', '', imp.PKG_DIRECTORY)) 18 | # Copy the relevant attributes 19 | try: 20 | __revision__ = real_distutils.__revision__ 21 | except AttributeError: 22 | pass 23 | __version__ = real_distutils.__version__ 24 | 25 | from distutils import dist, sysconfig 26 | 27 | try: 28 | basestring 29 | except NameError: 30 | basestring = str 31 | 32 | ## patch build_ext (distutils doesn't know how to get the libs directory 33 | ## path on windows - it hardcodes the paths around the patched sys.prefix) 34 | 35 | if sys.platform == 'win32': 36 | from distutils.command.build_ext import build_ext as old_build_ext 37 | class build_ext(old_build_ext): 38 | def finalize_options (self): 39 | if self.library_dirs is None: 40 | self.library_dirs = [] 41 | elif isinstance(self.library_dirs, basestring): 42 | self.library_dirs = self.library_dirs.split(os.pathsep) 43 | 44 | self.library_dirs.insert(0, os.path.join(sys.real_prefix, "Libs")) 45 | old_build_ext.finalize_options(self) 46 | 47 | from distutils.command import build_ext as build_ext_module 48 | build_ext_module.build_ext = build_ext 49 | 50 | ## distutils.dist patches: 51 | 52 | old_find_config_files = dist.Distribution.find_config_files 53 | def find_config_files(self): 54 | found = old_find_config_files(self) 55 | system_distutils = os.path.join(distutils_path, 'distutils.cfg') 56 | #if os.path.exists(system_distutils): 57 | # found.insert(0, system_distutils) 58 | # What to call the per-user config file 59 | if os.name == 'posix': 60 | user_filename = ".pydistutils.cfg" 61 | else: 62 | user_filename = "pydistutils.cfg" 63 | user_filename = os.path.join(sys.prefix, user_filename) 64 | if os.path.isfile(user_filename): 65 | for item in list(found): 66 | if item.endswith('pydistutils.cfg'): 67 | found.remove(item) 68 | found.append(user_filename) 69 | return found 70 | dist.Distribution.find_config_files = find_config_files 71 | 72 | ## distutils.sysconfig patches: 73 | 74 | old_get_python_inc = sysconfig.get_python_inc 75 | def sysconfig_get_python_inc(plat_specific=0, prefix=None): 76 | if prefix is None: 77 | prefix = sys.real_prefix 78 | return old_get_python_inc(plat_specific, prefix) 79 | sysconfig_get_python_inc.__doc__ = old_get_python_inc.__doc__ 80 | sysconfig.get_python_inc = sysconfig_get_python_inc 81 | 82 | old_get_python_lib = sysconfig.get_python_lib 83 | def sysconfig_get_python_lib(plat_specific=0, standard_lib=0, prefix=None): 84 | if standard_lib and prefix is None: 85 | prefix = sys.real_prefix 86 | return old_get_python_lib(plat_specific, standard_lib, prefix) 87 | sysconfig_get_python_lib.__doc__ = old_get_python_lib.__doc__ 88 | sysconfig.get_python_lib = sysconfig_get_python_lib 89 | 90 | old_get_config_vars = sysconfig.get_config_vars 91 | def sysconfig_get_config_vars(*args): 92 | real_vars = old_get_config_vars(*args) 93 | if sys.platform == 'win32': 94 | lib_dir = os.path.join(sys.real_prefix, "libs") 95 | if isinstance(real_vars, dict) and 'LIBDIR' not in real_vars: 96 | real_vars['LIBDIR'] = lib_dir # asked for all 97 | elif isinstance(real_vars, list) and 'LIBDIR' in args: 98 | real_vars = real_vars + [lib_dir] # asked for list 99 | return real_vars 100 | sysconfig_get_config_vars.__doc__ = old_get_config_vars.__doc__ 101 | sysconfig.get_config_vars = sysconfig_get_config_vars 102 | -------------------------------------------------------------------------------- /lib/python3.6/distutils/distutils.cfg: -------------------------------------------------------------------------------- 1 | # This is a config file local to this virtualenv installation 2 | # You may include options that will be used by all distutils commands, 3 | # and by easy_install. For instance: 4 | # 5 | # [easy_install] 6 | # find_links = http://mylocalsite 7 | -------------------------------------------------------------------------------- /lib/python3.6/encodings: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/encodings -------------------------------------------------------------------------------- /lib/python3.6/enum.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/enum.py -------------------------------------------------------------------------------- /lib/python3.6/fnmatch.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/fnmatch.py -------------------------------------------------------------------------------- /lib/python3.6/functools.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/functools.py -------------------------------------------------------------------------------- /lib/python3.6/genericpath.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/genericpath.py -------------------------------------------------------------------------------- /lib/python3.6/hashlib.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/hashlib.py -------------------------------------------------------------------------------- /lib/python3.6/heapq.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/heapq.py -------------------------------------------------------------------------------- /lib/python3.6/hmac.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/hmac.py -------------------------------------------------------------------------------- /lib/python3.6/imp.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/imp.py -------------------------------------------------------------------------------- /lib/python3.6/importlib: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/importlib -------------------------------------------------------------------------------- /lib/python3.6/io.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/io.py -------------------------------------------------------------------------------- /lib/python3.6/keyword.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/keyword.py -------------------------------------------------------------------------------- /lib/python3.6/lib-dynload: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/lib-dynload -------------------------------------------------------------------------------- /lib/python3.6/linecache.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/linecache.py -------------------------------------------------------------------------------- /lib/python3.6/locale.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/locale.py -------------------------------------------------------------------------------- /lib/python3.6/no-global-site-packages.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/heartyguy/ml-image-denoising/2c2d2e454714c890cc53f18c45d8733a6b9aeae8/lib/python3.6/no-global-site-packages.txt -------------------------------------------------------------------------------- /lib/python3.6/ntpath.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/ntpath.py -------------------------------------------------------------------------------- /lib/python3.6/operator.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/operator.py -------------------------------------------------------------------------------- /lib/python3.6/orig-prefix.txt: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3 -------------------------------------------------------------------------------- /lib/python3.6/os.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/os.py -------------------------------------------------------------------------------- /lib/python3.6/posixpath.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/posixpath.py -------------------------------------------------------------------------------- /lib/python3.6/random.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/random.py -------------------------------------------------------------------------------- /lib/python3.6/re.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/re.py -------------------------------------------------------------------------------- /lib/python3.6/reprlib.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/reprlib.py -------------------------------------------------------------------------------- /lib/python3.6/rlcompleter.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/rlcompleter.py -------------------------------------------------------------------------------- /lib/python3.6/shutil.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/shutil.py -------------------------------------------------------------------------------- /lib/python3.6/site.py: -------------------------------------------------------------------------------- 1 | """Append module search paths for third-party packages to sys.path. 2 | 3 | **************************************************************** 4 | * This module is automatically imported during initialization. * 5 | **************************************************************** 6 | 7 | In earlier versions of Python (up to 1.5a3), scripts or modules that 8 | needed to use site-specific modules would place ``import site'' 9 | somewhere near the top of their code. Because of the automatic 10 | import, this is no longer necessary (but code that does it still 11 | works). 12 | 13 | This will append site-specific paths to the module search path. On 14 | Unix, it starts with sys.prefix and sys.exec_prefix (if different) and 15 | appends lib/python/site-packages as well as lib/site-python. 16 | It also supports the Debian convention of 17 | lib/python/dist-packages. On other platforms (mainly Mac and 18 | Windows), it uses just sys.prefix (and sys.exec_prefix, if different, 19 | but this is unlikely). The resulting directories, if they exist, are 20 | appended to sys.path, and also inspected for path configuration files. 21 | 22 | FOR DEBIAN, this sys.path is augmented with directories in /usr/local. 23 | Local addons go into /usr/local/lib/python/site-packages 24 | (resp. /usr/local/lib/site-python), Debian addons install into 25 | /usr/{lib,share}/python/dist-packages. 26 | 27 | A path configuration file is a file whose name has the form 28 | .pth; its contents are additional directories (one per line) 29 | to be added to sys.path. Non-existing directories (or 30 | non-directories) are never added to sys.path; no directory is added to 31 | sys.path more than once. Blank lines and lines beginning with 32 | '#' are skipped. Lines starting with 'import' are executed. 33 | 34 | For example, suppose sys.prefix and sys.exec_prefix are set to 35 | /usr/local and there is a directory /usr/local/lib/python2.X/site-packages 36 | with three subdirectories, foo, bar and spam, and two path 37 | configuration files, foo.pth and bar.pth. Assume foo.pth contains the 38 | following: 39 | 40 | # foo package configuration 41 | foo 42 | bar 43 | bletch 44 | 45 | and bar.pth contains: 46 | 47 | # bar package configuration 48 | bar 49 | 50 | Then the following directories are added to sys.path, in this order: 51 | 52 | /usr/local/lib/python2.X/site-packages/bar 53 | /usr/local/lib/python2.X/site-packages/foo 54 | 55 | Note that bletch is omitted because it doesn't exist; bar precedes foo 56 | because bar.pth comes alphabetically before foo.pth; and spam is 57 | omitted because it is not mentioned in either path configuration file. 58 | 59 | After these path manipulations, an attempt is made to import a module 60 | named sitecustomize, which can perform arbitrary additional 61 | site-specific customizations. If this import fails with an 62 | ImportError exception, it is silently ignored. 63 | 64 | """ 65 | 66 | import sys 67 | import os 68 | try: 69 | import __builtin__ as builtins 70 | except ImportError: 71 | import builtins 72 | try: 73 | set 74 | except NameError: 75 | from sets import Set as set 76 | 77 | # Prefixes for site-packages; add additional prefixes like /usr/local here 78 | PREFIXES = [sys.prefix, sys.exec_prefix] 79 | # Enable per user site-packages directory 80 | # set it to False to disable the feature or True to force the feature 81 | ENABLE_USER_SITE = None 82 | # for distutils.commands.install 83 | USER_SITE = None 84 | USER_BASE = None 85 | 86 | _is_64bit = (getattr(sys, 'maxsize', None) or getattr(sys, 'maxint')) > 2**32 87 | _is_pypy = hasattr(sys, 'pypy_version_info') 88 | _is_jython = sys.platform[:4] == 'java' 89 | if _is_jython: 90 | ModuleType = type(os) 91 | 92 | def makepath(*paths): 93 | dir = os.path.join(*paths) 94 | if _is_jython and (dir == '__classpath__' or 95 | dir.startswith('__pyclasspath__')): 96 | return dir, dir 97 | dir = os.path.abspath(dir) 98 | return dir, os.path.normcase(dir) 99 | 100 | def abs__file__(): 101 | """Set all module' __file__ attribute to an absolute path""" 102 | for m in sys.modules.values(): 103 | if ((_is_jython and not isinstance(m, ModuleType)) or 104 | hasattr(m, '__loader__')): 105 | # only modules need the abspath in Jython. and don't mess 106 | # with a PEP 302-supplied __file__ 107 | continue 108 | f = getattr(m, '__file__', None) 109 | if f is None: 110 | continue 111 | m.__file__ = os.path.abspath(f) 112 | 113 | def removeduppaths(): 114 | """ Remove duplicate entries from sys.path along with making them 115 | absolute""" 116 | # This ensures that the initial path provided by the interpreter contains 117 | # only absolute pathnames, even if we're running from the build directory. 118 | L = [] 119 | known_paths = set() 120 | for dir in sys.path: 121 | # Filter out duplicate paths (on case-insensitive file systems also 122 | # if they only differ in case); turn relative paths into absolute 123 | # paths. 124 | dir, dircase = makepath(dir) 125 | if not dircase in known_paths: 126 | L.append(dir) 127 | known_paths.add(dircase) 128 | sys.path[:] = L 129 | return known_paths 130 | 131 | # XXX This should not be part of site.py, since it is needed even when 132 | # using the -S option for Python. See http://www.python.org/sf/586680 133 | def addbuilddir(): 134 | """Append ./build/lib. in case we're running in the build dir 135 | (especially for Guido :-)""" 136 | from distutils.util import get_platform 137 | s = "build/lib.%s-%.3s" % (get_platform(), sys.version) 138 | if hasattr(sys, 'gettotalrefcount'): 139 | s += '-pydebug' 140 | s = os.path.join(os.path.dirname(sys.path[-1]), s) 141 | sys.path.append(s) 142 | 143 | def _init_pathinfo(): 144 | """Return a set containing all existing directory entries from sys.path""" 145 | d = set() 146 | for dir in sys.path: 147 | try: 148 | if os.path.isdir(dir): 149 | dir, dircase = makepath(dir) 150 | d.add(dircase) 151 | except TypeError: 152 | continue 153 | return d 154 | 155 | def addpackage(sitedir, name, known_paths): 156 | """Add a new path to known_paths by combining sitedir and 'name' or execute 157 | sitedir if it starts with 'import'""" 158 | if known_paths is None: 159 | _init_pathinfo() 160 | reset = 1 161 | else: 162 | reset = 0 163 | fullname = os.path.join(sitedir, name) 164 | try: 165 | f = open(fullname, "rU") 166 | except IOError: 167 | return 168 | try: 169 | for line in f: 170 | if line.startswith("#"): 171 | continue 172 | if line.startswith("import"): 173 | exec(line) 174 | continue 175 | line = line.rstrip() 176 | dir, dircase = makepath(sitedir, line) 177 | if not dircase in known_paths and os.path.exists(dir): 178 | sys.path.append(dir) 179 | known_paths.add(dircase) 180 | finally: 181 | f.close() 182 | if reset: 183 | known_paths = None 184 | return known_paths 185 | 186 | def addsitedir(sitedir, known_paths=None): 187 | """Add 'sitedir' argument to sys.path if missing and handle .pth files in 188 | 'sitedir'""" 189 | if known_paths is None: 190 | known_paths = _init_pathinfo() 191 | reset = 1 192 | else: 193 | reset = 0 194 | sitedir, sitedircase = makepath(sitedir) 195 | if not sitedircase in known_paths: 196 | sys.path.append(sitedir) # Add path component 197 | try: 198 | names = os.listdir(sitedir) 199 | except os.error: 200 | return 201 | names.sort() 202 | for name in names: 203 | if name.endswith(os.extsep + "pth"): 204 | addpackage(sitedir, name, known_paths) 205 | if reset: 206 | known_paths = None 207 | return known_paths 208 | 209 | def addsitepackages(known_paths, sys_prefix=sys.prefix, exec_prefix=sys.exec_prefix): 210 | """Add site-packages (and possibly site-python) to sys.path""" 211 | prefixes = [os.path.join(sys_prefix, "local"), sys_prefix] 212 | if exec_prefix != sys_prefix: 213 | prefixes.append(os.path.join(exec_prefix, "local")) 214 | 215 | for prefix in prefixes: 216 | if prefix: 217 | if sys.platform in ('os2emx', 'riscos') or _is_jython: 218 | sitedirs = [os.path.join(prefix, "Lib", "site-packages")] 219 | elif _is_pypy: 220 | sitedirs = [os.path.join(prefix, 'site-packages')] 221 | elif sys.platform == 'darwin' and prefix == sys_prefix: 222 | 223 | if prefix.startswith("/System/Library/Frameworks/"): # Apple's Python 224 | 225 | sitedirs = [os.path.join("/Library/Python", sys.version[:3], "site-packages"), 226 | os.path.join(prefix, "Extras", "lib", "python")] 227 | 228 | else: # any other Python distros on OSX work this way 229 | sitedirs = [os.path.join(prefix, "lib", 230 | "python" + sys.version[:3], "site-packages")] 231 | 232 | elif os.sep == '/': 233 | sitedirs = [os.path.join(prefix, 234 | "lib", 235 | "python" + sys.version[:3], 236 | "site-packages"), 237 | os.path.join(prefix, "lib", "site-python"), 238 | os.path.join(prefix, "python" + sys.version[:3], "lib-dynload")] 239 | lib64_dir = os.path.join(prefix, "lib64", "python" + sys.version[:3], "site-packages") 240 | if (os.path.exists(lib64_dir) and 241 | os.path.realpath(lib64_dir) not in [os.path.realpath(p) for p in sitedirs]): 242 | if _is_64bit: 243 | sitedirs.insert(0, lib64_dir) 244 | else: 245 | sitedirs.append(lib64_dir) 246 | try: 247 | # sys.getobjects only available in --with-pydebug build 248 | sys.getobjects 249 | sitedirs.insert(0, os.path.join(sitedirs[0], 'debug')) 250 | except AttributeError: 251 | pass 252 | # Debian-specific dist-packages directories: 253 | sitedirs.append(os.path.join(prefix, "local/lib", 254 | "python" + sys.version[:3], 255 | "dist-packages")) 256 | if sys.version[0] == '2': 257 | sitedirs.append(os.path.join(prefix, "lib", 258 | "python" + sys.version[:3], 259 | "dist-packages")) 260 | else: 261 | sitedirs.append(os.path.join(prefix, "lib", 262 | "python" + sys.version[0], 263 | "dist-packages")) 264 | sitedirs.append(os.path.join(prefix, "lib", "dist-python")) 265 | else: 266 | sitedirs = [prefix, os.path.join(prefix, "lib", "site-packages")] 267 | if sys.platform == 'darwin': 268 | # for framework builds *only* we add the standard Apple 269 | # locations. Currently only per-user, but /Library and 270 | # /Network/Library could be added too 271 | if 'Python.framework' in prefix: 272 | home = os.environ.get('HOME') 273 | if home: 274 | sitedirs.append( 275 | os.path.join(home, 276 | 'Library', 277 | 'Python', 278 | sys.version[:3], 279 | 'site-packages')) 280 | for sitedir in sitedirs: 281 | if os.path.isdir(sitedir): 282 | addsitedir(sitedir, known_paths) 283 | return None 284 | 285 | def check_enableusersite(): 286 | """Check if user site directory is safe for inclusion 287 | 288 | The function tests for the command line flag (including environment var), 289 | process uid/gid equal to effective uid/gid. 290 | 291 | None: Disabled for security reasons 292 | False: Disabled by user (command line option) 293 | True: Safe and enabled 294 | """ 295 | if hasattr(sys, 'flags') and getattr(sys.flags, 'no_user_site', False): 296 | return False 297 | 298 | if hasattr(os, "getuid") and hasattr(os, "geteuid"): 299 | # check process uid == effective uid 300 | if os.geteuid() != os.getuid(): 301 | return None 302 | if hasattr(os, "getgid") and hasattr(os, "getegid"): 303 | # check process gid == effective gid 304 | if os.getegid() != os.getgid(): 305 | return None 306 | 307 | return True 308 | 309 | def addusersitepackages(known_paths): 310 | """Add a per user site-package to sys.path 311 | 312 | Each user has its own python directory with site-packages in the 313 | home directory. 314 | 315 | USER_BASE is the root directory for all Python versions 316 | 317 | USER_SITE is the user specific site-packages directory 318 | 319 | USER_SITE/.. can be used for data. 320 | """ 321 | global USER_BASE, USER_SITE, ENABLE_USER_SITE 322 | env_base = os.environ.get("PYTHONUSERBASE", None) 323 | 324 | def joinuser(*args): 325 | return os.path.expanduser(os.path.join(*args)) 326 | 327 | #if sys.platform in ('os2emx', 'riscos'): 328 | # # Don't know what to put here 329 | # USER_BASE = '' 330 | # USER_SITE = '' 331 | if os.name == "nt": 332 | base = os.environ.get("APPDATA") or "~" 333 | if env_base: 334 | USER_BASE = env_base 335 | else: 336 | USER_BASE = joinuser(base, "Python") 337 | USER_SITE = os.path.join(USER_BASE, 338 | "Python" + sys.version[0] + sys.version[2], 339 | "site-packages") 340 | else: 341 | if env_base: 342 | USER_BASE = env_base 343 | else: 344 | USER_BASE = joinuser("~", ".local") 345 | USER_SITE = os.path.join(USER_BASE, "lib", 346 | "python" + sys.version[:3], 347 | "site-packages") 348 | 349 | if ENABLE_USER_SITE and os.path.isdir(USER_SITE): 350 | addsitedir(USER_SITE, known_paths) 351 | if ENABLE_USER_SITE: 352 | for dist_libdir in ("lib", "local/lib"): 353 | user_site = os.path.join(USER_BASE, dist_libdir, 354 | "python" + sys.version[:3], 355 | "dist-packages") 356 | if os.path.isdir(user_site): 357 | addsitedir(user_site, known_paths) 358 | return known_paths 359 | 360 | 361 | 362 | def setBEGINLIBPATH(): 363 | """The OS/2 EMX port has optional extension modules that do double duty 364 | as DLLs (and must use the .DLL file extension) for other extensions. 365 | The library search path needs to be amended so these will be found 366 | during module import. Use BEGINLIBPATH so that these are at the start 367 | of the library search path. 368 | 369 | """ 370 | dllpath = os.path.join(sys.prefix, "Lib", "lib-dynload") 371 | libpath = os.environ['BEGINLIBPATH'].split(';') 372 | if libpath[-1]: 373 | libpath.append(dllpath) 374 | else: 375 | libpath[-1] = dllpath 376 | os.environ['BEGINLIBPATH'] = ';'.join(libpath) 377 | 378 | 379 | def setquit(): 380 | """Define new built-ins 'quit' and 'exit'. 381 | These are simply strings that display a hint on how to exit. 382 | 383 | """ 384 | if os.sep == ':': 385 | eof = 'Cmd-Q' 386 | elif os.sep == '\\': 387 | eof = 'Ctrl-Z plus Return' 388 | else: 389 | eof = 'Ctrl-D (i.e. EOF)' 390 | 391 | class Quitter(object): 392 | def __init__(self, name): 393 | self.name = name 394 | def __repr__(self): 395 | return 'Use %s() or %s to exit' % (self.name, eof) 396 | def __call__(self, code=None): 397 | # Shells like IDLE catch the SystemExit, but listen when their 398 | # stdin wrapper is closed. 399 | try: 400 | sys.stdin.close() 401 | except: 402 | pass 403 | raise SystemExit(code) 404 | builtins.quit = Quitter('quit') 405 | builtins.exit = Quitter('exit') 406 | 407 | 408 | class _Printer(object): 409 | """interactive prompt objects for printing the license text, a list of 410 | contributors and the copyright notice.""" 411 | 412 | MAXLINES = 23 413 | 414 | def __init__(self, name, data, files=(), dirs=()): 415 | self.__name = name 416 | self.__data = data 417 | self.__files = files 418 | self.__dirs = dirs 419 | self.__lines = None 420 | 421 | def __setup(self): 422 | if self.__lines: 423 | return 424 | data = None 425 | for dir in self.__dirs: 426 | for filename in self.__files: 427 | filename = os.path.join(dir, filename) 428 | try: 429 | fp = open(filename, "rU") 430 | data = fp.read() 431 | fp.close() 432 | break 433 | except IOError: 434 | pass 435 | if data: 436 | break 437 | if not data: 438 | data = self.__data 439 | self.__lines = data.split('\n') 440 | self.__linecnt = len(self.__lines) 441 | 442 | def __repr__(self): 443 | self.__setup() 444 | if len(self.__lines) <= self.MAXLINES: 445 | return "\n".join(self.__lines) 446 | else: 447 | return "Type %s() to see the full %s text" % ((self.__name,)*2) 448 | 449 | def __call__(self): 450 | self.__setup() 451 | prompt = 'Hit Return for more, or q (and Return) to quit: ' 452 | lineno = 0 453 | while 1: 454 | try: 455 | for i in range(lineno, lineno + self.MAXLINES): 456 | print(self.__lines[i]) 457 | except IndexError: 458 | break 459 | else: 460 | lineno += self.MAXLINES 461 | key = None 462 | while key is None: 463 | try: 464 | key = raw_input(prompt) 465 | except NameError: 466 | key = input(prompt) 467 | if key not in ('', 'q'): 468 | key = None 469 | if key == 'q': 470 | break 471 | 472 | def setcopyright(): 473 | """Set 'copyright' and 'credits' in __builtin__""" 474 | builtins.copyright = _Printer("copyright", sys.copyright) 475 | if _is_jython: 476 | builtins.credits = _Printer( 477 | "credits", 478 | "Jython is maintained by the Jython developers (www.jython.org).") 479 | elif _is_pypy: 480 | builtins.credits = _Printer( 481 | "credits", 482 | "PyPy is maintained by the PyPy developers: http://pypy.org/") 483 | else: 484 | builtins.credits = _Printer("credits", """\ 485 | Thanks to CWI, CNRI, BeOpen.com, Zope Corporation and a cast of thousands 486 | for supporting Python development. See www.python.org for more information.""") 487 | here = os.path.dirname(os.__file__) 488 | builtins.license = _Printer( 489 | "license", "See http://www.python.org/%.3s/license.html" % sys.version, 490 | ["LICENSE.txt", "LICENSE"], 491 | [os.path.join(here, os.pardir), here, os.curdir]) 492 | 493 | 494 | class _Helper(object): 495 | """Define the built-in 'help'. 496 | This is a wrapper around pydoc.help (with a twist). 497 | 498 | """ 499 | 500 | def __repr__(self): 501 | return "Type help() for interactive help, " \ 502 | "or help(object) for help about object." 503 | def __call__(self, *args, **kwds): 504 | import pydoc 505 | return pydoc.help(*args, **kwds) 506 | 507 | def sethelper(): 508 | builtins.help = _Helper() 509 | 510 | def aliasmbcs(): 511 | """On Windows, some default encodings are not provided by Python, 512 | while they are always available as "mbcs" in each locale. Make 513 | them usable by aliasing to "mbcs" in such a case.""" 514 | if sys.platform == 'win32': 515 | import locale, codecs 516 | enc = locale.getdefaultlocale()[1] 517 | if enc.startswith('cp'): # "cp***" ? 518 | try: 519 | codecs.lookup(enc) 520 | except LookupError: 521 | import encodings 522 | encodings._cache[enc] = encodings._unknown 523 | encodings.aliases.aliases[enc] = 'mbcs' 524 | 525 | def setencoding(): 526 | """Set the string encoding used by the Unicode implementation. The 527 | default is 'ascii', but if you're willing to experiment, you can 528 | change this.""" 529 | encoding = "ascii" # Default value set by _PyUnicode_Init() 530 | if 0: 531 | # Enable to support locale aware default string encodings. 532 | import locale 533 | loc = locale.getdefaultlocale() 534 | if loc[1]: 535 | encoding = loc[1] 536 | if 0: 537 | # Enable to switch off string to Unicode coercion and implicit 538 | # Unicode to string conversion. 539 | encoding = "undefined" 540 | if encoding != "ascii": 541 | # On Non-Unicode builds this will raise an AttributeError... 542 | sys.setdefaultencoding(encoding) # Needs Python Unicode build ! 543 | 544 | 545 | def execsitecustomize(): 546 | """Run custom site specific code, if available.""" 547 | try: 548 | import sitecustomize 549 | except ImportError: 550 | pass 551 | 552 | def virtual_install_main_packages(): 553 | f = open(os.path.join(os.path.dirname(__file__), 'orig-prefix.txt')) 554 | sys.real_prefix = f.read().strip() 555 | f.close() 556 | pos = 2 557 | hardcoded_relative_dirs = [] 558 | if sys.path[0] == '': 559 | pos += 1 560 | if _is_jython: 561 | paths = [os.path.join(sys.real_prefix, 'Lib')] 562 | elif _is_pypy: 563 | if sys.version_info > (3, 2): 564 | cpyver = '%d' % sys.version_info[0] 565 | elif sys.pypy_version_info >= (1, 5): 566 | cpyver = '%d.%d' % sys.version_info[:2] 567 | else: 568 | cpyver = '%d.%d.%d' % sys.version_info[:3] 569 | paths = [os.path.join(sys.real_prefix, 'lib_pypy'), 570 | os.path.join(sys.real_prefix, 'lib-python', cpyver)] 571 | if sys.pypy_version_info < (1, 9): 572 | paths.insert(1, os.path.join(sys.real_prefix, 573 | 'lib-python', 'modified-%s' % cpyver)) 574 | hardcoded_relative_dirs = paths[:] # for the special 'darwin' case below 575 | # 576 | # This is hardcoded in the Python executable, but relative to sys.prefix: 577 | for path in paths[:]: 578 | plat_path = os.path.join(path, 'plat-%s' % sys.platform) 579 | if os.path.exists(plat_path): 580 | paths.append(plat_path) 581 | elif sys.platform == 'win32': 582 | paths = [os.path.join(sys.real_prefix, 'Lib'), os.path.join(sys.real_prefix, 'DLLs')] 583 | else: 584 | paths = [os.path.join(sys.real_prefix, 'lib', 'python'+sys.version[:3])] 585 | hardcoded_relative_dirs = paths[:] # for the special 'darwin' case below 586 | lib64_path = os.path.join(sys.real_prefix, 'lib64', 'python'+sys.version[:3]) 587 | if os.path.exists(lib64_path): 588 | if _is_64bit: 589 | paths.insert(0, lib64_path) 590 | else: 591 | paths.append(lib64_path) 592 | # This is hardcoded in the Python executable, but relative to 593 | # sys.prefix. Debian change: we need to add the multiarch triplet 594 | # here, which is where the real stuff lives. As per PEP 421, in 595 | # Python 3.3+, this lives in sys.implementation, while in Python 2.7 596 | # it lives in sys. 597 | try: 598 | arch = getattr(sys, 'implementation', sys)._multiarch 599 | except AttributeError: 600 | # This is a non-multiarch aware Python. Fallback to the old way. 601 | arch = sys.platform 602 | plat_path = os.path.join(sys.real_prefix, 'lib', 603 | 'python'+sys.version[:3], 604 | 'plat-%s' % arch) 605 | if os.path.exists(plat_path): 606 | paths.append(plat_path) 607 | # This is hardcoded in the Python executable, but 608 | # relative to sys.prefix, so we have to fix up: 609 | for path in list(paths): 610 | tk_dir = os.path.join(path, 'lib-tk') 611 | if os.path.exists(tk_dir): 612 | paths.append(tk_dir) 613 | 614 | # These are hardcoded in the Apple's Python executable, 615 | # but relative to sys.prefix, so we have to fix them up: 616 | if sys.platform == 'darwin': 617 | hardcoded_paths = [os.path.join(relative_dir, module) 618 | for relative_dir in hardcoded_relative_dirs 619 | for module in ('plat-darwin', 'plat-mac', 'plat-mac/lib-scriptpackages')] 620 | 621 | for path in hardcoded_paths: 622 | if os.path.exists(path): 623 | paths.append(path) 624 | 625 | sys.path.extend(paths) 626 | 627 | def force_global_eggs_after_local_site_packages(): 628 | """ 629 | Force easy_installed eggs in the global environment to get placed 630 | in sys.path after all packages inside the virtualenv. This 631 | maintains the "least surprise" result that packages in the 632 | virtualenv always mask global packages, never the other way 633 | around. 634 | 635 | """ 636 | egginsert = getattr(sys, '__egginsert', 0) 637 | for i, path in enumerate(sys.path): 638 | if i > egginsert and path.startswith(sys.prefix): 639 | egginsert = i 640 | sys.__egginsert = egginsert + 1 641 | 642 | def virtual_addsitepackages(known_paths): 643 | force_global_eggs_after_local_site_packages() 644 | return addsitepackages(known_paths, sys_prefix=sys.real_prefix) 645 | 646 | def fixclasspath(): 647 | """Adjust the special classpath sys.path entries for Jython. These 648 | entries should follow the base virtualenv lib directories. 649 | """ 650 | paths = [] 651 | classpaths = [] 652 | for path in sys.path: 653 | if path == '__classpath__' or path.startswith('__pyclasspath__'): 654 | classpaths.append(path) 655 | else: 656 | paths.append(path) 657 | sys.path = paths 658 | sys.path.extend(classpaths) 659 | 660 | def execusercustomize(): 661 | """Run custom user specific code, if available.""" 662 | try: 663 | import usercustomize 664 | except ImportError: 665 | pass 666 | 667 | 668 | def main(): 669 | global ENABLE_USER_SITE 670 | virtual_install_main_packages() 671 | abs__file__() 672 | paths_in_sys = removeduppaths() 673 | if (os.name == "posix" and sys.path and 674 | os.path.basename(sys.path[-1]) == "Modules"): 675 | addbuilddir() 676 | if _is_jython: 677 | fixclasspath() 678 | GLOBAL_SITE_PACKAGES = not os.path.exists(os.path.join(os.path.dirname(__file__), 'no-global-site-packages.txt')) 679 | if not GLOBAL_SITE_PACKAGES: 680 | ENABLE_USER_SITE = False 681 | if ENABLE_USER_SITE is None: 682 | ENABLE_USER_SITE = check_enableusersite() 683 | paths_in_sys = addsitepackages(paths_in_sys) 684 | paths_in_sys = addusersitepackages(paths_in_sys) 685 | if GLOBAL_SITE_PACKAGES: 686 | paths_in_sys = virtual_addsitepackages(paths_in_sys) 687 | if sys.platform == 'os2emx': 688 | setBEGINLIBPATH() 689 | setquit() 690 | setcopyright() 691 | sethelper() 692 | aliasmbcs() 693 | setencoding() 694 | execsitecustomize() 695 | if ENABLE_USER_SITE: 696 | execusercustomize() 697 | # Remove sys.setdefaultencoding() so that users cannot change the 698 | # encoding after initialization. The test for presence is needed when 699 | # this module is run as a script, because this code is executed twice. 700 | if hasattr(sys, "setdefaultencoding"): 701 | del sys.setdefaultencoding 702 | 703 | main() 704 | 705 | def _script(): 706 | help = """\ 707 | %s [--user-base] [--user-site] 708 | 709 | Without arguments print some useful information 710 | With arguments print the value of USER_BASE and/or USER_SITE separated 711 | by '%s'. 712 | 713 | Exit codes with --user-base or --user-site: 714 | 0 - user site directory is enabled 715 | 1 - user site directory is disabled by user 716 | 2 - uses site directory is disabled by super user 717 | or for security reasons 718 | >2 - unknown error 719 | """ 720 | args = sys.argv[1:] 721 | if not args: 722 | print("sys.path = [") 723 | for dir in sys.path: 724 | print(" %r," % (dir,)) 725 | print("]") 726 | def exists(path): 727 | if os.path.isdir(path): 728 | return "exists" 729 | else: 730 | return "doesn't exist" 731 | print("USER_BASE: %r (%s)" % (USER_BASE, exists(USER_BASE))) 732 | print("USER_SITE: %r (%s)" % (USER_SITE, exists(USER_BASE))) 733 | print("ENABLE_USER_SITE: %r" % ENABLE_USER_SITE) 734 | sys.exit(0) 735 | 736 | buffer = [] 737 | if '--user-base' in args: 738 | buffer.append(USER_BASE) 739 | if '--user-site' in args: 740 | buffer.append(USER_SITE) 741 | 742 | if buffer: 743 | print(os.pathsep.join(buffer)) 744 | if ENABLE_USER_SITE: 745 | sys.exit(0) 746 | elif ENABLE_USER_SITE is False: 747 | sys.exit(1) 748 | elif ENABLE_USER_SITE is None: 749 | sys.exit(2) 750 | else: 751 | sys.exit(3) 752 | else: 753 | import textwrap 754 | print(textwrap.dedent(help % (sys.argv[0], os.pathsep))) 755 | sys.exit(10) 756 | 757 | if __name__ == '__main__': 758 | _script() 759 | -------------------------------------------------------------------------------- /lib/python3.6/sre_compile.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/sre_compile.py -------------------------------------------------------------------------------- /lib/python3.6/sre_constants.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/sre_constants.py -------------------------------------------------------------------------------- /lib/python3.6/sre_parse.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/sre_parse.py -------------------------------------------------------------------------------- /lib/python3.6/stat.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/stat.py -------------------------------------------------------------------------------- /lib/python3.6/struct.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/struct.py -------------------------------------------------------------------------------- /lib/python3.6/tarfile.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/tarfile.py -------------------------------------------------------------------------------- /lib/python3.6/tempfile.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/tempfile.py -------------------------------------------------------------------------------- /lib/python3.6/token.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/token.py -------------------------------------------------------------------------------- /lib/python3.6/tokenize.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/tokenize.py -------------------------------------------------------------------------------- /lib/python3.6/types.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/types.py -------------------------------------------------------------------------------- /lib/python3.6/warnings.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/warnings.py -------------------------------------------------------------------------------- /lib/python3.6/weakref.py: -------------------------------------------------------------------------------- 1 | /home/tian/miniconda3/lib/python3.6/weakref.py -------------------------------------------------------------------------------- /lib/rng.py: -------------------------------------------------------------------------------- 1 | from numpy.random import RandomState 2 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 3 | from random import Random 4 | 5 | seed = 42 6 | 7 | py_rng = Random(seed) 8 | np_rng = RandomState(seed) 9 | t_rng = RandomStreams(seed) 10 | 11 | def set_seed(n): 12 | global seed, py_rng, np_rng, t_rng 13 | 14 | seed = n 15 | py_rng = Random(seed) 16 | np_rng = RandomState(seed) 17 | t_rng = RandomStreams(seed) 18 | -------------------------------------------------------------------------------- /lib/theano_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano 3 | 4 | def intX(X): 5 | return np.asarray(X, dtype=np.int32) 6 | 7 | def floatX(X): 8 | return np.asarray(X, dtype=theano.config.floatX) 9 | 10 | def sharedX(X, dtype=theano.config.floatX, name=None): 11 | return theano.shared(np.asarray(X, dtype=dtype), name=name) 12 | 13 | def shared0s(shape, dtype=theano.config.floatX, name=None): 14 | return sharedX(np.zeros(shape), dtype=dtype, name=name) 15 | 16 | def sharedNs(shape, n, dtype=theano.config.floatX, name=None): 17 | return sharedX(np.ones(shape)*n, dtype=dtype, name=name) -------------------------------------------------------------------------------- /lib/updates.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from theano_utils import shared0s, floatX, sharedX 6 | from ops import l2norm 7 | 8 | def clip_norm(g, c, n): 9 | if c > 0: 10 | g = T.switch(T.ge(n, c), g*c/n, g) 11 | return g 12 | 13 | def clip_norms(gs, c): 14 | norm = T.sqrt(sum([T.sum(g**2) for g in gs])) 15 | return [clip_norm(g, c, norm) for g in gs] 16 | 17 | class Regularizer(object): 18 | 19 | def __init__(self, l1=0., l2=0., maxnorm=0., l2norm=False, frobnorm=False): 20 | self.__dict__.update(locals()) 21 | 22 | def max_norm(self, p, maxnorm): 23 | if maxnorm > 0: 24 | norms = T.sqrt(T.sum(T.sqr(p), axis=0)) 25 | desired = T.clip(norms, 0, maxnorm) 26 | p = p * (desired/ (1e-7 + norms)) 27 | return p 28 | 29 | def l2_norm(self, p): 30 | return p/l2norm(p, axis=0) 31 | 32 | def frob_norm(self, p, nrows): 33 | return (p/T.sqrt(T.sum(T.sqr(p))))*T.sqrt(nrows) 34 | 35 | def gradient_regularize(self, p, g): 36 | g += p * self.l2 37 | g += T.sgn(p) * self.l1 38 | return g 39 | 40 | def weight_regularize(self, p): 41 | p = self.max_norm(p, self.maxnorm) 42 | if self.l2norm: 43 | p = self.l2_norm(p) 44 | if self.frobnorm > 0: 45 | p = self.frob_norm(p, self.frobnorm) 46 | return p 47 | 48 | 49 | class Update(object): 50 | 51 | def __init__(self, regularizer=Regularizer(), clipnorm=0.): 52 | self.__dict__.update(locals()) 53 | 54 | def __call__(self, params, grads): 55 | raise NotImplementedError 56 | 57 | class SGD(Update): 58 | 59 | def __init__(self, lr=0.01, *args, **kwargs): 60 | Update.__init__(self, *args, **kwargs) 61 | self.__dict__.update(locals()) 62 | 63 | def __call__(self, params, cost): 64 | updates = [] 65 | grads = T.grad(cost, params) 66 | grads = clip_norms(grads, self.clipnorm) 67 | for p,g in zip(params,grads): 68 | g = self.regularizer.gradient_regularize(p, g) 69 | updated_p = p - self.lr * g 70 | updated_p = self.regularizer.weight_regularize(updated_p) 71 | updates.append((p, updated_p)) 72 | return updates 73 | 74 | class Momentum(Update): 75 | 76 | def __init__(self, lr=0.01, momentum=0.9, *args, **kwargs): 77 | Update.__init__(self, *args, **kwargs) 78 | self.__dict__.update(locals()) 79 | 80 | def __call__(self, params, cost): 81 | updates = [] 82 | grads = T.grad(cost, params) 83 | grads = clip_norms(grads, self.clipnorm) 84 | for p,g in zip(params,grads): 85 | g = self.regularizer.gradient_regularize(p, g) 86 | m = theano.shared(p.get_value() * 0.) 87 | v = (self.momentum * m) - (self.lr * g) 88 | updates.append((m, v)) 89 | 90 | updated_p = p + v 91 | updated_p = self.regularizer.weight_regularize(updated_p) 92 | updates.append((p, updated_p)) 93 | return updates 94 | 95 | 96 | class NAG(Update): 97 | 98 | def __init__(self, lr=0.01, momentum=0.9, *args, **kwargs): 99 | Update.__init__(self, *args, **kwargs) 100 | self.__dict__.update(locals()) 101 | 102 | def __call__(self, params, cost): 103 | updates = [] 104 | grads = T.grad(cost, params) 105 | grads = clip_norms(grads, self.clipnorm) 106 | for p, g in zip(params, grads): 107 | g = self.regularizer.gradient_regularize(p, g) 108 | m = theano.shared(p.get_value() * 0.) 109 | v = (self.momentum * m) - (self.lr * g) 110 | 111 | updated_p = p + self.momentum * v - self.lr * g 112 | updated_p = self.regularizer.weight_regularize(updated_p) 113 | updates.append((m,v)) 114 | updates.append((p, updated_p)) 115 | return updates 116 | 117 | 118 | class RMSprop(Update): 119 | 120 | def __init__(self, lr=0.001, rho=0.9, epsilon=1e-6, *args, **kwargs): 121 | Update.__init__(self, *args, **kwargs) 122 | self.__dict__.update(locals()) 123 | 124 | def __call__(self, params, cost): 125 | updates = [] 126 | grads = T.grad(cost, params) 127 | grads = clip_norms(grads, self.clipnorm) 128 | for p,g in zip(params,grads): 129 | g = self.regularizer.gradient_regularize(p, g) 130 | acc = theano.shared(p.get_value() * 0.) 131 | acc_new = self.rho * acc + (1 - self.rho) * g ** 2 132 | updates.append((acc, acc_new)) 133 | 134 | updated_p = p - self.lr * (g / T.sqrt(acc_new + self.epsilon)) 135 | updated_p = self.regularizer.weight_regularize(updated_p) 136 | updates.append((p, updated_p)) 137 | return updates 138 | 139 | 140 | class Adam(Update): 141 | 142 | def __init__(self, lr=0.001, b1=0.9, b2=0.999, e=1e-8, l=1-1e-8, *args, **kwargs): 143 | Update.__init__(self, *args, **kwargs) 144 | self.__dict__.update(locals()) 145 | 146 | def __call__(self, params, cost): 147 | updates = [] 148 | grads = T.grad(cost, params) 149 | grads = clip_norms(grads, self.clipnorm) 150 | t = theano.shared(floatX(1.)) 151 | b1_t = self.b1*self.l**(t-1) 152 | 153 | for p, g in zip(params, grads): 154 | g = self.regularizer.gradient_regularize(p, g) 155 | m = theano.shared(p.get_value() * 0.) 156 | v = theano.shared(p.get_value() * 0.) 157 | 158 | m_t = b1_t*m + (1 - b1_t)*g 159 | v_t = self.b2*v + (1 - self.b2)*g**2 160 | m_c = m_t / (1-self.b1**t) 161 | v_c = v_t / (1-self.b2**t) 162 | p_t = p - (self.lr * m_c) / (T.sqrt(v_c) + self.e) 163 | p_t = self.regularizer.weight_regularize(p_t) 164 | updates.append((m, m_t)) 165 | updates.append((v, v_t)) 166 | updates.append((p, p_t) ) 167 | updates.append((t, t + 1.)) 168 | return updates 169 | 170 | 171 | class Adagrad(Update): 172 | 173 | def __init__(self, lr=0.01, epsilon=1e-6, *args, **kwargs): 174 | Update.__init__(self, *args, **kwargs) 175 | self.__dict__.update(locals()) 176 | 177 | def __call__(self, params, cost): 178 | updates = [] 179 | grads = T.grad(cost, params) 180 | grads = clip_norms(grads, self.clipnorm) 181 | for p,g in zip(params,grads): 182 | g = self.regularizer.gradient_regularize(p, g) 183 | acc = theano.shared(p.get_value() * 0.) 184 | acc_t = acc + g ** 2 185 | updates.append((acc, acc_t)) 186 | 187 | p_t = p - (self.lr / T.sqrt(acc_t + self.epsilon)) * g 188 | p_t = self.regularizer.weight_regularize(p_t) 189 | updates.append((p, p_t)) 190 | return updates 191 | 192 | 193 | class Adadelta(Update): 194 | 195 | def __init__(self, lr=0.5, rho=0.95, epsilon=1e-6, *args, **kwargs): 196 | Update.__init__(self, *args, **kwargs) 197 | self.__dict__.update(locals()) 198 | 199 | def __call__(self, params, cost): 200 | updates = [] 201 | grads = T.grad(cost, params) 202 | grads = clip_norms(grads, self.clipnorm) 203 | for p,g in zip(params,grads): 204 | g = self.regularizer.gradient_regularize(p, g) 205 | 206 | acc = theano.shared(p.get_value() * 0.) 207 | acc_delta = theano.shared(p.get_value() * 0.) 208 | acc_new = self.rho * acc + (1 - self.rho) * g ** 2 209 | updates.append((acc,acc_new)) 210 | 211 | update = g * T.sqrt(acc_delta + self.epsilon) / T.sqrt(acc_new + self.epsilon) 212 | updated_p = p - self.lr * update 213 | updated_p = self.regularizer.weight_regularize(updated_p) 214 | updates.append((p, updated_p)) 215 | 216 | acc_delta_new = self.rho * acc_delta + (1 - self.rho) * update ** 2 217 | updates.append((acc_delta,acc_delta_new)) 218 | return updates 219 | 220 | 221 | class NoUpdate(Update): 222 | 223 | def __init__(self, lr=0.01, momentum=0.9, *args, **kwargs): 224 | Update.__init__(self, *args, **kwargs) 225 | self.__dict__.update(locals()) 226 | 227 | def __call__(self, params, cost): 228 | updates = [] 229 | for p in params: 230 | updates.append((p, p)) 231 | return updates 232 | -------------------------------------------------------------------------------- /lib/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.misc import imsave 3 | 4 | def grayscale_grid_vis(X, (nh, nw), save_path=None): 5 | h, w = X[0].shape[:2] 6 | img = np.zeros((h*nh, w*nw)) 7 | for n, x in enumerate(X): 8 | j = n/nw 9 | i = n%nw 10 | img[j*h:j*h+h, i*w:i*w+w] = x 11 | if save_path is not None: 12 | imsave(save_path, img) 13 | return img 14 | 15 | def color_grid_vis(X, (nh, nw), save_path=None): 16 | h, w = X[0].shape[:2] 17 | img = np.zeros((h*nh, w*nw, 3)) 18 | for n, x in enumerate(X): 19 | j = n/nw 20 | i = n%nw 21 | img[j*h:j*h+h, i*w:i*w+w, :] = x 22 | if save_path is not None: 23 | imsave(save_path, img) 24 | return img 25 | 26 | def grayscale_weight_grid_vis(w, (nh, nw), save_path=None): 27 | w = (w+w.min())/(w.max()-w.min()) 28 | return grayscale_grid_vis(w, (nh, nw), save_path=save_path) -------------------------------------------------------------------------------- /nmf_denoising.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [] 11 | } 12 | ], 13 | "metadata": { 14 | "kernelspec": { 15 | "display_name": "R", 16 | "language": "R", 17 | "name": "ir" 18 | }, 19 | "language_info": { 20 | "codemirror_mode": "r", 21 | "file_extension": ".r", 22 | "mimetype": "text/x-r-source", 23 | "name": "R", 24 | "pygments_lexer": "r", 25 | "version": "3.3.2" 26 | } 27 | }, 28 | "nbformat": 4, 29 | "nbformat_minor": 2 30 | } 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | #blocks>=0.0.1 2 | #fuel>=0.1.1 3 | h5py>=2.5.0 4 | ipython>=4.0.0 5 | ipython-genutils>=0.1.0 6 | numpy>=1.10.2 7 | Pillow>=3.0.0 8 | Theano==0.7.0 9 | tqdm>=3.4.0 10 | --------------------------------------------------------------------------------