├── LICENSE.txt ├── README.md ├── environment.yml ├── keras.json.for_TensorFlow ├── keras.json.for_Theano ├── nbs ├── char-rnn.ipynb ├── convolution-intro.ipynb ├── dogs_cats_redux.ipynb ├── dogscats-ensemble.ipynb ├── imagenet_batchnorm.ipynb ├── lesson1.ipynb ├── lesson2.ipynb ├── lesson3.ipynb ├── lesson4.ipynb ├── lesson5.ipynb ├── lesson6.ipynb ├── lesson7.ipynb ├── mnist.ipynb ├── resnet50.py ├── sgd-intro.ipynb ├── statefarm-sample.ipynb ├── statefarm.ipynb ├── utils.py ├── vgg16.py ├── vgg16bn.py └── wordvectors.ipynb └── nbs2 ├── DCGAN.ipynb ├── Keras-Tensorflow-Tutorial.ipynb ├── attention_wrapper.py ├── babi-memnn.ipynb ├── batcher.py ├── bcolz_array_iterator.py ├── bcolz_iter_test.ipynb ├── dcgan.py ├── densenet-keras.ipynb ├── imagenet_process.ipynb ├── kmeans.py ├── kmeans_test.ipynb ├── meanshift.ipynb ├── neural-sr.ipynb ├── neural-style-pytorch.ipynb ├── neural-style.ipynb ├── pytorch-tut.ipynb ├── rossman.ipynb ├── rossman_exp.py ├── seq2seq-translation.ipynb ├── spelling_bee_RNN.ipynb ├── taxi.ipynb ├── taxi_data_prep_and_mlp.ipynb ├── tf-basics.ipynb ├── tiramisu-keras.ipynb ├── tiramisu-pytorch.ipynb ├── torch_utils.py ├── translate-pytorch.ipynb ├── translate.ipynb ├── utils2.py ├── vgg16.py ├── vgg16_avg.py └── wgan-pytorch.ipynb /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Modified notebooks and Python files for Keras 2 and Python 3 from the fast.ai Deep Learning course v.1 2 | The repository includes modified copies of the original Jupyter notebooks and Python files from the excellent 3 | (and really unique) deep learning course "Practical Deep Learning For Coders" Part 1 and Part 2, v.1, 4 | created by [fast.ai](http://fast.ai). 5 | 6 | The [original files](https://github.com/fastai/courses) require Keras 1. One main goal has been to modify the original files to the minimum extent possible. The comments added to the modules generally start with *"# -"* when they are not just *"# Keras 2"*. 7 | 8 | The current version of the repository has been tested with **_Keras 2.1.2_**. 9 | The previous version, tested with _Keras 2.0.6_, is available [here](https://github.com/roebius/deeplearning_keras2/releases). 10 | ### Part 1 11 | Located in the _nbs_ folder. Tested on _Ubuntu 16.04_ and _Python 3.5_, installed through [Anaconda](https://www.anaconda.com), using the [Theano](http://deeplearning.net/software/theano/) 1.0.1 backend. 12 | 13 | ### Part 2 14 | Located in the _nbs2_ folder. Tested on _Ubuntu 16.04_ and _Python 3.5_, installed through [Anaconda](https://www.anaconda.com), using the [TensorFlow](https://www.tensorflow.org/) 1.3.0 backend. 15 | A few modules requiring PyTorch were also tested, using [PyTorch](http://pytorch.org/) 0.3.0. 16 | 17 | The files _keras.json.for\_TensorFlow_ and _keras.json.for\_Theano_ provide a template for the appropriate _keras.json_ file, based on which one of the two backends needs to be used by Keras. 18 | 19 | An _environment.yml_ file for creating a suitable [conda environment](https://conda.io/docs/user-guide/tasks/manage-environments.html) is provided. 20 | 21 | 22 | ### Notes and issues about Part 2 23 | *neural-style.ipynb*: due to a function parameter change in _Keras 2.1_, the _VGG16_ provided by _Keras 2.1_ has been used instead of the original custom module _vgg16\_avg.py_ 24 | 25 | *rossman.ipynb*: section "Using 3rd place data" has been left out for lack of the required data 26 | 27 | *spelling_bee_RNN.ipynb* and *attention_wrapper.py*: due to the changed implementation of the recurrent.py module in Keras 2.1, the attention part of the notebook doesn't work anymore 28 | 29 | *taxi_data_prep_and_mlp.ipynb*: section "Uh oh ..." has been left out. Caveat: running all the notebook at once exhausted 128 GB RAM; I was able to run each section individually only after resetting the notebook kernel each time 30 | 31 | *tiramisu-keras.ipynb*: in order to run the larger size model I had to reset the notebook kernel in order to free up enough GPU memory (almost 12 GB) and jump directly to the model 32 | 33 | 34 | #### Left-out modules 35 | *neural-style-pytorch.ipynb* (found no way to load the VGG weights; it looks like some version compatibility issue) 36 | 37 | *rossman_exp.py* 38 | 39 | *seq2seq-translation.ipynb* 40 | 41 | *taxi.ipynb* 42 | 43 | *tiramisu-pytorch.ipynb* 44 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: p3 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - backports.weakref=1.0rc1=py35_0 8 | - bleach=1.5.0=py35_0 9 | - distributed=1.20.2=py35_0 10 | - html5lib=0.9999999=py35_0 11 | - jupyter_contrib_core=0.3.3=py35_1 12 | - jupyter_nbextensions_configurator=0.3.0=py35_0 13 | - markdown=2.6.9=py35_0 14 | - asn1crypto=0.23.0=py35h4ab26a5_0 15 | - backports=1.0=py35hd471ac7_1 16 | - bcolz=1.1.2=py35hcb27967_0 17 | - binutils_impl_linux-64=2.28.1=h04c84fa_2 18 | - binutils_linux-64=7.2.0=25 19 | - bokeh=0.12.13=py35h2f9c1c0_0 20 | - boto=2.48.0=py35h2cfd601_1 21 | - bz2file=0.98=py35_0 22 | - bzip2=1.0.6=h6d464ef_2 23 | - ca-certificates=2017.08.26=h1d4fec5_0 24 | - certifi=2017.11.5=py35h9749603_0 25 | - cffi=1.11.2=py35hc7b2db7_0 26 | - chardet=3.0.4=py35hb6e9ddf_1 27 | - click=6.7=py35h353a69f_0 28 | - cloudpickle=0.5.2=py35hbe86bc5_0 29 | - cryptography=2.1.4=py35hbeb2da1_0 30 | - cudatoolkit=8.0=3 31 | - cudnn=6.0.21=cuda8.0_0 32 | - cycler=0.10.0=py35hc4d5149_0 33 | - cython=0.27.3=py35h6cdc64b_0 34 | - dask=0.16.0=py35hcb8ecc8_0 35 | - dask-core=0.16.0=py35hfc66869_0 36 | - dbus=1.10.22=h3b5a359_0 37 | - decorator=4.1.2=py35h3a268aa_0 38 | - entrypoints=0.2.3=py35h48174a2_2 39 | - expat=2.2.5=he0dffb1_0 40 | - fastcache=1.0.2=py35hec2bbaa_0 41 | - fontconfig=2.12.4=h88586e7_1 42 | - freetype=2.8=hab7d2ae_1 43 | - gcc_impl_linux-64=7.2.0=hc5ce805_2 44 | - gcc_linux-64=7.2.0=25 45 | - gensim=3.1.0=py35h7300b16_0 46 | - glib=2.53.6=h5d9569c_2 47 | - gmp=6.1.2=h6c8ec71_1 48 | - gmpy2=2.0.8=py35hd0a1c9a_2 49 | - gst-plugins-base=1.12.2=he3457e5_0 50 | - gstreamer=1.12.2=h4f93127_0 51 | - gxx_impl_linux-64=7.2.0=hd3faf3d_2 52 | - gxx_linux-64=7.2.0=25 53 | - h5py=2.7.1=py35h8d53cdc_0 54 | - hdf5=1.10.1=h9caa474_1 55 | - heapdict=1.0.0=py35h51e6c10_0 56 | - icu=58.2=h9c2bf20_1 57 | - idna=2.6=py35h8605a33_1 58 | - imageio=2.2.0=py35hd0a6de2_0 59 | - intel-openmp=2018.0.0=hc7b2577_8 60 | - ipykernel=4.7.0=py35h2f9c1c0_0 61 | - ipython=6.2.1=py35hd850d2a_1 62 | - ipython_genutils=0.2.0=py35hc9e07d0_0 63 | - ipywidgets=7.0.5=py35h8147dc1_0 64 | - jedi=0.11.0=py35_2 65 | - jinja2=2.10=py35h480ab6d_0 66 | - jpeg=9b=h024ee3a_2 67 | - jsonschema=2.6.0=py35h4395190_0 68 | - jupyter=1.0.0=py35hd38625c_0 69 | - jupyter_client=5.1.0=py35h2bff583_0 70 | - jupyter_console=5.2.0=py35h4044a63_1 71 | - jupyter_core=4.4.0=py35ha89e94b_0 72 | - keras=2.1.2=py35_0 73 | - libedit=3.1=heed3624_0 74 | - libffi=3.2.1=hd88cf55_4 75 | - libgcc=7.2.0=h69d50b8_2 76 | - libgcc-ng=7.2.0=h7cc24e2_2 77 | - libgfortran-ng=7.2.0=h9f7466a_2 78 | - libgpuarray=0.7.5=h14c3975_0 79 | - libpng=1.6.32=hbd3595f_4 80 | - libprotobuf=3.4.1=h5b8497f_0 81 | - libsodium=1.0.15=hf101ebd_0 82 | - libstdcxx-ng=7.2.0=h7a57d05_2 83 | - libtiff=4.0.9=h28f6b97_0 84 | - libxcb=1.12=hcd93eb1_4 85 | - libxml2=2.9.4=h2e8b1d7_6 86 | - locket=0.2.0=py35h170bc82_1 87 | - lzo=2.10=h49e0be7_2 88 | - mako=1.0.7=py35h69899ea_0 89 | - markupsafe=1.0=py35h4f4fcf6_1 90 | - matplotlib=2.1.1=py35ha26af80_0 91 | - mistune=0.8.1=py35h9251d8c_0 92 | - mkl=2018.0.1=h19d6760_4 93 | - mkl-service=1.1.2=py35h0fc7090_4 94 | - mpc=1.0.3=hec55b23_5 95 | - mpfr=3.1.5=h11a74b3_2 96 | - mpmath=1.0.0=py35h7ce6e34_2 97 | - msgpack-python=0.4.8=py35h783f4c8_0 98 | - nbconvert=5.3.1=py35hc5194e3_0 99 | - nbformat=4.4.0=py35h12e6e07_0 100 | - ncurses=6.0=h9df7e31_2 101 | - networkx=2.0=py35hc690e10_0 102 | - nltk=3.2.5=py35h09ad193_0 103 | - notebook=5.2.2=py35he644770_0 104 | - numexpr=2.6.4=py35h119f745_0 105 | - numpy=1.13.3=py35hd829ed6_0 106 | - olefile=0.44=py35h2c86149_0 107 | - openssl=1.0.2n=hb7f436b_0 108 | - pandas=0.22.0=py35hf484d3e_0 109 | - pandoc=1.19.2.1=hea2e7c5_1 110 | - pandocfilters=1.4.2=py35h1565a15_1 111 | - parso=0.1.1=py35h1b200a3_0 112 | - partd=0.3.8=py35h68187f2_0 113 | - pcre=8.41=hc27e229_1 114 | - pexpect=4.3.0=py35hf410859_0 115 | - pickleshare=0.7.4=py35hd57304d_0 116 | - pillow=5.0.0=py35h3deb7b8_0 117 | - pip=9.0.1=py35h7e7da9d_4 118 | - prompt_toolkit=1.0.15=py35hc09de7a_0 119 | - protobuf=3.4.1=py35he6b9134_0 120 | - psutil=5.4.1=py35h2e39a06_0 121 | - ptyprocess=0.5.2=py35h38ce0a3_0 122 | - pycparser=2.18=py35h61b3040_1 123 | - pygments=2.2.0=py35h0f41973_0 124 | - pygpu=0.7.5=py35h14c3975_0 125 | - pyopenssl=17.5.0=py35h4f8b8c8_0 126 | - pyparsing=2.2.0=py35h041ed72_1 127 | - pyqt=5.6.0=py35h0e41ada_5 128 | - pysocks=1.6.7=py35h6aefbb0_1 129 | - pytables=3.4.2=py35hfa98db7_2 130 | - python=3.5.4=h417fded_24 131 | - python-dateutil=2.6.1=py35h90d5b31_1 132 | - pytz=2017.3=py35hb13c558_0 133 | - pywavelets=0.5.2=py35h53ec731_0 134 | - pyyaml=3.12=py35h46ef4ae_1 135 | - pyzmq=16.0.3=py35ha889422_0 136 | - qt=5.6.2=h974d657_12 137 | - qtconsole=4.3.1=py35h4626a06_0 138 | - readline=7.0=ha6073c6_4 139 | - requests=2.18.4=py35hb9e6ad1_1 140 | - scikit-image=0.13.1=py35h14c3975_1 141 | - scikit-learn=0.19.1=py35hbf1f462_0 142 | - scipy=1.0.0=py35hcbbe4a2_0 143 | - setuptools=36.5.0=py35ha8c1747_0 144 | - simplegeneric=0.8.1=py35h2ec4104_0 145 | - sip=4.18.1=py35h9eaea60_2 146 | - six=1.11.0=py35h423b573_1 147 | - smart_open=1.5.3=py35_0 148 | - sortedcontainers=1.5.7=py35h683703c_0 149 | - sqlite=3.20.1=hb898158_2 150 | - sympy=1.1.1=py35h919b29a_0 151 | - tblib=1.3.2=py35hf1eb0b4_0 152 | - tensorflow=1.3.0=0 153 | - tensorflow-base=1.3.0=py35h79a3156_1 154 | - tensorflow-gpu=1.3.0=0 155 | - tensorflow-gpu-base=1.3.0=py35cuda8.0cudnn6.0_1 156 | - tensorflow-tensorboard=0.1.5=py35_0 157 | - terminado=0.6=py35hce234ed_0 158 | - testpath=0.3.1=py35had42eaf_0 159 | - theano=1.0.1=py35h6bb024c_0 160 | - tk=8.6.7=hc745277_3 161 | - toolz=0.8.2=py35h90f1797_0 162 | - tornado=4.5.2=py35hf879e1d_0 163 | - tqdm=4.19.4=py35h68e51d2_0 164 | - traitlets=4.3.2=py35ha522a97_0 165 | - ujson=1.35=py35_0 166 | - urllib3=1.22=py35h2ab6e29_0 167 | - wcwidth=0.1.7=py35hcd08066_0 168 | - webencodings=0.5.1=py35hb6cf162_1 169 | - werkzeug=0.12.2=py35hbfc1ea6_0 170 | - wheel=0.30.0=py35hd3883cf_1 171 | - widgetsnbextension=3.0.8=py35h84cb72a_0 172 | - xz=5.2.3=h55aa19d_2 173 | - yaml=0.1.7=had09818_2 174 | - zeromq=4.2.2=hbedb6e5_2 175 | - zict=0.1.3=py35h29275ca_0 176 | - zlib=1.2.11=ha838bed_2 177 | - pytorch=0.3.0=py35_cuda8.0.61_cudnn7.0.3hb362f6e_4 178 | - torchvision=0.2.0=py35heaa392f_1 179 | - pip: 180 | - keras-tqdm==2.0.1 181 | - tables==3.4.2 182 | - torch==0.3.0.post4 183 | - xgboost==0.7.post3 184 | prefix: /home/roebius/anaconda/envs/p3 185 | 186 | -------------------------------------------------------------------------------- /keras.json.for_TensorFlow: -------------------------------------------------------------------------------- 1 | { 2 | "epsilon": 1e-07, 3 | "backend": "tensorflow", 4 | "floatx": "float32", 5 | "image_data_format": "channels_last" 6 | } 7 | -------------------------------------------------------------------------------- /keras.json.for_Theano: -------------------------------------------------------------------------------- 1 | { 2 | "image_data_format": "channels_first", 3 | "epsilon": 1e-07, 4 | "floatx": "float32", 5 | "backend": "theano" 6 | } 7 | -------------------------------------------------------------------------------- /nbs/char-rnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2018-01-09T17:37:36.320493Z", 9 | "start_time": "2018-01-09T17:37:32.715223Z" 10 | } 11 | }, 12 | "outputs": [ 13 | { 14 | "name": "stderr", 15 | "output_type": "stream", 16 | "text": [ 17 | "Using cuDNN version 6021 on context None\n", 18 | "Mapped name None to device cuda0: GeForce GTX TITAN X (0000:04:00.0)\n", 19 | "Using Theano backend.\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "from __future__ import division, print_function\n", 25 | "%matplotlib inline\n", 26 | "from importlib import reload # Python 3\n", 27 | "import utils; reload(utils)\n", 28 | "from utils import *" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": { 35 | "ExecuteTime": { 36 | "end_time": "2018-01-09T17:37:38.078225Z", 37 | "start_time": "2018-01-09T17:37:38.073874Z" 38 | } 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "from keras.layers import TimeDistributed, Activation\n", 43 | "from numpy.random import choice" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "## Setup" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "We haven't really looked into the detail of how this works yet - so this is provided for self-study for those who are interested. We'll look at it closely next week." 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 3, 63 | "metadata": { 64 | "ExecuteTime": { 65 | "end_time": "2018-01-09T17:37:39.530495Z", 66 | "start_time": "2018-01-09T17:37:39.513647Z" 67 | } 68 | }, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "corpus length: 600893\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "path = get_file('nietzsche.txt', origin=\"https://s3.amazonaws.com/text-datasets/nietzsche.txt\")\n", 80 | "text = open(path).read().lower()\n", 81 | "print('corpus length:', len(text))" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 4, 87 | "metadata": { 88 | "ExecuteTime": { 89 | "end_time": "2018-01-09T17:37:40.553853Z", 90 | "start_time": "2018-01-09T17:37:40.408768Z" 91 | } 92 | }, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "are thinkers who believe in the saints.\r\n", 99 | "\r\n", 100 | "\r\n", 101 | "144\r\n", 102 | "\r\n", 103 | "It stands to reason that this sketch of the saint, made upon the model\r\n", 104 | "of the whole species, can be confronted with many opposing sketches that\r\n", 105 | "would create a more agreeable impression. There are certain exceptions\r\n", 106 | "among the species who distinguish themselves either by especial\r\n", 107 | "gentleness or especial humanity, and perhaps by the strength of their\r\n", 108 | "own personality. Others are in the highest degree fascinating because\r\n", 109 | "certain of their delusions shed a particular glow over their whole\r\n", 110 | "being, as is the case with the founder of christianity who took himself\r\n", 111 | "for the only begotten son of God and hence felt himself sinless; so that\r\n", 112 | "through his imagination--that should not be too harshly judged since the\r\n", 113 | "whole of antiquity swarmed with sons of god--he attained the same goal,\r\n", 114 | "the sense of complete sinlessness, complete irresponsibility, that can\r\n", 115 | "now be attained by every individual through science.--In the same manner\r\n", 116 | "I have viewed the saints of India who occupy an intermediate station\r\n", 117 | "between the christian saints and the Greek philosophers and hence are\r\n", 118 | "not to be regarded as a pure type. Knowledge and science--as far as they\r\n", 119 | "existed--and superiority to the rest of mankind by logical discipline\r\n", 120 | "and training of the intellectual powers were insisted upon by the\r\n", 121 | "Buddhists as essential to sanctity, just as they were denounced by the\r\n", 122 | "christian world as the indications of sinfulness." 123 | ] 124 | } 125 | ], 126 | "source": [ 127 | "!tail -n 25 {path}" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 5, 133 | "metadata": { 134 | "ExecuteTime": { 135 | "end_time": "2018-01-09T17:37:42.261626Z", 136 | "start_time": "2018-01-09T17:37:42.232982Z" 137 | } 138 | }, 139 | "outputs": [ 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "total chars: 58\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "chars = sorted(list(set(text)))\n", 150 | "vocab_size = len(chars)+1\n", 151 | "print('total chars:', vocab_size)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 6, 157 | "metadata": { 158 | "ExecuteTime": { 159 | "end_time": "2018-01-09T17:37:42.673825Z", 160 | "start_time": "2018-01-09T17:37:42.670388Z" 161 | } 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "chars.insert(0, \"\\0\")" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 7, 171 | "metadata": { 172 | "ExecuteTime": { 173 | "end_time": "2018-01-09T17:37:43.405865Z", 174 | "start_time": "2018-01-09T17:37:43.393184Z" 175 | } 176 | }, 177 | "outputs": [ 178 | { 179 | "data": { 180 | "text/plain": [ 181 | "'\\n !\"\\'(),-.0123456789:;=?[]_abcdefghijklmnopqrstuvwx'" 182 | ] 183 | }, 184 | "execution_count": 7, 185 | "metadata": {}, 186 | "output_type": "execute_result" 187 | } 188 | ], 189 | "source": [ 190 | "''.join(chars[1:-6])" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 8, 196 | "metadata": { 197 | "ExecuteTime": { 198 | "end_time": "2018-01-09T17:37:43.653291Z", 199 | "start_time": "2018-01-09T17:37:43.648297Z" 200 | } 201 | }, 202 | "outputs": [], 203 | "source": [ 204 | "char_indices = dict((c, i) for i, c in enumerate(chars))\n", 205 | "indices_char = dict((i, c) for i, c in enumerate(chars))" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 9, 211 | "metadata": { 212 | "ExecuteTime": { 213 | "end_time": "2018-01-09T17:37:43.970560Z", 214 | "start_time": "2018-01-09T17:37:43.875090Z" 215 | } 216 | }, 217 | "outputs": [], 218 | "source": [ 219 | "idx = [char_indices[c] for c in text]" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": 10, 225 | "metadata": { 226 | "ExecuteTime": { 227 | "end_time": "2018-01-09T17:37:44.088043Z", 228 | "start_time": "2018-01-09T17:37:44.081181Z" 229 | } 230 | }, 231 | "outputs": [ 232 | { 233 | "data": { 234 | "text/plain": [ 235 | "[43, 45, 32, 33, 28, 30, 32, 1, 1, 1]" 236 | ] 237 | }, 238 | "execution_count": 10, 239 | "metadata": {}, 240 | "output_type": "execute_result" 241 | } 242 | ], 243 | "source": [ 244 | "idx[:10]" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 11, 250 | "metadata": { 251 | "ExecuteTime": { 252 | "end_time": "2018-01-09T17:37:44.278125Z", 253 | "start_time": "2018-01-09T17:37:44.272800Z" 254 | } 255 | }, 256 | "outputs": [ 257 | { 258 | "data": { 259 | "text/plain": [ 260 | "'preface\\n\\n\\nsupposing that truth is a woman--what then? is there not gro'" 261 | ] 262 | }, 263 | "execution_count": 11, 264 | "metadata": {}, 265 | "output_type": "execute_result" 266 | } 267 | ], 268 | "source": [ 269 | "''.join(indices_char[i] for i in idx[:70])" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "metadata": {}, 275 | "source": [ 276 | "## Preprocess and create model" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 12, 282 | "metadata": { 283 | "ExecuteTime": { 284 | "end_time": "2018-01-09T17:37:50.507182Z", 285 | "start_time": "2018-01-09T17:37:48.167841Z" 286 | } 287 | }, 288 | "outputs": [ 289 | { 290 | "name": "stdout", 291 | "output_type": "stream", 292 | "text": [ 293 | "nb sequences: 600854\n" 294 | ] 295 | } 296 | ], 297 | "source": [ 298 | "maxlen = 40\n", 299 | "sentences = []\n", 300 | "next_chars = []\n", 301 | "for i in range(0, len(idx) - maxlen+1):\n", 302 | " sentences.append(idx[i: i + maxlen])\n", 303 | " next_chars.append(idx[i+1: i+maxlen+1])\n", 304 | "print('nb sequences:', len(sentences))" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 13, 310 | "metadata": { 311 | "ExecuteTime": { 312 | "end_time": "2018-01-09T17:37:57.204305Z", 313 | "start_time": "2018-01-09T17:37:50.508646Z" 314 | } 315 | }, 316 | "outputs": [], 317 | "source": [ 318 | "sentences = np.concatenate([[np.array(o)] for o in sentences[:-2]])\n", 319 | "next_chars = np.concatenate([[np.array(o)] for o in next_chars[:-2]])" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 14, 325 | "metadata": { 326 | "ExecuteTime": { 327 | "end_time": "2018-01-09T17:37:57.208861Z", 328 | "start_time": "2018-01-09T17:37:57.205817Z" 329 | } 330 | }, 331 | "outputs": [ 332 | { 333 | "data": { 334 | "text/plain": [ 335 | "((600852, 40), (600852, 40))" 336 | ] 337 | }, 338 | "execution_count": 14, 339 | "metadata": {}, 340 | "output_type": "execute_result" 341 | } 342 | ], 343 | "source": [ 344 | "sentences.shape, next_chars.shape" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 15, 350 | "metadata": { 351 | "ExecuteTime": { 352 | "end_time": "2018-01-09T17:37:57.249341Z", 353 | "start_time": "2018-01-09T17:37:57.209999Z" 354 | } 355 | }, 356 | "outputs": [], 357 | "source": [ 358 | "n_fac = 24" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 16, 364 | "metadata": { 365 | "ExecuteTime": { 366 | "end_time": "2018-01-09T17:38:10.121159Z", 367 | "start_time": "2018-01-09T17:37:57.250999Z" 368 | } 369 | }, 370 | "outputs": [], 371 | "source": [ 372 | "model=Sequential([\n", 373 | " Embedding(vocab_size, n_fac, input_length=maxlen),\n", 374 | " LSTM(units=512, input_shape=(n_fac,),return_sequences=True, dropout=0.2, recurrent_dropout=0.2,\n", 375 | " implementation=2),\n", 376 | " Dropout(0.2),\n", 377 | " LSTM(512, return_sequences=True, dropout=0.2, recurrent_dropout=0.2,\n", 378 | " implementation=2),\n", 379 | " Dropout(0.2),\n", 380 | " TimeDistributed(Dense(vocab_size)),\n", 381 | " Activation('softmax')\n", 382 | " ]) " 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 17, 388 | "metadata": { 389 | "ExecuteTime": { 390 | "end_time": "2018-01-09T17:38:10.153817Z", 391 | "start_time": "2018-01-09T17:38:10.123477Z" 392 | } 393 | }, 394 | "outputs": [], 395 | "source": [ 396 | "model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam())" 397 | ] 398 | }, 399 | { 400 | "cell_type": "markdown", 401 | "metadata": {}, 402 | "source": [ 403 | "## Train" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 18, 409 | "metadata": { 410 | "ExecuteTime": { 411 | "end_time": "2018-01-09T17:38:12.858009Z", 412 | "start_time": "2018-01-09T17:38:12.840547Z" 413 | } 414 | }, 415 | "outputs": [], 416 | "source": [ 417 | "def print_example():\n", 418 | " seed_string=\"ethics is a basic foundation of all that\"\n", 419 | " for i in range(320):\n", 420 | " x=np.array([char_indices[c] for c in seed_string[-40:]])[np.newaxis,:] # [-40] picks up the last 40 chars\n", 421 | " preds = model.predict(x, verbose=0)[0][-1] # [-1] picks up the last char\n", 422 | " preds = preds/np.sum(preds)\n", 423 | " next_char = choice(chars, p=preds)\n", 424 | " seed_string = seed_string + next_char\n", 425 | " print(seed_string)" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": 19, 431 | "metadata": { 432 | "ExecuteTime": { 433 | "end_time": "2018-01-09T17:53:01.861777Z", 434 | "start_time": "2018-01-09T17:38:13.104719Z" 435 | } 436 | }, 437 | "outputs": [ 438 | { 439 | "name": "stdout", 440 | "output_type": "stream", 441 | "text": [ 442 | "Epoch 1/1\n", 443 | "600852/600852 [==============================] - 795s 1ms/step - loss: 1.4965\n" 444 | ] 445 | }, 446 | { 447 | "data": { 448 | "text/plain": [ 449 | "" 450 | ] 451 | }, 452 | "execution_count": 19, 453 | "metadata": {}, 454 | "output_type": "execute_result" 455 | } 456 | ], 457 | "source": [ 458 | "model.fit(sentences, np.expand_dims(next_chars,-1), batch_size=64, epochs=1)" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 20, 464 | "metadata": { 465 | "ExecuteTime": { 466 | "end_time": "2018-01-09T17:53:16.668682Z", 467 | "start_time": "2018-01-09T17:53:01.863269Z" 468 | }, 469 | "scrolled": true 470 | }, 471 | "outputs": [ 472 | { 473 | "name": "stdout", 474 | "output_type": "stream", 475 | "text": [ 476 | "ethics is a basic foundation of all that which principle. there is i have said gon to fight on the responsibility\n", 477 | "of intercourse is\n", 478 | "is not subsequently possible that one\n", 479 | "can not promise solitude, neither with all this over the half. the whole mewaphysical philosophers have were this requirement to his even failure as his power; even in love comes to be it, d\n" 480 | ] 481 | } 482 | ], 483 | "source": [ 484 | "print_example()" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 21, 490 | "metadata": { 491 | "ExecuteTime": { 492 | "end_time": "2018-01-09T18:06:27.312422Z", 493 | "start_time": "2018-01-09T17:53:16.670290Z" 494 | } 495 | }, 496 | "outputs": [ 497 | { 498 | "name": "stdout", 499 | "output_type": "stream", 500 | "text": [ 501 | "Epoch 1/1\n", 502 | "600852/600852 [==============================] - 791s 1ms/step - loss: 1.2726\n" 503 | ] 504 | }, 505 | { 506 | "data": { 507 | "text/plain": [ 508 | "" 509 | ] 510 | }, 511 | "execution_count": 21, 512 | "metadata": {}, 513 | "output_type": "execute_result" 514 | } 515 | ], 516 | "source": [ 517 | "model.fit(sentences, np.expand_dims(next_chars,-1), batch_size=64, epochs=1)" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": 22, 523 | "metadata": { 524 | "ExecuteTime": { 525 | "end_time": "2018-01-09T18:06:34.396111Z", 526 | "start_time": "2018-01-09T18:06:27.314283Z" 527 | }, 528 | "scrolled": true 529 | }, 530 | "outputs": [ 531 | { 532 | "name": "stdout", 533 | "output_type": "stream", 534 | "text": [ 535 | "ethics is a basic foundation of all that he realized how can the same\n", 536 | "degree, and\n", 537 | "bitter! everywhere may not\n", 538 | "be pessimistic time. it sympathy and of our dull things, one may demand and would not have reaction also a kind of the advance of the brute\", this deenest race: it is necessary to understand\n", 539 | "to contradict it; but the just as a\n", 540 | "being which does not bel\n" 541 | ] 542 | } 543 | ], 544 | "source": [ 545 | "print_example()" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 23, 551 | "metadata": { 552 | "ExecuteTime": { 553 | "end_time": "2018-01-09T18:06:34.400512Z", 554 | "start_time": "2018-01-09T18:06:34.398029Z" 555 | } 556 | }, 557 | "outputs": [], 558 | "source": [ 559 | "model.optimizer.lr=0.001" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": 24, 565 | "metadata": { 566 | "ExecuteTime": { 567 | "end_time": "2018-01-09T18:19:44.658226Z", 568 | "start_time": "2018-01-09T18:06:34.402291Z" 569 | } 570 | }, 571 | "outputs": [ 572 | { 573 | "name": "stdout", 574 | "output_type": "stream", 575 | "text": [ 576 | "Epoch 1/1\n", 577 | "600852/600852 [==============================] - 790s 1ms/step - loss: 1.2383\n" 578 | ] 579 | }, 580 | { 581 | "data": { 582 | "text/plain": [ 583 | "" 584 | ] 585 | }, 586 | "execution_count": 24, 587 | "metadata": {}, 588 | "output_type": "execute_result" 589 | } 590 | ], 591 | "source": [ 592 | "model.fit(sentences, np.expand_dims(next_chars,-1), batch_size=64, epochs=1)" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": 25, 598 | "metadata": { 599 | "ExecuteTime": { 600 | "end_time": "2018-01-09T18:19:51.757814Z", 601 | "start_time": "2018-01-09T18:19:44.659529Z" 602 | } 603 | }, 604 | "outputs": [ 605 | { 606 | "name": "stdout", 607 | "output_type": "stream", 608 | "text": [ 609 | "ethics is a basic foundation of all that originates him\n", 610 | "\n", 611 | "instance, it true impulses and belief\n", 612 | "in christianity, results, easily allowed to\n", 613 | "regard our principle.--one dests inspire concerning the logical is termination; and that the\n", 614 | "contrary to puritante and attain.\n", 615 | "\n", 616 | "162. from deveropment and little itself we have deceived ourselves to\n", 617 | "action, and without dec\n" 618 | ] 619 | } 620 | ], 621 | "source": [ 622 | "print_example()" 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 26, 628 | "metadata": { 629 | "ExecuteTime": { 630 | "end_time": "2018-01-09T18:19:51.761120Z", 631 | "start_time": "2018-01-09T18:19:51.759261Z" 632 | } 633 | }, 634 | "outputs": [], 635 | "source": [ 636 | "model.optimizer.lr=0.0001" 637 | ] 638 | }, 639 | { 640 | "cell_type": "code", 641 | "execution_count": 27, 642 | "metadata": { 643 | "ExecuteTime": { 644 | "end_time": "2018-01-09T18:33:08.856328Z", 645 | "start_time": "2018-01-09T18:19:51.762328Z" 646 | }, 647 | "scrolled": true 648 | }, 649 | "outputs": [ 650 | { 651 | "name": "stdout", 652 | "output_type": "stream", 653 | "text": [ 654 | "Epoch 1/1\n", 655 | "600852/600852 [==============================] - 797s 1ms/step - loss: 1.2193\n" 656 | ] 657 | }, 658 | { 659 | "data": { 660 | "text/plain": [ 661 | "" 662 | ] 663 | }, 664 | "execution_count": 27, 665 | "metadata": {}, 666 | "output_type": "execute_result" 667 | } 668 | ], 669 | "source": [ 670 | "model.fit(sentences, np.expand_dims(next_chars,-1), batch_size=64, epochs=1)" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": 28, 676 | "metadata": { 677 | "ExecuteTime": { 678 | "end_time": "2018-01-09T18:33:15.941120Z", 679 | "start_time": "2018-01-09T18:33:08.857628Z" 680 | } 681 | }, 682 | "outputs": [ 683 | { 684 | "name": "stdout", 685 | "output_type": "stream", 686 | "text": [ 687 | "ethics is a basic foundation of all that \"ego,\" is craceful easy, and through the trainly left itself until feelings, makes this very pleasure to shiftand\n", 688 | "an emotion of their gutting, mopling and skepcicism--he would like to brighten men and them as france of\n", 689 | "humanity.\n", 690 | "\n", 691 | "\n", 692 | "54\n", 693 | "\n", 694 | "=justice, or even of the foundation of causality which always \"does not know about w\n" 695 | ] 696 | } 697 | ], 698 | "source": [ 699 | "print_example()" 700 | ] 701 | }, 702 | { 703 | "cell_type": "code", 704 | "execution_count": 29, 705 | "metadata": { 706 | "ExecuteTime": { 707 | "end_time": "2018-01-09T18:33:15.995753Z", 708 | "start_time": "2018-01-09T18:33:15.942509Z" 709 | } 710 | }, 711 | "outputs": [], 712 | "source": [ 713 | "model.save_weights('data/char_rnn.h5')" 714 | ] 715 | }, 716 | { 717 | "cell_type": "code", 718 | "execution_count": 30, 719 | "metadata": { 720 | "ExecuteTime": { 721 | "end_time": "2018-01-09T18:33:16.029984Z", 722 | "start_time": "2018-01-09T18:33:15.998784Z" 723 | } 724 | }, 725 | "outputs": [], 726 | "source": [ 727 | "model.optimizer.lr=0.00001" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": 31, 733 | "metadata": { 734 | "ExecuteTime": { 735 | "end_time": "2018-01-09T18:46:26.796768Z", 736 | "start_time": "2018-01-09T18:33:16.033101Z" 737 | } 738 | }, 739 | "outputs": [ 740 | { 741 | "name": "stdout", 742 | "output_type": "stream", 743 | "text": [ 744 | "Epoch 1/1\n", 745 | "600852/600852 [==============================] - 791s 1ms/step - loss: 1.2049\n" 746 | ] 747 | }, 748 | { 749 | "data": { 750 | "text/plain": [ 751 | "" 752 | ] 753 | }, 754 | "execution_count": 31, 755 | "metadata": {}, 756 | "output_type": "execute_result" 757 | } 758 | ], 759 | "source": [ 760 | "model.fit(sentences, np.expand_dims(next_chars,-1), batch_size=64, epochs=1)" 761 | ] 762 | }, 763 | { 764 | "cell_type": "code", 765 | "execution_count": 32, 766 | "metadata": { 767 | "ExecuteTime": { 768 | "end_time": "2018-01-09T18:46:33.857340Z", 769 | "start_time": "2018-01-09T18:46:26.798046Z" 770 | } 771 | }, 772 | "outputs": [ 773 | { 774 | "name": "stdout", 775 | "output_type": "stream", 776 | "text": [ 777 | "ethics is a basic foundation of all that sympathy thinks which they may be the most customary or else,\n", 778 | "owing to\n", 779 | "horror, to a new\n", 780 | "riddle-like experience (and is there why learnt\n", 781 | "to bring at the\n", 782 | "immense, and have long still profound, dissatisfied in that neighbour and incapacity for\n", 783 | "me?\" in\n", 784 | "spite of the proper people, an intercourse, still upself--subtlety and\n" 785 | ] 786 | } 787 | ], 788 | "source": [ 789 | "print_example()" 790 | ] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "execution_count": 33, 795 | "metadata": { 796 | "ExecuteTime": { 797 | "end_time": "2018-01-09T18:59:44.358611Z", 798 | "start_time": "2018-01-09T18:46:33.858823Z" 799 | } 800 | }, 801 | "outputs": [ 802 | { 803 | "name": "stdout", 804 | "output_type": "stream", 805 | "text": [ 806 | "Epoch 1/1\n", 807 | "600852/600852 [==============================] - 790s 1ms/step - loss: 1.1925\n" 808 | ] 809 | }, 810 | { 811 | "data": { 812 | "text/plain": [ 813 | "" 814 | ] 815 | }, 816 | "execution_count": 33, 817 | "metadata": {}, 818 | "output_type": "execute_result" 819 | } 820 | ], 821 | "source": [ 822 | "model.fit(sentences, np.expand_dims(next_chars,-1), batch_size=64, epochs=1)" 823 | ] 824 | }, 825 | { 826 | "cell_type": "code", 827 | "execution_count": 34, 828 | "metadata": { 829 | "ExecuteTime": { 830 | "end_time": "2018-01-09T18:59:51.517817Z", 831 | "start_time": "2018-01-09T18:59:44.360741Z" 832 | } 833 | }, 834 | "outputs": [ 835 | { 836 | "name": "stdout", 837 | "output_type": "stream", 838 | "text": [ 839 | "ethics is a basic foundation of all that is called \"higher,\" inspire the permanent thing, at once and strive, remains the most, but that the new\n", 840 | "construction which do not believe in germany. in music only a shamed through the discipline of mind whone same goethe) perhaps they have its personality\n", 841 | "itself, they are responsible, and it seems to\n", 842 | "feel them is\n", 843 | "don\n" 844 | ] 845 | } 846 | ], 847 | "source": [ 848 | "print_example()" 849 | ] 850 | }, 851 | { 852 | "cell_type": "code", 853 | "execution_count": 35, 854 | "metadata": { 855 | "ExecuteTime": { 856 | "end_time": "2018-01-09T18:59:58.596503Z", 857 | "start_time": "2018-01-09T18:59:51.519361Z" 858 | } 859 | }, 860 | "outputs": [ 861 | { 862 | "name": "stdout", 863 | "output_type": "stream", 864 | "text": [ 865 | "ethics is a basic foundation of all that is always vained by the reward. if one should grew back again acknowledge with their semi-barbarity,--they are avlided to life.--we have\n", 866 | "finds a contradictory?--so the ascetic judgs a defect in every\n", 867 | "deception change away something that is remained by\n", 868 | "means of community, in fact, it is precisely through napoleon's sen\n" 869 | ] 870 | } 871 | ], 872 | "source": [ 873 | "print_example()" 874 | ] 875 | }, 876 | { 877 | "cell_type": "code", 878 | "execution_count": 36, 879 | "metadata": { 880 | "ExecuteTime": { 881 | "end_time": "2018-01-09T18:59:58.615619Z", 882 | "start_time": "2018-01-09T18:59:58.597957Z" 883 | } 884 | }, 885 | "outputs": [], 886 | "source": [ 887 | "model.save_weights('data/char_rnn.h5')" 888 | ] 889 | }, 890 | { 891 | "cell_type": "code", 892 | "execution_count": null, 893 | "metadata": {}, 894 | "outputs": [], 895 | "source": [] 896 | } 897 | ], 898 | "metadata": { 899 | "kernelspec": { 900 | "display_name": "Python 3", 901 | "language": "python", 902 | "name": "python3" 903 | }, 904 | "language_info": { 905 | "codemirror_mode": { 906 | "name": "ipython", 907 | "version": 3 908 | }, 909 | "file_extension": ".py", 910 | "mimetype": "text/x-python", 911 | "name": "python", 912 | "nbconvert_exporter": "python", 913 | "pygments_lexer": "ipython3", 914 | "version": "3.5.4" 915 | }, 916 | "nav_menu": {}, 917 | "toc": { 918 | "navigate_menu": true, 919 | "number_sections": true, 920 | "sideBar": true, 921 | "threshold": 6, 922 | "toc_cell": true, 923 | "toc_section_display": "block", 924 | "toc_window_display": false 925 | }, 926 | "widgets": { 927 | "state": {}, 928 | "version": "1.1.2" 929 | } 930 | }, 931 | "nbformat": 4, 932 | "nbformat_minor": 1 933 | } 934 | -------------------------------------------------------------------------------- /nbs/imagenet_batchnorm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This notebook explains how to add batch normalization to VGG. The code shown here is implemented in [vgg_bn.py](https://github.com/fastai/courses/blob/master/deeplearning1/nbs/vgg16bn.py), and there is a version of ``vgg_ft`` (our fine tuning function) with batch norm called ``vgg_ft_bn`` in [utils.py](https://github.com/fastai/courses/blob/master/deeplearning1/nbs/utils.py)." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "collapsed": true 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "from __future__ import division, print_function\n", 19 | "%matplotlib inline\n", 20 | "from importlib import reload\n", 21 | "import utils; reload(utils)\n", 22 | "from utils import *" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "# The problem, and the solution" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "## The problem" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "The problem that we faced in the lesson 3 is that when we wanted to add batch normalization, we initialized *all* the dense layers of the model to random weights, and then tried to train them with our cats v dogs dataset. But that's a lot of weights to initialize to random - out of 134m params, around 119m are in the dense layers! Take a moment to think about why this is, and convince yourself that dense layers are where most of the weights will be. Also, think about whether this implies that most of the *time* will be spent training these weights. What do you think?\n", 44 | "\n", 45 | "Trying to train 120m params using just 23k images is clearly an unreasonable expectation. The reason we haven't had this problem before is that the dense layers were not random, but were trained to recognize imagenet categories (other than the very last layer, which only has 8194 params)." 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "## The solution" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "The solution, obviously enough, is to add batch normalization to the VGG model! To do so, we have to be careful - we can't just insert batchnorm layers, since their parameters (*gamma* - which is used to multiply by each activation, and *beta* - which is used to add to each activation) will not be set correctly. Without setting these correctly, the new batchnorm layers will normalize the previous layer's activations, meaning that the next layer will receive totally different activations to what it would have without new batchnorm layer. And that means that all the pre-trained weights are no longer of any use!\n", 60 | "\n", 61 | "So instead, we need to figure out what beta and gamma to choose when we insert the layers. The answer to this turns out to be pretty simple - we need to calculate what the mean and standard deviation of that activations for that layer are when calculated on all of imagenet, and then set beta and gamma to these values. That means that the new batchnorm layer will normalize the data with the mean and standard deviation, and then immediately un-normalize the data using the beta and gamma parameters we provide. So the output of the batchnorm layer will be identical to it's input - which means that all the pre-trained weights will continue to work just as well as before.\n", 62 | "\n", 63 | "The benefit of this is that when we wish to fine-tune our own networks, we will have all the benefits of batch normalization (higher learning rates, more resiliant training, and less need for dropout) plus all the benefits of a pre-trained network." 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "To calculate the mean and standard deviation of the activations on imagenet, we need to download imagenet. You can download imagenet from http://www.image-net.org/download-images . The file you want is the one titled **Download links to ILSVRC2013 image data**. You'll need to request access from the imagenet admins for this, although it seems to be an automated system - I've always found that access is provided instantly. Once you're logged in and have gone to that page, look for the **CLS-LOC dataset** section. Both training and validation images are available, and you should download both. There's not much reason to download the test images, however.\n", 71 | "\n", 72 | "Note that this will not be the entire imagenet archive, but just the 1000 categories that are used in the annual competition. Since that's what VGG16 was originally trained on, that seems like a good choice - especially since the full dataset is 1.1 terabytes, whereas the 1000 category dataset is 138 gigabytes." 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "# Adding batchnorm to Imagenet" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "## Setup" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "### Sample" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "As per usual, we create a sample so we can experiment more rapidly." 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": { 107 | "collapsed": true 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "# %pushd data/imagenet\n", 112 | "%pushd data/imagenet\n", 113 | "%cd train" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": { 120 | "collapsed": true 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "%mkdir ../sample\n", 125 | "%mkdir ../sample/train\n", 126 | "%mkdir ../sample/valid\n", 127 | "\n", 128 | "from shutil import copyfile\n", 129 | "\n", 130 | "g = glob('*')\n", 131 | "for d in g: \n", 132 | " os.mkdir('../sample/train/'+d)\n", 133 | " os.mkdir('../sample/valid/'+d)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": { 140 | "collapsed": true 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "g = glob('*/*.JPEG')\n", 145 | "shuf = np.random.permutation(g)\n", 146 | "for i in range(25000): copyfile(shuf[i], '../sample/train/' + shuf[i])" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": { 153 | "collapsed": true, 154 | "scrolled": true 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "%cd ../valid\n", 159 | "\n", 160 | "g = glob('*/*.JPEG')\n", 161 | "shuf = np.random.permutation(g)\n", 162 | "for i in range(5000): copyfile(shuf[i], '../sample/valid/' + shuf[i])\n", 163 | "\n", 164 | "%cd .." 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "collapsed": true 172 | }, 173 | "outputs": [], 174 | "source": [ 175 | "%mkdir sample/results" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": { 182 | "collapsed": true 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "%popd" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "### Data setup" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "We set up our paths, data, and labels in the usual way. Note that we don't try to read all of Imagenet into memory! We only load the sample into memory." 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": { 207 | "collapsed": true 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "sample_path = \"data/imagenet/sample/\"\n", 212 | "path = \"data/imagenet/\"\n", 213 | "\n", 214 | "#sample_path = 'data/jhoward/imagenet/sample/'\n", 215 | "# This is the path to my fast SSD - I put datasets there when I can to get the speed benefit\n", 216 | "#fast_path = '/home/jhoward/ILSVRC2012_img_proc/'\n", 217 | "#path = '/data/jhoward/imagenet/sample/'\n", 218 | "#path = 'data/jhoward/imagenet/'" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": { 225 | "collapsed": true 226 | }, 227 | "outputs": [], 228 | "source": [ 229 | "batch_size=64" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": { 236 | "collapsed": true 237 | }, 238 | "outputs": [], 239 | "source": [ 240 | "samp_trn = get_data(sample_path+'train')\n", 241 | "samp_val = get_data(sample_path+'valid')" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": { 248 | "collapsed": true 249 | }, 250 | "outputs": [], 251 | "source": [ 252 | "save_array(sample_path+'results/trn.dat', samp_trn)\n", 253 | "save_array(sample_path+'results/val.dat', samp_val)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": { 260 | "collapsed": true 261 | }, 262 | "outputs": [], 263 | "source": [ 264 | "samp_trn = load_array(sample_path+'results/trn.dat')\n", 265 | "samp_val = load_array(sample_path+'results/val.dat')" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": { 272 | "collapsed": true, 273 | "scrolled": true 274 | }, 275 | "outputs": [], 276 | "source": [ 277 | "(val_classes, trn_classes, val_labels, trn_labels, \n", 278 | " val_filenames, filenames, test_filenames) = get_classes(path)" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": { 285 | "collapsed": true 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "(samp_val_classes, samp_trn_classes, samp_val_labels, samp_trn_labels, \n", 290 | " samp_val_filenames, samp_filenames, samp_test_filenames) = get_classes(sample_path)" 291 | ] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "metadata": {}, 296 | "source": [ 297 | "### Model setup" 298 | ] 299 | }, 300 | { 301 | "cell_type": "markdown", 302 | "metadata": {}, 303 | "source": [ 304 | "Since we're just working with the dense layers, we should pre-compute the output of the convolutional layers." 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": null, 310 | "metadata": { 311 | "collapsed": true, 312 | "scrolled": true 313 | }, 314 | "outputs": [], 315 | "source": [ 316 | "vgg = Vgg16()\n", 317 | "model = vgg.model" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "metadata": { 324 | "collapsed": true 325 | }, 326 | "outputs": [], 327 | "source": [ 328 | "layers = model.layers\n", 329 | "last_conv_idx = [index for index,layer in enumerate(layers) \n", 330 | " if type(layer) is Conv2D][-1]\n", 331 | "conv_layers = layers[:last_conv_idx+1]" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": null, 337 | "metadata": { 338 | "collapsed": true 339 | }, 340 | "outputs": [], 341 | "source": [ 342 | "dense_layers = layers[last_conv_idx+1:]" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "metadata": { 349 | "collapsed": true 350 | }, 351 | "outputs": [], 352 | "source": [ 353 | "conv_model = Sequential(conv_layers)" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "metadata": { 360 | "collapsed": true 361 | }, 362 | "outputs": [], 363 | "source": [ 364 | "samp_conv_val_feat = conv_model.predict(samp_val, batch_size=batch_size*2)\n", 365 | "samp_conv_feat = conv_model.predict(samp_trn, batch_size=batch_size*2)" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "metadata": { 372 | "collapsed": true 373 | }, 374 | "outputs": [], 375 | "source": [ 376 | "save_array(sample_path+'results/conv_val_feat.dat', samp_conv_val_feat)\n", 377 | "save_array(sample_path+'results/conv_feat.dat', samp_conv_feat)" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "metadata": { 384 | "collapsed": true 385 | }, 386 | "outputs": [], 387 | "source": [ 388 | "samp_conv_feat = load_array(sample_path+'results/conv_feat.dat')\n", 389 | "samp_conv_val_feat = load_array(sample_path+'results/conv_val_feat.dat')" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "metadata": { 396 | "collapsed": true, 397 | "scrolled": true 398 | }, 399 | "outputs": [], 400 | "source": [ 401 | "samp_conv_val_feat.shape" 402 | ] 403 | }, 404 | { 405 | "cell_type": "markdown", 406 | "metadata": {}, 407 | "source": [ 408 | "This is our usual Vgg network just covering the dense layers:" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": null, 414 | "metadata": { 415 | "collapsed": true 416 | }, 417 | "outputs": [], 418 | "source": [ 419 | "def get_dense_layers():\n", 420 | " return [\n", 421 | " MaxPooling2D(input_shape=conv_layers[-1].output_shape[1:]),\n", 422 | " Flatten(),\n", 423 | " Dense(4096, activation='relu'),\n", 424 | " Dropout(0.5),\n", 425 | " Dense(4096, activation='relu'),\n", 426 | " Dropout(0.5),\n", 427 | " # Dense(1000, activation='softmax')\n", 428 | " Dense(1000, activation='relu')\n", 429 | " ]" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": { 436 | "collapsed": true 437 | }, 438 | "outputs": [], 439 | "source": [ 440 | "dense_model = Sequential(get_dense_layers())" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "metadata": { 447 | "collapsed": true 448 | }, 449 | "outputs": [], 450 | "source": [ 451 | "for l1, l2 in zip(dense_layers, dense_model.layers):\n", 452 | " l2.set_weights(l1.get_weights())" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": null, 458 | "metadata": { 459 | "collapsed": true 460 | }, 461 | "outputs": [], 462 | "source": [ 463 | "dense_model.add(Dense(763, activation='softmax'))" 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": {}, 469 | "source": [ 470 | "### Check model" 471 | ] 472 | }, 473 | { 474 | "cell_type": "markdown", 475 | "metadata": {}, 476 | "source": [ 477 | "It's a good idea to check that your models are giving reasonable answers, before using them." 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": null, 483 | "metadata": { 484 | "collapsed": true 485 | }, 486 | "outputs": [], 487 | "source": [ 488 | "dense_model.compile(Adam(), 'categorical_crossentropy', ['accuracy'])" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "metadata": { 495 | "collapsed": true 496 | }, 497 | "outputs": [], 498 | "source": [ 499 | "dense_model.evaluate(samp_conv_val_feat, samp_val_labels)" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "metadata": { 506 | "collapsed": true 507 | }, 508 | "outputs": [], 509 | "source": [ 510 | "model.compile(Adam(), 'categorical_crossentropy', ['accuracy'])" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": null, 516 | "metadata": { 517 | "collapsed": true 518 | }, 519 | "outputs": [], 520 | "source": [ 521 | "# should be identical to above\n", 522 | "# model.evaluate(val, val_labels)" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": null, 528 | "metadata": { 529 | "collapsed": true 530 | }, 531 | "outputs": [], 532 | "source": [ 533 | "# should be a little better than above, since VGG authors overfit\n", 534 | "# dense_model.evaluate(conv_feat, trn_labels)" 535 | ] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "metadata": { 540 | "collapsed": true 541 | }, 542 | "source": [ 543 | "## Adding our new layers" 544 | ] 545 | }, 546 | { 547 | "cell_type": "markdown", 548 | "metadata": {}, 549 | "source": [ 550 | "### Calculating batchnorm params" 551 | ] 552 | }, 553 | { 554 | "cell_type": "markdown", 555 | "metadata": {}, 556 | "source": [ 557 | "To calculate the output of a layer in a Keras sequential model, we have to create a function that defines the input layer and the output layer, like this:" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": null, 563 | "metadata": { 564 | "collapsed": true 565 | }, 566 | "outputs": [], 567 | "source": [ 568 | "k_layer_out = K.function([dense_model.layers[0].input, K.learning_phase()], \n", 569 | " [dense_model.layers[2].output])" 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "metadata": {}, 575 | "source": [ 576 | "Then we can call the function to get our layer activations:" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": null, 582 | "metadata": { 583 | "collapsed": true 584 | }, 585 | "outputs": [], 586 | "source": [ 587 | "d0_out = k_layer_out([samp_conv_val_feat, 0])[0]" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": null, 593 | "metadata": { 594 | "collapsed": true 595 | }, 596 | "outputs": [], 597 | "source": [ 598 | "k_layer_out = K.function([dense_model.layers[0].input, K.learning_phase()], \n", 599 | " [dense_model.layers[4].output])" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": null, 605 | "metadata": { 606 | "collapsed": true 607 | }, 608 | "outputs": [], 609 | "source": [ 610 | "d2_out = k_layer_out([samp_conv_val_feat, 0])[0]" 611 | ] 612 | }, 613 | { 614 | "cell_type": "markdown", 615 | "metadata": {}, 616 | "source": [ 617 | "Now that we've got our activations, we can calculate the mean and standard deviation for each (note that due to a bug in keras, it's actually the variance that we'll need)." 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": null, 623 | "metadata": { 624 | "collapsed": true 625 | }, 626 | "outputs": [], 627 | "source": [ 628 | "mu0,var0 = d0_out.mean(axis=0), d0_out.var(axis=0)\n", 629 | "mu2,var2 = d2_out.mean(axis=0), d2_out.var(axis=0)" 630 | ] 631 | }, 632 | { 633 | "cell_type": "markdown", 634 | "metadata": {}, 635 | "source": [ 636 | "### Creating batchnorm model" 637 | ] 638 | }, 639 | { 640 | "cell_type": "markdown", 641 | "metadata": {}, 642 | "source": [ 643 | "Now we're ready to create and insert our layers just after each dense layer." 644 | ] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "execution_count": null, 649 | "metadata": { 650 | "collapsed": true 651 | }, 652 | "outputs": [], 653 | "source": [ 654 | "nl1 = BatchNormalization()\n", 655 | "nl2 = BatchNormalization()" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": null, 661 | "metadata": { 662 | "collapsed": true 663 | }, 664 | "outputs": [], 665 | "source": [ 666 | "bn_model = insert_layer(dense_model, nl2, 5)\n", 667 | "bn_model = insert_layer(bn_model, nl1, 3)" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": null, 673 | "metadata": { 674 | "collapsed": true 675 | }, 676 | "outputs": [], 677 | "source": [ 678 | "bnl1 = bn_model.layers[3]\n", 679 | "bnl4 = bn_model.layers[6]" 680 | ] 681 | }, 682 | { 683 | "cell_type": "markdown", 684 | "metadata": {}, 685 | "source": [ 686 | "After inserting the layers, we can set their weights to the variance and mean we just calculated." 687 | ] 688 | }, 689 | { 690 | "cell_type": "code", 691 | "execution_count": null, 692 | "metadata": { 693 | "collapsed": true 694 | }, 695 | "outputs": [], 696 | "source": [ 697 | "bnl1.set_weights([var0, mu0, mu0, var0])\n", 698 | "bnl4.set_weights([var2, mu2, mu2, var2])" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": null, 704 | "metadata": { 705 | "collapsed": true 706 | }, 707 | "outputs": [], 708 | "source": [ 709 | "bn_model.compile(Adam(1e-5), 'categorical_crossentropy', ['accuracy'])" 710 | ] 711 | }, 712 | { 713 | "cell_type": "markdown", 714 | "metadata": {}, 715 | "source": [ 716 | "We should find that the new model gives identical results to those provided by the original VGG model." 717 | ] 718 | }, 719 | { 720 | "cell_type": "code", 721 | "execution_count": null, 722 | "metadata": { 723 | "collapsed": true 724 | }, 725 | "outputs": [], 726 | "source": [ 727 | "bn_model.evaluate(samp_conv_val_feat, samp_val_labels)" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": null, 733 | "metadata": { 734 | "collapsed": true 735 | }, 736 | "outputs": [], 737 | "source": [ 738 | "bn_model.evaluate(samp_conv_feat, samp_trn_labels)" 739 | ] 740 | }, 741 | { 742 | "cell_type": "markdown", 743 | "metadata": {}, 744 | "source": [ 745 | "### Optional - additional fine-tuning" 746 | ] 747 | }, 748 | { 749 | "cell_type": "markdown", 750 | "metadata": {}, 751 | "source": [ 752 | "Now that we have a VGG model with batchnorm, we might expect that the optimal weights would be a little different to what they were when originally created without batchnorm. So we fine tune the weights for one epoch." 753 | ] 754 | }, 755 | { 756 | "cell_type": "code", 757 | "execution_count": null, 758 | "metadata": { 759 | "collapsed": true 760 | }, 761 | "outputs": [], 762 | "source": [ 763 | "feat_bc = bcolz.open(fast_path+'trn_features.dat')" 764 | ] 765 | }, 766 | { 767 | "cell_type": "code", 768 | "execution_count": null, 769 | "metadata": { 770 | "collapsed": true 771 | }, 772 | "outputs": [], 773 | "source": [ 774 | "labels = load_array(fast_path+'trn_labels.dat')" 775 | ] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "execution_count": null, 780 | "metadata": { 781 | "collapsed": true 782 | }, 783 | "outputs": [], 784 | "source": [ 785 | "val_feat_bc = bcolz.open(fast_path+'val_features.dat')" 786 | ] 787 | }, 788 | { 789 | "cell_type": "code", 790 | "execution_count": null, 791 | "metadata": { 792 | "collapsed": true 793 | }, 794 | "outputs": [], 795 | "source": [ 796 | "val_labels = load_array(fast_path+'val_labels.dat')" 797 | ] 798 | }, 799 | { 800 | "cell_type": "code", 801 | "execution_count": null, 802 | "metadata": { 803 | "collapsed": true 804 | }, 805 | "outputs": [], 806 | "source": [ 807 | "bn_model.fit(feat_bc, labels, nb_epoch=1, batch_size=batch_size,\n", 808 | " validation_data=(val_feat_bc, val_labels))" 809 | ] 810 | }, 811 | { 812 | "cell_type": "markdown", 813 | "metadata": {}, 814 | "source": [ 815 | "The results look quite encouraging! Note that these VGG weights are now specific to how keras handles image scaling - that is, it squashes and stretches images, rather than adding black borders. So this model is best used on images created in that way." 816 | ] 817 | }, 818 | { 819 | "cell_type": "code", 820 | "execution_count": null, 821 | "metadata": { 822 | "collapsed": true 823 | }, 824 | "outputs": [], 825 | "source": [ 826 | "bn_model.save_weights(path+'models/bn_model2.h5')" 827 | ] 828 | }, 829 | { 830 | "cell_type": "code", 831 | "execution_count": null, 832 | "metadata": { 833 | "collapsed": true 834 | }, 835 | "outputs": [], 836 | "source": [ 837 | "bn_model.load_weights(path+'models/bn_model2.h5')" 838 | ] 839 | }, 840 | { 841 | "cell_type": "markdown", 842 | "metadata": { 843 | "collapsed": true 844 | }, 845 | "source": [ 846 | "### Create combined model" 847 | ] 848 | }, 849 | { 850 | "cell_type": "markdown", 851 | "metadata": {}, 852 | "source": [ 853 | "Our last step is simply to copy our new dense layers on to the end of the convolutional part of the network, and save the new complete set of weights, so we can use them in the future when using VGG. (Of course, we'll also need to update our VGG architecture to add the batchnorm layers)." 854 | ] 855 | }, 856 | { 857 | "cell_type": "code", 858 | "execution_count": null, 859 | "metadata": { 860 | "collapsed": true 861 | }, 862 | "outputs": [], 863 | "source": [ 864 | "new_layers = copy_layers(bn_model.layers)\n", 865 | "for layer in new_layers:\n", 866 | " conv_model.add(layer)" 867 | ] 868 | }, 869 | { 870 | "cell_type": "code", 871 | "execution_count": null, 872 | "metadata": { 873 | "collapsed": true 874 | }, 875 | "outputs": [], 876 | "source": [ 877 | "copy_weights(bn_model.layers, new_layers)" 878 | ] 879 | }, 880 | { 881 | "cell_type": "code", 882 | "execution_count": null, 883 | "metadata": { 884 | "collapsed": true 885 | }, 886 | "outputs": [], 887 | "source": [ 888 | "conv_model.compile(Adam(1e-5), 'categorical_crossentropy', ['accuracy'])" 889 | ] 890 | }, 891 | { 892 | "cell_type": "code", 893 | "execution_count": null, 894 | "metadata": { 895 | "collapsed": true 896 | }, 897 | "outputs": [], 898 | "source": [ 899 | "conv_model.evaluate(samp_val, samp_val_labels)" 900 | ] 901 | }, 902 | { 903 | "cell_type": "code", 904 | "execution_count": null, 905 | "metadata": { 906 | "collapsed": true 907 | }, 908 | "outputs": [], 909 | "source": [ 910 | "conv_model.save_weights(path+'models/inet_224squash_bn.h5')" 911 | ] 912 | }, 913 | { 914 | "cell_type": "markdown", 915 | "metadata": { 916 | "collapsed": true 917 | }, 918 | "source": [ 919 | "The code shown here is implemented in [vgg_bn.py](https://github.com/fastai/courses/blob/master/deeplearning1/nbs/vgg16bn.py), and there is a version of ``vgg_ft`` (our fine tuning function) with batch norm called ``vgg_ft_bn`` in [utils.py](https://github.com/fastai/courses/blob/master/deeplearning1/nbs/utils.py)." 920 | ] 921 | }, 922 | { 923 | "cell_type": "code", 924 | "execution_count": null, 925 | "metadata": { 926 | "collapsed": true 927 | }, 928 | "outputs": [], 929 | "source": [] 930 | } 931 | ], 932 | "metadata": { 933 | "anaconda-cloud": {}, 934 | "kernelspec": { 935 | "display_name": "Python 3", 936 | "language": "python", 937 | "name": "python3" 938 | }, 939 | "language_info": { 940 | "codemirror_mode": { 941 | "name": "ipython", 942 | "version": 3 943 | }, 944 | "file_extension": ".py", 945 | "mimetype": "text/x-python", 946 | "name": "python", 947 | "nbconvert_exporter": "python", 948 | "pygments_lexer": "ipython3", 949 | "version": "3.5.4" 950 | }, 951 | "nav_menu": {}, 952 | "toc": { 953 | "navigate_menu": true, 954 | "number_sections": true, 955 | "sideBar": true, 956 | "threshold": 6, 957 | "toc_cell": false, 958 | "toc_section_display": "block", 959 | "toc_window_display": false 960 | } 961 | }, 962 | "nbformat": 4, 963 | "nbformat_minor": 1 964 | } 965 | -------------------------------------------------------------------------------- /nbs/lesson5.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2018-01-09T15:36:28.456263Z", 9 | "start_time": "2018-01-09T15:36:25.246783Z" 10 | } 11 | }, 12 | "outputs": [ 13 | { 14 | "name": "stderr", 15 | "output_type": "stream", 16 | "text": [ 17 | "Using cuDNN version 6021 on context None\n", 18 | "Mapped name None to device cuda0: GeForce GTX TITAN X (0000:04:00.0)\n", 19 | "Using Theano backend.\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "from __future__ import division, print_function\n", 25 | "%matplotlib inline\n", 26 | "from importlib import reload # Python 3\n", 27 | "import utils; reload(utils)\n", 28 | "from utils import *" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": { 35 | "ExecuteTime": { 36 | "end_time": "2018-01-09T15:36:28.461234Z", 37 | "start_time": "2018-01-09T15:36:28.458339Z" 38 | } 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "path = \"data/imdb/\"\n", 43 | "model_path = path + 'models/'\n", 44 | "if not os.path.exists(model_path): os.mkdir(model_path)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "## Setup data" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "We're going to look at the IMDB dataset, which contains movie reviews from IMDB, along with their sentiment. Keras comes with some helpers for this dataset." 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 3, 64 | "metadata": { 65 | "ExecuteTime": { 66 | "end_time": "2018-01-09T15:36:28.541861Z", 67 | "start_time": "2018-01-09T15:36:28.463022Z" 68 | } 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "from keras.datasets import imdb\n", 73 | "idx = imdb.get_word_index()" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "This is the word list:" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 4, 86 | "metadata": { 87 | "ExecuteTime": { 88 | "end_time": "2018-01-09T15:36:28.602069Z", 89 | "start_time": "2018-01-09T15:36:28.543828Z" 90 | } 91 | }, 92 | "outputs": [ 93 | { 94 | "data": { 95 | "text/plain": [ 96 | "['the', 'and', 'a', 'of', 'to', 'is', 'br', 'in', 'it', 'i']" 97 | ] 98 | }, 99 | "execution_count": 4, 100 | "metadata": {}, 101 | "output_type": "execute_result" 102 | } 103 | ], 104 | "source": [ 105 | "idx_arr = sorted(idx, key=idx.get)\n", 106 | "idx_arr[:10]" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "...and this is the mapping from id to word" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "metadata": { 120 | "ExecuteTime": { 121 | "end_time": "2018-01-09T15:36:28.644220Z", 122 | "start_time": "2018-01-09T15:36:28.603536Z" 123 | }, 124 | "scrolled": false 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "idx2word = {v: k for k, v in idx.items()}" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "We download the reviews using code copied from keras.datasets:" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 6, 141 | "metadata": { 142 | "ExecuteTime": { 143 | "end_time": "2018-01-09T15:36:30.774903Z", 144 | "start_time": "2018-01-09T15:36:28.645649Z" 145 | } 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "path = get_file('imdb_full.pkl',\n", 150 | " origin='https://s3.amazonaws.com/text-datasets/imdb_full.pkl',\n", 151 | " md5_hash='d091312047c43cf9e4e38fef92437263')\n", 152 | "f = open(path, 'rb')\n", 153 | "(x_train, labels_train), (x_test, labels_test) = pickle.load(f)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 7, 159 | "metadata": { 160 | "ExecuteTime": { 161 | "end_time": "2018-01-09T15:36:30.779253Z", 162 | "start_time": "2018-01-09T15:36:30.776488Z" 163 | } 164 | }, 165 | "outputs": [ 166 | { 167 | "data": { 168 | "text/plain": [ 169 | "25000" 170 | ] 171 | }, 172 | "execution_count": 7, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | } 176 | ], 177 | "source": [ 178 | "len(x_train)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "Here's the 1st review. As you see, the words have been replaced by ids. The ids can be looked up in idx2word." 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 8, 191 | "metadata": { 192 | "ExecuteTime": { 193 | "end_time": "2018-01-09T15:36:30.817217Z", 194 | "start_time": "2018-01-09T15:36:30.780486Z" 195 | }, 196 | "scrolled": false 197 | }, 198 | "outputs": [ 199 | { 200 | "data": { 201 | "text/plain": [ 202 | "'23022, 309, 6, 3, 1069, 209, 9, 2175, 30, 1, 169, 55, 14, 46, 82, 5869, 41, 393, 110, 138, 14, 5359, 58, 4477, 150, 8, 1, 5032, 5948, 482, 69, 5, 261, 12, 23022, 73935, 2003, 6, 73, 2436, 5, 632, 71, 6, 5359, 1, 25279, 5, 2004, 10471, 1, 5941, 1534, 34, 67, 64, 205, 140, 65, 1232, 63526, 21145, 1, 49265, 4, 1, 223, 901, 29, 3024, 69, 4, 1, 5863, 10, 694, 2, 65, 1534, 51, 10, 216, 1, 387, 8, 60, 3, 1472, 3724, 802, 5, 3521, 177, 1, 393, 10, 1238, 14030, 30, 309, 3, 353, 344, 2989, 143, 130, 5, 7804, 28, 4, 126, 5359, 1472, 2375, 5, 23022, 309, 10, 532, 12, 108, 1470, 4, 58, 556, 101, 12, 23022, 309, 6, 227, 4187, 48, 3, 2237, 12, 9, 215'" 203 | ] 204 | }, 205 | "execution_count": 8, 206 | "metadata": {}, 207 | "output_type": "execute_result" 208 | } 209 | ], 210 | "source": [ 211 | "', '.join(map(str, x_train[0]))" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": {}, 217 | "source": [ 218 | "The first word of the first review is 23022. Let's see what that is." 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 9, 224 | "metadata": { 225 | "ExecuteTime": { 226 | "end_time": "2018-01-09T15:36:30.839786Z", 227 | "start_time": "2018-01-09T15:36:30.819303Z" 228 | } 229 | }, 230 | "outputs": [ 231 | { 232 | "data": { 233 | "text/plain": [ 234 | "'bromwell'" 235 | ] 236 | }, 237 | "execution_count": 9, 238 | "metadata": {}, 239 | "output_type": "execute_result" 240 | } 241 | ], 242 | "source": [ 243 | "idx2word[23022]" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "Here's the whole review, mapped from ids to words." 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 10, 256 | "metadata": { 257 | "ExecuteTime": { 258 | "end_time": "2018-01-09T15:36:30.870150Z", 259 | "start_time": "2018-01-09T15:36:30.841763Z" 260 | }, 261 | "scrolled": false 262 | }, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "text/plain": [ 267 | "\"bromwell high is a cartoon comedy it ran at the same time as some other programs about school life such as teachers my 35 years in the teaching profession lead me to believe that bromwell high's satire is much closer to reality than is teachers the scramble to survive financially the insightful students who can see right through their pathetic teachers' pomp the pettiness of the whole situation all remind me of the schools i knew and their students when i saw the episode in which a student repeatedly tried to burn down the school i immediately recalled at high a classic line inspector i'm here to sack one of your teachers student welcome to bromwell high i expect that many adults of my age think that bromwell high is far fetched what a pity that it isn't\"" 268 | ] 269 | }, 270 | "execution_count": 10, 271 | "metadata": {}, 272 | "output_type": "execute_result" 273 | } 274 | ], 275 | "source": [ 276 | "' '.join([idx2word[o] for o in x_train[0]])" 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "metadata": {}, 282 | "source": [ 283 | "The labels are 1 for positive, 0 for negative." 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 11, 289 | "metadata": { 290 | "ExecuteTime": { 291 | "end_time": "2018-01-09T15:36:30.910789Z", 292 | "start_time": "2018-01-09T15:36:30.872255Z" 293 | } 294 | }, 295 | "outputs": [ 296 | { 297 | "data": { 298 | "text/plain": [ 299 | "[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]" 300 | ] 301 | }, 302 | "execution_count": 11, 303 | "metadata": {}, 304 | "output_type": "execute_result" 305 | } 306 | ], 307 | "source": [ 308 | "labels_train[:10]" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "Reduce vocab size by setting rare words to max index." 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 12, 321 | "metadata": { 322 | "ExecuteTime": { 323 | "end_time": "2018-01-09T15:36:32.664917Z", 324 | "start_time": "2018-01-09T15:36:30.914144Z" 325 | } 326 | }, 327 | "outputs": [ 328 | { 329 | "data": { 330 | "text/plain": [ 331 | "'bergman'" 332 | ] 333 | }, 334 | "execution_count": 12, 335 | "metadata": {}, 336 | "output_type": "execute_result" 337 | } 338 | ], 339 | "source": [ 340 | "vocab_size = 5000\n", 341 | "\n", 342 | "trn = [np.array([i if i