├── README.md ├── docs ├── Makefile ├── make.sh └── source │ ├── conf.py │ ├── examples.rst │ ├── index.rst │ ├── modules.rst │ ├── tests.rst │ └── tools.rst ├── examples ├── __init__.py ├── bsds500.py ├── bsds500 │ └── .gitignore ├── cifar10.py ├── cifar10 │ └── test_dog.png ├── iris.py ├── iris │ └── .gitignore ├── lmdb_io.py ├── mnist.py ├── mnist │ ├── test_0.png │ ├── test_1.png │ ├── test_2.png │ ├── test_3.png │ ├── test_4.png │ ├── test_5.png │ ├── test_6.png │ ├── test_7.png │ ├── test_8.png │ └── test_9.png └── visualization.py ├── install_caffe.sh ├── tests ├── __init__.py ├── cifar10_test │ ├── 00000.png │ ├── 00001.png │ ├── 00002.png │ ├── 00003.png │ └── 00004.png ├── lmdb_io.py ├── mnist_test │ ├── 00000000.png │ ├── 00000001.png │ ├── 00000002.png │ ├── 00000003.png │ └── 00000004.png ├── pre_processing.py ├── prototxt.py └── solvers.py └── tools ├── __init__.py ├── data_augmentation.py ├── layers.py ├── lmdb_io.py ├── pre_processing.py ├── prototxt.py ├── solvers.py └── visualization.py /README.md: -------------------------------------------------------------------------------- 1 | # Caffe-Tools 2 | 3 | Tools and examples for pyCaffe, including: 4 | 5 | * LMDB input and output and conversion from/to CSV and image files; 6 | * monitoring the training process including error, loss and gradients; 7 | * on-the-fly data augmentation; 8 | * custom Python layers. 9 | 10 | The data used for the examples can either be generated manually, see the documentation 11 | or corresponding files in `examples`, or downloaded from [davidstutz/caffe-tools-data](https://github.com/davidstutz/caffe-tools-data). 12 | 13 | Also see the corresponding blog articles at [davidstutz.de](http://davidstutz.de). 14 | 15 | ## Examples 16 | 17 | The provided examples include: 18 | 19 | * [MNIST](http://yann.lecun.com/exdb/mnist/): [examples/mnist.py](examples/mnist.py) 20 | * [Iris](https://archive.ics.uci.edu/ml/datasets/Iris): [examples/iris.py](examples/iris.py) 21 | * [Cifar10](https://www.cs.toronto.edu/~kriz/cifar.html): [examples/cifar10.py](examples/cifar10.py.py) 22 | * [BSDS500](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html): [examples/bsds500.py](examples/bsds500.py) 23 | 24 | Note that the BSDS500 example is **work in progress**! The corresponding data can 25 | be downloaded from [davidstutz/caffe-tools-data](https://github.com/davidstutz/caffe-tools-data). 26 | See the instructions in the corresponding files for details. 27 | 28 | ## Resources 29 | 30 | Some resources I found usefl while working with Caffe: 31 | 32 | * Installation: 33 | * http://stackoverflow.com/questions/31395729/how-to-enable-multithreading-with-caffe/31396229 34 | * https://github.com/BVLC/caffe/wiki/Install-Caffe-on-EC2-from-scratch-(Ubuntu,-CUDA-7,-cuDNN-3) 35 | * https://github.com/BVLC/caffe/wiki/Ubuntu-16.04-or-15.10-Installation-Guide 36 | * https://gist.github.com/titipata/f0ef48ad2f0ebc07bcb9 37 | * https://github.com/asampat3090/caffe-ubuntu-14.04 38 | * https://github.com/mrgloom/Caffe-snippets 39 | * GitHub Repositories for pyCaffe: 40 | * https://github.com/nitnelave/pycaffe_tutorial 41 | * https://github.com/pulkitag/pycaffe-utils 42 | * https://github.com/DeeperCS/pycaffe-mnist 43 | * https://github.com/swift-n-brutal/pycaffe_utils 44 | * https://github.com/jimgoo/caffe-oxford102 45 | * https://github.com/ruimashita/caffe-train 46 | * https://github.com/roseperrone/video-object-detection 47 | * https://github.com/pecarlat/caffeTools 48 | * https://github.com/donnemartin/data-science-ipython-notebooks 49 | * https://github.com/jay-mahadeokar/pynetbuilder 50 | * https://github.com/adilmoujahid/deeplearning-cats-dogs-tutorial 51 | * https://github.com/Franck-Dernoncourt/caffe_demos 52 | https://github.com/koosyong/caffestudy 53 | * Issues: 54 | * https://github.com/BVLC/caffe/issues/3651 (solverstate) 55 | * https://github.com/BVLC/caffe/issues/1566 56 | * https://github.com/BVLC/caffe/pull/3082/files (snapshot) 57 | * https://github.com/BVLC/caffe/issues/1257 (net surgery on solver net) 58 | * https://github.com/BVLC/caffe/issues/409 (net diverges, loss = NaN) 59 | * https://github.com/BVLC/caffe/issues/1168 (pyCaffe example incldued) 60 | * https://github.com/BVLC/caffe/issues/462 (pyCaffe example incldued) 61 | * https://github.com/BVLC/caffe/issues/2684 (change batch size) 62 | * https://github.com/rbgirshick/py-faster-rcnn/issues/77 (load solverstate) 63 | * https://github.com/BVLC/caffe/issues/2116 (Caffe LMDB float data) 64 | * LMDB: 65 | * https://lmdb.readthedocs.io/en/release/ 66 | * http://research.beenfrog.com/code/2015/12/30/write-read-lmdb-example.html 67 | * http://deepdish.io/2015/04/28/creating-lmdb-in-python/ 68 | * https://github.com/BVLC/caffe/issues/3959 69 | * Tutorials/Blogs: 70 | * http://christopher5106.github.io/deep/learning/2015/09/04/Deep-learning-tutorial-on-Caffe-Technology.html 71 | * http://www.alanzucconi.com/2016/05/25/generating-deep-dreams/#part2 72 | * http://adilmoujahid.com/posts/2016/06/introduction-deep-learning-python-caffe/ 73 | * Caffe Versions: 74 | * https://github.com/kevinlin311tw/caffe-augmentation (on-the-fly data augmentation) 75 | * https://github.com/ShaharKatz/Caffe-Data-Augmentation (data augmentation) 76 | * 3D: 77 | * https://github.com/faustomilletari/3D-Caffe 78 | * https://github.com/wps712/caffe4video 79 | * StackOverflow: 80 | * http://stackoverflow.com/questions/33905326/caffe-training-without-testing (training without testing) 81 | * http://stackoverflow.com/questions/38348801/caffe-hangs-after-printing-data-label (stuck at data -> label) 82 | * http://stackoverflow.com/questions/35529078/how-to-predict-in-pycaffe (predicting in pyCaffe) 83 | * http://stackoverflow.com/questions/35529078/how-to-predict-in-pycaffe/35572495#35572495 (testing from LMDB with transformer) 84 | * http://stackoverflow.com/questions/37642885/am-i-using-lmdb-incorrectly-it-says-environment-mapsize-limit-reached-after-0-i (LMDB mapsize) 85 | * http://stackoverflow.com/questions/31820976/lmdb-increase-map-size (LMDB mapsize) 86 | * http://stackoverflow.com/questions/34092606/how-to-get-the-dataset-size-of-a-caffe-net-in-python/34117558 (dataset size) 87 | * http://stackoverflow.com/questions/32379878/cheat-sheet-for-caffe-pycaffe (pyCaffe cheat sheet) 88 | * http://stackoverflow.com/questions/38511503/how-to-compute-test-validation-loss-in-pycaffe (copying weights to test net) 89 | * http://stackoverflow.com/questions/29788075/setting-glog-minloglevel-1-to-prevent-output-in-shell-from-caffe (slience GLOG logging in 90 | * http://stackoverflow.com/questions/36108120/shuffle-data-in-lmdb-file 91 | * http://stackoverflow.com/questions/36459266/caffe-python-manual-sgd 92 | * Layers: 93 | * http://installing-caffe-the-right-way.wikidot.com/start 94 | * https://github.com/NVIDIA/DIGITS/tree/master/examples/python-layer 95 | * https://github.com/BVLC/caffe/blob/master/examples/pycaffe/layers/pyloss.py 96 | * https://github.com/BVLC/caffe/blob/master/examples/pycaffe/layers/pascal_multilabel_datalayers.py 97 | * http://stackoverflow.com/questions/34549743/caffe-how-to-get-the-phase-of-a-python-layer/34588801#34588801 98 | * http://stackoverflow.com/questions/34996075/caffe-data-layer-example-step-by-step 99 | * https://github.com/BVLC/caffe/issues/4023 100 | * https://codegists.com/code/caffe-python-layer/ 101 | * https://codedump.io/share/CiQmhfC63OD0/1/pycaffe-how-to-create-custom-weights-in-a-python-layer 102 | * http://stackoverflow.com/questions/34498527/pycaffe-how-to-create-custom-weights-in-a-python-layer 103 | * https://github.com/gcucurull/caffe-conf-matrix/blob/master/python_confmat.py | http://gcucurull.github.io/caffe/python/deep-learning/2016/06/29/caffe-confusion-matrix/ 104 | 105 | ## Documentation 106 | 107 | Installing and running Sphinx (also see [davidstutz/sphinx-example](https://github.com/davidstutz/sphinx-example) for details): 108 | 109 | $ sudo apt-get install python-sphinx 110 | $ sudo pip install sphinx 111 | $ cd docs 112 | $ make html 113 | 114 | ## License 115 | 116 | Copyright (c) 2016 David Stutz All rights reserved. 117 | 118 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 119 | 120 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 121 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 122 | * Neither the name of David Stutz nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 123 | 124 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 125 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -c "import sys,sphinx;sys.exit(sphinx.main(sys.argv))" 7 | PAPER = 8 | BUILDDIR = build 9 | 10 | # Internal variables. 11 | PAPEROPT_a4 = -D latex_paper_size=a4 12 | PAPEROPT_letter = -D latex_paper_size=letter 13 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 14 | # the i18n builder cannot share the environment and doctrees with the others 15 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 16 | 17 | .PHONY: help 18 | help: 19 | @echo "Please use \`make ' where is one of" 20 | @echo " html to make standalone HTML files" 21 | @echo " dirhtml to make HTML files named index.html in directories" 22 | @echo " singlehtml to make a single large HTML file" 23 | @echo " pickle to make pickle files" 24 | @echo " json to make JSON files" 25 | @echo " htmlhelp to make HTML files and a HTML help project" 26 | @echo " qthelp to make HTML files and a qthelp project" 27 | @echo " applehelp to make an Apple Help Book" 28 | @echo " devhelp to make HTML files and a Devhelp project" 29 | @echo " epub to make an epub" 30 | @echo " epub3 to make an epub3" 31 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 32 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 33 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 34 | @echo " text to make text files" 35 | @echo " man to make manual pages" 36 | @echo " texinfo to make Texinfo files" 37 | @echo " info to make Texinfo files and run them through makeinfo" 38 | @echo " gettext to make PO message catalogs" 39 | @echo " changes to make an overview of all changed/added/deprecated items" 40 | @echo " xml to make Docutils-native XML files" 41 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 42 | @echo " linkcheck to check all external links for integrity" 43 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 44 | @echo " coverage to run coverage check of the documentation (if enabled)" 45 | @echo " dummy to check syntax errors of document sources" 46 | 47 | .PHONY: clean 48 | clean: 49 | rm -rf $(BUILDDIR)/* 50 | 51 | .PHONY: html 52 | html: 53 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 54 | @echo 55 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 56 | 57 | .PHONY: dirhtml 58 | dirhtml: 59 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 60 | @echo 61 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 62 | 63 | .PHONY: singlehtml 64 | singlehtml: 65 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 66 | @echo 67 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 68 | 69 | .PHONY: pickle 70 | pickle: 71 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 72 | @echo 73 | @echo "Build finished; now you can process the pickle files." 74 | 75 | .PHONY: json 76 | json: 77 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 78 | @echo 79 | @echo "Build finished; now you can process the JSON files." 80 | 81 | .PHONY: htmlhelp 82 | htmlhelp: 83 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 84 | @echo 85 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 86 | ".hhp project file in $(BUILDDIR)/htmlhelp." 87 | 88 | .PHONY: qthelp 89 | qthelp: 90 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 91 | @echo 92 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 93 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 94 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/caffe-tools.qhcp" 95 | @echo "To view the help file:" 96 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/caffe-tools.qhc" 97 | 98 | .PHONY: applehelp 99 | applehelp: 100 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 101 | @echo 102 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 103 | @echo "N.B. You won't be able to view it unless you put it in" \ 104 | "~/Library/Documentation/Help or install it in your application" \ 105 | "bundle." 106 | 107 | .PHONY: devhelp 108 | devhelp: 109 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 110 | @echo 111 | @echo "Build finished." 112 | @echo "To view the help file:" 113 | @echo "# mkdir -p $$HOME/.local/share/devhelp/caffe-tools" 114 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/caffe-tools" 115 | @echo "# devhelp" 116 | 117 | .PHONY: epub 118 | epub: 119 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 120 | @echo 121 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 122 | 123 | .PHONY: epub3 124 | epub3: 125 | $(SPHINXBUILD) -b epub3 $(ALLSPHINXOPTS) $(BUILDDIR)/epub3 126 | @echo 127 | @echo "Build finished. The epub3 file is in $(BUILDDIR)/epub3." 128 | 129 | .PHONY: latex 130 | latex: 131 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 132 | @echo 133 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 134 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 135 | "(use \`make latexpdf' here to do that automatically)." 136 | 137 | .PHONY: latexpdf 138 | latexpdf: 139 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 140 | @echo "Running LaTeX files through pdflatex..." 141 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 142 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 143 | 144 | .PHONY: latexpdfja 145 | latexpdfja: 146 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 147 | @echo "Running LaTeX files through platex and dvipdfmx..." 148 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 149 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 150 | 151 | .PHONY: text 152 | text: 153 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 154 | @echo 155 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 156 | 157 | .PHONY: man 158 | man: 159 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 160 | @echo 161 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 162 | 163 | .PHONY: texinfo 164 | texinfo: 165 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 166 | @echo 167 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 168 | @echo "Run \`make' in that directory to run these through makeinfo" \ 169 | "(use \`make info' here to do that automatically)." 170 | 171 | .PHONY: info 172 | info: 173 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 174 | @echo "Running Texinfo files through makeinfo..." 175 | make -C $(BUILDDIR)/texinfo info 176 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 177 | 178 | .PHONY: gettext 179 | gettext: 180 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 181 | @echo 182 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 183 | 184 | .PHONY: changes 185 | changes: 186 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 187 | @echo 188 | @echo "The overview file is in $(BUILDDIR)/changes." 189 | 190 | .PHONY: linkcheck 191 | linkcheck: 192 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 193 | @echo 194 | @echo "Link check complete; look for any errors in the above output " \ 195 | "or in $(BUILDDIR)/linkcheck/output.txt." 196 | 197 | .PHONY: doctest 198 | doctest: 199 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 200 | @echo "Testing of doctests in the sources finished, look at the " \ 201 | "results in $(BUILDDIR)/doctest/output.txt." 202 | 203 | .PHONY: coverage 204 | coverage: 205 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 206 | @echo "Testing of coverage in the sources finished, look at the " \ 207 | "results in $(BUILDDIR)/coverage/python.txt." 208 | 209 | .PHONY: xml 210 | xml: 211 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 212 | @echo 213 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 214 | 215 | .PHONY: pseudoxml 216 | pseudoxml: 217 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 218 | @echo 219 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 220 | 221 | .PHONY: dummy 222 | dummy: 223 | $(SPHINXBUILD) -b dummy $(ALLSPHINXOPTS) $(BUILDDIR)/dummy 224 | @echo 225 | @echo "Build finished. Dummy builder generates no files." 226 | -------------------------------------------------------------------------------- /docs/make.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | sphinx-apidoc -f -o source/ ../tools/ 3 | sphinx-apidoc -f -o source/ ../tests/ 4 | sphinx-apidoc -f -o source/ ../examples/ 5 | make html 6 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | # 4 | # caffe-tools documentation build configuration file, created by 5 | # sphinx-quickstart on Mon Nov 28 20:26:36 2016. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | import os 21 | import sys 22 | sys.path.insert(0, os.path.abspath('../../')) 23 | 24 | # -- General configuration ------------------------------------------------ 25 | 26 | # If your documentation needs a minimal Sphinx version, state it here. 27 | # 28 | # needs_sphinx = '1.0' 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'sphinx.ext.autodoc', 35 | 'sphinx.ext.todo', 36 | 'sphinx.ext.mathjax', 37 | 'sphinxarg.ext', 38 | ] 39 | 40 | # Add any paths that contain templates here, relative to this directory. 41 | templates_path = ['_templates'] 42 | 43 | # The suffix(es) of source filenames. 44 | # You can specify multiple suffix as a list of string: 45 | # 46 | # source_suffix = ['.rst', '.md'] 47 | source_suffix = '.rst' 48 | 49 | # The encoding of source files. 50 | # 51 | # source_encoding = 'utf-8-sig' 52 | 53 | # The master toctree document. 54 | master_doc = 'index' 55 | 56 | # General information about the project. 57 | project = 'caffe-tools' 58 | copyright = '2016, David Stutz' 59 | author = 'David Stutz' 60 | 61 | # The version info for the project you're documenting, acts as replacement for 62 | # |version| and |release|, also used in various other places throughout the 63 | # built documents. 64 | # 65 | # The short X.Y version. 66 | version = '1.0' 67 | # The full version, including alpha/beta/rc tags. 68 | release = '1.0' 69 | 70 | # The language for content autogenerated by Sphinx. Refer to documentation 71 | # for a list of supported languages. 72 | # 73 | # This is also used if you do content translation via gettext catalogs. 74 | # Usually you set "language" from the command line for these cases. 75 | language = None 76 | 77 | # There are two options for replacing |today|: either, you set today to some 78 | # non-false value, then it is used: 79 | # 80 | # today = '' 81 | # 82 | # Else, today_fmt is used as the format for a strftime call. 83 | # 84 | # today_fmt = '%B %d, %Y' 85 | 86 | # List of patterns, relative to source directory, that match files and 87 | # directories to ignore when looking for source files. 88 | # This patterns also effect to html_static_path and html_extra_path 89 | exclude_patterns = [] 90 | 91 | # The reST default role (used for this markup: `text`) to use for all 92 | # documents. 93 | # 94 | # default_role = None 95 | 96 | # If true, '()' will be appended to :func: etc. cross-reference text. 97 | # 98 | # add_function_parentheses = True 99 | 100 | # If true, the current module name will be prepended to all description 101 | # unit titles (such as .. function::). 102 | # 103 | # add_module_names = True 104 | 105 | # If true, sectionauthor and moduleauthor directives will be shown in the 106 | # output. They are ignored by default. 107 | # 108 | # show_authors = False 109 | 110 | # The name of the Pygments (syntax highlighting) style to use. 111 | pygments_style = 'sphinx' 112 | 113 | # A list of ignored prefixes for module index sorting. 114 | # modindex_common_prefix = [] 115 | 116 | # If true, keep warnings as "system message" paragraphs in the built documents. 117 | # keep_warnings = False 118 | 119 | # If true, `todo` and `todoList` produce output, else they produce nothing. 120 | todo_include_todos = True 121 | 122 | 123 | # -- Options for HTML output ---------------------------------------------- 124 | 125 | # The theme to use for HTML and HTML Help pages. See the documentation for 126 | # a list of builtin themes. 127 | # 128 | 129 | html_theme = 'classic' 130 | 131 | # Theme options are theme-specific and customize the look and feel of a theme 132 | # further. For a list of options available for each theme, see the 133 | # documentation. 134 | # 135 | # html_theme_options = {} 136 | 137 | # Add any paths that contain custom themes here, relative to this directory. 138 | # html_theme_path = [] 139 | 140 | # The name for this set of Sphinx documents. 141 | # " v documentation" by default. 142 | # 143 | # html_title = 'caffe-tools v1.0' 144 | 145 | # A shorter title for the navigation bar. Default is the same as html_title. 146 | # 147 | # html_short_title = None 148 | 149 | # The name of an image file (relative to this directory) to place at the top 150 | # of the sidebar. 151 | # 152 | # html_logo = None 153 | 154 | # The name of an image file (relative to this directory) to use as a favicon of 155 | # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 156 | # pixels large. 157 | # 158 | # html_favicon = None 159 | 160 | # Add any paths that contain custom static files (such as style sheets) here, 161 | # relative to this directory. They are copied after the builtin static files, 162 | # so a file named "default.css" will overwrite the builtin "default.css". 163 | html_static_path = ['_static'] 164 | 165 | # Add any extra paths that contain custom files (such as robots.txt or 166 | # .htaccess) here, relative to this directory. These files are copied 167 | # directly to the root of the documentation. 168 | # 169 | # html_extra_path = [] 170 | 171 | # If not None, a 'Last updated on:' timestamp is inserted at every page 172 | # bottom, using the given strftime format. 173 | # The empty string is equivalent to '%b %d, %Y'. 174 | # 175 | # html_last_updated_fmt = None 176 | 177 | # If true, SmartyPants will be used to convert quotes and dashes to 178 | # typographically correct entities. 179 | # 180 | # html_use_smartypants = True 181 | 182 | # Custom sidebar templates, maps document names to template names. 183 | # 184 | # html_sidebars = {} 185 | 186 | # Additional templates that should be rendered to pages, maps page names to 187 | # template names. 188 | # 189 | # html_additional_pages = {} 190 | 191 | # If false, no module index is generated. 192 | # 193 | # html_domain_indices = True 194 | 195 | # If false, no index is generated. 196 | # 197 | # html_use_index = True 198 | 199 | # If true, the index is split into individual pages for each letter. 200 | # 201 | # html_split_index = False 202 | 203 | # If true, links to the reST sources are added to the pages. 204 | # 205 | # html_show_sourcelink = True 206 | 207 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 208 | # 209 | # html_show_sphinx = True 210 | 211 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 212 | # 213 | # html_show_copyright = True 214 | 215 | # If true, an OpenSearch description file will be output, and all pages will 216 | # contain a tag referring to it. The value of this option must be the 217 | # base URL from which the finished HTML is served. 218 | # 219 | # html_use_opensearch = '' 220 | 221 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 222 | # html_file_suffix = None 223 | 224 | # Language to be used for generating the HTML full-text search index. 225 | # Sphinx supports the following languages: 226 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'h', 'it', 'ja' 227 | # 'nl', 'no', 'pt', 'ro', 'r', 'sv', 'tr', 'zh' 228 | # 229 | # html_search_language = 'en' 230 | 231 | # A dictionary with options for the search language support, empty by default. 232 | # 'ja' uses this config value. 233 | # 'zh' user can custom change `jieba` dictionary path. 234 | # 235 | # html_search_options = {'type': 'default'} 236 | 237 | # The name of a javascript file (relative to the configuration directory) that 238 | # implements a search results scorer. If empty, the default will be used. 239 | # 240 | # html_search_scorer = 'scorer.js' 241 | 242 | # Output file base name for HTML help builder. 243 | htmlhelp_basename = 'caffe-toolsdoc' 244 | 245 | # -- Options for LaTeX output --------------------------------------------- 246 | 247 | latex_elements = { 248 | # The paper size ('letterpaper' or 'a4paper'). 249 | # 250 | # 'papersize': 'letterpaper', 251 | 252 | # The font size ('10pt', '11pt' or '12pt'). 253 | # 254 | # 'pointsize': '10pt', 255 | 256 | # Additional stuff for the LaTeX preamble. 257 | # 258 | # 'preamble': '', 259 | 260 | # Latex figure (float) alignment 261 | # 262 | # 'figure_align': 'htbp', 263 | } 264 | 265 | # Grouping the document tree into LaTeX files. List of tuples 266 | # (source start file, target name, title, 267 | # author, documentclass [howto, manual, or own class]). 268 | latex_documents = [ 269 | (master_doc, 'caffe-tools.tex', 'caffe-tools Documentation', 270 | 'David Stutz', 'manual'), 271 | ] 272 | 273 | # The name of an image file (relative to this directory) to place at the top of 274 | # the title page. 275 | # 276 | # latex_logo = None 277 | 278 | # For "manual" documents, if this is true, then toplevel headings are parts, 279 | # not chapters. 280 | # 281 | # latex_use_parts = False 282 | 283 | # If true, show page references after internal links. 284 | # 285 | # latex_show_pagerefs = False 286 | 287 | # If true, show URL addresses after external links. 288 | # 289 | # latex_show_urls = False 290 | 291 | # Documents to append as an appendix to all manuals. 292 | # 293 | # latex_appendices = [] 294 | 295 | # It false, will not define \strong, \code, itleref, \crossref ... but only 296 | # \sphinxstrong, ..., \sphinxtitleref, ... To help avoid clash with user added 297 | # packages. 298 | # 299 | # latex_keep_old_macro_names = True 300 | 301 | # If false, no module index is generated. 302 | # 303 | # latex_domain_indices = True 304 | 305 | 306 | # -- Options for manual page output --------------------------------------- 307 | 308 | # One entry per manual page. List of tuples 309 | # (source start file, name, description, authors, manual section). 310 | man_pages = [ 311 | (master_doc, 'caffe-tools', 'caffe-tools Documentation', 312 | [author], 1) 313 | ] 314 | 315 | # If true, show URL addresses after external links. 316 | # 317 | # man_show_urls = False 318 | 319 | 320 | # -- Options for Texinfo output ------------------------------------------- 321 | 322 | # Grouping the document tree into Texinfo files. List of tuples 323 | # (source start file, target name, title, author, 324 | # dir menu entry, description, category) 325 | texinfo_documents = [ 326 | (master_doc, 'caffe-tools', 'caffe-tools Documentation', 327 | author, 'caffe-tools', 'One line description of project.', 328 | 'Miscellaneous'), 329 | ] 330 | 331 | # Documents to append as an appendix to all manuals. 332 | # 333 | # texinfo_appendices = [] 334 | 335 | # If false, no module index is generated. 336 | # 337 | # texinfo_domain_indices = True 338 | 339 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 340 | # 341 | # texinfo_show_urls = 'footnote' 342 | 343 | # If true, do not generate a @detailmenu in the "Top" node's menu. 344 | # 345 | # texinfo_no_detailmenu = False 346 | -------------------------------------------------------------------------------- /docs/source/examples.rst: -------------------------------------------------------------------------------- 1 | examples package 2 | ================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | examples.bsds500 module 8 | ----------------------- 9 | 10 | .. automodule:: examples.bsds500 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | examples.cifar10 module 16 | ----------------------- 17 | 18 | .. automodule:: examples.cifar10 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | examples.iris module 24 | -------------------- 25 | 26 | .. automodule:: examples.iris 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | examples.lmdb_io module 32 | ----------------------- 33 | 34 | .. automodule:: examples.lmdb_io 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | examples.mnist module 40 | --------------------- 41 | 42 | .. automodule:: examples.mnist 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | examples.visualization module 48 | ----------------------------- 49 | 50 | .. automodule:: examples.visualization 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | 56 | Module contents 57 | --------------- 58 | 59 | .. automodule:: examples 60 | :members: 61 | :undoc-members: 62 | :show-inheritance: 63 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. caffe-tools documentation master file, created by 2 | sphinx-quickstart on Mon Nov 28 20:26:36 2016. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to caffe-tools's documentation! 7 | ======================================= 8 | 9 | Contents: 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | 14 | 15 | 16 | Indices and tables 17 | ================== 18 | 19 | * :ref:`genindex` 20 | * :ref:`modindex` 21 | * :ref:`search` 22 | 23 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | caffe-tools 2 | =========== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | examples 8 | tests 9 | tools 10 | -------------------------------------------------------------------------------- /docs/source/tests.rst: -------------------------------------------------------------------------------- 1 | tests package 2 | ============= 3 | 4 | Submodules 5 | ---------- 6 | 7 | tests.lmdb_io module 8 | -------------------- 9 | 10 | .. automodule:: tests.lmdb_io 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | tests.pre_processing module 16 | --------------------------- 17 | 18 | .. automodule:: tests.pre_processing 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | tests.prototxt module 24 | --------------------- 25 | 26 | .. automodule:: tests.prototxt 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | tests.solvers module 32 | -------------------- 33 | 34 | .. automodule:: tests.solvers 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: tests 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /docs/source/tools.rst: -------------------------------------------------------------------------------- 1 | tools package 2 | ============= 3 | 4 | Submodules 5 | ---------- 6 | 7 | tools.data_augmentation module 8 | ------------------------------ 9 | 10 | .. automodule:: tools.data_augmentation 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | tools.layers module 16 | ------------------- 17 | 18 | .. automodule:: tools.layers 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | tools.lmdb_io module 24 | -------------------- 25 | 26 | .. automodule:: tools.lmdb_io 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | tools.pre_processing module 32 | --------------------------- 33 | 34 | .. automodule:: tools.pre_processing 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | tools.prototxt module 40 | --------------------- 41 | 42 | .. automodule:: tools.prototxt 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | tools.solvers module 48 | -------------------- 49 | 50 | .. automodule:: tools.solvers 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | tools.visualization module 56 | -------------------------- 57 | 58 | .. automodule:: tools.visualization 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | 64 | Module contents 65 | --------------- 66 | 67 | .. automodule:: tools 68 | :members: 69 | :undoc-members: 70 | :show-inheritance: 71 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Examples. 3 | """ -------------------------------------------------------------------------------- /examples/bsds500.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example for edge detection on BSDS500 [1] 3 | 4 | .. code-block:: python 5 | 6 | [1] P. Arbelaez, M. Maire, C. Fowlkes and J. Malik. 7 | Contour Detection and Hierarchical Image Segmentation. 8 | IEEE TPAMI, Vol. 33, No. 5, 2011. 9 | 10 | **Note: the LMDBs can also be found in the data repository, see README.** 11 | 12 | In order for the example to work, there are two options: Either download the 13 | BSDS500 dataset with CSV ground truths or directly download the corresponding 14 | LMDBs. You can find both in the resources section of the repository. 15 | 16 | In either case, the directory structure (after converting the datasets to 17 | LMDBs, if applicable) should look as follows: 18 | 19 | .. code-block:: python 20 | 21 | examples/bsds500 22 | |- csv_groundTruth/ 23 | |- test/ 24 | |- train/ 25 | |- val/ 26 | |- images/ 27 | |- test/ 28 | |- train/ 29 | |- val/ 30 | |- test_lmdb/ 31 | |- train_lmdb/ 32 | 33 | .. argparse:: 34 | :ref: examples.bsds500.get_parser 35 | :prog: bsds500 36 | """ 37 | 38 | import os 39 | import cv2 40 | import csv 41 | import glob 42 | import numpy 43 | import random 44 | import argparse 45 | 46 | # To silence Caffe! Must be added before importing Caffe or modules which 47 | # are importing Caffe. 48 | os.environ['GLOG_minloglevel'] = '3' 49 | import caffe 50 | import tools.solvers 51 | import tools.lmdb_io 52 | import tools.prototxt 53 | import tools.pre_processing 54 | 55 | caffe.set_mode_gpu() 56 | 57 | def get_parser(): 58 | """ 59 | Get the parser. 60 | 61 | :return: parser 62 | :rtype: argparse.ArgumentParser 63 | """ 64 | 65 | parser = argparse.ArgumentParser(description = 'Deep learning for edge detection on BSDS500.') 66 | parser.add_argument('mode', default = 'convert', 67 | help = 'Mode to run: "extract", "subsample_test" or "train"') 68 | parser.add_argument('--working_directory', default = 'examples/bsds500', type = str, 69 | help = 'path to the working directory, see documentation of this example') 70 | parser.add_argument('--train_lmdb', default = 'examples/bsds500/train_lmdb', type = str, 71 | help = 'path to train LMDB') 72 | parser.add_argument('--test_lmdb', default = 'examples/bsds500/test_lmdb', type = str, 73 | help = 'path to test LMDB') 74 | parser.add_argument('--iterations', dest = 'iterations', type = int, 75 | help = 'number of iterations to train or resume', 76 | default = 10000) 77 | 78 | return parser 79 | 80 | def csv_read(csv_file, delimiter = ','): 81 | """ 82 | Read a CSV file into a numpy.ndarray assuming that each row has the same 83 | number as columns. 84 | 85 | :param csv_file: path to CSV file 86 | :type csv_file: string 87 | :param delimiter: delimiter between cells 88 | :type delimiter: string 89 | :return: CSV contents as Numpy array as float 90 | :rtype: numpy.ndarray 91 | """ 92 | 93 | cols = -1 94 | array = [] 95 | 96 | with open(csv_file) as f: 97 | for cells in csv.reader(f, delimiter = delimiter): 98 | cells = [cell.strip() for cell in cells if len(cell.strip()) > 0] 99 | 100 | if len(cells) > 0: 101 | if cols < 0: 102 | cols = len(cells) 103 | 104 | assert cols == len(cells), "CSV file does not contain a consistent number of columns" 105 | 106 | cells = [float(cell) for cell in cells] 107 | array.append(cells) 108 | 109 | return numpy.array(array) 110 | 111 | def main_extract(): 112 | """ 113 | Extracts train and test samples from the train and test images and ground truth 114 | in bsds500/csv_groundTruth and bsds500/images. For each positive edge pixels, 115 | a quadratic patch is extracted. For non-edge pixels, all patches are subsampled 116 | by only taking 20% of the patches. 117 | 118 | It might be beneficial to also run :func:`examples.bsds500.main_subsample_test` 119 | on the extracted test LMDB for efficient testing during training. 120 | """ 121 | 122 | def extract(directory, lmdb_path): 123 | assert not os.path.exists(lmdb_path), "%s already exists" % lmdb_path 124 | 125 | segmentation_files = [filename for filename in os.listdir(args.working_directory + '/csv_groundTruth/' + directory) if filename[-4:] == '.csv'] 126 | 127 | lmdb_path = args.working_directory + '/' + directory + '_lmdb' 128 | lmdb = tools.lmdb_io.LMDB(lmdb_path) 129 | 130 | s = 1 131 | for segmentation_file in segmentation_files: 132 | image_file = args.working_directory + '/images/' + directory + '/' + segmentation_file[:-6] + '.jpg' 133 | image = cv2.imread(image_file) 134 | segmentation = csv_read(args.working_directory + '/csv_groundTruth/' + directory + '/' + segmentation_file) 135 | 136 | inner = segmentation[1:segmentation.shape[0] - 2, 1:segmentation.shape[1] - 2] 137 | inner_top = segmentation[0:segmentation.shape[0] - 3, 1:segmentation.shape[1] - 2] 138 | inner_left = segmentation[1:segmentation.shape[0] - 2, 0:segmentation.shape[1] - 3] 139 | 140 | segmentation[1:segmentation.shape[0] - 2, 1:segmentation.shape[1] - 2] = numpy.abs(inner - inner_top) + numpy.abs(inner - inner_left) 141 | 142 | segmentation[:, :2] = 0 143 | segmentation[:, segmentation.shape[1] - 3:] = 0 144 | segmentation[:2, :] = 0 145 | segmentation[segmentation.shape[0] - 3:, :] = 0 146 | 147 | segmentation[segmentation > 0] = 1 148 | 149 | images = [] 150 | labels = [] 151 | 152 | k = 3 153 | n = 0 154 | for i in range(k, segmentation.shape[0] - k): 155 | for j in range(k, segmentation.shape[1] - k): 156 | 157 | r = random.random() 158 | patch = image[i - k:i + k + 1, j - k:j + k + 1, :] 159 | 160 | if segmentation[i, j] > 0: 161 | images.append(patch) 162 | labels.append(1) 163 | elif r > 0.8: 164 | images.append(patch) 165 | labels.append(0) 166 | 167 | n += 1 168 | 169 | lmdb.write(images, labels) 170 | print(str(s) + '/' + str(len(segmentation_files))) 171 | s += 1 172 | 173 | extract('train', args.train_lmdb) 174 | extract('val', args.test_lmdb) 175 | 176 | def main_subsample_test(): 177 | """ 178 | Subsample the test LMDB by only taking 5% of the samples. The original test 179 | LMDB is renamed by appending '_full' and a newtest is created having the same 180 | name as the original one. 181 | """ 182 | 183 | test_in_lmdb = args.test_lmdb + '_full' 184 | test_out_lmdb = args.test_lmdb 185 | 186 | assert os.path.exists(test_out_lmdb), "LMDB %s not found" % test_out_lmdb 187 | os.rename(test_out_lmdb, test_in_lmdb) 188 | 189 | pp_in = tools.pre_processing.PreProcessingInputLMDB(test_in_lmdb) 190 | pp_out = tools.pre_processing.PreProcessingOutputLMDB(test_out_lmdb) 191 | pp = tools.pre_processing.PreProcessingSubsample(pp_in, pp_out, 0.05) 192 | pp.run() 193 | 194 | def main_train(): 195 | """ 196 | After running :func:`examples.bsds500.main_train`, a network can be trained. 197 | """ 198 | 199 | def network(lmdb_path, batch_size): 200 | """ 201 | The network definition given the LMDB path and the used batch size. 202 | 203 | :param lmdb_path: path to LMDB to use (train or test LMDB) 204 | :type lmdb_path: string 205 | :param batch_size: batch size to use 206 | :type batch_size: int 207 | :return: the network definition as string to write to the prototxt file 208 | :rtype: string 209 | """ 210 | 211 | net = caffe.NetSpec() 212 | 213 | net.data, net.labels = caffe.layers.Data(batch_size = batch_size, 214 | backend = caffe.params.Data.LMDB, 215 | source = lmdb_path, 216 | transform_param = dict(scale = 1./255), 217 | ntop = 2) 218 | 219 | net.conv1 = caffe.layers.Convolution(net.data, kernel_size = 3, num_output = 7, 220 | weight_filler = dict(type = 'xavier')) 221 | net.bn1 = caffe.layers.BatchNorm(net.conv1) 222 | net.relu1 = caffe.layers.ReLU(net.bn1, in_place = True) 223 | net.conv2 = caffe.layers.Convolution(net.relu1, kernel_size = 3, num_output = 21, 224 | weight_filler = dict(type = 'xavier')) 225 | net.bn2 = caffe.layers.BatchNorm(net.conv2) 226 | net.relu2 = caffe.layers.ReLU(net.bn2, in_place = True) 227 | net.conv3 = caffe.layers.Convolution(net.relu2, kernel_size = 3, num_output = 7, 228 | weight_filler = dict(type = 'xavier')) 229 | net.bn3 = caffe.layers.BatchNorm(net.conv3) 230 | net.relu3 = caffe.layers.ReLU(net.bn3, in_place = True) 231 | net.score = caffe.layers.InnerProduct(net.relu3, num_output = 1, 232 | weight_filler = dict(type = 'xavier')) 233 | net.loss = caffe.layers.SigmoidCrossEntropyLoss(net.score, net.labels) 234 | 235 | return net.to_proto() 236 | 237 | def count_errors(scores, labels): 238 | """ 239 | Utility method to count the errors given the ouput of the 240 | "score" layer and the labels. 241 | 242 | :param score: output of score layer 243 | :type score: numpy.ndarray 244 | :param labels: labels 245 | :type labels: numpy.ndarray 246 | :return: count of errors 247 | :rtype: int 248 | """ 249 | 250 | return numpy.sum(numpy.argmax(scores, axis = 1) != labels) 251 | 252 | assert os.path.exists(args.train_lmdb), "LMDB %s does not exist" % args.train_lmdb 253 | assert os.path.exists(args.test_lmdb), "LMDB %s does not exist" % args.test_lmdb 254 | 255 | train_prototxt_path = args.working_directory + '/train.prototxt' 256 | test_prototxt_path = args.working_directory + '/test.prototxt' 257 | deploy_prototxt_path = args.working_directory + '/deploy.prototxt' 258 | 259 | with open(train_prototxt_path, 'w') as f: 260 | f.write(str(network(args.train_lmdb, 1024))) 261 | 262 | with open(test_prototxt_path, 'w') as f: 263 | f.write(str(network(args.test_lmdb, 5000))) 264 | 265 | tools.prototxt.train2deploy(train_prototxt_path, (1, 3, 7, 7), deploy_prototxt_path) 266 | 267 | 268 | prototxt_solver = args.working_directory + '/solver.prototxt' 269 | solver_prototxt = tools.solvers.SolverProtoTXT({ 270 | 'train_net': train_prototxt_path, 271 | 'test_net': test_prototxt_path, 272 | 'test_initialization': 'false', # no testing 273 | 'test_iter': 0, # no testing 274 | 'test_interval': 100000, 275 | 'base_lr': 0.001, 276 | 'lr_policy': 'step', 277 | 'gamma': 0.01, 278 | 'stepsize': 1000, 279 | 'display': 100, 280 | 'max_iter': 1000, 281 | 'momentum': 0.95, 282 | 'weight_decay': 0.0005, 283 | 'snapshot': 0, # only at the end 284 | 'snapshot_prefix': args.working_directory + '/snapshot', 285 | 'solver_mode': 'CPU' 286 | }) 287 | 288 | solver_prototxt.write(prototxt_solver) 289 | solver = caffe.SGDSolver(prototxt_solver) 290 | callbacks = [] 291 | 292 | # Callback to report loss in console. Also automatically plots the loss 293 | # and writes it to the given file. In order to silence the console, 294 | # use plot_loss instead of report_loss. 295 | report_loss = tools.solvers.PlotLossCallback(100, args.working_directory + '/loss.png') 296 | callbacks.append({ 297 | 'callback': tools.solvers.PlotLossCallback.report_loss, 298 | 'object': report_loss, 299 | 'interval': 1, 300 | }) 301 | 302 | # Callback to report error in console. 303 | report_error = tools.solvers.PlotErrorCallback(count_errors, 60000, 10000, 304 | solver_prototxt.get_parameters()['snapshot_prefix'], 305 | args.working_directory + '/error.png') 306 | callbacks.append({ 307 | 'callback': tools.solvers.PlotErrorCallback.report_error, 308 | 'object': report_error, 309 | 'interval': 500, 310 | }) 311 | 312 | # Callback to save an "early stopping" model. 313 | callbacks.append({ 314 | 'callback': tools.solvers.PlotErrorCallback.stop_early, 315 | 'object': report_error, 316 | 'interval': 500, 317 | }) 318 | 319 | # Callback for reporting the gradients for all layers in the console. 320 | report_gradient = tools.solvers.PlotGradientCallback(100, args.working_directory + '/gradient.png') 321 | callbacks.append({ 322 | 'callback': tools.solvers.PlotGradientCallback.report_gradient, 323 | 'object': report_gradient, 324 | 'interval': 1, 325 | }) 326 | 327 | # Callback for saving regular snapshots using the snapshot_prefix in the 328 | # solver prototxt file. 329 | # Is added after the "early stopping" callback to avoid problems. 330 | callbacks.append({ 331 | 'callback': tools.solvers.SnapshotCallback.write_snapshot, 332 | 'object': tools.solvers.SnapshotCallback(), 333 | 'interval': 500, 334 | }) 335 | 336 | monitoring_solver = tools.solvers.MonitoringSolver(solver) 337 | monitoring_solver.register_callback(callbacks) 338 | monitoring_solver.solve(args.iterations) 339 | 340 | def main_resume(): 341 | """ 342 | Resume training; assumes training has been started using :func:`examples.bsds500.main_train`. 343 | """ 344 | 345 | def count_errors(scores, labels): 346 | """ 347 | Utility method to count the errors given the ouput of the 348 | "score" layer and the labels. 349 | 350 | :param score: output of score layer 351 | :type score: numpy.ndarray 352 | :param labels: labels 353 | :type labels: numpy.ndarray 354 | :return: count of errors 355 | :rtype: int 356 | """ 357 | 358 | return numpy.sum(numpy.argmax(scores, axis = 1) != labels) 359 | 360 | max_iteration = 0 361 | files = glob.glob(args.working_directory + '/*.solverstate') 362 | 363 | for filename in files: 364 | filenames = filename.split('_') 365 | iteration = filenames[-1][:-12] 366 | 367 | try: 368 | iteration = int(iteration) 369 | if iteration > max_iteration: 370 | max_iteration = iteration 371 | except: 372 | pass 373 | 374 | caffemodel = args.working_directory + '/snapshot_iter_' + str(max_iteration) + '.caffemodel' 375 | solverstate = args.working_directory + '/snapshot_iter_' + str(max_iteration) + '.solverstate' 376 | 377 | train_prototxt_path = args.working_directory + '/train.prototxt' 378 | test_prototxt_path = args.working_directory + '/test.prototxt' 379 | deploy_prototxt_path = args.working_directory + '/deploy.prototxt' 380 | solver_prototxt_path = args.working_directory + '/solver.prototxt' 381 | 382 | assert max_iteration > 0, "could not find a solverstate or snaphot file to resume" 383 | assert os.path.exists(caffemodel), "caffemodel %s not found" % caffemodel 384 | assert os.path.exists(solverstate), "solverstate %s not found" % solverstate 385 | assert os.path.exists(train_prototxt_path), "prototxt %s not found" % train_prototxt_path 386 | assert os.path.exists(test_prototxt_path), "prototxt %s not found" % test_prototxt_path 387 | assert os.path.exists(deploy_prototxt_path), "prototxt %s not found" % deploy_prototxt_path 388 | assert os.path.exists(solver_prototxt_path), "prototxt %s not found" % solver_prototxt_path 389 | 390 | solver = caffe.SGDSolver(solver_prototxt_path) 391 | solver.restore(solverstate) 392 | 393 | solver.net.copy_from(caffemodel) 394 | 395 | solver_prototxt = tools.solvers.SolverProtoTXT() 396 | solver_prototxt.read(solver_prototxt_path) 397 | callbacks = [] 398 | 399 | # Callback to report loss in console. 400 | report_loss = tools.solvers.PlotLossCallback(100, args.working_directory + '/loss.png') 401 | callbacks.append({ 402 | 'callback': tools.solvers.PlotLossCallback.report_loss, 403 | 'object': report_loss, 404 | 'interval': 1, 405 | }) 406 | 407 | # Callback to report error in console. 408 | report_error = tools.solvers.PlotErrorCallback(count_errors, 60000, 10000, 409 | solver_prototxt.get_parameters()['snapshot_prefix'], 410 | args.working_directory + '/error.png') 411 | callbacks.append({ 412 | 'callback': tools.solvers.PlotErrorCallback.report_error, 413 | 'object': report_error, 414 | 'interval': 500, 415 | }) 416 | 417 | # Callback to save an "early stopping" model. 418 | callbacks.append({ 419 | 'callback': tools.solvers.PlotErrorCallback.stop_early, 420 | 'object': report_error, 421 | 'interval': 500, 422 | }) 423 | 424 | # Callback for reporting the gradients for all layers in the console. 425 | report_gradient = tools.solvers.PlotGradientCallback(100, args.working_directory + '/gradient.png') 426 | callbacks.append({ 427 | 'callback': tools.solvers.PlotGradientCallback.report_gradient, 428 | 'object': report_gradient, 429 | 'interval': 1, 430 | }) 431 | 432 | # Callback for saving regular snapshots using the snapshot_prefix in the 433 | # solver prototxt file. 434 | # Is added after the "early stopping" callback to avoid problems. 435 | callbacks.append({ 436 | 'callback': tools.solvers.SnapshotCallback.write_snapshot, 437 | 'object': tools.solvers.SnapshotCallback(), 438 | 'interval': 500, 439 | }) 440 | 441 | monitoring_solver = tools.solvers.MonitoringSolver(solver, max_iteration) 442 | monitoring_solver.register_callback(callbacks) 443 | monitoring_solver.solve(args.iterations) 444 | 445 | def main_detect(): 446 | """ 447 | Detect edges on a given image, after training a network using :func:`examples.bsds500.main_train`. 448 | """ 449 | 450 | pass 451 | 452 | if __name__ == '__main__': 453 | parser = get_parser() 454 | args = parser.parse_args() 455 | 456 | if args.mode == 'extract': 457 | main_extract() 458 | if args.mode == 'subsample_test': 459 | main_subsample_test() 460 | elif args.mode == 'train': 461 | main_train() 462 | elif args.mode =='resume': 463 | main_resume() 464 | else: 465 | print('Invalid mode.') -------------------------------------------------------------------------------- /examples/bsds500/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/examples/bsds500/.gitignore -------------------------------------------------------------------------------- /examples/cifar10.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example for classification on Cifar10 [1] 3 | 4 | .. code-block:: python 5 | 6 | [1] A. Krizhevsky. 7 | Learning Multiple Layers of Features from Tiny Images. 8 | 2009. 9 | 10 | **Note: the LMDBs can also be found in the data repository, see README.** 11 | 12 | Use ``caffe/data/cifar10/get_cifar10.sh`` to download Cifar10 and 13 | ``caffe/examples/create_cifar10.sh`` to create the corresponding LMDBs. 14 | Copy them over into ``examples/cifar10`` for the following data structure: 15 | 16 | .. code-block:: python 17 | 18 | examples/cifar10 19 | |- train_lmdb/ 20 | |- test_lmdb/ 21 | 22 | .. argparse:: 23 | :ref: examples.cifar10.get_parser 24 | :prog: cifar10 25 | """ 26 | 27 | import os 28 | import cv2 29 | import glob 30 | import numpy 31 | import argparse 32 | from matplotlib import pyplot 33 | 34 | # To silence Caffe! Must be added before importing Caffe or modules which 35 | # are importing Caffe. 36 | os.environ['GLOG_minloglevel'] = '3' 37 | import caffe 38 | import tools.solvers 39 | import tools.lmdb_io 40 | import tools.pre_processing 41 | import tools.prototxt 42 | 43 | caffe.set_mode_gpu() 44 | 45 | def get_parser(): 46 | """ 47 | Get the parser. 48 | 49 | :return: parser 50 | :rtype: argparse.ArgumentParser 51 | """ 52 | 53 | parser = argparse.ArgumentParser(description = 'Caffe example on Cifar-10.') 54 | parser.add_argument('mode', default = 'train', 55 | help = 'mode to run: "train" or "resume"') 56 | parser.add_argument('--train_lmdb', dest = 'train_lmdb', type = str, 57 | help = 'path to train LMDB', 58 | default = 'examples/cifar10/train_lmdb') 59 | parser.add_argument('--test_lmdb', dest = 'test_lmdb', type = str, 60 | help = 'path to test LMDB', 61 | default = 'examples/cifar10/test_lmdb') 62 | parser.add_argument('--working_directory', dest = 'working_directory', type = str, 63 | help = 'path to a directory (created if not existent) where to store the created .prototxt and snapshot files', 64 | default = 'examples/cifar10') 65 | parser.add_argument('--iterations', dest = 'iterations', type = int, 66 | help = 'number of iterations to train or resume', 67 | default = 10000) 68 | parser.add_argument('--image', dest = 'image', type = str, 69 | help = 'path to image for testing', 70 | default = 'examples/cifar10/test_dog.png') 71 | return parser 72 | 73 | def main_train(): 74 | """ 75 | Train a network on Cifar10 on scratch. 76 | """ 77 | 78 | def network(lmdb_path, batch_size): 79 | """ 80 | The network definition given the LMDB path and the used batch size. 81 | 82 | :param lmdb_path: path to LMDB to use (train or test LMDB) 83 | :type lmdb_path: string 84 | :param batch_size: batch size to use 85 | :type batch_size: int 86 | :return: the network definition as string to write to the prototxt file 87 | :rtype: string 88 | """ 89 | 90 | net = caffe.NetSpec() 91 | 92 | net.data, net.labels = caffe.layers.Data(batch_size = batch_size, 93 | backend = caffe.params.Data.LMDB, 94 | source = lmdb_path, 95 | transform_param = dict(scale = 1./255), 96 | ntop = 2) 97 | 98 | net.conv1 = caffe.layers.Convolution(net.data, kernel_size = 5, num_output = 32, pad = 2, 99 | stride = 1, weight_filler = dict(type = 'xavier')) 100 | net.pool1 = caffe.layers.Pooling(net.conv1, kernel_size = 3, stride = 2, 101 | pool = caffe.params.Pooling.MAX) 102 | net.relu1 = caffe.layers.ReLU(net.pool1, in_place = True) 103 | net.norm1 = caffe.layers.LRN(net.relu1, local_size = 3, alpha = 5e-05, 104 | beta = 0.75, norm_region = caffe.params.LRN.WITHIN_CHANNEL) 105 | net.conv2 = caffe.layers.Convolution(net.relu1, kernel_size = 5, num_output = 32, pad = 2, 106 | stride = 1, weight_filler = dict(type = 'xavier')) 107 | net.relu2 = caffe.layers.ReLU(net.conv2, in_place = True) 108 | net.pool2 = caffe.layers.Pooling(net.relu2, kernel_size = 3, stride = 2, 109 | pool = caffe.params.Pooling.AVE) 110 | net.norm2 = caffe.layers.LRN(net.pool2, local_size = 3, alpha = 5e-05, beta = 0.75, 111 | norm_region = caffe.params.LRN.WITHIN_CHANNEL) 112 | net.conv3 = caffe.layers.Convolution(net.norm2, kernel_size = 5, num_output = 64, pad = 2, 113 | stride = 1, weight_filler = dict(type = 'xavier')) 114 | net.relu3 = caffe.layers.ReLU(net.conv3, in_place = True) 115 | net.pool3 = caffe.layers.Pooling(net.relu3, kernel_size = 3, stride = 2, 116 | pool = caffe.params.Pooling.AVE) 117 | net.score = caffe.layers.InnerProduct(net.pool3, num_output = 10) 118 | net.loss = caffe.layers.SoftmaxWithLoss(net.score, net.labels) 119 | 120 | return net.to_proto() 121 | 122 | def count_errors(scores, labels): 123 | """ 124 | Utility method to count the errors given the ouput of the 125 | "score" layer and the labels. 126 | 127 | :param score: output of score layer 128 | :type score: numpy.ndarray 129 | :param labels: labels 130 | :type labels: numpy.ndarray 131 | :return: count of errors 132 | :rtype: int 133 | """ 134 | 135 | return numpy.sum(numpy.argmax(scores, axis = 1) != labels) 136 | 137 | train_prototxt_path = args.working_directory + '/train.prototxt' 138 | test_prototxt_path = args.working_directory + '/test.prototxt' 139 | deploy_prototxt_path = args.working_directory + '/deploy.prototxt' 140 | 141 | with open(train_prototxt_path, 'w') as f: 142 | f.write(str(network(args.train_lmdb, 128))) 143 | 144 | with open(test_prototxt_path, 'w') as f: 145 | f.write(str(network(args.test_lmdb, 1000))) 146 | 147 | tools.prototxt.train2deploy(train_prototxt_path, (1, 3, 32, 32), deploy_prototxt_path) 148 | 149 | solver_prototxt_path = args.working_directory + '/solver.prototxt' 150 | solver_prototxt = tools.solvers.SolverProtoTXT({ 151 | 'train_net': train_prototxt_path, 152 | 'test_net': test_prototxt_path, 153 | 'test_initialization': 'false', # no testing 154 | 'test_iter': 0, # no testing 155 | 'test_interval': 1000, 156 | 'base_lr': 0.01, 157 | 'lr_policy': 'inv', 158 | 'gamma': 0.0001, 159 | 'power': 0.75, 160 | 'stepsize': 1000, 161 | 'display': 100, 162 | 'max_iter': 1000, 163 | 'momentum': 0.95, 164 | 'weight_decay': 0.0005, 165 | 'snapshot': 0, # only at the end 166 | 'snapshot_prefix': args.working_directory + '/snapshot', 167 | 'solver_mode': 'CPU' 168 | }) 169 | 170 | solver_prototxt.write(solver_prototxt_path) 171 | solver = caffe.SGDSolver(solver_prototxt_path) 172 | callbacks = [] 173 | 174 | # Callback to report loss in console. Also automatically plots the loss 175 | # and writes it to the given file. In order to silence the console, 176 | # use plot_loss instead of report_loss. 177 | report_loss = tools.solvers.PlotLossCallback(100, args.working_directory + '/loss.png') 178 | callbacks.append({ 179 | 'callback': tools.solvers.PlotLossCallback.report_loss, 180 | 'object': report_loss, 181 | 'interval': 1, 182 | }) 183 | 184 | # Callback to report error in console. 185 | report_error = tools.solvers.PlotErrorCallback(count_errors, 60000, 10000, 186 | solver_prototxt.get_parameters()['snapshot_prefix'], 187 | args.working_directory + '/error.png') 188 | callbacks.append({ 189 | 'callback': tools.solvers.PlotErrorCallback.report_error, 190 | 'object': report_error, 191 | 'interval': 500, 192 | }) 193 | 194 | # Callback to save an "early stopping" model. 195 | callbacks.append({ 196 | 'callback': tools.solvers.PlotErrorCallback.stop_early, 197 | 'object': report_error, 198 | 'interval': 500, 199 | }) 200 | 201 | # Callback for reporting the gradients for all layers in the console. 202 | report_gradient = tools.solvers.PlotGradientCallback(100, args.working_directory + '/gradient.png') 203 | callbacks.append({ 204 | 'callback': tools.solvers.PlotGradientCallback.report_gradient, 205 | 'object': report_gradient, 206 | 'interval': 1, 207 | }) 208 | 209 | # Callback for saving regular snapshots using the snapshot_prefix in the 210 | # solver prototxt file. 211 | # Is added after the "early stopping" callback to avoid problems. 212 | callbacks.append({ 213 | 'callback': tools.solvers.SnapshotCallback.write_snapshot, 214 | 'object': tools.solvers.SnapshotCallback(), 215 | 'interval': 500, 216 | }) 217 | 218 | monitoring_solver = tools.solvers.MonitoringSolver(solver) 219 | monitoring_solver.register_callback(callbacks) 220 | monitoring_solver.solve(args.iterations) 221 | 222 | def main_resume(): 223 | """ 224 | Resume training; assumes training has been started using :func:`examples.cifar10.main_train`. 225 | """ 226 | 227 | def count_errors(scores, labels): 228 | """ 229 | Utility method to count the errors given the ouput of the 230 | "score" layer and the labels. 231 | 232 | :param score: output of score layer 233 | :type score: numpy.ndarray 234 | :param labels: labels 235 | :type labels: numpy.ndarray 236 | :return: count of errors 237 | :rtype: int 238 | """ 239 | 240 | return numpy.sum(numpy.argmax(scores, axis = 1) != labels) 241 | 242 | max_iteration = 0 243 | files = glob.glob(args.working_directory + '/*.solverstate') 244 | 245 | for filename in files: 246 | filenames = filename.split('_') 247 | iteration = filenames[-1][:-12] 248 | 249 | try: 250 | iteration = int(iteration) 251 | if iteration > max_iteration: 252 | max_iteration = iteration 253 | except: 254 | pass 255 | 256 | caffemodel = args.working_directory + '/snapshot_iter_' + str(max_iteration) + '.caffemodel' 257 | solverstate = args.working_directory + '/snapshot_iter_' + str(max_iteration) + '.solverstate' 258 | 259 | train_prototxt_path = args.working_directory + '/train.prototxt' 260 | test_prototxt_path = args.working_directory + '/test.prototxt' 261 | deploy_prototxt_path = args.working_directory + '/deploy.prototxt' 262 | solver_prototxt_path = args.working_directory + '/solver.prototxt' 263 | 264 | assert max_iteration > 0, "could not find a solverstate or snaphot file to resume" 265 | assert os.path.exists(caffemodel), "caffemodel %s not found" % caffemodel 266 | assert os.path.exists(solverstate), "solverstate %s not found" % solverstate 267 | assert os.path.exists(train_prototxt_path), "prototxt %s not found" % train_prototxt_path 268 | assert os.path.exists(test_prototxt_path), "prototxt %s not found" % test_prototxt_path 269 | assert os.path.exists(deploy_prototxt_path), "prototxt %s not found" % deploy_prototxt_path 270 | assert os.path.exists(solver_prototxt_path), "prototxt %s not found" % solver_prototxt_path 271 | 272 | solver = caffe.SGDSolver(solver_prototxt_path) 273 | solver.restore(solverstate) 274 | 275 | solver.net.copy_from(caffemodel) 276 | 277 | solver_prototxt = tools.solvers.SolverProtoTXT() 278 | solver_prototxt.read(solver_prototxt_path) 279 | callbacks = [] 280 | 281 | # Callback to report loss in console. 282 | report_loss = tools.solvers.PlotLossCallback(100, args.working_directory + '/loss.png') 283 | callbacks.append({ 284 | 'callback': tools.solvers.PlotLossCallback.report_loss, 285 | 'object': report_loss, 286 | 'interval': 1, 287 | }) 288 | 289 | # Callback to report error in console. 290 | report_error = tools.solvers.PlotErrorCallback(count_errors, 60000, 10000, 291 | solver_prototxt.get_parameters()['snapshot_prefix'], 292 | args.working_directory + '/error.png') 293 | callbacks.append({ 294 | 'callback': tools.solvers.PlotErrorCallback.report_error, 295 | 'object': report_error, 296 | 'interval': 500, 297 | }) 298 | 299 | # Callback to save an "early stopping" model. 300 | callbacks.append({ 301 | 'callback': tools.solvers.PlotErrorCallback.stop_early, 302 | 'object': report_error, 303 | 'interval': 500, 304 | }) 305 | 306 | # Callback for reporting the gradients for all layers in the console. 307 | report_gradient = tools.solvers.PlotGradientCallback(100, args.working_directory + '/gradient.png') 308 | callbacks.append({ 309 | 'callback': tools.solvers.PlotGradientCallback.report_gradient, 310 | 'object': report_gradient, 311 | 'interval': 1, 312 | }) 313 | 314 | # Callback for saving regular snapshots using the snapshot_prefix in the 315 | # solver prototxt file. 316 | # Is added after the "early stopping" callback to avoid problems. 317 | callbacks.append({ 318 | 'callback': tools.solvers.SnapshotCallback.write_snapshot, 319 | 'object': tools.solvers.SnapshotCallback(), 320 | 'interval': 500, 321 | }) 322 | 323 | monitoring_solver = tools.solvers.MonitoringSolver(solver, max_iteration) 324 | monitoring_solver.register_callback(callbacks) 325 | monitoring_solver.solve(args.iterations) 326 | 327 | def main_test(): 328 | """ 329 | Test the latest model obtained by :func:`examples.cifar10.main_train` 330 | or :func:`examples.cifar10.main_resume` on the given input image. 331 | """ 332 | 333 | max_iteration = 0 334 | files = glob.glob(args.working_directory + '/*.solverstate') 335 | 336 | for filename in files: 337 | filenames = filename.split('_') 338 | iteration = filenames[-1][:-12] 339 | 340 | try: 341 | iteration = int(iteration) 342 | if iteration > max_iteration: 343 | max_iteration = iteration 344 | except: 345 | pass 346 | 347 | caffemodel = args.working_directory + '/snapshot_iter_' + str(max_iteration) + '.caffemodel' 348 | deploy_prototxt_path = args.working_directory + '/deploy.prototxt' 349 | 350 | assert max_iteration > 0, "could not find a solverstate or snaphot file to resume" 351 | assert os.path.exists(caffemodel), "caffemodel %s not found" % caffemodel 352 | assert os.path.exists(deploy_prototxt_path), "prototxt %s not found" % deploy_prototxt_path 353 | 354 | net = caffe.Net(deploy_prototxt_path, caffemodel, caffe.TEST) 355 | transformer = caffe.io.Transformer({'data': (1, 3, 32, 32)}) 356 | transformer.set_transpose('data', (2, 0, 1)) 357 | transformer.set_raw_scale('data', 1/255.) 358 | 359 | assert os.path.exists(args.image), "image %s not found" % args.image 360 | image = cv2.imread(args.image) 361 | cv2.imshow('image', image) 362 | 363 | net.blobs['data'].reshape(1, 3, 32, 32) 364 | net.blobs['data'].data[...] = transformer.preprocess('data', image) 365 | 366 | net.forward() 367 | scores = net.blobs['score'].data 368 | 369 | x = range(10) 370 | classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 371 | 372 | pyplot.bar(x, scores[0, :], 1/1.5, color = 'blue') 373 | pyplot.xticks(x, classes, rotation = 90) 374 | pyplot.gcf().subplots_adjust(bottom = 0.2) 375 | pyplot.show() 376 | 377 | if __name__ == '__main__': 378 | parser = get_parser() 379 | args = parser.parse_args() 380 | 381 | if args.mode == 'train': 382 | main_train() 383 | elif args.mode == 'resume': 384 | main_resume() 385 | elif args.mode == 'test': 386 | main_test() 387 | else: 388 | print('Invalid mode') -------------------------------------------------------------------------------- /examples/cifar10/test_dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/examples/cifar10/test_dog.png -------------------------------------------------------------------------------- /examples/iris.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example for classification on Iris. 3 | 4 | **Note: the LMDBs can also be found in the data repository, see README.** 5 | 6 | To acquire the dataset, follow http://archive.ics.uci.edu/ml/datasets/Iris. 7 | The downloaded dataset should be saved in ``examples/iris/iris.data.txt``. 8 | :func:`examples.iris.main_convert` will then convert data to LMDBs to obtain the 9 | following data structure: 10 | 11 | .. code-block:: python 12 | 13 | examples/iris/ 14 | |- test_lmdb 15 | |- train_lmdb 16 | |- iris.data.txt 17 | 18 | .. argparse:: 19 | :ref: examples.iris.get_parser 20 | :prog: iris 21 | """ 22 | 23 | import os 24 | import numpy 25 | import shutil 26 | import argparse 27 | 28 | # To silence Caffe! Must be added before importing Caffe or modules which 29 | # are importing Caffe. 30 | os.environ['GLOG_minloglevel'] = '3' 31 | import caffe 32 | import tools.solvers 33 | import tools.lmdb_io 34 | import tools.pre_processing 35 | 36 | def get_parser(): 37 | """ 38 | Get the parser. 39 | 40 | :return: parser 41 | :rtype: argparse.ArgumentParser 42 | """ 43 | 44 | parser = argparse.ArgumentParser(description = 'Deep learning for Iris.') 45 | parser.add_argument('mode', default = 'convert') 46 | parser.add_argument('--file', default = 'examples/iris/iris.data.txt', type = str, 47 | help = 'path to the iris data file') 48 | parser.add_argument('--split', default = 0.8, type = float, 49 | help = 'fraction of samples to use for taining') 50 | parser.add_argument('--train_lmdb', default = 'examples/iris/train_lmdb', type = str, 51 | help = 'path to train LMDB') 52 | parser.add_argument('--test_lmdb', default = 'examples/iris/test_lmdb', type = str, 53 | help = 'path to test LMDB') 54 | parser.add_argument('--working_directory', dest = 'working_directory', type = str, 55 | help = 'path to a directory (created if not existent) where to store the created .prototxt and snapshot files', 56 | default = 'examples/mnist') 57 | 58 | return parser 59 | 60 | def main_convert(): 61 | """ 62 | Convert the Iris dataset to LMDB. 63 | """ 64 | 65 | lmdb_converted = args.working_directory + '/lmdb_converted' 66 | lmdb_shuffled = args.working_directory + '/lmdb_shuffled' 67 | 68 | if os.path.exists(lmdb_converted): 69 | shutil.rmtree(lmdb_converted) 70 | if os.path.exists(lmdb_shuffled): 71 | shutil.rmtree(lmdb_shuffled) 72 | 73 | assert os.path.exists(args.file), "file %s could not be found" % args.file 74 | assert not os.path.exists(args.train_lmdb), "LMDB %s already exists" % args.train_lmdb 75 | assert not os.path.exists(args.test_lmdb), "LMDB %s already exists" % args.test_lmdb 76 | 77 | pp_in = tools.pre_processing.PreProcessingInputCSV(args.file, delimiter = ',', 78 | label_column = 4, 79 | label_column_mapping = { 80 | 'Iris-setosa': 0, 81 | 'Iris-versicolor': 1, 82 | 'Iris-virginica': 2 83 | }) 84 | pp_out_converted = tools.pre_processing.PreProcessingOutputLMDB(lmdb_converted) 85 | pp_convert = tools.pre_processing.PreProcessingNormalize(pp_in, pp_out_converted, 7.9) 86 | pp_convert.run() 87 | 88 | pp_in_converted = tools.pre_processing.PreProcessingInputLMDB(lmdb_converted) 89 | pp_out_shuffled = tools.pre_processing.PreProcessingOutputLMDB(lmdb_shuffled) 90 | pp_shuffle = tools.pre_processing.PreProcessingShuffle(pp_in_converted, pp_out_shuffled) 91 | pp_shuffle.run() 92 | 93 | pp_in_shuffled = tools.pre_processing.PreProcessingInputLMDB(lmdb_shuffled) 94 | pp_out_train = tools.pre_processing.PreProcessingOutputLMDB(args.train_lmdb) 95 | pp_out_test = tools.pre_processing.PreProcessingOutputLMDB(args.test_lmdb) 96 | pp_split = tools.pre_processing.PreProcessingSplit(pp_in_shuffled, (pp_out_train, pp_out_test), (0.9, 0.1)) 97 | pp_split.run() 98 | 99 | # to make sure 100 | print('Train:') 101 | lmdb = tools.lmdb_io.LMDB(args.train_lmdb) 102 | images, labels, keys = lmdb.read() 103 | 104 | for n in range(len(images)): 105 | print images[n].reshape((4)), labels[n] 106 | 107 | print('Test:') 108 | lmdb = tools.lmdb_io.LMDB(args.test_lmdb) 109 | images, labels, keys = lmdb.read() 110 | 111 | for n in range(len(images)): 112 | print(images[n].reshape((4)), labels[n]) 113 | 114 | def main_train(): 115 | """ 116 | Train a network from scratch on Iris using data augmentaiton to get more 117 | training samples. 118 | """ 119 | 120 | def network(lmdb_path, batch_size): 121 | """ 122 | The network definition given the LMDB path and the used batch size. 123 | 124 | :param lmdb_path: path to LMDB to use (train or test LMDB) 125 | :type lmdb_path: string 126 | :param batch_size: batch size to use 127 | :type batch_size: int 128 | :return: the network definition as string to write to the prototxt file 129 | :rtype: string 130 | """ 131 | 132 | net = caffe.NetSpec() 133 | net.data, net.labels = caffe.layers.Data(batch_size = batch_size, backend = caffe.params.Data.LMDB, 134 | source = lmdb_path, ntop = 2) 135 | net.data_aug = caffe.layers.Python(net.data, 136 | python_param = dict(module = 'tools.layers', layer = 'DataAugmentationRandomMultiplicativeNoiseLayer')) 137 | net.labels_aug = caffe.layers.Python(net.labels, 138 | python_param = dict(module = 'tools.layers', layer = 'DataAugmentationDuplicateLabelsLayer')) 139 | net.fc1 = caffe.layers.InnerProduct(net.data_aug, num_output = 12, 140 | bias_filler = dict(type = 'xavier', std = 0.1), 141 | weight_filler = dict(type = 'xavier', std = 0.1)) 142 | net.sigmoid1 = caffe.layers.Sigmoid(net.fc1) 143 | net.fc2 = caffe.layers.InnerProduct(net.sigmoid1, num_output = 3, 144 | bias_filler = dict(type = 'xavier', std = 0.1), 145 | weight_filler = dict(type = 'xavier', std = 0.1)) 146 | net.score = caffe.layers.Softmax(net.fc2) 147 | net.loss = caffe.layers.MultinomialLogisticLoss(net.score, net.labels_aug) 148 | 149 | return net.to_proto() 150 | 151 | def count_errors(scores, labels): 152 | """ 153 | Utility method to count the errors given the ouput of the 154 | "score" layer and the labels. 155 | 156 | :param score: output of score layer 157 | :type score: numpy.ndarray 158 | :param labels: labels 159 | :type labels: numpy.ndarray 160 | :return: count of errors 161 | :rtype: int 162 | """ 163 | 164 | return numpy.sum(numpy.argmax(scores, axis = 1) != labels) 165 | 166 | assert os.path.exists(args.train_lmdb), "LMDB %s not found" % args.train_lmdb 167 | assert os.path.exists(args.test_lmdb), "LMDB %s not found" % args.test_lmdb 168 | 169 | if not os.path.exists(args.working_directory): 170 | os.makedirs(args.working_directory) 171 | 172 | prototxt_train = args.working_directory + '/train.prototxt' 173 | prototxt_test = args.working_directory + '/test.prototxt' 174 | 175 | with open(prototxt_train, 'w') as f: 176 | f.write(str(network(args.train_lmdb, 6))) 177 | 178 | with open(prototxt_test, 'w') as f: 179 | f.write(str(network(args.test_lmdb, 6))) 180 | 181 | prototxt_solver = args.lmdb + '_solver.prototxt' 182 | solver_prototxt = tools.solvers.SolverProtoTXT({ 183 | 'train_net': prototxt_train, 184 | 'test_net': prototxt_test, 185 | 'test_initialization': 'false', # no testing 186 | 'test_iter': 0, # no testing 187 | 'test_interval': 100000, 188 | 'base_lr': 0.001, 189 | 'lr_policy': 'step', 190 | 'gamma': 0.01, 191 | 'stepsize': 1000, 192 | 'display': 100, 193 | 'max_iter': 1000, 194 | 'momentum': 0.9, 195 | 'weight_decay': 0.0005, 196 | 'snapshot': 0, # only at the end 197 | 'snapshot_prefix': args.working_directory + '/snapshot', 198 | 'solver_mode': 'CPU' 199 | }) 200 | 201 | solver_prototxt.write(prototxt_solver) 202 | solver = caffe.SGDSolver(prototxt_solver) 203 | callbacks = [] 204 | 205 | # Callback to report loss in console. Also automatically plots the loss 206 | # and writes it to the given file. In order to silence the console, 207 | # use plot_loss instead of report_loss. 208 | report_loss = tools.solvers.PlotLossCallback(100, args.working_directory + '/loss.png') 209 | callbacks.append({ 210 | 'callback': tools.solvers.PlotLossCallback.report_loss, 211 | 'object': report_loss, 212 | 'interval': 1, 213 | }) 214 | 215 | # Callback to report error in console. 216 | report_error = tools.solvers.PlotErrorCallback(count_errors, 60000, 10000, 217 | solver_prototxt.get_parameters()['snapshot_prefix'], 218 | args.working_directory + '/error.png') 219 | callbacks.append({ 220 | 'callback': tools.solvers.PlotErrorCallback.report_error, 221 | 'object': report_error, 222 | 'interval': 500, 223 | }) 224 | 225 | # Callback to save an "early stopping" model. 226 | callbacks.append({ 227 | 'callback': tools.solvers.PlotErrorCallback.stop_early, 228 | 'object': report_error, 229 | 'interval': 500, 230 | }) 231 | 232 | # Callback for reporting the gradients for all layers in the console. 233 | report_gradient = tools.solvers.PlotGradientCallback(100, args.working_directory + '/gradient.png') 234 | callbacks.append({ 235 | 'callback': tools.solvers.PlotGradientCallback.report_gradient, 236 | 'object': report_gradient, 237 | 'interval': 1, 238 | }) 239 | 240 | # Callback for saving regular snapshots using the snapshot_prefix in the 241 | # solver prototxt file. 242 | # Is added after the "early stopping" callback to avoid problems. 243 | callbacks.append({ 244 | 'callback': tools.solvers.SnapshotCallback.write_snapshot, 245 | 'object': tools.solvers.SnapshotCallback(), 246 | 'interval': 500, 247 | }) 248 | 249 | monitoring_solver = tools.solvers.MonitoringSolver(solver) 250 | monitoring_solver.register_callback(callbacks) 251 | monitoring_solver.solve(args.iterations) 252 | 253 | if __name__ == '__main__': 254 | parser = get_parser() 255 | args = parser.parse_args() 256 | 257 | if args.mode == 'convert': 258 | main_convert() 259 | elif args.mode == 'train': 260 | main_train() 261 | else: 262 | print('Invalid mode.') -------------------------------------------------------------------------------- /examples/iris/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/examples/iris/.gitignore -------------------------------------------------------------------------------- /examples/lmdb_io.py: -------------------------------------------------------------------------------- 1 | """ 2 | Examples for reading LMDBs. 3 | 4 | .. argparse:: 5 | :ref: examples.lmdb_io.get_parser 6 | :prog: lmdb_io 7 | """ 8 | 9 | import os 10 | import cv2 11 | import argparse 12 | 13 | # To silence Caffe! Must be added before importing Caffe or modules which 14 | # are importing Caffe. 15 | os.environ['GLOG_minloglevel'] = '0' 16 | import tools.lmdb_io 17 | 18 | def get_parser(): 19 | """ 20 | Get the parser. 21 | 22 | :return: parser 23 | :rtype: argparse.ArgumentParser 24 | """ 25 | 26 | parser = argparse.ArgumentParser(description = 'Read LMDBs.') 27 | parser.add_argument('mode', default = 'read') 28 | parser.add_argument('--lmdb', default = 'examples/cifar10/train_lmdb', type = str, 29 | help = 'path to input LMDB') 30 | parser.add_argument('--output', default = 'examples/output', type = str, 31 | help = 'output directory') 32 | parser.add_argument('--limit', default = 100, type = int, 33 | help = 'limit the number of images to read') 34 | 35 | return parser 36 | 37 | def main_statistics(): 38 | """ 39 | Read and print the size of an LMDB. 40 | """ 41 | 42 | lmdb = tools.lmdb_io.LMDB(args.lmdb) 43 | print(lmdb.count()) 44 | 45 | def main_read(): 46 | """ 47 | Read up to ``--limit`` images from the LMDB. 48 | """ 49 | 50 | lmdb = tools.lmdb_io.LMDB(args.lmdb) 51 | keys = lmdb.keys() 52 | 53 | if not os.path.exists(args.output): 54 | os.makedirs(args.output) 55 | 56 | with open(args.output + '/labels.txt', 'w') as f: 57 | for n in range(min(len(keys), args.limit)): 58 | image, label, key = lmdb.read_single(keys[n]) 59 | image_path = args.output + '/' + keys[n] + '.png' 60 | cv2.imwrite(image_path, image) 61 | f.write(image_path + ': ' + str(label) + '\n') 62 | 63 | if __name__ == '__main__': 64 | parser = get_parser() 65 | args = parser.parse_args() 66 | 67 | if args.mode == 'read': 68 | main_read() 69 | elif args.mode == 'statistics': 70 | main_statistics() 71 | else: 72 | print('Invalid mode.') -------------------------------------------------------------------------------- /examples/mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example for classification on MNIST [1]. 3 | 4 | .. code-block:: python 5 | 6 | [1] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. 7 | Gradient-based learning applied to document recognition. 8 | Proceedings of the IEEE, 86(11), 1998. 9 | 10 | **Note: the LMDBs can also be found in the data repository, see README.** 11 | 12 | Use ``caffe/data/mnist/get_mnist.sh`` and ``caffe/examples/mnist/create_mnist.sh`` 13 | to convert MNIST to LMDBs. Copy the LMDBs to ``examples/mnist`` to get the 14 | following directory structure: 15 | 16 | .. code-block:: python 17 | 18 | examples/mnist/ 19 | |- train_lmdb 20 | |- test_lmdb 21 | """ 22 | 23 | import os 24 | import cv2 25 | import glob 26 | import numpy 27 | import argparse 28 | from matplotlib import pyplot 29 | 30 | # To silence Caffe! Must be added before importing Caffe or modules which 31 | # are importing Caffe. 32 | os.environ['GLOG_minloglevel'] = '3' 33 | import caffe 34 | import tools.solvers 35 | import tools.lmdb_io 36 | import tools.prototxt 37 | import tools.pre_processing 38 | 39 | caffe.set_mode_gpu() 40 | 41 | def get_parser(): 42 | """ 43 | Get the parser. 44 | 45 | :return: parser 46 | :rtype: argparse.ArgumentParser 47 | """ 48 | 49 | parser = argparse.ArgumentParser(description = 'Caffe example on MNIST.') 50 | parser.add_argument('mode', default = 'train') 51 | parser.add_argument('--train_lmdb', dest = 'train_lmdb', type = str, 52 | help = 'path to train LMDB', 53 | default = 'examples/mnist/train_lmdb') 54 | parser.add_argument('--test_lmdb', dest = 'test_lmdb', type = str, 55 | help = 'path to test LMDB', 56 | default = 'examples/mnist/test_lmdb') 57 | parser.add_argument('--working_directory', dest = 'working_directory', type = str, 58 | help = 'path to a directory (created if not existent) where to store the created .prototxt and snapshot files', 59 | default = 'examples/mnist') 60 | parser.add_argument('--iterations', dest = 'iterations', type = int, 61 | help = 'number of iterations to train or resume', 62 | default = 10000) 63 | parser.add_argument('--image', dest = 'image', type = str, 64 | help = 'path to image for testing', 65 | default = 'examples/mnist/test_1.png') 66 | 67 | return parser 68 | 69 | def main_train(): 70 | """ 71 | Train a network for MNIST from scratch. 72 | """ 73 | 74 | def network(lmdb_path, batch_size): 75 | """ 76 | The network definition given the LMDB path and the used batch size. 77 | 78 | :param lmdb_path: path to LMDB to use (train or test LMDB) 79 | :type lmdb_path: string 80 | :param batch_size: batch size to use 81 | :type batch_size: int 82 | :return: the network definition as string to write to the prototxt file 83 | :rtype: string 84 | """ 85 | 86 | net = caffe.NetSpec() 87 | 88 | net.data, net.labels = caffe.layers.Data(batch_size = batch_size, 89 | backend = caffe.params.Data.LMDB, 90 | source = lmdb_path, 91 | transform_param = dict(scale = 1./255), 92 | ntop = 2) 93 | 94 | net.conv1 = caffe.layers.Convolution(net.data, kernel_size = 5, num_output = 20, 95 | weight_filler = dict(type = 'xavier')) 96 | net.pool1 = caffe.layers.Pooling(net.conv1, kernel_size = 2, stride = 2, 97 | pool = caffe.params.Pooling.MAX) 98 | net.conv2 = caffe.layers.Convolution(net.pool1, kernel_size = 5, num_output = 50, 99 | weight_filler = dict(type = 'xavier')) 100 | net.pool2 = caffe.layers.Pooling(net.conv2, kernel_size = 2, stride = 2, 101 | pool = caffe.params.Pooling.MAX) 102 | net.fc1 = caffe.layers.InnerProduct(net.pool2, num_output = 500, 103 | weight_filler = dict(type = 'xavier')) 104 | net.relu1 = caffe.layers.ReLU(net.fc1, in_place = True) 105 | net.score = caffe.layers.InnerProduct(net.relu1, num_output = 10, 106 | weight_filler = dict(type = 'xavier')) 107 | net.loss = caffe.layers.SoftmaxWithLoss(net.score, net.labels) 108 | 109 | return net.to_proto() 110 | 111 | def count_errors(scores, labels): 112 | """ 113 | Utility method to count the errors given the ouput of the 114 | "score" layer and the labels. 115 | 116 | :param score: output of score layer 117 | :type score: numpy.ndarray 118 | :param labels: labels 119 | :type labels: numpy.ndarray 120 | :return: count of errors 121 | :rtype: int 122 | """ 123 | 124 | return numpy.sum(numpy.argmax(scores, axis = 1) != labels) 125 | 126 | assert os.path.exists(args.train_lmdb), "LMDB %s not found" % args.train_lmdb 127 | assert os.path.exists(args.test_lmdb), "LMDB %s not found" % args.test_lmdb 128 | 129 | if not os.path.exists(args.working_directory): 130 | os.makedirs(args.working_directory) 131 | 132 | train_prototxt_path = args.working_directory + '/train.prototxt' 133 | test_prototxt_path = args.working_directory + '/test.prototxt' 134 | deploy_prototxt_path = args.working_directory + '/deploy.prototxt' 135 | 136 | with open(train_prototxt_path, 'w') as f: 137 | f.write(str(network(args.train_lmdb, 128))) 138 | 139 | with open(test_prototxt_path, 'w') as f: 140 | f.write(str(network(args.test_lmdb, 1000))) 141 | 142 | tools.prototxt.train2deploy(train_prototxt_path, (1, 1, 28, 28), deploy_prototxt_path) 143 | 144 | solver_prototxt_path = args.working_directory + '/solver.prototxt' 145 | solver_prototxt = tools.solvers.SolverProtoTXT({ 146 | 'train_net': train_prototxt_path, 147 | 'test_net': test_prototxt_path, 148 | 'test_initialization': 'false', # no testing 149 | 'test_iter': 0, # no testing 150 | 'test_interval': 1000, 151 | 'base_lr': 0.01, 152 | 'lr_policy': 'inv', 153 | 'gamma': 0.0001, 154 | 'power': 0.75, 155 | 'stepsize': 1000, 156 | 'display': 100, 157 | 'max_iter': 1000, 158 | 'momentum': 0.95, 159 | 'weight_decay': 0.0005, 160 | 'snapshot': 0, # only at the end 161 | 'snapshot_prefix': args.working_directory + '/snapshot', 162 | 'solver_mode': 'CPU' 163 | }) 164 | 165 | solver_prototxt.write(solver_prototxt_path) 166 | solver = caffe.SGDSolver(solver_prototxt_path) 167 | callbacks = [] 168 | 169 | # Callback to report loss in console. Also automatically plots the loss 170 | # and writes it to the given file. In order to silence the console, 171 | # use plot_loss instead of report_loss. 172 | report_loss = tools.solvers.PlotLossCallback(100, args.working_directory + '/loss.png') 173 | callbacks.append({ 174 | 'callback': tools.solvers.PlotLossCallback.report_loss, 175 | 'object': report_loss, 176 | 'interval': 1, 177 | }) 178 | 179 | # Callback to report error in console. 180 | report_error = tools.solvers.PlotErrorCallback(count_errors, 60000, 10000, 181 | solver_prototxt.get_parameters()['snapshot_prefix'], 182 | args.working_directory + '/error.png') 183 | callbacks.append({ 184 | 'callback': tools.solvers.PlotErrorCallback.report_error, 185 | 'object': report_error, 186 | 'interval': 500, 187 | }) 188 | 189 | # Callback to save an "early stopping" model. 190 | callbacks.append({ 191 | 'callback': tools.solvers.PlotErrorCallback.stop_early, 192 | 'object': report_error, 193 | 'interval': 500, 194 | }) 195 | 196 | # Callback for reporting the gradients for all layers in the console. 197 | report_gradient = tools.solvers.PlotGradientCallback(100, args.working_directory + '/gradient.png') 198 | callbacks.append({ 199 | 'callback': tools.solvers.PlotGradientCallback.report_gradient, 200 | 'object': report_gradient, 201 | 'interval': 1, 202 | }) 203 | 204 | # Callback for saving regular snapshots using the snapshot_prefix in the 205 | # solver prototxt file. 206 | # Is added after the "early stopping" callback to avoid problems. 207 | callbacks.append({ 208 | 'callback': tools.solvers.SnapshotCallback.write_snapshot, 209 | 'object': tools.solvers.SnapshotCallback(), 210 | 'interval': 500, 211 | }) 212 | 213 | monitoring_solver = tools.solvers.MonitoringSolver(solver) 214 | monitoring_solver.register_callback(callbacks) 215 | monitoring_solver.solve(args.iterations) 216 | 217 | def main_train_augmented(): 218 | """ 219 | Train a network from scratch on augmented MNIST. Augmentation is done on the 220 | fly and only involves multiplicative Gaussian noise. 221 | 222 | Uses the same working directory as :func:`examples.mnist.main_train`, i.e. 223 | the corresponding snapshots will be overwritten if not changed via 224 | ``--working_directory``. 225 | """ 226 | 227 | def network(lmdb_path, batch_size): 228 | """ 229 | The network definition given the LMDB path and the used batch size. 230 | 231 | :param lmdb_path: path to LMDB to use (train or test LMDB) 232 | :type lmdb_path: string 233 | :param batch_size: batch size to use 234 | :type batch_size: int 235 | :return: the network definition as string to write to the prototxt file 236 | :rtype: string 237 | """ 238 | 239 | net = caffe.NetSpec() 240 | 241 | net.data, net.labels = caffe.layers.Data(batch_size = batch_size, 242 | backend = caffe.params.Data.LMDB, 243 | source = lmdb_path, 244 | transform_param = dict(scale = 1./255), 245 | ntop = 2) 246 | net.augmented_data = caffe.layers.Python(net.data, python_param = dict(module = 'tools.layers', layer = 'DataAugmentationMultiplicativeGaussianNoiseLayer')) 247 | net.augmented_labels = caffe.layers.Python(net.labels, python_param = dict(module = 'tools.layers', layer = 'DataAugmentationDoubleLabelsLayer')) 248 | 249 | net.conv1 = caffe.layers.Convolution(net.augmented_data, kernel_size = 5, num_output = 20, 250 | weight_filler = dict(type = 'xavier')) 251 | net.pool1 = caffe.layers.Pooling(net.conv1, kernel_size = 2, stride = 2, 252 | pool = caffe.params.Pooling.MAX) 253 | net.conv2 = caffe.layers.Convolution(net.pool1, kernel_size = 5, num_output = 50, 254 | weight_filler = dict(type = 'xavier')) 255 | net.pool2 = caffe.layers.Pooling(net.conv2, kernel_size = 2, stride = 2, 256 | pool = caffe.params.Pooling.MAX) 257 | net.fc1 = caffe.layers.InnerProduct(net.pool2, num_output = 500, 258 | weight_filler = dict(type = 'xavier')) 259 | net.relu1 = caffe.layers.ReLU(net.fc1, in_place = True) 260 | net.score = caffe.layers.InnerProduct(net.relu1, num_output = 10, 261 | weight_filler = dict(type = 'xavier')) 262 | net.loss = caffe.layers.SoftmaxWithLoss(net.score, net.augmented_labels) 263 | 264 | return net.to_proto() 265 | 266 | def count_errors(scores, labels): 267 | """ 268 | Utility method to count the errors given the ouput of the 269 | "score" layer and the labels. 270 | 271 | :param score: output of score layer 272 | :type score: numpy.ndarray 273 | :param labels: labels 274 | :type labels: numpy.ndarray 275 | :return: count of errors 276 | :rtype: int 277 | """ 278 | 279 | return numpy.sum(numpy.argmax(scores, axis = 1) != labels) 280 | 281 | assert os.path.exists(args.train_lmdb), "LMDB %s not found" % args.train_lmdb 282 | assert os.path.exists(args.test_lmdb), "LMDB %s not found" % args.test_lmdb 283 | 284 | if not os.path.exists(args.working_directory): 285 | os.makedirs(args.working_directory) 286 | 287 | train_prototxt_path = args.working_directory + '/train.prototxt' 288 | test_prototxt_path = args.working_directory + '/test.prototxt' 289 | deploy_prototxt_path = args.working_directory + '/deploy.prototxt' 290 | 291 | with open(train_prototxt_path, 'w') as f: 292 | f.write(str(network(args.train_lmdb, 128))) 293 | 294 | with open(test_prototxt_path, 'w') as f: 295 | f.write(str(network(args.test_lmdb, 1000))) 296 | 297 | tools.prototxt.train2deploy(train_prototxt_path, (1, 1, 28, 28), deploy_prototxt_path) 298 | 299 | solver_prototxt_path = args.working_directory + '/solver.prototxt' 300 | solver_prototxt = tools.solvers.SolverProtoTXT({ 301 | 'train_net': train_prototxt_path, 302 | 'test_net': test_prototxt_path, 303 | 'test_initialization': 'false', # no testing 304 | 'test_iter': 0, # no testing 305 | 'test_interval': 1000, 306 | 'base_lr': 0.01, 307 | 'lr_policy': 'inv', 308 | 'gamma': 0.0001, 309 | 'power': 0.75, 310 | 'stepsize': 1000, 311 | 'display': 100, 312 | 'max_iter': 1000, 313 | 'momentum': 0.95, 314 | 'weight_decay': 0.0005, 315 | 'snapshot': 0, # only at the end 316 | 'snapshot_prefix': args.working_directory + '/snapshot', 317 | 'solver_mode': 'CPU' 318 | }) 319 | 320 | solver_prototxt.write(solver_prototxt_path) 321 | solver = caffe.SGDSolver(solver_prototxt_path) 322 | callbacks = [] 323 | 324 | # Callback to report loss in console. Also automatically plots the loss 325 | # and writes it to the given file. In order to silence the console, 326 | # use plot_loss instead of report_loss. 327 | report_loss = tools.solvers.PlotLossCallback(100, args.working_directory + '/loss.png') 328 | callbacks.append({ 329 | 'callback': tools.solvers.PlotLossCallback.report_loss, 330 | 'object': report_loss, 331 | 'interval': 1, 332 | }) 333 | 334 | # Callback to report error in console. 335 | report_error = tools.solvers.PlotErrorCallback(count_errors, 60000, 10000, 336 | solver_prototxt.get_parameters()['snapshot_prefix'], 337 | args.working_directory + '/error.png') 338 | callbacks.append({ 339 | 'callback': tools.solvers.PlotErrorCallback.report_error, 340 | 'object': report_error, 341 | 'interval': 500, 342 | }) 343 | 344 | # Callback to save an "early stopping" model. 345 | callbacks.append({ 346 | 'callback': tools.solvers.PlotErrorCallback.stop_early, 347 | 'object': report_error, 348 | 'interval': 500, 349 | }) 350 | 351 | # Callback for reporting the gradients for all layers in the console. 352 | report_gradient = tools.solvers.PlotGradientCallback(100, args.working_directory + '/gradient.png') 353 | callbacks.append({ 354 | 'callback': tools.solvers.PlotGradientCallback.report_gradient, 355 | 'object': report_gradient, 356 | 'interval': 1, 357 | }) 358 | 359 | # Callback for saving regular snapshots using the snapshot_prefix in the 360 | # solver prototxt file. 361 | # Is added after the "early stopping" callback to avoid problems. 362 | callbacks.append({ 363 | 'callback': tools.solvers.SnapshotCallback.write_snapshot, 364 | 'object': tools.solvers.SnapshotCallback(), 365 | 'interval': 500, 366 | }) 367 | 368 | monitoring_solver = tools.solvers.MonitoringSolver(solver) 369 | monitoring_solver.register_callback(callbacks) 370 | monitoring_solver.solve(args.iterations) 371 | 372 | def main_resume(): 373 | """ 374 | Resume training a network as started via :func:`examples.mnist.main_train`, 375 | :func:`examples.mnist.main_train_augmented` or :func:`examples.mnist.main_train_autoencoder`. 376 | """ 377 | 378 | def network(lmdb_path, batch_size): 379 | """ 380 | The network definition given the LMDB path and the used batch size. 381 | 382 | :param lmdb_path: path to LMDB to use (train or test LMDB) 383 | :type lmdb_path: string 384 | :param batch_size: batch size to use 385 | :type batch_size: int 386 | :return: the network definition as string to write to the prototxt file 387 | :rtype: string 388 | """ 389 | 390 | net = caffe.NetSpec() 391 | 392 | net.data, net.labels = caffe.layers.Data(batch_size = batch_size, 393 | backend = caffe.params.Data.LMDB, 394 | source = lmdb_path, 395 | transform_param = dict(scale = 1./255), 396 | ntop = 2) 397 | net.augmented_data = caffe.layers.Python(net.data, python_param = dict(module = 'tools.layers', layer = 'DataAugmentationMultiplicativeGaussianNoiseLayer')) 398 | net.augmented_labels = caffe.layers.Python(net.labels, python_param = dict(module = 'tools.layers', layer = 'DataAugmentationDoubleLabelsLayer')) 399 | 400 | net.conv1 = caffe.layers.Convolution(net.augmented_data, kernel_size = 5, num_output = 20, 401 | weight_filler = dict(type = 'xavier')) 402 | net.pool1 = caffe.layers.Pooling(net.conv1, kernel_size = 2, stride = 2, 403 | pool = caffe.params.Pooling.MAX) 404 | net.conv2 = caffe.layers.Convolution(net.pool1, kernel_size = 5, num_output = 50, 405 | weight_filler = dict(type = 'xavier')) 406 | net.pool2 = caffe.layers.Pooling(net.conv2, kernel_size = 2, stride = 2, 407 | pool = caffe.params.Pooling.MAX) 408 | net.fc1 = caffe.layers.InnerProduct(net.pool2, num_output = 500, 409 | weight_filler = dict(type = 'xavier')) 410 | net.relu1 = caffe.layers.ReLU(net.fc1, in_place = True) 411 | net.score = caffe.layers.InnerProduct(net.relu1, num_output = 10, 412 | weight_filler = dict(type = 'xavier')) 413 | net.loss = caffe.layers.SoftmaxWithLoss(net.score, net.augmented_labels) 414 | 415 | return net.to_proto() 416 | 417 | def count_errors(scores, labels): 418 | """ 419 | Utility method to count the errors given the ouput of the 420 | "score" layer and the labels. 421 | 422 | :param score: output of score layer 423 | :type score: numpy.ndarray 424 | :param labels: labels 425 | :type labels: numpy.ndarray 426 | :return: count of errors 427 | :rtype: int 428 | """ 429 | 430 | return numpy.sum(numpy.argmax(scores, axis = 1) != labels) 431 | 432 | max_iteration = 0 433 | files = glob.glob(args.working_directory + '/*.solverstate') 434 | 435 | for filename in files: 436 | filenames = filename.split('_') 437 | iteration = filenames[-1][:-12] 438 | 439 | try: 440 | iteration = int(iteration) 441 | if iteration > max_iteration: 442 | max_iteration = iteration 443 | except: 444 | pass 445 | 446 | caffemodel = args.working_directory + '/snapshot_iter_' + str(max_iteration) + '.caffemodel' 447 | solverstate = args.working_directory + '/snapshot_iter_' + str(max_iteration) + '.solverstate' 448 | 449 | train_prototxt_path = args.working_directory + '/train.prototxt' 450 | test_prototxt_path = args.working_directory + '/test.prototxt' 451 | deploy_prototxt_path = args.working_directory + '/deploy.prototxt' 452 | solver_prototxt_path = args.working_directory + '/solver.prototxt' 453 | 454 | assert max_iteration > 0, "could not find a solverstate or snaphot file to resume" 455 | assert os.path.exists(caffemodel), "caffemodel %s not found" % caffemodel 456 | assert os.path.exists(solverstate), "solverstate %s not found" % solverstate 457 | assert os.path.exists(train_prototxt_path), "prototxt %s not found" % train_prototxt_path 458 | assert os.path.exists(test_prototxt_path), "prototxt %s not found" % test_prototxt_path 459 | assert os.path.exists(deploy_prototxt_path), "prototxt %s not found" % deploy_prototxt_path 460 | assert os.path.exists(solver_prototxt_path), "prototxt %s not found" % solver_prototxt_path 461 | 462 | solver = caffe.SGDSolver(solver_prototxt_path) 463 | solver.restore(solverstate) 464 | 465 | solver.net.copy_from(caffemodel) 466 | 467 | solver_prototxt = tools.solvers.SolverProtoTXT() 468 | solver_prototxt.read(solver_prototxt_path) 469 | callbacks = [] 470 | 471 | # Callback to report loss in console. 472 | report_loss = tools.solvers.PlotLossCallback(100, args.working_directory + '/loss.png') 473 | callbacks.append({ 474 | 'callback': tools.solvers.PlotLossCallback.report_loss, 475 | 'object': report_loss, 476 | 'interval': 1, 477 | }) 478 | 479 | # Callback to report error in console. 480 | report_error = tools.solvers.PlotErrorCallback(count_errors, 60000, 10000, 481 | solver_prototxt.get_parameters()['snapshot_prefix'], 482 | args.working_directory + '/error.png') 483 | callbacks.append({ 484 | 'callback': tools.solvers.PlotErrorCallback.report_error, 485 | 'object': report_error, 486 | 'interval': 500, 487 | }) 488 | 489 | # Callback to save an "early stopping" model. 490 | callbacks.append({ 491 | 'callback': tools.solvers.PlotErrorCallback.stop_early, 492 | 'object': report_error, 493 | 'interval': 500, 494 | }) 495 | 496 | # Callback for reporting the gradients for all layers in the console. 497 | report_gradient = tools.solvers.PlotGradientCallback(100, args.working_directory + '/gradient.png') 498 | callbacks.append({ 499 | 'callback': tools.solvers.PlotGradientCallback.report_gradient, 500 | 'object': report_gradient, 501 | 'interval': 1, 502 | }) 503 | 504 | # Callback for saving regular snapshots using the snapshot_prefix in the 505 | # solver prototxt file. 506 | # Is added after the "early stopping" callback to avoid problems. 507 | callbacks.append({ 508 | 'callback': tools.solvers.SnapshotCallback.write_snapshot, 509 | 'object': tools.solvers.SnapshotCallback(), 510 | 'interval': 500, 511 | }) 512 | 513 | monitoring_solver = tools.solvers.MonitoringSolver(solver, max_iteration) 514 | monitoring_solver.register_callback(callbacks) 515 | monitoring_solver.solve(args.iterations) 516 | 517 | def main_test(): 518 | """ 519 | Test the latest model obtained by :func:`examples.cifar10.main_train` 520 | or :func:`examples.cifar10.main_resume` on the given input image. 521 | """ 522 | 523 | max_iteration = 0 524 | files = glob.glob(args.working_directory + '/*.solverstate') 525 | 526 | for filename in files: 527 | filenames = filename.split('_') 528 | iteration = filenames[-1][:-12] 529 | 530 | try: 531 | iteration = int(iteration) 532 | if iteration > max_iteration: 533 | max_iteration = iteration 534 | except: 535 | pass 536 | 537 | caffemodel = args.working_directory + '/snapshot_iter_' + str(max_iteration) + '.caffemodel' 538 | deploy_prototxt_path = args.working_directory + '/deploy.prototxt' 539 | 540 | assert max_iteration > 0, "could not find a solverstate or snaphot file to resume" 541 | assert os.path.exists(caffemodel), "caffemodel %s not found" % caffemodel 542 | assert os.path.exists(deploy_prototxt_path), "prototxt %s not found" % deploy_prototxt_path 543 | 544 | net = caffe.Net(deploy_prototxt_path, caffemodel, caffe.TEST) 545 | transformer = caffe.io.Transformer({'data': (1, 1, 28, 28)}) 546 | transformer.set_transpose('data', (2, 0, 1)) 547 | transformer.set_raw_scale('data', 1/255.) 548 | 549 | assert os.path.exists(args.image), "image %s not found" % args.image 550 | image = cv2.imread(args.image) 551 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 552 | #image = 255 - image 553 | image.resize((28, 28, 1)) 554 | cv2.imshow('image', image) 555 | 556 | net.blobs['data'].reshape(1, 1, 28, 28) 557 | net.blobs['data'].data[...] = transformer.preprocess('data', image) 558 | 559 | net.forward() 560 | scores = net.blobs['score'].data 561 | 562 | x = range(10) 563 | pyplot.bar(x, scores[0, :], 1/1.5, color = 'blue') 564 | pyplot.gcf().subplots_adjust(bottom = 0.2) 565 | pyplot.show() 566 | 567 | if __name__ == '__main__': 568 | parser = get_parser() 569 | args = parser.parse_args() 570 | 571 | if args.mode == 'train': 572 | main_train() 573 | elif args.mode == 'train_augmented': 574 | main_train_augmented() 575 | elif args.mode == 'resume': 576 | main_resume() 577 | elif args.mode == 'test': 578 | main_test() 579 | else: 580 | print('Invalid mode.') -------------------------------------------------------------------------------- /examples/mnist/test_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/examples/mnist/test_0.png -------------------------------------------------------------------------------- /examples/mnist/test_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/examples/mnist/test_1.png -------------------------------------------------------------------------------- /examples/mnist/test_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/examples/mnist/test_2.png -------------------------------------------------------------------------------- /examples/mnist/test_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/examples/mnist/test_3.png -------------------------------------------------------------------------------- /examples/mnist/test_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/examples/mnist/test_4.png -------------------------------------------------------------------------------- /examples/mnist/test_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/examples/mnist/test_5.png -------------------------------------------------------------------------------- /examples/mnist/test_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/examples/mnist/test_6.png -------------------------------------------------------------------------------- /examples/mnist/test_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/examples/mnist/test_7.png -------------------------------------------------------------------------------- /examples/mnist/test_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/examples/mnist/test_8.png -------------------------------------------------------------------------------- /examples/mnist/test_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/examples/mnist/test_9.png -------------------------------------------------------------------------------- /examples/visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualize trained network weights. 3 | 4 | .. argparse:: 5 | :ref: examples.visualization.get_parser 6 | :prog: visualization 7 | """ 8 | 9 | import os 10 | import cv2 11 | import numpy 12 | import argparse 13 | from matplotlib import pyplot 14 | 15 | # To silence Caffe! Must be added before importing Caffe or modules which 16 | # are importing Caffe. 17 | os.environ['GLOG_minloglevel'] = '3' 18 | import caffe 19 | import tools.visualization 20 | 21 | caffe.set_mode_gpu() 22 | 23 | def get_parser(): 24 | """ 25 | Get the parser. 26 | 27 | :return: parser 28 | :rtype: argparse.ArgumentParser 29 | """ 30 | 31 | parser = argparse.ArgumentParser(description = 'Caffe example on Cifar-10.') 32 | parser.add_argument('--prototxt', dest = 'prototxt', type = str, 33 | help = 'path to the prototxt network definition', 34 | default = 'examples/cifar10/train_lmdb') 35 | parser.add_argument('--caffemodel', dest = 'caffemodel', type = str, 36 | help = 'path to the caffemodel', 37 | default = 'examples/cifar10/test_lmdb') 38 | parser.add_argument('--output', dest = 'output', type = str, 39 | help = 'output directory for visualizations', 40 | default = 'examples/output') 41 | 42 | return parser 43 | 44 | def main(): 45 | """ 46 | Visualize weights of the network. 47 | """ 48 | 49 | assert os.path.exists(args.prototxt), "prototxt %s not found" % args.prototxt 50 | assert os.path.exists(args.caffemodel), "caffemodel %s not found" % args.caffemodel 51 | 52 | if not os.path.exists(args.output): 53 | os.makedirs(args.output) 54 | 55 | net = caffe.Net(args.prototxt, args.caffemodel, caffe.TEST) 56 | layers = tools.visualization.get_layers(net) 57 | 58 | for layer in layers: 59 | if layer.find('conv') >= 0: 60 | kernels = tools.visualization.visualize_kernels(net, layer, 5) 61 | 62 | cv2.imwrite(args.output + '/' + layer + '.png', (kernels*255).astype(numpy.uint8)) 63 | #pyplot.imshow(kernels, interpolation = 'none') 64 | #pyplot.colorbar() 65 | #pyplot.savefig(args.output + '/' + layer + '.png') 66 | #pyplot.clf() 67 | 68 | elif layer.find('fc') >= 0 or layer.find('score') >= 0: 69 | weights = tools.visualization.visualize_weights(net, layer, 5) 70 | 71 | cv2.imwrite(args.output + '/' + layer + '.png', (weights*255).astype(numpy.uint8)) 72 | #pyplot.imshow(weights, interpolation = 'none') 73 | #pyplot.colorbar() 74 | #pyplot.savefig(args.output + '/' + layer + '.png') 75 | #pyplot.clf() 76 | 77 | if __name__ == '__main__': 78 | parser = get_parser() 79 | args = parser.parse_args() 80 | 81 | main() -------------------------------------------------------------------------------- /install_caffe.sh: -------------------------------------------------------------------------------- 1 | # This script installs Caffe and pycaffe on Ubuntu 14.04 x64 or 14.10 x64. CPU only, multi-threaded Caffe. 2 | # Usage: 3 | # 0. Set up here how many cores you want to use during the installation: 4 | # By default Caffe will use all these cores. 5 | NUMBER_OF_CORES=2 6 | # 1. Execute this script, e.g. "bash compile_caffe_ubuntu_14.04.sh" (~30 to 60 minutes on a new Ubuntu). 7 | # 2. Open a new shell (or run "source ~/.bash_profile"). You're done. You can try 8 | # running "import caffe" from the Python interpreter to test. 9 | 10 | #http://caffe.berkeleyvision.org/install_apt.html : (general install info: http://caffe.berkeleyvision.org/installation.html) 11 | cd 12 | sudo apt-get update 13 | #sudo apt-get upgrade -y # If you are OK getting prompted 14 | sudo DEBIAN_FRONTEND=noninteractive apt-get upgrade -y -q -o Dpkg::Options::="--force-confdef" -o Dpkg::Options::="--force-confold" # If you are OK with all defaults 15 | 16 | sudo apt-get install -y libprotobuf-dev libleveldb-dev libsnappy-dev libopencv-dev libhdf5-serial-dev 17 | sudo apt-get install -y --no-install-recommends libboost-all-dev 18 | sudo apt-get install -y libatlas-base-dev 19 | sudo apt-get install -y python-dev 20 | sudo apt-get install -y python-pip git 21 | 22 | # For Ubuntu 14.04 23 | sudo apt-get install -y libgflags-dev libgoogle-glog-dev liblmdb-dev protobuf-compiler 24 | 25 | git clone https://github.com/LMDB/lmdb.git 26 | cd lmdb/libraries/liblmdb 27 | sudo make 28 | sudo make install 29 | 30 | # More pre-requisites 31 | sudo apt-get install -y cmake unzip doxygen 32 | sudo apt-get install -y protobuf-compiler 33 | sudo apt-get install -y libffi-dev python-dev build-essential 34 | sudo pip install lmdb 35 | sudo pip install numpy 36 | sudo apt-get install -y python-numpy 37 | sudo apt-get install -y gfortran # required by scipy 38 | sudo pip install scipy # required by scikit-image 39 | sudo apt-get install -y python-scipy # in case pip failed 40 | sudo apt-get install -y python-nose 41 | sudo pip install scikit-image # to fix https://github.com/BVLC/caffe/issues/50 42 | 43 | # Get caffe (http://caffe.berkeleyvision.org/installation.html#compilation) 44 | cd 45 | mkdir caffe 46 | cd caffe 47 | wget https://github.com/BVLC/caffe/archive/master.zip 48 | unzip -o master.zip 49 | cd caffe-master 50 | 51 | # Prepare Python binding (pycaffe) 52 | cd python 53 | for req in $(cat requirements.txt); do sudo pip install $req; done 54 | echo "export PYTHONPATH=$(pwd):$PYTHONPATH " >> ~/.bash_profile # to be able to call "import caffe" from Python after reboot 55 | source ~/.bash_profile # Update shell 56 | cd .. 57 | 58 | # Compile caffe and pycaffe 59 | cp Makefile.config.example Makefile.config 60 | sed -i '8s/.*/CPU_ONLY := 1/' Makefile.config # Line 8: CPU only 61 | sudo apt-get install -y libopenblas-dev 62 | sed -i '33s/.*/BLAS := open/' Makefile.config # Line 33: to use OpenBLAS 63 | # Note that if one day the Makefile.config changes and these line numbers change, we're screwed 64 | # Maybe it would be best to simply append those changes at the end of Makefile.config 65 | echo "export OPENBLAS_NUM_THREADS=($NUMBER_OF_CORES)" >> ~/.bash_profile 66 | mkdir build 67 | cd build 68 | cmake .. 69 | cd .. 70 | make all -j$NUMBER_OF_CORES # 4 is the number of parallel threads for compilation: typically equal to number of physical cores 71 | make pycaffe -j$NUMBER_OF_CORES 72 | make test 73 | make runtest 74 | #make matcaffe 75 | make distribute 76 | 77 | # Bonus for other work with pycaffe 78 | sudo pip install pydot 79 | sudo apt-get install -y graphviz 80 | sudo pip install scikit-learn 81 | 82 | # At the end, you need to run "source ~/.bash_profile" manually or start a new shell to be able to do 'python import caffe', 83 | # because one cannot source in a bash script. (http://stackoverflow.com/questions/ 84 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests, make sure to copy over the correct LMDBs used for testing. 3 | """ -------------------------------------------------------------------------------- /tests/cifar10_test/00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/tests/cifar10_test/00000.png -------------------------------------------------------------------------------- /tests/cifar10_test/00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/tests/cifar10_test/00001.png -------------------------------------------------------------------------------- /tests/cifar10_test/00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/tests/cifar10_test/00002.png -------------------------------------------------------------------------------- /tests/cifar10_test/00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/tests/cifar10_test/00003.png -------------------------------------------------------------------------------- /tests/cifar10_test/00004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/tests/cifar10_test/00004.png -------------------------------------------------------------------------------- /tests/lmdb_io.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for reading and writing LMDBs using :class:`tools.lmdb.LMDB`. 3 | 4 | In order to run the tests, the MNIST LMDB created by Caffe is created. 5 | Change to the caffe root directory and run ``./data/mnist/get_mnist.sh`` 6 | as well as ``./examples/mnist/create_mnist.sh``. Then copy 7 | ``examples/mnist/mnist_test_lmdb`` to this folder. 8 | 9 | Do the same for Cifar10! 10 | """ 11 | 12 | import tools.lmdb_io 13 | import unittest 14 | import shutil 15 | import numpy 16 | import cv2 17 | import os 18 | 19 | class TestLMDB(unittest.TestCase): 20 | """ 21 | Tests for :class:`tools.lmdb.LMDB`. 22 | """ 23 | 24 | def test_keys_mnist(self): 25 | """ 26 | Tests reading the keys from the MNIST LMDB. 27 | """ 28 | 29 | lmdb_path = 'tests/mnist_test_lmdb' 30 | lmdb = tools.lmdb_io.LMDB(lmdb_path) 31 | 32 | i = 0 33 | for key in lmdb.keys(): 34 | self.assertEqual(key, '{:08}'.format(i)) 35 | i += 1 36 | 37 | def test_read_mnist(self): 38 | """ 39 | Tests reading from the MNIST LMDB. 40 | """ 41 | 42 | lmdb_path = 'tests/mnist_test_lmdb' 43 | lmdb = tools.lmdb_io.LMDB(lmdb_path) 44 | 45 | keys = lmdb.keys(5) 46 | for key in keys: 47 | image, label, key = lmdb.read(key) 48 | 49 | image_path = 'tests/mnist_test/' + key + '.png' 50 | assert os.path.exists(image_path) 51 | 52 | image = cv2.imread(image_path, cv2.CV_LOAD_IMAGE_GRAYSCALE) 53 | 54 | for i in range(image.shape[0]): 55 | for j in range(image.shape[1]): 56 | self.assertEqual(image[i, j], image[i, j]) 57 | 58 | def test_keys_cifar(self): 59 | """ 60 | Tests reading the keys from the cifar10 LMDB. 61 | """ 62 | 63 | lmdb_path = 'tests/cifar10_test_lmdb' 64 | lmdb = tools.lmdb_io.LMDB(lmdb_path) 65 | 66 | i = 0 67 | for key in lmdb.keys(): 68 | self.assertEqual(key, '{:05}'.format(i)) 69 | i += 1 70 | 71 | def test_read_cifar(self): 72 | """ 73 | Tests reading from the cifar10 LMDB. 74 | """ 75 | 76 | lmdb_path = 'tests/cifar10_test_lmdb' 77 | lmdb = tools.lmdb_io.LMDB(lmdb_path) 78 | 79 | keys = lmdb.keys(5) 80 | for key in keys: 81 | image, label, key = lmdb.read(key) 82 | 83 | image_path = 'tests/cifar10_test/' + key + '.png' 84 | assert os.path.exists(image_path) 85 | 86 | image = cv2.imread(image_path) 87 | 88 | for i in range(image.shape[0]): 89 | for j in range(image.shape[1]): 90 | for c in range(image.shape[2]): 91 | self.assertEqual(image[i, j, c], image[i, j, c]) 92 | 93 | def test_write_read_random(self): 94 | """ 95 | Tests writing and reading on sample images. 96 | """ 97 | 98 | lmdb_path = 'tests/test_lmdb' 99 | lmdb = tools.lmdb_io.LMDB(lmdb_path) 100 | 101 | write_images = [(numpy.random.rand(10, 10, 3)*255).astype(numpy.uint8)]*10 102 | write_labels = [0]*10 103 | 104 | lmdb.write(write_images, write_labels) 105 | read_images, read_labels, read_keys = lmdb.read() 106 | 107 | for n in range(10): 108 | for i in range(10): 109 | for j in range(10): 110 | for c in range(3): 111 | self.assertEqual(write_images[n][i, j, c], read_images[n][i, j, c]) 112 | 113 | self.assertEqual(write_labels[n], read_labels[n]) 114 | 115 | if os.path.exists(lmdb_path): 116 | shutil.rmtree(lmdb_path) 117 | 118 | def test_write_read_random_float(self): 119 | """ 120 | Tests writing and reading on sample images. 121 | """ 122 | 123 | lmdb_path = 'tests/test_lmdb' 124 | lmdb = tools.lmdb_io.LMDB(lmdb_path) 125 | 126 | write_images = [numpy.random.rand(10, 10, 3).astype(numpy.float)]*10 127 | write_labels = [0]*10 128 | 129 | lmdb.write(write_images, write_labels) 130 | read_images, read_labels, read_keys = lmdb.read() 131 | 132 | for n in range(10): 133 | for i in range(10): 134 | for j in range(10): 135 | for c in range(3): 136 | self.assertAlmostEqual(write_images[n][i, j, c], read_images[n][i, j, c]) 137 | 138 | self.assertEqual(write_labels[n], read_labels[n]) 139 | 140 | if os.path.exists(lmdb_path): 141 | shutil.rmtree(lmdb_path) 142 | 143 | if __name__ == '__main__': 144 | unittest.main() -------------------------------------------------------------------------------- /tests/mnist_test/00000000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/tests/mnist_test/00000000.png -------------------------------------------------------------------------------- /tests/mnist_test/00000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/tests/mnist_test/00000001.png -------------------------------------------------------------------------------- /tests/mnist_test/00000002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/tests/mnist_test/00000002.png -------------------------------------------------------------------------------- /tests/mnist_test/00000003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/tests/mnist_test/00000003.png -------------------------------------------------------------------------------- /tests/mnist_test/00000004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidstutz/caffe-tools/b18e78e39aaabfd9a7b70d9202c77a5dec2aae18/tests/mnist_test/00000004.png -------------------------------------------------------------------------------- /tests/pre_processing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for :mod:`tools.pre_processing`. 3 | """ 4 | 5 | import tools.pre_processing 6 | import tools.lmdb_io 7 | import unittest 8 | import random 9 | import shutil 10 | import numpy 11 | import cv2 12 | import os 13 | 14 | class TestPreProcessing(unittest.TestCase): 15 | """ 16 | Tests for :mod:`tools.pre_processing`. 17 | """ 18 | 19 | def test_lmdb_input(self): 20 | """ 21 | Test LMDB input to pre processing. 22 | """ 23 | 24 | N = 27 25 | H = 10 26 | W = 10 27 | C = 3 28 | 29 | images = [] 30 | labels = [] 31 | 32 | for n in range(N): 33 | image = (numpy.random.rand(H, W, C)*255).astype(numpy.uint8) 34 | label = random.randint(0, 1000) 35 | 36 | images.append(image) 37 | labels.append(label) 38 | 39 | 40 | lmdb_path = 'tests/test_lmdb' 41 | if os.path.exists(lmdb_path): 42 | shutil.rmtree(lmdb_path) 43 | 44 | lmdb = tools.lmdb_io.LMDB(lmdb_path) 45 | lmdb.write(images, labels) 46 | 47 | pp_in = tools.pre_processing.PreProcessingInputLMDB(lmdb_path) 48 | self.assertEqual(pp_in.count(), 27) 49 | 50 | n = 0 51 | for i in range(3): 52 | read_images, read_labels = pp_in.read(10) 53 | self.assertEqual(len(read_images), len(read_labels)) 54 | 55 | for j in range(len(read_images)): 56 | for ii in range(H): 57 | for jj in range(W): 58 | for cc in range(C): 59 | self.assertEqual(images[n][ii, jj, cc], read_images[j][ii, jj, cc]) 60 | 61 | self.assertEqual(labels[n], read_labels[j]) 62 | n += 1 63 | 64 | self.assertTrue(pp_in.end()) 65 | 66 | def test_files_input(self): 67 | """ 68 | Test LMDB input to pre processing. 69 | """ 70 | 71 | N = 27 72 | H = 10 73 | W = 10 74 | C = 3 75 | 76 | files = [] 77 | images = [] 78 | labels = [] 79 | 80 | pos_path = 'tests/test_pos' 81 | neg_path = 'tests/test_neg' 82 | 83 | if os.path.exists(pos_path): 84 | shutil.rmtree(pos_path) 85 | 86 | if os.path.exists(neg_path): 87 | shutil.rmtree(neg_path) 88 | 89 | os.mkdir(pos_path) 90 | os.mkdir(neg_path) 91 | 92 | for n in range(N): 93 | image = (numpy.random.rand(H, W, C)*255).astype(numpy.uint8) 94 | 95 | label = random.randint(0, 1) 96 | path = neg_path + '/' + str(n) + '.png' 97 | 98 | if label == 1: 99 | path = pos_path + '/' + str(n) + '.png' 100 | 101 | files.append(path) 102 | cv2.imwrite(path, image) 103 | 104 | images.append(image) 105 | labels.append(label) 106 | 107 | pp_in = tools.pre_processing.PreProcessingInputFiles(files, labels) 108 | self.assertEqual(pp_in.count(), 27) 109 | 110 | n = 0 111 | for i in range(3): 112 | read_images, read_labels = pp_in.read(10) 113 | self.assertEqual(len(read_images), len(read_labels)) 114 | 115 | for j in range(len(read_images)): 116 | for ii in range(H): 117 | for jj in range(W): 118 | for cc in range(C): 119 | self.assertEqual(images[n][ii, jj, cc], read_images[j][ii, jj, cc]) 120 | 121 | self.assertEqual(labels[n], read_labels[j]) 122 | n += 1 123 | 124 | self.assertTrue(pp_in.end()) 125 | 126 | def test_csv_input(self): 127 | """ 128 | Test CSV input. 129 | """ 130 | 131 | csv_file = 'tests/csv.csv' 132 | if os.path.exists(csv_file): 133 | os.unlink(csv_file) 134 | 135 | lmdb_path = 'tests/test_lmdb' 136 | if os.path.exists(lmdb_path): 137 | shutil.rmtree(lmdb_path) 138 | 139 | images = [] 140 | labels = [] 141 | 142 | with open(csv_file, 'w') as f: 143 | for n in range(27): 144 | image = [random.random(), random.random(), random.random()] 145 | 146 | label = 0 147 | if random.random() > 0.5: 148 | label = 1 149 | 150 | if label == 0: 151 | f.write(str(image[0]) + ',' + str(image[1]) + ',' + str(image[2]) + ',eins\n') 152 | else: 153 | f.write(str(image[0]) + ',' + str(image[1]) + ',' + str(image[2]) + ',zwei\n') 154 | 155 | images.append(image) 156 | labels.append(label) 157 | 158 | pp_in = tools.pre_processing.PreProcessingInputCSV(csv_file, ',', 3, {'eins': 0, 'zwei': 1}) 159 | 160 | self.assertEqual(pp_in.count(), 27) 161 | 162 | n = 0 163 | for t in range(3): 164 | read_images, read_labels = pp_in.read(10) 165 | self.assertEqual(len(read_images), len(read_labels)) 166 | 167 | for i in range(len(read_images)): 168 | for ii in range(3): 169 | self.assertAlmostEqual(images[n][ii], read_images[i][ii, 0, 0]) 170 | 171 | self.assertEqual(labels[n], read_labels[i]) 172 | n += 1 173 | 174 | self.assertTrue(pp_in.end()) 175 | 176 | def test_pre_processing(self): 177 | """ 178 | Test Pre-Processing. 179 | """ 180 | 181 | N = 27 182 | H = 10 183 | W = 10 184 | C = 3 185 | 186 | files = [] 187 | images = [] 188 | labels = [] 189 | 190 | pos_path = 'tests/test_pos' 191 | neg_path = 'tests/test_neg' 192 | lmdb_path = 'tests/test_lmdb' 193 | 194 | if os.path.exists(pos_path): 195 | shutil.rmtree(pos_path) 196 | 197 | if os.path.exists(neg_path): 198 | shutil.rmtree(neg_path) 199 | 200 | if os.path.exists(lmdb_path): 201 | shutil.rmtree(lmdb_path) 202 | 203 | os.mkdir(pos_path) 204 | os.mkdir(neg_path) 205 | 206 | for n in range(N): 207 | image = (numpy.random.rand(H, W, C)*255).astype(numpy.uint8) 208 | 209 | label = random.randint(0, 1) 210 | path = neg_path + '/' + str(n) + '.png' 211 | 212 | if label == 1: 213 | path = pos_path + '/' + str(n) + '.png' 214 | 215 | files.append(path) 216 | cv2.imwrite(path, image) 217 | 218 | images.append(image) 219 | labels.append(label) 220 | 221 | pp_in = tools.pre_processing.PreProcessingInputFiles(files, labels) 222 | pp_out = tools.pre_processing.PreProcessingOutputLMDB(lmdb_path) 223 | pp = tools.pre_processing.PreProcessing(pp_in, pp_out, 10) 224 | 225 | pp.run() 226 | 227 | lmdb = tools.lmdb_io.LMDB(lmdb_path) 228 | read_images, read_labels, read_keys = lmdb.read() 229 | self.assertEqual(len(images), len(labels)) 230 | 231 | n = 0 232 | for n in range(N): 233 | for i in range(H): 234 | for j in range(W): 235 | for c in range(C): 236 | self.assertEqual(images[n][i, j, c], read_images[n][i, j, c]) 237 | 238 | self.assertEqual(labels[n], read_labels[n]) 239 | n += 1 240 | 241 | def test_pre_processing_csv(self): 242 | """ 243 | Test CSV pre processing. 244 | """ 245 | 246 | csv_file = 'tests/csv.csv' 247 | if os.path.exists(csv_file): 248 | os.unlink(csv_file) 249 | 250 | lmdb_path = 'tests/test_lmdb' 251 | if os.path.exists(lmdb_path): 252 | shutil.rmtree(lmdb_path) 253 | 254 | images = [] 255 | labels = [] 256 | 257 | with open(csv_file, 'w') as f: 258 | for n in range(100): 259 | image = [random.random(), random.random(), random.random()] 260 | 261 | label = 0 262 | if random.random() > 0.5: 263 | label = 1 264 | 265 | if label == 0: 266 | f.write(str(image[0]) + ',' + str(image[1]) + ',' + str(image[2]) + ',eins\n') 267 | else: 268 | f.write(str(image[0]) + ',' + str(image[1]) + ',' + str(image[2]) + ',zwei\n') 269 | 270 | images.append(image) 271 | labels.append(label) 272 | 273 | pp_in = tools.pre_processing.PreProcessingInputCSV(csv_file, ',', 3, {'eins': 0, 'zwei': 1}) 274 | pp_out = tools.pre_processing.PreProcessingOutputLMDB(lmdb_path) 275 | pp = tools.pre_processing.PreProcessing(pp_in, pp_out, 10) 276 | 277 | pp.run() 278 | 279 | lmdb = tools.lmdb_io.LMDB(lmdb_path) 280 | read_images, read_labels, read_keys = lmdb.read() 281 | 282 | self.assertEqual(len(images), 100) 283 | self.assertEqual(len(images), len(labels)) 284 | 285 | for n in range(100): 286 | for i in range(3): 287 | self.assertAlmostEqual(images[n][i], read_images[n][i, 0, 0]) 288 | 289 | self.assertEqual(labels[n], read_labels[n]) 290 | 291 | if __name__ == '__main__': 292 | unittest.main() -------------------------------------------------------------------------------- /tests/prototxt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for :mod:`tools.prototxt`. 3 | """ 4 | 5 | import tools.prototxt 6 | import unittest 7 | import caffe 8 | 9 | class TestPrototxt(unittest.TestCase): 10 | """ 11 | Tests for :mod:`tools.prototxt`. 12 | """ 13 | 14 | def test_train2deploy(self): 15 | """ 16 | Test train to deploy conversion. 17 | """ 18 | 19 | def network(lmdb_path, batch_size): 20 | net = caffe.NetSpec() 21 | 22 | net.data, net.labels = caffe.layers.Data(batch_size = batch_size, 23 | backend = caffe.params.Data.LMDB, 24 | source = lmdb_path, 25 | transform_param = dict(scale = 1./255), 26 | ntop = 2) 27 | 28 | net.conv1 = caffe.layers.Convolution(net.data, kernel_size = 5, num_output = 20, 29 | weight_filler = dict(type = 'xavier')) 30 | net.pool1 = caffe.layers.Pooling(net.conv1, kernel_size = 2, stride = 2, 31 | pool = caffe.params.Pooling.MAX) 32 | net.conv2 = caffe.layers.Convolution(net.pool1, kernel_size = 5, num_output = 50, 33 | weight_filler = dict(type = 'xavier')) 34 | net.pool2 = caffe.layers.Pooling(net.conv2, kernel_size = 2, stride = 2, 35 | pool = caffe.params.Pooling.MAX) 36 | net.fc1 = caffe.layers.InnerProduct(net.pool2, num_output = 500, 37 | weight_filler = dict(type = 'xavier')) 38 | net.relu1 = caffe.layers.ReLU(net.fc1, in_place = True) 39 | net.score = caffe.layers.InnerProduct(net.relu1, num_output = 10, 40 | weight_filler = dict(type = 'xavier')) 41 | net.loss = caffe.layers.SoftmaxWithLoss(net.score, net.labels) 42 | 43 | return net.to_proto() 44 | 45 | train_prototxt_path = 'tests/train.prototxt' 46 | deploy_prototxt_path = 'tests/deploy.prototxt' 47 | 48 | with open(train_prototxt_path, 'w') as f: 49 | f.write(str(network('tests/train_lmdb', 128))) 50 | 51 | tools.prototxt.train2deploy(train_prototxt_path, (128, 3, 28, 28), deploy_prototxt_path) 52 | 53 | if __name__ == '__main__': 54 | unittest.main() -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for Caffe. 3 | """ -------------------------------------------------------------------------------- /tools/data_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data augmentation methods. These methods are used in :mod:`tools.layers` and 3 | assume that the data is shaped according to Caffe (i.e. batch, height, width, channels) 4 | and some data augmentation methods assume float data. 5 | """ 6 | 7 | import cv2 8 | import numpy 9 | 10 | def multiplicative_gaussian_noise(images, std = 0.05): 11 | """ 12 | Multiply with Gaussian noise. 13 | 14 | :param images: images (or data) in Caffe format (batch_size, height, width, channels) 15 | :type images: numpy.ndarray 16 | :param std: standard deviation of Gaussian 17 | :type std: float 18 | :return: images (or data) with multiplicative Gaussian noise 19 | :rtype: numpy.ndarray 20 | """ 21 | 22 | assert images.ndim == 4 23 | assert images.dtype == numpy.float32 24 | 25 | return numpy.multiply(images, numpy.random.randn(images.shape[0], images.shape[1], images.shape[2], images.shape[3])*std + 1) 26 | 27 | def additive_gaussian_noise(images, std = 0.05): 28 | """ 29 | Add Gaussian noise to the images. 30 | 31 | :param images: images (or data) in Caffe format (batch_size, height, width, channels) 32 | :type images: numpy.ndarray 33 | :param std: standard deviation of Gaussian 34 | :type std: float 35 | :return: images (or data) with additive Gaussian noise 36 | :rtype: numpy.ndarray 37 | """ 38 | 39 | assert images.ndim == 4 40 | assert images.dtype == numpy.float32 41 | 42 | return images + numpy.random.randn(images.shape[0], images.shape[1], images.shape[2], images.shape[3])*std 43 | 44 | def crop(images, crop): 45 | """ 46 | Crop the images along all dimensions. 47 | 48 | :param images: images (or data) in Caffe format (batch_size, height, width, channels) 49 | :type images: numpy.ndarray 50 | :param crop: cropy in (crop_left, crop_top, crop_right, crop_bottom) 51 | :type crop: (int, int, int, int) 52 | :return: images (or data) cropped 53 | :rtype: numpy.ndarray 54 | """ 55 | 56 | assert images.ndim == 4 57 | assert images.dtype == numpy.float32 58 | assert len(crop) == 4 59 | assert crop[0] >= 0 and crop[1] >= 0 and crop[2] >= 0 and crop[3] >= 0 60 | assert crop[0] + crop[2] <= images.shape[2] 61 | assert crop[1] + crop[3] <= images.shape[1] 62 | 63 | return images[:, crop[1]:images.shape[1] - crop[3], crop[0]:images.shape[2] - crop[2], :] 64 | 65 | def flip(images): 66 | """ 67 | Flip the images horizontally. 68 | 69 | :param images: images (or data) in Caffe format (batch_size, height, width, channels) 70 | :type images: numpy.ndarray 71 | :return: images (or data) flipped horizontally 72 | :rtype: numpy.ndarray 73 | """ 74 | 75 | pass 76 | 77 | def drop_color_gaussian(images, channel, mean = 0.5, std = 0.05): 78 | """ 79 | Drop the specified color channel and replace by Gaussian noise with given mean 80 | and standard deviation. 81 | 82 | :param images: images (or data) in Caffe format (batch_size, height, width, channels) 83 | :type images: numpy.ndarray 84 | :param channel: channel to drop 85 | :type channel: int 86 | :param mean: mean of Gaussian noise 87 | :type mean: float 88 | :param std: standard deviation of Gaussian noise 89 | :type std: float 90 | :return: images (or data) dropped channel 91 | :rtype: numpy.ndarray 92 | """ 93 | 94 | assert images.ndim == 4 95 | assert images.dtype == numpy.float32 96 | assert images.shape[3] == 3 97 | 98 | channels = [] 99 | for i in range(images.shape[3]): 100 | if i == channel: 101 | channels.append(numpy.random.randn(images.shape[0], images.shape[1], images.shape[3], 1)) 102 | else: 103 | channels.append(images[:, :, :, i].reshape(images.shape[0], images.shape[1], images.shape[3], 1)) 104 | 105 | return numpy.concatenate(tuple(channels), axis = 3) 106 | 107 | def scaling_artifacts(image, factor = .5, interpolation = cv2.INTER_LINEAR): 108 | """ 109 | Introduce scaling articacts by downscaling to the given factor and upscaling again. 110 | 111 | :param images: images (or data) in Caffe format (batch_size, height, width, channels) 112 | :type images: numpy.ndarray 113 | :param factor: factor to downsample by 114 | :type factor: float 115 | :param interpolation: interpolation to use, see OpenCV documentation for resize 116 | :type interpolation: int 117 | :return: images (or data) with scaling artifacts 118 | :rtype: numpy.ndarray 119 | """ 120 | 121 | pass 122 | 123 | def contrast(images, exponent): 124 | """ 125 | Apply contrast transformation. 126 | 127 | :param images: images (or data) in Caffe format (batch_size, height, width, channels) 128 | :type images: numpy.ndarray 129 | :param exponent: exponentfor contrast transformation 130 | :type exponent: float 131 | :return: images (or data) with contrast normalization 132 | :rtype: numpy.ndarray 133 | """ 134 | 135 | return numpy.power(images, exponent) -------------------------------------------------------------------------------- /tools/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python layers. 3 | """ 4 | 5 | import caffe 6 | import numpy 7 | import random 8 | import tools.data_augmentation 9 | 10 | class TestLayer(caffe.Layer): 11 | """ 12 | A test layer meant for testing purposes which actually does nothing. 13 | Note, however, to use the force_backward: true option in the net specification 14 | to enable the backward pass in layers without parameters. 15 | """ 16 | 17 | def setup(self, bottom, top): 18 | """ 19 | Checks the correct number of bottom inputs. 20 | 21 | :param bottom: bottom inputs 22 | :type bottom: [numpy.ndarray] 23 | :param top: top outputs 24 | :type top: [numpy.ndarray] 25 | """ 26 | 27 | pass 28 | 29 | def reshape(self, bottom, top): 30 | """ 31 | Make sure all involved blobs have the right dimension. 32 | 33 | :param bottom: bottom inputs 34 | :type bottom: caffe._caffe.RawBlobVec 35 | :param top: top outputs 36 | :type top: caffe._caffe.RawBlobVec 37 | """ 38 | 39 | top[0].reshape(bottom[0].data.shape[0], bottom[0].data.shape[1], bottom[0].data.shape[2], bottom[0].data.shape[3]) 40 | 41 | def forward(self, bottom, top): 42 | """ 43 | Forward propagation. 44 | 45 | :param bottom: bottom inputs 46 | :type bottom: caffe._caffe.RawBlobVec 47 | :param top: top outputs 48 | :type top: caffe._caffe.RawBlobVec 49 | """ 50 | 51 | top[0].data[...] = bottom[0].data 52 | 53 | def backward(self, top, propagate_down, bottom): 54 | """ 55 | Backward pass. 56 | 57 | :param bottom: bottom inputs 58 | :type bottom: caffe._caffe.RawBlobVec 59 | :param propagate_down: 60 | :type propagate_down: 61 | :param top: top outputs 62 | :type top: caffe._caffe.RawBlobVec 63 | """ 64 | 65 | bottom[0].diff[...] = top[0].diff[...] 66 | 67 | class DataAugmentationDoubleLabelsLayer(caffe.Layer): 68 | """ 69 | All data augmentation labels double or quadruple the number of samples per 70 | batch. This layer is the base layer to double or quadruple the 71 | labels accordingly. 72 | """ 73 | 74 | def setup(self, bottom, top): 75 | """ 76 | Checks the correct number of bottom inputs. 77 | 78 | :param bottom: bottom inputs 79 | :type bottom: [numpy.ndarray] 80 | :param top: top outputs 81 | :type top: [numpy.ndarray] 82 | """ 83 | 84 | self._k = 2 85 | 86 | def reshape(self, bottom, top): 87 | """ 88 | Make sure all involved blobs have the right dimension. 89 | 90 | :param bottom: bottom inputs 91 | :type bottom: caffe._caffe.RawBlobVec 92 | :param top: top outputs 93 | :type top: caffe._caffe.RawBlobVec 94 | """ 95 | 96 | if len(bottom[0].shape) == 4: 97 | top[0].reshape(self._k*bottom[0].data.shape[0], bottom[0].data.shape[1], bottom[0].data.shape[2], bottom[0].data.shape[3]) 98 | elif len(bottom[0].shape) == 3: 99 | top[0].reshape(self._k*bottom[0].data.shape[0], bottom[0].data.shape[1], bottom[0].data.shape[2]) 100 | elif len(bottom[0].shape) == 2: 101 | top[0].reshape(self._k*bottom[0].data.shape[0], bottom[0].data.shape[1]) 102 | else: 103 | top[0].reshape(self._k*bottom[0].data.shape[0]) 104 | 105 | def forward(self, bottom, top): 106 | """ 107 | Forward propagation. 108 | 109 | :param bottom: bottom inputs 110 | :type bottom: caffe._caffe.RawBlobVec 111 | :param top: top outputs 112 | :type top: caffe._caffe.RawBlobVec 113 | """ 114 | 115 | batch_size = bottom[0].data.shape[0] 116 | if len(bottom[0].shape) == 4: 117 | top[0].data[0:batch_size, :, :, :] = bottom[0].data 118 | 119 | for i in range(self._k - 1): 120 | top[0].data[(i + 1)*batch_size:(i + 2)*batch_size, :, :, :] = bottom[0].data 121 | elif len(bottom[0].shape) == 3: 122 | top[0].data[0:batch_size, :, :] = bottom[0].data 123 | 124 | for i in range(self._k - 1): 125 | top[0].data[(i + 1)*batch_size:(i + 2)*batch_size, :, :] = bottom[0].data 126 | elif len(bottom[0].shape) == 2: 127 | top[0].data[0:batch_size, :] = bottom[0].data 128 | 129 | for i in range(self._k - 1): 130 | top[0].data[(i + 1)*batch_size:(i + 2)*batch_size, :] = bottom[0].data 131 | else: 132 | top[0].data[0:batch_size] = bottom[0].data 133 | 134 | for i in range(self._k - 1): 135 | top[0].data[(i + 1)*batch_size:(i + 2)*batch_size] = bottom[0].data 136 | 137 | def backward(self, top, propagate_down, bottom): 138 | """ 139 | Backward pass. 140 | 141 | :param bottom: bottom inputs 142 | :type bottom: caffe._caffe.RawBlobVec 143 | :param propagate_down: 144 | :type propagate_down: 145 | :param top: top outputs 146 | :type top: caffe._caffe.RawBlobVec 147 | """ 148 | 149 | pass 150 | 151 | class DataAugmentationMultiplicativeGaussianNoiseLayer(caffe.Layer): 152 | """ 153 | Multiplicative Gaussian noise. 154 | """ 155 | 156 | def setup(self, bottom, top): 157 | """ 158 | Checks the correct number of bottom inputs. 159 | 160 | :param bottom: bottom inputs 161 | :type bottom: [numpy.ndarray] 162 | :param top: top outputs 163 | :type top: [numpy.ndarray] 164 | """ 165 | 166 | pass 167 | 168 | def reshape(self, bottom, top): 169 | """ 170 | Make sure all involved blobs have the right dimension. 171 | 172 | :param bottom: bottom inputs 173 | :type bottom: caffe._caffe.RawBlobVec 174 | :param top: top outputs 175 | :type top: caffe._caffe.RawBlobVec 176 | """ 177 | 178 | top[0].reshape(2*bottom[0].data.shape[0], bottom[0].data.shape[1], bottom[0].data.shape[2], bottom[0].data.shape[3]) 179 | 180 | def forward(self, bottom, top): 181 | """ 182 | Forward propagation. 183 | 184 | :param bottom: bottom inputs 185 | :type bottom: caffe._caffe.RawBlobVec 186 | :param top: top outputs 187 | :type top: caffe._caffe.RawBlobVec 188 | """ 189 | 190 | batch_size = bottom[0].data.shape[0] 191 | top[0].data[0:batch_size, :, :, :] = bottom[0].data 192 | top[0].data[batch_size:2*batch_size, :, :, :] = tools.data_augmentation.multiplicative_gaussian_noise(bottom[0].data) 193 | 194 | def backward(self, top, propagate_down, bottom): 195 | """ 196 | Backward pass. 197 | 198 | :param bottom: bottom inputs 199 | :type bottom: caffe._caffe.RawBlobVec 200 | :param propagate_down: 201 | :type propagate_down: 202 | :param top: top outputs 203 | :type top: caffe._caffe.RawBlobVec 204 | """ 205 | 206 | pass 207 | 208 | class DataAugmentationAdditiveGaussianNoiseLayer(caffe.Layer): 209 | """ 210 | Additive Gaussian noise. 211 | """ 212 | 213 | def setup(self, bottom, top): 214 | """ 215 | Checks the correct number of bottom inputs. 216 | 217 | :param bottom: bottom inputs 218 | :type bottom: [numpy.ndarray] 219 | :param top: top outputs 220 | :type top: [numpy.ndarray] 221 | """ 222 | 223 | pass 224 | 225 | def reshape(self, bottom, top): 226 | """ 227 | Make sure all involved blobs have the right dimension. 228 | 229 | :param bottom: bottom inputs 230 | :type bottom: caffe._caffe.RawBlobVec 231 | :param top: top outputs 232 | :type top: caffe._caffe.RawBlobVec 233 | """ 234 | 235 | top[0].reshape(2*bottom[0].data.shape[0], bottom[0].data.shape[1], bottom[0].data.shape[2], bottom[0].data.shape[3]) 236 | 237 | def forward(self, bottom, top): 238 | """ 239 | Forward propagation. 240 | 241 | :param bottom: bottom inputs 242 | :type bottom: caffe._caffe.RawBlobVec 243 | :param top: top outputs 244 | :type top: caffe._caffe.RawBlobVec 245 | """ 246 | 247 | batch_size = bottom[0].data.shape[0] 248 | top[0].data[0:batch_size, :, :, :] = bottom[0].data 249 | top[0].data[batch_size:2*batch_size, :, :, :] = tools.data_augmentation.additive_gaussian_noise(bottom[0].data) 250 | 251 | def backward(self, top, propagate_down, bottom): 252 | """ 253 | Backward pass. 254 | 255 | :param bottom: bottom inputs 256 | :type bottom: caffe._caffe.RawBlobVec 257 | :param propagate_down: 258 | :type propagate_down: 259 | :param top: top outputs 260 | :type top: caffe._caffe.RawBlobVec 261 | """ 262 | 263 | pass 264 | 265 | class DataAugmentationQuadrupleCropsLayer(caffe.Layer): 266 | """ 267 | Quadruple the data with random crops. Note that this reduces the size of the input 268 | by (per default) 4 pixels in each dimension. 269 | """ 270 | 271 | def setup(self, bottom, top): 272 | """ 273 | Checks the correct number of bottom inputs. 274 | 275 | :param bottom: bottom inputs 276 | :type bottom: [numpy.ndarray] 277 | :param top: top outputs 278 | :type top: [numpy.ndarray] 279 | """ 280 | 281 | pass 282 | 283 | def reshape(self, bottom, top): 284 | """ 285 | Make sure all involved blobs have the right dimension. 286 | 287 | :param bottom: bottom inputs 288 | :type bottom: caffe._caffe.RawBlobVec 289 | :param top: top outputs 290 | :type top: caffe._caffe.RawBlobVec 291 | """ 292 | 293 | top[0].reshape(2*bottom[0].data.shape[0], bottom[0].data.shape[1], bottom[0].data.shape[2], bottom[0].data.shape[3]) 294 | 295 | def forward(self, bottom, top): 296 | """ 297 | Forward propagation. 298 | 299 | :param bottom: bottom inputs 300 | :type bottom: caffe._caffe.RawBlobVec 301 | :param top: top outputs 302 | :type top: caffe._caffe.RawBlobVec 303 | """ 304 | 305 | batch_size = bottom[0].data.shape[0] 306 | crop_left = random.randint(0, 4) 307 | crop_top = random.randint(0, 4) 308 | top[0].data[0:batch_size, :, :, :] = tools.data_augmentation.crop(bottom[0].data, (crop_left, crop_top, 4 - crop_left, 4 - crop_top)) 309 | 310 | crop_left = random.randint(0, 4) 311 | crop_top = random.randint(0, 4) 312 | top[0].data[batch_size:2*batch_size, :, :, :] = tools.data_augmentation.crop(bottom[0].data, (crop_left, crop_top, 4 - crop_left, 4 - crop_top)) 313 | 314 | crop_left = random.randint(0, 4) 315 | crop_top = random.randint(0, 4) 316 | top[0].data[2*batch_size:3*batch_size, :, :, :] = tools.data_augmentation.crop(bottom[0].data, (crop_left, crop_top, 4 - crop_left, 4 - crop_top)) 317 | 318 | crop_left = random.randint(0, 4) 319 | crop_top = random.randint(0, 4) 320 | top[0].data[3*batch_size:4*batch_size, :, :, :] = tools.data_augmentation.crop(bottom[0].data, (crop_left, crop_top, 4 - crop_left, 4 - crop_top)) 321 | 322 | def backward(self, top, propagate_down, bottom): 323 | """ 324 | Backward pass. 325 | 326 | :param bottom: bottom inputs 327 | :type bottom: caffe._caffe.RawBlobVec 328 | :param propagate_down: 329 | :type propagate_down: 330 | :param top: top outputs 331 | :type top: caffe._caffe.RawBlobVec 332 | """ 333 | 334 | pass 335 | 336 | class ManhattenLoss(caffe.Layer): 337 | """ 338 | Compute the Manhatten Loss. 339 | """ 340 | 341 | def setup(self, bottom, top): 342 | """ 343 | Checks the correct number of bottom inputs. 344 | 345 | :param bottom: bottom inputs 346 | :type bottom: [numpy.ndarray] 347 | :param top: top outputs 348 | :type top: [numpy.ndarray] 349 | """ 350 | 351 | if len(bottom) != 2: 352 | raise Exception('Need two bottom inputs for Manhatten distance.') 353 | 354 | def reshape(self, bottom, top): 355 | """ 356 | Make sure all involved blobs have the right dimension. 357 | 358 | :param bottom: bottom inputs 359 | :type bottom: caffe._caffe.RawBlobVec 360 | :param top: top outputs 361 | :type top: caffe._caffe.RawBlobVec 362 | """ 363 | 364 | # Check bottom dimensions. 365 | if bottom[0].count != bottom[1].count: 366 | raise Exception('Inputs of both bottom inputs have to match.') 367 | 368 | # Set shape of diff to input shape. 369 | self.diff = numpy.zeros_like(bottom[0].data, dtype = numpy.float32) 370 | 371 | # Set output dimensions: 372 | top[0].reshape(1) 373 | 374 | def forward(self, bottom, top): 375 | """ 376 | Forward propagation, i.e. compute the Manhatten loss. 377 | 378 | :param bottom: bottom inputs 379 | :type bottom: caffe._caffe.RawBlobVec 380 | :param top: top outputs 381 | :type top: caffe._caffe.RawBlobVec 382 | """ 383 | 384 | scores = bottom[0].data # network output 385 | labels = bottom[1].data.reshape(scores.shape) # labels 386 | 387 | self.diff[...] = (-1)*(scores < labels).astype(int) \ 388 | + (scores > labels).astype(int) 389 | 390 | top[0].data[0] = numpy.sum(numpy.abs(scores - labels)) / bottom[0].num 391 | 392 | def backward(self, top, propagate_down, bottom): 393 | """ 394 | Backward pass. 395 | 396 | :param bottom: bottom inputs 397 | :type bottom: caffe._caffe.RawBlobVec 398 | :param propagate_down: 399 | :type propagate_down: 400 | :param top: top outputs 401 | :type top: caffe._caffe.RawBlobVec 402 | """ 403 | 404 | for i in range(2): 405 | if not propagate_down[i]: 406 | continue 407 | 408 | if i == 0: 409 | sign = 1 410 | else: 411 | sign = -1 412 | 413 | # also see the discussion at http://davidstutz.de/pycaffe-tools-examples-and-resources/ 414 | bottom[i].diff[...] = (sign * self.diff * top[0].diff[0] / bottom[i].num).reshape(bottom[i].diff.shape) 415 | -------------------------------------------------------------------------------- /tools/lmdb_io.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for comfortably reading and writing LMDBs. 3 | """ 4 | 5 | import lmdb 6 | import numpy 7 | import re 8 | 9 | import caffe 10 | 11 | def version_compare(version_a, version_b): 12 | """ 13 | Compare two versions given as strings, taken from `here`_. 14 | 15 | .. _here: http://stackoverflow.com/questions/1714027/version-number-comparison 16 | 17 | :param version_a: version a 18 | :type version_a: string 19 | :param version_b: version b 20 | :type version_b: string 21 | :return: 0 if versions are equivalent, < 0 if version_a is lower than version_b 22 | , > 0 if version_b is lower than version_b 23 | """ 24 | def normalize(v): 25 | return [int(x) for x in re.sub(r'(\.0+)*$','', v).split(".")] 26 | 27 | return cmp(normalize(version_a), normalize(version_b)) 28 | 29 | def to_key(i): 30 | """ 31 | Transform the given id integer to the key used by :class:`lmdb_io.LMDB`. 32 | 33 | :param i: integer id 34 | :type i: int 35 | :return: string key 36 | :rtype: string 37 | """ 38 | 39 | return '{:08}'.format(i) 40 | 41 | class LMDB: 42 | """ 43 | Utility class to read and write LMDBs. The code is based on the `LMDB documentation`_, 44 | as well as `this blog post`_. 45 | 46 | .. _LMDB documentation: https://lmdb.readthedocs.io/en/release/ 47 | .. _this blog post: http://deepdish.io/2015/04/28/creating-lmdb-in-python/ 48 | """ 49 | 50 | def __init__(self, lmdb_path): 51 | """ 52 | Constructor, given LMDB path. 53 | 54 | :param lmdb_path: path to LMDB 55 | :type lmdb_path: string 56 | """ 57 | 58 | self._lmdb_path = lmdb_path 59 | """ (string) The path to the LMDB to read or write. """ 60 | 61 | self._write_pointer = 0 62 | """ (int) Pointer for writing and appending. """ 63 | 64 | def read(self, key = ''): 65 | """ 66 | Read a single element or the whole LMDB depending on whether 'key' 67 | is specified. Essentially a prox for :func:`lmdb.LMDB.read_single` 68 | and :func:`lmdb.LMDB.read_all`. 69 | 70 | :param key: key as 8-digit string of the entry to read 71 | :type key: string 72 | :return: data and labels from the LMDB as associate dictionaries, where 73 | the key as string is the dictionary key and the value the numpy.ndarray 74 | for the data and the label for the labels 75 | :rtype: ({string: numpy.ndarray}, {string: float}) 76 | """ 77 | 78 | if not key: 79 | return self.read_all(); 80 | else: 81 | return self.read_single(key); 82 | 83 | def read_single(self, key): 84 | """ 85 | Read a single element according to the given key. Note that data in an 86 | LMDB is organized using string keys, which are eight-digit numbers 87 | when using this class to write and read LMDBs. 88 | 89 | :param key: the key to read 90 | :type key: string 91 | :return: image, label and corresponding key 92 | :rtype: (numpy.ndarray, int, string) 93 | """ 94 | 95 | image = False 96 | label = False 97 | env = lmdb.open(self._lmdb_path, readonly = True) 98 | 99 | with env.begin() as transaction: 100 | raw = transaction.get(key) 101 | datum = caffe.proto.caffe_pb2.Datum() 102 | datum.ParseFromString(raw) 103 | 104 | label = datum.label 105 | if datum.data: 106 | image = numpy.fromstring(datum.data, dtype = numpy.uint8).reshape(datum.channels, datum.height, datum.width).transpose(1, 2, 0) 107 | else: 108 | image = numpy.array(datum.float_data).astype(numpy.float).reshape(datum.channels, datum.height, datum.width).transpose(1, 2, 0) 109 | 110 | return image, label, key 111 | 112 | def read_all(self): 113 | """ 114 | Read the whole LMDB. The method will return the data and labels (if 115 | applicable) as dictionary which is indexed by the eight-digit numbers 116 | stored as strings. 117 | 118 | :return: images, labels and corresponding keys 119 | :rtype: ([numpy.ndarray], [int], [string]) 120 | """ 121 | 122 | images = [] 123 | labels = [] 124 | keys = [] 125 | env = lmdb.open(self._lmdb_path, readonly = True) 126 | 127 | with env.begin() as transaction: 128 | cursor = transaction.cursor(); 129 | 130 | for key, raw in cursor: 131 | datum = caffe.proto.caffe_pb2.Datum() 132 | datum.ParseFromString(raw) 133 | 134 | label = datum.label 135 | 136 | if datum.data: 137 | image = numpy.fromstring(datum.data, dtype = numpy.uint8).reshape(datum.channels, datum.height, datum.width).transpose(1, 2, 0) 138 | else: 139 | image = numpy.array(datum.float_data).astype(numpy.float).reshape(datum.channels, datum.height, datum.width).transpose(1, 2, 0) 140 | 141 | images.append(image) 142 | labels.append(label) 143 | keys.append(key) 144 | 145 | return images, labels, keys 146 | 147 | def count(self): 148 | """ 149 | Get the number of elements in the LMDB. 150 | 151 | :return: count of elements 152 | :rtype: int 153 | """ 154 | 155 | env = lmdb.open(self._lmdb_path) 156 | with env.begin() as transaction: 157 | return transaction.stat()['entries'] 158 | 159 | def keys(self, n = 0): 160 | """ 161 | Get the first n (or all) keys of the LMDB 162 | 163 | :param n: number of keys to get, 0 to get all keys 164 | :type n: int 165 | :return: list of keys 166 | :rtype: [string] 167 | """ 168 | 169 | keys = [] 170 | env = lmdb.open(self._lmdb_path, readonly = True) 171 | 172 | with env.begin() as transaction: 173 | cursor = transaction.cursor() 174 | 175 | i = 0 176 | for key, value in cursor: 177 | 178 | if i >= n and n > 0: 179 | break; 180 | 181 | keys.append(key) 182 | i += 1 183 | 184 | return keys 185 | 186 | def write(self, images, labels = []): 187 | """ 188 | Write a single image or multiple images and the corresponding label(s). 189 | The imags are expected to be two-dimensional NumPy arrays with 190 | multiple channels (if applicable). 191 | 192 | :param images: input images as list of numpy.ndarray with height x width x channels 193 | :type images: [numpy.ndarray] 194 | :param labels: corresponding labels (if applicable) as list 195 | :type labels: [float] 196 | :return: list of keys corresponding to the written images 197 | :rtype: [string] 198 | """ 199 | 200 | if len(labels) > 0: 201 | assert len(images) == len(labels) 202 | 203 | keys = [] 204 | env = lmdb.open(self._lmdb_path, map_size = max(1099511627776, len(images)*images[0].nbytes)) 205 | 206 | with env.begin(write = True) as transaction: 207 | for i in range(len(images)): 208 | datum = caffe.proto.caffe_pb2.Datum() 209 | datum.channels = images[i].shape[2] 210 | datum.height = images[i].shape[0] 211 | datum.width = images[i].shape[1] 212 | 213 | assert version_compare(numpy.version.version, '1.9') < 0, "installed numpy is 1.9 or higher, change .tostring() to .tobytes()" 214 | assert images[i].dtype == numpy.uint8 or images[i].dtype == numpy.float, "currently only numpy.uint8 and numpy.float images are supported" 215 | 216 | if images[i].dtype == numpy.uint8: 217 | datum.data = images[i].transpose(2, 0, 1).tostring() 218 | else: 219 | datum.float_data.extend(images[i].transpose(2, 0, 1).flat) 220 | 221 | if len(labels) > 0: 222 | datum.label = labels[i] 223 | 224 | key = to_key(self._write_pointer) 225 | keys.append(key) 226 | 227 | transaction.put(key.encode('ascii'), datum.SerializeToString()); 228 | self._write_pointer += 1 229 | 230 | return keys 231 | -------------------------------------------------------------------------------- /tools/pre_processing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pre-processing for Caffe. 3 | """ 4 | 5 | import os 6 | import cv2 7 | import csv 8 | import numpy 9 | import random 10 | import tools.lmdb_io 11 | 12 | class PreProcessing(object): 13 | """ 14 | Pre-processing utilities to normalize data and compute mean as well as 15 | standard deviation. 16 | """ 17 | 18 | def __init__(self, source, dest, batch_size = 100000): 19 | """ 20 | Constructor, needs a :class:`tools.pre_processing.PreProcessingInput` as 21 | input source and a :class:`tools.pre_processing.PreProcessingOutput` as 22 | output destination. 23 | 24 | :param source: input source 25 | :type source: (tools.pre_processing.PreProcessingInput) 26 | :param dest: output destination 27 | :type dest: (tools.pre_processing.PreProcessingOutput) 28 | :param batch_size: batch size in which to process the images 29 | :type batch_size: int 30 | """ 31 | 32 | self._source = source 33 | """ (tools.pre_processing.PreProcessingInput) Input source. """ 34 | 35 | self._dest = dest 36 | """ (tools.pre_processing.PreProcessingOutput) Output source. """ 37 | 38 | self._batch_size = batch_size 39 | """ (int) Batch size to process images in. """ 40 | 41 | def run(self): 42 | """ 43 | Run pre-processing, in this case simply writes from source to destination. 44 | """ 45 | 46 | while not self._source.end(): 47 | images, labels = self._source.read(self._batch_size) 48 | self._dest.write(images, labels) 49 | 50 | class PreProcessingNormalize(PreProcessing): 51 | """ 52 | Normalize the data to lie in [0, 1] by dividing by a fixed given value. 53 | """ 54 | 55 | def __init__(self, source, dest, normalizer = 255., batch_size = 100000): 56 | """ 57 | Constructor, needs a :class:`tools.pre_processing.PreProcessingInput` as 58 | input source and a :class:`tools.pre_processing.PreProcessingOutput` as 59 | output destination. 60 | 61 | :param source: input source 62 | :type source: (tools.pre_processing.PreProcessingInput) 63 | :param dest: output destination 64 | :type dest: (tools.pre_processing.PreProcessingOutput) 65 | :param normalizer: value to normalize by 66 | :type normalizer: float 67 | :param batch_size: batch size in which to process the images 68 | :type batch_size: int 69 | """ 70 | 71 | super(PreProcessingNormalize, self).__init__(source, dest, batch_size) 72 | 73 | self._normalizer = normalizer 74 | """ (float) The value to normalize by. """ 75 | 76 | def run(self): 77 | """ 78 | Run pre-procesisng, this will normalize the images by the overall mean. 79 | 80 | :return: mean of normalized data 81 | :rtype: numpy.ndarray 82 | """ 83 | 84 | while not self._source.end(): 85 | images, labels = self._source.read(self._batch_size) 86 | 87 | normalized = [] 88 | for i in range(len(images)): 89 | normalized_image = images[i]/float(self._normalizer) 90 | normalized.append(normalized_image) 91 | 92 | self._dest.write(normalized, labels) 93 | 94 | class PreProcessingSplit: 95 | """ 96 | Pre processing utilities to split the data into training and test set 97 | or training, validaiton and test sets. 98 | """ 99 | 100 | def __init__(self, source, dests, split = (0.9, 0.1), batch_size = 100000): 101 | """ 102 | Constructor, needs a :class:`tools.pre_processing.PreProcessingInput` as 103 | input source and a :class:`tools.pre_processing.PreProcessingOutput` as 104 | output destination. 105 | 106 | :param source: input source 107 | :type source: (tools.pre_processing.PreProcessingInput) 108 | :param dests: output destinations for training/validation/test sets 109 | :type dests: ((tools.pre_processing.PreProcessingOutput, tools.pre_processing.PreProcessingOutput, tools.pre_processing.PreProcessingOutput)) 110 | :param batch_size: batch size in which to process the images 111 | :type batch_size: int 112 | :param split: the train/validation/test split, a tuple of probabilities 113 | summing to one, either 3 probabilities or two (without validation set) 114 | :type split: (int, int, int) or (int, int) 115 | """ 116 | 117 | assert len(split) == 2 or len(split) == 3, "split should contain 2 or 3 probabilities" 118 | 119 | # http://stackoverflow.com/questions/5595425/what-is-the-best-way-to-compare-floats-for-almost-equality-in-python 120 | def isclose(a, b, relative_tolerance = 1e-09, absolute_tolerance = 1e-06): 121 | return abs(a - b) <= max(relative_tolerance * max(abs(a), abs(b)), absolute_tolerance) 122 | 123 | if len(split) == 2: 124 | assert isclose(split[0] + split[1], 1.0), "the probabilities do not sum to 1" 125 | if len(split) == 3: 126 | assert isclose(split[0] + split[1] + split[2], 1.0), "the porbabilities do not sum to 1" 127 | 128 | assert len(split) == len(dests), "number of destinations does not fit number of split probabilities" 129 | 130 | self._source = source 131 | """ (tools.pre_processing.PreProcessingInput) Input source. """ 132 | 133 | self._dests = dests 134 | """ (tools.pre_processing.PreProcessingOutput,tools.pre_processing.PreProcessingOutput,tools.pre_processing.PreProcessingOutput) Output source. """ 135 | 136 | self._split = split 137 | """ ((float, float) or (float, float, float)) split in training/validation/test sets. """ 138 | 139 | self._batch_size = batch_size 140 | """ (int) Batch size to process images in. """ 141 | 142 | def run(self): 143 | """ 144 | Run pre-processing, decide to output in training/validation/test sets. 145 | """ 146 | 147 | while not self._source.end(): 148 | images, labels = self._source.read(self._batch_size) 149 | 150 | training_images = [] 151 | validation_images = [] 152 | test_images = [] 153 | 154 | training_labels = [] 155 | validation_labels = [] 156 | test_labels = [] 157 | 158 | for n in range(len(images)): 159 | r = random.random() 160 | if r < self._split[0]: 161 | training_images.append(images[n]) 162 | if len(labels) > 0: 163 | training_labels.append(labels[n]) 164 | 165 | else: 166 | if len(self._split) == 2: 167 | test_images.append(images[n]) 168 | if len(labels) > 0: 169 | test_labels.append(labels[n]) 170 | 171 | else: 172 | if r >= self._split[0] + self._split[1]: 173 | test_images.append(images[n]) 174 | if len(labels) > 0: 175 | test_labels.append(labels[n]) 176 | 177 | else: 178 | validation_images.append(images[n]) 179 | if len(labels) > 0: 180 | validation_labels.append(labels[n]) 181 | 182 | self._dests[0].write(training_images, training_labels) 183 | 184 | if len(self._split) == 2: 185 | self._dests[1].write(test_images, test_labels) 186 | else: 187 | self._dests[1].write(validation_images, validation_labels) 188 | self._dests[2].write(test_images, test_labels) 189 | 190 | class PreProcessingSubsample(PreProcessing): 191 | """ 192 | Pre processing utilities to shuffle the data. This is a simple approach of 193 | shuffling that only works if the batch size is larger than the dataset size! 194 | """ 195 | 196 | def __init__(self, source, dest, p = 0.5, batch_size = 100000): 197 | """ 198 | Constructor, needs a :class:`tools.pre_processing.PreProcessingInput` as 199 | input source and a :class:`tools.pre_processing.PreProcessingOutput` as 200 | output destination. 201 | 202 | :param source: input source 203 | :type source: (tools.pre_processing.PreProcessingInput) 204 | :param dest: output destination 205 | :type dest: (tools.pre_processing.PreProcessingOutput) 206 | :param p: the probability of taking a sample 207 | :type p: float 208 | :param batch_size: batch size in which to process the images 209 | :type batch_size: int 210 | """ 211 | 212 | super(PreProcessingSubsample, self).__init__(source, dest, batch_size) 213 | 214 | self._p = p 215 | """ (float) Probability of taking a sample. """ 216 | 217 | def run(self): 218 | """ 219 | Run pre-processing, i.e. shuffle the data. 220 | """ 221 | 222 | while not self._source.end(): 223 | 224 | read_images, read_labels = self._source.read(self._batch_size) 225 | indices = numpy.random.choice(len(read_images), len(read_images)) 226 | 227 | write_images = [] 228 | write_labels = [] 229 | 230 | for index in indices: 231 | r = random.random() 232 | if r < self._p: 233 | write_images.append(read_images[index]) 234 | if len(read_labels) > 0: 235 | write_labels.append(read_labels[index]) 236 | 237 | self._dest.write(write_images, write_labels) 238 | 239 | class PreProcessingShuffle(PreProcessing): 240 | """ 241 | Pre processing utilities to shuffle the data. This is a simple approach of 242 | shuffling that only works if the batch size is larger than the dataset size! 243 | """ 244 | 245 | def __init__(self, source, dest, iterations = 10, batch_size = 100000): 246 | """ 247 | Constructor, needs a :class:`tools.pre_processing.PreProcessingInput` as 248 | input source and a :class:`tools.pre_processing.PreProcessingOutput` as 249 | output destination. 250 | 251 | :param source: input source 252 | :type source: (tools.pre_processing.PreProcessingInput) 253 | :param dest: output destination 254 | :type dest: (tools.pre_processing.PreProcessingOutput) 255 | :param batch_size: batch size in which to process the images 256 | :type batch_size: int 257 | """ 258 | 259 | self._source = source 260 | """ (tools.pre_processing.PreProcessingInput) Input source. """ 261 | 262 | self._dest = dest 263 | """ (tools.pre_processing.PreProcessingOutput) Output source. """ 264 | 265 | self._batch_size = batch_size 266 | """ (int) Batch size to process images in. """ 267 | 268 | def run(self): 269 | """ 270 | Run pre-processing, i.e. shuffle the data. 271 | """ 272 | 273 | while not self._source.end(): 274 | 275 | read_images, read_labels = self._source.read(self._batch_size) 276 | indices = numpy.random.choice(len(read_images), len(read_images)) 277 | 278 | write_images = [] 279 | write_labels = [] 280 | 281 | for index in indices: 282 | write_images.append(read_images[index]) 283 | if len(read_labels) > 0: 284 | write_labels.append(read_labels[index]) 285 | 286 | self._dest.write(write_images, write_labels) 287 | 288 | class PreProcessingInput: 289 | """ 290 | Provides the input data for :class:`tools.pre_processing.PreProcessing`. 291 | """ 292 | 293 | def reset(self): 294 | """ 295 | Reset reading to start from the beginning. 296 | """ 297 | 298 | raise NotImplementedError("Should have been implemented!") 299 | 300 | def read(self, n): 301 | """ 302 | Read data in batches. 303 | 304 | :param n: number of images to read 305 | :type n: int 306 | :return: images and optionally labels as lists 307 | :rtype: ([numpy.ndarray], [float]) 308 | """ 309 | 310 | raise NotImplementedError("Should have been implemented!") 311 | 312 | def count(self): 313 | """ 314 | Return the count of imags. 315 | 316 | :return: count 317 | :rtype: int 318 | """ 319 | 320 | raise NotImplementedError("Should have been implemented!") 321 | 322 | def end(self): 323 | """ 324 | Whether the end has been reached. 325 | 326 | :return: true if the end has been reached or overstepped 327 | :rtype: bool 328 | """ 329 | 330 | raise NotImplementedError("Should have been implemented!") 331 | 332 | class PreProcessingInputLMDB(PreProcessingInput): 333 | """ 334 | Provides the input data for :class:`tools.pre_processing.PreProcessing` 335 | from an LMDB. 336 | """ 337 | 338 | def __init__(self, lmdb_path): 339 | """ 340 | Constructor, provide path to LMDB. 341 | 342 | :param lmdb_path: path to LMDB 343 | :type lmdb_path: string 344 | """ 345 | 346 | self._lmdb = tools.lmdb_io.LMDB(lmdb_path) 347 | """ (tools.lmdb_io.LMDB) Underlying LMDB. """ 348 | 349 | self._keys = self._lmdb.keys() 350 | """ ([string]) Keys of elements stored in the LMDB. """ 351 | 352 | self._pointer = 0 353 | """ (int) Current index to start reading. """ 354 | 355 | def reset(self): 356 | """ 357 | Reset reading to start from the beginning. 358 | """ 359 | 360 | self._pointer = 0 361 | 362 | def read(self, n): 363 | """ 364 | Read data in batches. 365 | 366 | :param n: number of images to read 367 | :type n: int 368 | :return: images and optionally labels as lists 369 | :rtype: ([numpy.ndarray], [float]) 370 | """ 371 | 372 | images = [] 373 | labels = [] 374 | keys = self._keys[self._pointer: min(self._pointer + n, len(self._keys))] 375 | 376 | for key in keys: 377 | image, label, key = self._lmdb.read_single(key) 378 | 379 | images.append(image) 380 | labels.append(label) 381 | 382 | self._pointer += n 383 | 384 | return images, labels 385 | 386 | def count(self): 387 | """ 388 | Return the count of imags. 389 | 390 | :return: count 391 | :rtype: int 392 | """ 393 | 394 | return len(self._keys) 395 | 396 | def end(self): 397 | """ 398 | Whether the end has been reached. 399 | 400 | :return: true if the end has been reached or overstepped 401 | :rtype: bool 402 | """ 403 | 404 | return self._pointer >= len(self._keys) 405 | 406 | class PreProcessingInputFiles(PreProcessingInput): 407 | """ 408 | Provide input data for :class:`tools.pre_processing.PreProcessing` based on a list 409 | of file paths. 410 | """ 411 | 412 | def __init__(self, files, labels = []): 413 | """ 414 | Constructor, provide list of files and optional list of labels. 415 | 416 | :param files: file paths 417 | :type files: [string] 418 | :param labels: labels 419 | :type labels: [float] 420 | """ 421 | 422 | assert len(files) > 0, "files is empty" 423 | 424 | self._files = files 425 | """ ([string]) File paths. """ 426 | 427 | if len(labels) > 0: 428 | assert len(labels) == len(files), "if labels are provided there needs to be a label for each file" 429 | 430 | self._labels = labels 431 | """ ([float]) Labels. """ 432 | 433 | self._pointer = 0 434 | """ (int) Current index to start reading. """ 435 | 436 | def reset(self): 437 | """ 438 | Reset reading to start from the beginning. 439 | """ 440 | 441 | self._pointer = 0 442 | 443 | def read(self, n): 444 | """ 445 | Read data in batches. 446 | 447 | :param n: number of images to read 448 | :type n: int 449 | :return: images and optionally labels as lists 450 | :rtype: ([numpy.ndarray], [float]) 451 | """ 452 | 453 | images = [] 454 | labels = [] 455 | files = self._files[self._pointer: min(self._pointer + n, len(self._files))] 456 | 457 | labels = [] 458 | if len(self._labels) > 0: 459 | labels = self._labels[self._pointer: min(self._pointer + n, len(self._labels))] 460 | 461 | for i in range(len(files)): 462 | assert os.path.exists(files[i]), "file %s not found" % files[i] 463 | 464 | image = cv2.imread(files[i]) 465 | images.append(image) 466 | 467 | self._pointer += n 468 | 469 | return images, labels 470 | 471 | def count(self): 472 | """ 473 | Return the count of imags. 474 | 475 | :return: count 476 | :rtype: int 477 | """ 478 | 479 | return len(self._files) 480 | 481 | def end(self): 482 | """ 483 | Whether the end has been reached. 484 | 485 | :return: true if the end has been reached or overstepped 486 | :rtype: bool 487 | """ 488 | 489 | return self._pointer >= len(self._files) 490 | 491 | class PreProcessingInputCSV: 492 | """ 493 | Allows :class:`tools.pre_processing.PreProcessing` to take input from a 494 | CSV file. 495 | """ 496 | 497 | def __init__(self, csv_file, delimiter = ',', label_column = -1, label_column_mapping = {}): 498 | """ 499 | Constructor. 500 | 501 | :param csv_file: path to the csv file to use 502 | :type csv_file: string 503 | :param csv_delimiter: delimited used between cells 504 | :type csv_delimiter: string 505 | :param label_column: the label column index, or -1 if label is not 506 | provided in the CSV file 507 | :type label_column: int 508 | :param label_column_mapping: the mapping from categoric labels to label 509 | indices if the labels provided in the CSV file are category names, 510 | or an empty object if the labels are already saved as integers 511 | :type label_column_mapping: {string: int} 512 | """ 513 | 514 | assert os.path.exists(csv_file), "the CSV file could not be found" 515 | self._csv_file = csv_file 516 | """ (string) CSV file path. """ 517 | 518 | self._delimiter = delimiter 519 | """ (string) Delimiter between cells for CSV file. """ 520 | 521 | self._label_column = label_column 522 | """ (int) The column to use for labels, -1 if label is not present. """ 523 | 524 | self._label_column_mapping = label_column_mapping 525 | """ ({string: intr}) The mapping from categoric label names to label indices if required. """ 526 | 527 | self._columns = -1 528 | """ (int) Number of columns. """ 529 | 530 | self._rows = 0 531 | """ (int) Number of rows. """ 532 | 533 | self._pointer = 0 534 | """ (int) Pointer to current row. """ 535 | 536 | with open(self._csv_file) as f: 537 | self._rows = 0 538 | for cells in csv.reader(f, delimiter = self._delimiter): 539 | cells = [cell.strip() for cell in cells if len(cell.strip()) > 0] 540 | 541 | if self._columns < 0: 542 | self._columns = len(cells) 543 | 544 | if len(cells) > 0: 545 | assert self._columns == len(cells), "CSV file does not contain a consistent number of columns" 546 | self._rows += 1 547 | 548 | def reset(self): 549 | """ 550 | Reset reading to start from the beginning. 551 | """ 552 | 553 | self._pointer = 0 554 | 555 | def read(self, n): 556 | """ 557 | Read data in batches. 558 | 559 | :param n: number of images to read 560 | :type n: int 561 | :return: images and optionally labels as lists 562 | :rtype: ([numpy.ndarray], [float]) 563 | """ 564 | 565 | images = [] 566 | labels = [] 567 | 568 | with open(self._csv_file) as f: 569 | row = 0 570 | for cells in csv.reader(f, delimiter = self._delimiter): 571 | if row == self._pointer and n > 0: 572 | cells = [cell.strip() for cell in cells if len(cell.strip()) > 0] 573 | 574 | if len(cells) > 0: 575 | assert self._columns == len(cells), "CSV file does not contain a consistent number of columns" 576 | 577 | if self._label_column < 0: 578 | cells = [float(cell) for cell in cells] 579 | else: 580 | label = cells[self._label_column] 581 | cells = cells[0:self._label_column] + cells[self._label_column + 1:] 582 | 583 | if len(self._label_column_mapping) > 0: 584 | assert label in self._label_column_mapping, "label %s not found in label_column_mapping" % label 585 | label = int(self._label_column_mapping[label]) 586 | 587 | labels.append(label) 588 | images.append(numpy.array(cells).reshape(len(cells), 1, 1).astype(float)) 589 | 590 | self._pointer += 1 591 | n -= 1 592 | 593 | row += 1 594 | 595 | return images, labels 596 | 597 | def count(self): 598 | """ 599 | Return the count of imags. 600 | 601 | :return: count 602 | :rtype: int 603 | """ 604 | 605 | return self._rows 606 | 607 | def end(self): 608 | """ 609 | Whether the end has been reached. 610 | 611 | :return: true if the end has been reached or overstepped 612 | :rtype: bool 613 | """ 614 | 615 | return self._pointer >= self._rows 616 | 617 | class PreProcessingOutput: 618 | """ 619 | Allows :class:`tools.pre_processing.PreProcessing` to write its output. 620 | """ 621 | 622 | def write(self, images, labels = []): 623 | """ 624 | Write the images and the given labels as output. 625 | 626 | :param images: list of images as numpy.ndarray 627 | :type images: [numpy.ndarray] 628 | """ 629 | 630 | raise NotImplementedError("Should have been implemented!") 631 | 632 | class PreProcessingOutputLMDB(PreProcessingOutput): 633 | """ 634 | Allows :class:`tools.pre_processing.PreProcessing` to write its output 635 | to an LMDB. 636 | """ 637 | 638 | def __init__(self, lmdb_path): 639 | """ 640 | Constructor, provide path to LMDB. 641 | 642 | :param lmdb_path: path to LMDB 643 | :type lmdb_path: string 644 | """ 645 | 646 | self._lmdb = tools.lmdb_io.LMDB(lmdb_path) 647 | """ (tools.lmdb_io.LMDB) Underlying LMDB. """ 648 | 649 | def write(self, images, labels = []): 650 | """ 651 | Write the images and the given labels as output. 652 | 653 | :param images: list of images as numpy.ndarray 654 | :type images: [numpy.ndarray] 655 | """ 656 | 657 | self._lmdb.write(images, labels) -------------------------------------------------------------------------------- /tools/prototxt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with prototxt files. 3 | """ 4 | 5 | from __future__ import print_function 6 | import matplotlib.pyplot as plt 7 | import multiprocessing 8 | import numpy 9 | import time 10 | import os 11 | import re 12 | 13 | def train2deploy(train_prototxt, size, deploy_prototxt): 14 | """ 15 | Convert a train prototxt file to deploy prototxt file by removing labels 16 | and adding correct input shape. 17 | 18 | :param train_prototxt: path to train prototxt 19 | :type train_prototxt: string 20 | :param size: blob input size in (batch_size, channels, height, width) 21 | :type size: (int, int, int, int) 22 | :param deploy_prototxt: path to deploy prototxt 23 | :type deploy_prototxt: string 24 | """ 25 | 26 | assert len(size) == 4 27 | 28 | def replace_input_type(prototxt): 29 | """ 30 | Replace the input type. 31 | """ 32 | 33 | return prototxt.replace('type: "Data"', 'type: "Input"') 34 | 35 | def remove_labels(prototxt): 36 | """ 37 | Replace the input labels. 38 | """ 39 | 40 | return prototxt.replace('top: "labels"\n', '') 41 | 42 | def replace_input_param(prototxt, size): 43 | """ 44 | Replace the input_param with a correct one in the given size. 45 | """ 46 | 47 | started = False 48 | lines_old = prototxt.split('\n') 49 | lines_new = [] 50 | 51 | for line in lines_old: 52 | if line.find('data_param') >= 0: 53 | started = True 54 | elif started: 55 | if line.find('}') >= 0: 56 | lines_new.append(' input_param { shape: { dim: ' + str(size[0]) + ' dim: ' + str(size[1]) + ' dim: ' + str(size[2]) + ' dim: ' + str(size[3]) + ' } }') 57 | started = False 58 | else: 59 | lines_new.append(line) 60 | 61 | return '\n'.join(lines_new) 62 | 63 | def remove_transform_param(prototxt): 64 | """ 65 | Remove transform param. 66 | """ 67 | 68 | occurences = [(m.start(), m.end()) for m in re.finditer(r'\n[ \t]*transform_param[ \t]*\{[a-zA-Z0-9,.:" \t\r\n]*\}', prototxt)] 69 | 70 | if len(occurences) > 0: 71 | start = occurences[-1][0] 72 | end = occurences[-1][1] 73 | return prototxt[:start] + prototxt[end:] 74 | else: 75 | return prototxt 76 | 77 | def remove_loss(prototxt): 78 | """ 79 | Remove the loss layer. 80 | """ 81 | 82 | occurences = [m.start() for m in re.finditer(r'layer[ \t]*\{[a-zA-Z:" \t\r\n]*top:[ \t]*"loss"', prototxt)] 83 | 84 | if len(occurences) > 0: 85 | index = occurences[-1] 86 | return prototxt[:index] 87 | else: 88 | return prototxt 89 | 90 | with open(train_prototxt) as train: 91 | with open(deploy_prototxt, 'w') as deploy: 92 | prototxt_old = train.read() 93 | prototxt_new = replace_input_type(prototxt_old) 94 | prototxt_new = remove_labels(prototxt_new) 95 | prototxt_new = replace_input_param(prototxt_new, size) 96 | prototxt_new = remove_transform_param(prototxt_new) 97 | prototxt_new = remove_loss(prototxt_new) 98 | deploy.write(prototxt_new) 99 | -------------------------------------------------------------------------------- /tools/visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualization capabilities. 3 | """ 4 | 5 | import cv2 6 | import numpy 7 | 8 | def get_layers(net): 9 | """ 10 | Get the layer names of the network. 11 | 12 | :param net: caffe network 13 | :type net: caffe.Net 14 | :return: layer names 15 | :rtype: [string] 16 | """ 17 | 18 | return [layer for layer in net.params.keys()] 19 | 20 | def visualize_kernels(net, layer, zoom = 5): 21 | """ 22 | Visualize kernels in the given layer. 23 | 24 | :param net: caffe network 25 | :type net: caffe.Net 26 | :param layer: layer name 27 | :type layer: string 28 | :param zoom: the number of pixels (in width and height) per kernel weight 29 | :type zoom: int 30 | :return: image visualizing the kernels in a grid 31 | :rtype: numpy.ndarray 32 | """ 33 | 34 | assert layer in get_layers(net), "layer %s not found" % layer 35 | 36 | num_kernels = net.params[layer][0].data.shape[0] 37 | num_channels = net.params[layer][0].data.shape[1] 38 | kernel_height = net.params[layer][0].data.shape[2] 39 | kernel_width = net.params[layer][0].data.shape[3] 40 | 41 | image = numpy.zeros((num_kernels*zoom*kernel_height, num_channels*zoom*kernel_width)) 42 | for k in range(num_kernels): 43 | for c in range(num_channels): 44 | kernel = net.params[layer][0].data[k, c, :, :] 45 | kernel = cv2.resize(kernel, (zoom*kernel_height, zoom*kernel_width), kernel, 0, 0, cv2.INTER_NEAREST) 46 | kernel = (kernel - numpy.min(kernel))/(numpy.max(kernel) - numpy.min(kernel)) 47 | image[k*zoom*kernel_height:(k + 1)*zoom*kernel_height, c*zoom*kernel_width:(c + 1)*zoom*kernel_width] = kernel 48 | 49 | return image 50 | 51 | def visualize_weights(net, layer, zoom = 2): 52 | """ 53 | Visualize weights in a fully conencted layer. 54 | 55 | :param net: caffe network 56 | :type net: caffe.Net 57 | :param layer: layer name 58 | :type layer: string 59 | :param zoom: the number of pixels (in width and height) per weight 60 | :type zoom: int 61 | :return: image visualizing the kernels in a grid 62 | :rtype: numpy.ndarray 63 | """ 64 | 65 | assert layer in get_layers(net), "layer %s not found" % layer 66 | 67 | weights = net.params[layer][0].data 68 | weights = (weights - numpy.min(weights))/(numpy.max(weights) - numpy.min(weights)) 69 | return cv2.resize(weights, (weights.shape[0]*zoom, weights.shape[1]*zoom), weights, 0, 0, cv2.INTER_NEAREST) --------------------------------------------------------------------------------