├── .gitignore ├── .gitmodules ├── .travis.yml ├── Makefile ├── README.MD ├── barrista ├── __init__.py ├── config.py ├── design.py ├── initialization.py ├── monitoring.py ├── net.py ├── parallel.py ├── solver.py └── tools.py ├── documentation ├── Makefile ├── _static │ └── barrista.jpg ├── barrista.rst ├── conf.py ├── index.rst ├── make.bat ├── setup.rst └── usage.rst ├── examples ├── MNIST │ ├── README.txt │ ├── data.py │ ├── models │ │ └── basic.py │ ├── test.py │ ├── train.py │ └── visualize.py ├── residual-nets │ ├── README.txt │ ├── data.py │ ├── models │ │ ├── msra3.py │ │ └── msra9.py │ ├── test.py │ ├── train.py │ └── visualize.py └── showcase.py ├── license.txt ├── patches ├── barrista-patch-caffe-6eae122a8eb84f8371dde815986cd7524fc4cbaa.patch ├── barrista-patch-caffe-dc831aa8f5d3b7d9473958f5b9e745c98755a0a6.patch ├── barrista-patch-caffe-release-candidater-(tag-rc2).patch └── build_support │ ├── build.sh │ ├── pylint_call.sh │ └── travis_setup.sh ├── requirements.txt ├── setup.cfg ├── setup.py └── tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .nfs* 3 | 4 | # Coverage analysis. 5 | .coverage 6 | htmlcov 7 | 8 | # The setuptools temp-folders. 9 | *.egg-info 10 | *.egg 11 | 12 | # Generated documentation. 13 | documentation/_build 14 | 15 | # Example data results 16 | examples/MNIST/data 17 | examples/MNIST/models/*.png 18 | examples/MNIST/results 19 | examples/residual-nets/data 20 | examples/residual-nets/models/*.png 21 | examples/residual-nets/results 22 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "caffe"] 2 | path = caffe 3 | url = https://github.com/classner/caffe.git 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Use caffe and test barrista. 2 | env: 3 | matrix: 4 | - WITH_CUDA=false WITH_CMAKE=true WITH_IO=true PYTHON_VERSION=2 5 | # Disabling python 3 build for now, since caffe can't be built with the 6 | # most recent version. See https://github.com/BVLC/caffe/issues/2464. 7 | #- WITH_CUDA=false WITH_CMAKE=true WITH_IO=true PYTHON_VERSION=3 8 | 9 | language: cpp 10 | 11 | cache: 12 | apt: true 13 | directories: 14 | - /home/travis/miniconda 15 | - /home/travis/miniconda2 16 | - /home/travis/miniconda3 17 | 18 | compiler: gcc 19 | 20 | before_install: 21 | - cd caffe 22 | - export CAFFE_PYTHON_FOLDER=`pwd`/python 23 | - export CAFFE_BIN_FOLDER=`pwd`/build/install/bin 24 | - export NUM_THREADS=2 25 | - export SCRIPTS=./scripts/travis 26 | - export CONDA_DIR="/home/travis/miniconda$PYTHON_VERSION" 27 | 28 | install: 29 | - sudo -E ../patches/build_support/travis_setup.sh 30 | 31 | before_script: 32 | - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib:/usr/local/cuda/lib64:$CONDA_DIR/lib 33 | - export PATH=$CONDA_DIR/bin:$CONDA_DIR/scripts:$PATH 34 | - sudo -E $CONDA_DIR/bin/conda install --yes opencv || true 35 | - sudo -E $CONDA_DIR/bin/conda install --yes scikit-image scikit-learn pylint 36 | - sudo -E $CONDA_DIR/bin/pip install --upgrade pip 37 | - sudo -E $CONDA_DIR/bin/pip install progressbar2 38 | - cd .. 39 | - patches/build_support/pylint_call.sh 40 | - cd caffe 41 | 42 | script: 43 | - ../patches/build_support/build.sh 44 | - cd .. 45 | - export GLOG_minloglevel=2 46 | - coverage run --source=barrista setup.py test 47 | 48 | after_success: 49 | - coveralls 50 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | GH_PAGES_SOURCES = barrista documentation Makefile 2 | GH_PAGES_BUILD_BRANCH = master 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | rm -rf $(BUILDDIR)/* 51 | 52 | html: 53 | cd documentation && $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 54 | @echo 55 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 56 | 57 | dirhtml: 58 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 59 | @echo 60 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 61 | 62 | singlehtml: 63 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 64 | @echo 65 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 66 | 67 | pickle: 68 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 69 | @echo 70 | @echo "Build finished; now you can process the pickle files." 71 | 72 | json: 73 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 74 | @echo 75 | @echo "Build finished; now you can process the JSON files." 76 | 77 | htmlhelp: 78 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 79 | @echo 80 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 81 | ".hhp project file in $(BUILDDIR)/htmlhelp." 82 | 83 | qthelp: 84 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 85 | @echo 86 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 87 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 88 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/barrista.qhcp" 89 | @echo "To view the help file:" 90 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/barrista.qhc" 91 | 92 | devhelp: 93 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 94 | @echo 95 | @echo "Build finished." 96 | @echo "To view the help file:" 97 | @echo "# mkdir -p $$HOME/.local/share/devhelp/barrista" 98 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/barrista" 99 | @echo "# devhelp" 100 | 101 | epub: 102 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 103 | @echo 104 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 105 | 106 | latex: 107 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 108 | @echo 109 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 110 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 111 | "(use \`make latexpdf' here to do that automatically)." 112 | 113 | latexpdf: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo "Running LaTeX files through pdflatex..." 116 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 117 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 118 | 119 | latexpdfja: 120 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 121 | @echo "Running LaTeX files through platex and dvipdfmx..." 122 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 123 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 124 | 125 | text: 126 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 127 | @echo 128 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 129 | 130 | man: 131 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 132 | @echo 133 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 134 | 135 | texinfo: 136 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 137 | @echo 138 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 139 | @echo "Run \`make' in that directory to run these through makeinfo" \ 140 | "(use \`make info' here to do that automatically)." 141 | 142 | info: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo "Running Texinfo files through makeinfo..." 145 | make -C $(BUILDDIR)/texinfo info 146 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 147 | 148 | gettext: 149 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 150 | @echo 151 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 152 | 153 | changes: 154 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 155 | @echo 156 | @echo "The overview file is in $(BUILDDIR)/changes." 157 | 158 | linkcheck: 159 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 160 | @echo 161 | @echo "Link check complete; look for any errors in the above output " \ 162 | "or in $(BUILDDIR)/linkcheck/output.txt." 163 | 164 | doctest: 165 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 166 | @echo "Testing of doctests in the sources finished, look at the " \ 167 | "results in $(BUILDDIR)/doctest/output.txt." 168 | 169 | xml: 170 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 171 | @echo 172 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 173 | 174 | pseudoxml: 175 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 176 | @echo 177 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 178 | 179 | gh-pages: 180 | git checkout gh-pages 181 | rm -rf * 182 | git checkout $(GH_PAGES_BUILD_BRANCH) $(GH_PAGES_SOURCES) 183 | git reset HEAD 184 | mkdir -p documentation 185 | make html 186 | mv -fv documentation/_build/html/* ./ 187 | rm -rf $(GH_PAGES_SOURCES) 188 | touch .nojekyll 189 | git add -A 190 | git commit -m "Generated gh-pages for `git log $(GH_PAGES_BUILD_BRANCH) -1 --pretty=short --abbrev-commit`" && git push origin gh-pages ; git checkout $(GH_PAGES_BUILD_BRANCH) 191 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | 2 | [![Build Status](https://travis-ci.org/classner/barrista.svg?branch=unstable)](https://travis-ci.org/classner/barrista) 3 | [![Requirements Status](https://requires.io/github/classner/barrista/requirements.svg?branch=unstable)](https://requires.io/github/classner/barrista/requirements/?branch=unstable) 4 | [![Coverage Status](https://coveralls.io/repos/github/classner/barrista/badge.svg?branch=unstable)](https://coveralls.io/github/classner/barrista?branch=unstable) 5 | 6 | # Unmaintained 7 | 8 | With the rise of the company-backed wide-spread deep learning toolkits, we do 9 | not maintain barrista any more. It will not work with recent versions of caffe 10 | (please use the version that is linked as sub-repository). 11 | 12 | 13 | # The Barrista 14 | 15 | The barrista is a tool to enjoy [caffe](https://github.com/BVLC/caffe.git) 16 | particularly well from Python. It exposes the entire library functionality 17 | to Python, and you can create, modify, train, load and save 'caffe' networks 18 | conveniently through one API. For more information, 19 | see the dedicated [homepage & documentation](https://classner.github.io/barrista). 20 | 21 | ## CI status 22 | 23 | The master branch is always stable, with all tests passing. Development is 24 | going on on the unstable branch, and you can see the current build status 25 | here: 26 | 27 | [![Build Status](https://travis-ci.org/classner/barrista.svg?branch=unstable)](https://travis-ci.org/classner/barrista) unstable branch build status, Python 2 (Python 3 is deactivated for the moment, because of an instability with caffe on Python 3). 28 | -------------------------------------------------------------------------------- /barrista/__init__.py: -------------------------------------------------------------------------------- 1 | """The barrista main module.""" 2 | -------------------------------------------------------------------------------- /barrista/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | The configuration module for barrista. 4 | 5 | It is possible to programmatically change the configuration. For this, import 6 | `barrista.config` and edit the values as required before importing any other 7 | of the submodules. The configuration change is then taken into account. 8 | """ 9 | 10 | import os as _os 11 | 12 | #: This folder must contain the ``caffe`` module and is added to the python 13 | #: path after the first inclusion of the `initialization` module. 14 | if 'CAFFE_PYTHON_FOLDER' in list(_os.environ.keys()): 15 | CAFFE_PYTHON_FOLDER = _os.environ['CAFFE_PYTHON_FOLDER'] 16 | else: # pragma: no cover 17 | CAFFE_PYTHON_FOLDER = _os.path.abspath( 18 | _os.path.join( 19 | _os.path.dirname(__file__), '..', 'caffe', 'build', 'install', 'python')) 20 | 21 | #: This folder contains the file ``upgrade_net_proto_text``. 22 | if 'CAFFE_BIN_FOLDER' in list(_os.environ.keys()): 23 | CAFFE_BIN_FOLDER = _os.environ['CAFFE_BIN_FOLDER'] 24 | else: # pragma: no cover 25 | CAFFE_BIN_FOLDER = _os.path.join(CAFFE_PYTHON_FOLDER, 26 | '..', 'bin') 27 | 28 | #: This dictionary specifies the layer types and their configuration 29 | #: parameters. The keys are the layer keys, and the values a list of 30 | #: strings, where each string is the name of a parameter prefixed with 31 | #: `_caffe_pb2.` . 32 | LAYER_TYPES = {'AbsVal': [], 33 | 'Accuracy': ['AccuracyParameter'], 34 | 'ArgMax': ['ArgMaxParameter'], 35 | 'BatchNorm': ['BatchNormParameter'], 36 | 'Concat': ['ConcatParameter'], 37 | 'ContrastiveLoss': ['ContrastiveLossParameter'], 38 | 'Convolution': ['ConvolutionParameter'], 39 | 'Data': ['DataParameter'], 40 | 'Dropout': ['DropoutParameter'], 41 | 'DummyData': ['DummyDataParameter'], 42 | 'Embed': ['EmbedParameter'], 43 | 'Eltwise': ['EltwiseParameter'], 44 | 'EuclideanLoss': ['LossParameter'], 45 | 'Exp': ['ExpParameter'], 46 | 'Filter': [], 47 | 'Flatten': ['FlattenParameter'], 48 | 'HDF5Data': ['HDF5DataParameter'], 49 | 'HDF5Output': ['HDF5OutputParameter'], 50 | 'HingeLoss': ['HingeLossParameter'], 51 | 'ImageData': ['ImageDataParameter'], 52 | 'InfogainLoss': ['InfogainLossParameter'], 53 | 'InnerProduct': ['InnerProductParameter'], 54 | 'Log': ['LogParameter'], 55 | 'LRN': ['LRNParameter'], 56 | # Do not add this layer! It is superfluous with this interface 57 | # and might just be a source of bugs. 58 | # 'MemoryData': ['MemoryDataParameter'], 59 | 'MultinomialLogisticLoss': ['LossParameter'], 60 | 'MVN': ['MVNParameter'], 61 | 'Pooling': ['PoolingParameter'], 62 | 'Power': ['PowerParameter'], 63 | 'PReLU': ['PReLUParameter'], 64 | 'Python': ['PythonParameter'], 65 | 'Reduction': ['ReductionParameter'], 66 | 'ReLU': ['ReLUParameter'], 67 | 'Resample': [], 68 | 'Reshape': ['ReshapeParameter'], 69 | 'Scale': ['ScaleParameter'], 70 | 'Sigmoid': ['SigmoidParameter'], 71 | 'SigmoidCrossEntropyLoss': ['LossParameter'], 72 | 'Silence': [], 73 | 'Slice': ['SliceParameter'], 74 | 'Softmax': ['SoftmaxParameter'], 75 | 'SoftmaxWithLoss': ['SoftmaxParameter', 76 | 'LossParameter'], 77 | 'Split': [], 78 | 'SPP': ['SPPParameter'], 79 | 'TanH': ['TanHParameter'], 80 | 'Threshold': ['ThresholdParameter'], 81 | 'VariableHingeLoss': ['LossParameter'], 82 | 'WindowData': ['WindowDataParameter']} 83 | -------------------------------------------------------------------------------- /barrista/initialization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Initialization module. 4 | 5 | Adds caffe to the pythonpath when imported. Any changes of :mod:`config` 6 | objects must be done before the import. Any imports of caffe related objects 7 | may only be done after. 8 | """ 9 | # pylint: disable=E0611, F0401, wrong-import-order, wrong-import-position 10 | 11 | from .config import CAFFE_PYTHON_FOLDER as _PYCAFFE_FOLDER 12 | import os as _os 13 | import sys as _sys 14 | import logging as _logging 15 | 16 | _LOGGER = _logging.getLogger(__name__) 17 | 18 | if not _os.path.exists(_os.path.join(_PYCAFFE_FOLDER, 'caffe')): # pragma: no cover 19 | _LOGGER.warn('The caffe module does not exist in %s! It is specified as ' + 20 | 'barrista.CAFFE_PYTHON_FOLDER! Trying to fall back on ' + 21 | 'caffe on the python path.', 22 | _PYCAFFE_FOLDER) 23 | try: 24 | # pylint: disable=W0611 25 | import caffe as _caffe 26 | except ImportError: 27 | raise Exception('Failed to add the CAFFE_PYTHON_FOLDER and caffe is ' + 28 | 'not on the PYTHONPATH!') 29 | else: 30 | _sys.path.insert(0, _PYCAFFE_FOLDER) 31 | 32 | 33 | def init(): 34 | """Empty at the moment.""" 35 | pass 36 | -------------------------------------------------------------------------------- /barrista/parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 2 | """Collection of parallel tools.""" 3 | # pylint: disable=invalid-name, protected-access, redefined-outer-name 4 | from __future__ import print_function 5 | 6 | import logging as _logging 7 | import warnings as _warnings 8 | _DEBUG_SERIAL = False 9 | if _DEBUG_SERIAL: 10 | import multiprocessing.dummy as _multiprocessing 11 | else: 12 | import multiprocessing as _multiprocessing 13 | from multiprocessing import Array as _mpa 14 | from multiprocessing import log_to_stderr as _log_to_stderr 15 | import numpy as _np 16 | 17 | import barrista.monitoring as _monitoring 18 | 19 | 20 | class DummyNet(object): # pylint: disable=too-few-public-methods 21 | 22 | """Drop-in replacement to simulate blobs for the ParallelMonitors.""" 23 | 24 | def __init__(self): 25 | self._filled = False 26 | self.blobs = {} 27 | 28 | 29 | class DummyBlob(object): 30 | 31 | """Replacement network blob using SharedMemory for the ParallelMonitors.""" 32 | 33 | def __init__(self, shared_shape, shared_data, real_shape): 34 | self.shared_data = shared_data 35 | self.shared_shape = shared_shape 36 | self.real_shape = real_shape 37 | 38 | @property 39 | def shape(self): 40 | """Get the current blob shape.""" 41 | shape = _np.ctypeslib.as_array(self.shared_shape) 42 | return shape 43 | 44 | @property 45 | def data(self): 46 | """Get the blob data.""" 47 | data = _np.ctypeslib.as_array(self.shared_data) 48 | shape = self.shape 49 | if len(shape) == 4: 50 | return data.reshape(self.real_shape)[ 51 | :shape[0], :shape[1], :shape[2], :shape[3]] 52 | else: 53 | return data.reshape(self.real_shape) 54 | 55 | def reshape(self, num, chan, height, width): # pragma: no cover 56 | """Simulate the blob reshape method.""" 57 | shape = self.shape 58 | if len(shape) != 4: 59 | raise Exception("Can only reshape 4D blobs!") 60 | assert num == shape[0] 61 | assert chan == shape[1] 62 | assert height <= self.real_shape[2] 63 | assert width <= self.real_shape[3] 64 | shape[2] = height 65 | shape[3] = width 66 | 67 | 68 | # Coverage can not be detected in sub-threads. 69 | def init_filler(dummynet, filler_cbs, in_train_mode): # pragma: no cover 70 | """Initialize a filler thread.""" 71 | # pylint: disable=global-variable-undefined, global-variable-not-assigned 72 | global net, cbs, train_mode, initialized, logger 73 | logger = _log_to_stderr(_logging.WARN) 74 | logger.debug("Initializing filler. Train mode: %s.", in_train_mode) 75 | net = dummynet 76 | cbs = filler_cbs 77 | train_mode = in_train_mode 78 | initialized = False 79 | 80 | 81 | def run_cbs(cbparams): # pragma: no cover 82 | """Run the callbacks in this filler thread.""" 83 | # pylint: disable=global-variable-undefined, global-variable-not-assigned 84 | global net, cbs, train_mode, initialized, logger 85 | logger.debug("Preparing batch. cbparams: %s.", cbparams) 86 | if train_mode: 87 | cbparams['net'] = net 88 | else: 89 | cbparams['testnet'] = net 90 | for cb in cbs: 91 | cb(cbparams) 92 | 93 | 94 | def finalize_cbs(cbparams): # pragma: no cover 95 | """Finalize the callbacks in this filler thread.""" 96 | # pylint: disable=global-variable-undefined, global-variable-not-assigned 97 | global cbs, logger 98 | logger.debug("Finalizing callbacks.") 99 | for cb in cbs: 100 | cb.finalize(cbparams) 101 | 102 | 103 | def init_prebatch(self, # pylint: disable=too-many-locals 104 | net, 105 | callbacks, 106 | train_mode): 107 | """ 108 | Initialize parallel pre-batch processing. 109 | 110 | Should be used with the `run_prebatch` method from this module. 111 | 112 | The object must have the properties: 113 | 114 | * _parallel_batch_res_train (None) 115 | * _parallel_batch_rest_test (None) 116 | * _train_net_dummy (None) 117 | * _parallel_train_filler (None) 118 | * _test_net_dummy (None) 119 | * _parallel_test_filler (None) 120 | 121 | whereas the properties with train or test in their name are only used if 122 | the method is used for the respective `train_mode`. 123 | """ 124 | if train_mode: 125 | assert self._parallel_batch_res_train is None 126 | assert self._train_net_dummy is None 127 | assert self._parallel_train_filler is None 128 | else: 129 | assert self._test_net_dummy is None 130 | assert self._parallel_test_filler is None 131 | assert self._parallel_batch_res_test is None 132 | parallelcbs = [cb for cb in callbacks 133 | if isinstance(cb, _monitoring.ParallelMonitor) and not _DEBUG_SERIAL] 134 | nublobnames = [] 135 | for cb in parallelcbs: 136 | nublobnames.extend(cb.get_parallel_blob_names()) 137 | dummyblobs = list(set(nublobnames)) 138 | dummyshapes = [list(net.blobs[db].shape) for db in dummyblobs] 139 | dummynet = DummyNet() 140 | for bname, bsh in zip(dummyblobs, dummyshapes): 141 | if len(bsh) == 4: 142 | real_shape = (bsh[0], bsh[1], bsh[2] * 3, bsh[3] * 3) 143 | else: 144 | real_shape = bsh 145 | shared_arr = _mpa( 146 | 'f', 147 | _np.zeros(_np.prod(real_shape), dtype='float32'), 148 | lock=False) 149 | shared_sh = _mpa( 150 | 'i', 151 | _np.zeros(len(bsh), dtype='int'), 152 | lock=False) 153 | dummynet.blobs[bname] = DummyBlob(shared_sh, shared_arr, real_shape) 154 | with _warnings.catch_warnings(): 155 | # For more information on why this is necessary, see 156 | # https://www.reddit.com/r/Python/comments/j3qjb/parformatlabpool_replacement 157 | _warnings.simplefilter('ignore', RuntimeWarning) 158 | dummynet.blobs[bname].shape[...] = bsh 159 | filler_cbs = [cb for cb in callbacks 160 | if isinstance(cb, _monitoring.ParallelMonitor) and not _DEBUG_SERIAL] 161 | if train_mode: 162 | self._train_net_dummy = dummynet 163 | self._parallel_train_filler = _multiprocessing.Pool( 164 | 1, 165 | initializer=init_filler, 166 | initargs=(dummynet, filler_cbs, True)) 167 | else: 168 | self._test_net_dummy = dummynet 169 | self._parallel_test_filler = _multiprocessing.Pool( 170 | 1, 171 | initializer=init_filler, 172 | initargs=(dummynet, filler_cbs, False)) 173 | 174 | 175 | def _extract_ncbparams(cbparams): 176 | ncbparams = { 177 | 'iter': cbparams['iter'], 178 | 'callback_signal': cbparams['callback_signal'], 179 | 'max_iter': cbparams['max_iter'], 180 | 'batch_size': cbparams['batch_size'] 181 | } 182 | if 'test_interval' in list(cbparams.keys()): 183 | ncbparams['test_interval'] = cbparams['test_interval'] 184 | return ncbparams 185 | 186 | 187 | def run_prebatch(self, # pylint: disable=too-many-branches, too-many-arguments 188 | callbacks, 189 | cbparams, 190 | train_mode, 191 | iter_p1, 192 | run_pre): 193 | """Run the prebatch callbacks.""" 194 | # Prepare the parameters for the parallel workers. 195 | ncbparams = _extract_ncbparams(cbparams) 196 | if train_mode: 197 | dummy = self._train_net_dummy 198 | net = cbparams['net'] 199 | else: 200 | dummy = self._test_net_dummy 201 | net = cbparams['testnet'] 202 | if run_pre: 203 | # Run pre_test or pre_fit. 204 | callback_signal = 'pre_fit' if train_mode else 'pre_test' 205 | cbs_orig = cbparams['callback_signal'] 206 | cbparams['callback_signal'] = callback_signal 207 | for cb in [cb for cb in callbacks 208 | if not isinstance(cb, _monitoring.ParallelMonitor) or _DEBUG_SERIAL]: 209 | cb(cbparams) 210 | cbparams['callback_signal'] = cbs_orig 211 | # For the parallel workers. 212 | ncbparams['callback_signal'] = callback_signal 213 | if train_mode: 214 | self._parallel_train_filler.apply(run_cbs, 215 | args=(ncbparams,)) 216 | else: 217 | self._parallel_test_filler.apply(run_cbs, 218 | args=(ncbparams,)) 219 | # Set the test dummy as cleared, so to achieve reproducibility 220 | # for test results if the test dataset size is not a multiple of 221 | # batch size times test iterations. 222 | dummy._filled = False 223 | ncbparams['callback_signal'] = cbs_orig 224 | if not dummy._filled: 225 | # Run the callbacks. 226 | for cb in [callb for callb in callbacks 227 | if not isinstance(callb, _monitoring.ParallelMonitor) or _DEBUG_SERIAL]: 228 | cb(cbparams) 229 | if train_mode: 230 | self._parallel_batch_res_train =\ 231 | self._parallel_train_filler.apply_async( 232 | run_cbs, 233 | args=(ncbparams,)) 234 | else: 235 | self._parallel_batch_res_test =\ 236 | self._parallel_test_filler.apply_async( 237 | run_cbs, 238 | args=(ncbparams,)) 239 | dummy._filled = True 240 | # Get the parallel results. 241 | if train_mode: 242 | self._parallel_batch_res_train.get() 243 | else: 244 | self._parallel_batch_res_test.get() 245 | # Copy over the prepared data. 246 | for bname in list(dummy.blobs.keys()): 247 | if not _np.all(dummy.blobs[bname].data.shape == 248 | net.blobs[bname].data.shape): 249 | dummyshape = dummy.blobs[bname].data.shape 250 | net.blobs[bname].reshape(dummyshape[0], 251 | dummyshape[1], 252 | dummyshape[2], 253 | dummyshape[3]) 254 | net.blobs[bname].data[...] = dummy.blobs[bname].data 255 | # Start next parallel run. 256 | ncbparams['iter'] = iter_p1 257 | if train_mode: 258 | self._parallel_batch_res_train =\ 259 | self._parallel_train_filler.apply_async( 260 | run_cbs, args=(ncbparams,)) 261 | else: 262 | self._parallel_batch_res_test =\ 263 | self._parallel_test_filler.apply_async( 264 | run_cbs, args=(ncbparams,)) 265 | # Execute the serially-to-execute monitors. 266 | for cb in callbacks: 267 | if not isinstance(cb, _monitoring.ParallelMonitor) or _DEBUG_SERIAL: 268 | cb(cbparams) 269 | 270 | def finalize_prebatch(self, cbparams): 271 | """Cleanup workers and artifacts.""" 272 | ncbparams = _extract_ncbparams(cbparams) 273 | if hasattr(self, '_parallel_train_filler'): 274 | if self._parallel_train_filler is not None: 275 | self._parallel_train_filler.apply(finalize_cbs, args=(ncbparams,)) 276 | self._parallel_train_filler.close() 277 | self._parallel_train_filler.join() 278 | self._parallel_train_filler = None 279 | self._train_net_dummy = None 280 | self._parallel_batch_res_train = None 281 | if (hasattr(self, '_parallel_test_filler') and 282 | self._parallel_test_filler is not None): 283 | self._parallel_test_filler.apply(finalize_cbs, args=(ncbparams,)) 284 | self._parallel_test_filler.close() 285 | self._parallel_test_filler.join() 286 | self._parallel_test_filler = None 287 | self._test_net_dummy = None 288 | self._parallel_batch_res_test = None 289 | -------------------------------------------------------------------------------- /barrista/tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Implements some useful tools.""" 3 | # pylint: disable=C0103, wrong-import-order, no-member 4 | from __future__ import print_function 5 | 6 | 7 | import warnings as _warnings 8 | import os as _os 9 | import sys as _sys 10 | 11 | import numpy as np 12 | from tempfile import mkdtemp 13 | 14 | 15 | def pbufToPyEnum(pbufenum): 16 | r"""Helper function to create a Python enum out of a protobuf one.""" 17 | enums = dict(list(pbufenum.items())) 18 | return type('Enum', (), enums) 19 | 20 | 21 | def chunks(seq, size): 22 | r""" 23 | Create chunks of ``size`` of ``seq``. 24 | 25 | See http://stackoverflow.com/questions/434287/ 26 | what-is-the-most-pythonic-way-to-iterate-over-a-list-in-chunks. 27 | """ 28 | return (seq[pos:pos + size] for pos in range(0, len(seq), size)) 29 | 30 | 31 | def pad(image, input_dims, get_padding=False, val=0, pad_at_least=False): 32 | r""" 33 | Pad an image with given scale to the appropriate dimensions. 34 | 35 | The scaled image must fit into the input_dims, otherwise an exception is 36 | thrown! 37 | 38 | :param image: 3D numpy array. 39 | The image to pad, as (C, H, W). 40 | 41 | :param input_dims: tuple(int). 42 | A two-tuple of ints with the first value specifying height, the second 43 | width. 44 | 45 | :param get_padding: bool. 46 | If set to True, returns a second value, which is a tuple of two-tuples, 47 | where each tuple contains the two values for left-right paddings for 48 | one of the image dimensions. 49 | 50 | :param val: int. 51 | The value to pad with. 52 | 53 | :param pad_at_least: bool. 54 | If set to True, allows input_dims that are smaller than the image size. 55 | Otherwise, it throws in this case. 56 | 57 | :returns: 3D array, padded or, if ``get_padding``, 58 | (3D array, tuple(two-tuples)). 59 | """ 60 | assert len(input_dims) == 2 61 | assert image.ndim == 3 62 | IMAGE_DIMS = np.array(image.shape[1:]) 63 | SCALED_DIMS = IMAGE_DIMS[:].astype('int') 64 | WORK_SCALE_IMAGE = image 65 | PAD_WIDTH = (input_dims[1] - SCALED_DIMS[1]) / 2.0 66 | PAD_HEIGHT = (input_dims[0] - SCALED_DIMS[0]) / 2.0 67 | if not pad_at_least: 68 | assert PAD_WIDTH >= 0. and PAD_HEIGHT >= 0. 69 | else: 70 | PAD_WIDTH = max(0, PAD_WIDTH) 71 | PAD_HEIGHT = max(0, PAD_HEIGHT) 72 | # Padding is done, e.g., in deeplab, first with the mean values, 73 | # only to subtract the mean of the entire image, resulting in 74 | # 0. values in the padded areas. We're doing that here directly. 75 | padding = ((0, 0), 76 | (int(np.floor(PAD_HEIGHT)), int(np.ceil(PAD_HEIGHT))), 77 | (int(np.floor(PAD_WIDTH)), int(np.ceil(PAD_WIDTH)))) 78 | padded = np.pad(WORK_SCALE_IMAGE, padding, 'constant', constant_values=val) 79 | if get_padding: 80 | return padded, padding 81 | else: 82 | return padded 83 | 84 | 85 | # pylint: disable=R0903 86 | class TemporaryDirectory(object): # pragma: no cover 87 | 88 | """Create and return a temporary directory. 89 | 90 | This has the same behavior as mkdtemp but can be used as a context manager. 91 | For example: 92 | 93 | with TemporaryDirectory() as tmpdir: 94 | ... 95 | 96 | Upon exiting the context, the directory and everything contained 97 | in it are removed. 98 | 99 | Source: 100 | http://stackoverflow.com/questions/19296146/ 101 | tempfile-temporarydirectory-context-manager-in-python-2-7. 102 | """ 103 | 104 | # pylint: disable=W0622 105 | def __init__(self, suffix="", prefix="tmp", dir=None): 106 | """Same parameters as ``mkdtemp``.""" 107 | self._closed = False 108 | self.name = None # Handle mkdtemp raising an exception 109 | self.name = mkdtemp(suffix, prefix, dir) 110 | 111 | def __repr__(self): 112 | """Plain string representation.""" 113 | return "<{} {!r}>".format(self.__class__.__name__, self.name) 114 | 115 | def __enter__(self): 116 | """When entering the context.""" 117 | return self.name 118 | 119 | def cleanup(self, _warn=False): 120 | """Guarantee a cleaned up state.""" 121 | if self.name and not self._closed: 122 | try: 123 | self._rmtree(self.name) 124 | except (TypeError, AttributeError) as ex: 125 | # Issue #10188: Emit a warning on stderr 126 | # if the directory could not be cleaned 127 | # up due to missing globals 128 | if "None" not in str(ex): 129 | raise 130 | print("ERROR: {!r} while cleaning up {!r}".format(ex, self,), 131 | file=_sys.stderr) 132 | return 133 | self._closed = True 134 | if _warn: 135 | self._warn("Implicitly cleaning up {!r}".format(self), 136 | self._warn.ResourceWarning) 137 | 138 | def __exit__(self, exc, value, tb): 139 | """On leaving the context.""" 140 | self.cleanup() 141 | 142 | def __del__(self): 143 | """On deleting the context.""" 144 | # Issue a ResourceWarning if implicit cleanup needed. 145 | self.cleanup(_warn=True) 146 | 147 | # The following code attempts to make 148 | # this class tolerant of the module nulling out process 149 | # that happens during CPython interpreter shutdown 150 | # Alas, it doesn't actually manage it. See issue #10188. 151 | _listdir = staticmethod(_os.listdir) 152 | _path_join = staticmethod(_os.path.join) 153 | _isdir = staticmethod(_os.path.isdir) 154 | _islink = staticmethod(_os.path.islink) 155 | _remove = staticmethod(_os.remove) 156 | _rmdir = staticmethod(_os.rmdir) 157 | _warn = _warnings.warn 158 | 159 | def _rmtree(self, path): 160 | """ 161 | Essentially a stripped down version of shutil.rmtree. 162 | 163 | We can't use globals because they may be None'ed out at shutdown. 164 | """ 165 | for name in self._listdir(path): 166 | fullname = self._path_join(path, name) 167 | try: 168 | isdir = self._isdir(fullname) and not self._islink(fullname) 169 | except OSError: 170 | isdir = False 171 | if isdir: 172 | self._rmtree(fullname) 173 | else: 174 | try: 175 | self._remove(fullname) 176 | except OSError: 177 | pass 178 | try: 179 | self._rmdir(path) 180 | except OSError: 181 | pass 182 | -------------------------------------------------------------------------------- /documentation/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | rm -rf $(BUILDDIR)/* 51 | 52 | html: 53 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 54 | @echo 55 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 56 | 57 | dirhtml: 58 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 59 | @echo 60 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 61 | 62 | singlehtml: 63 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 64 | @echo 65 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 66 | 67 | pickle: 68 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 69 | @echo 70 | @echo "Build finished; now you can process the pickle files." 71 | 72 | json: 73 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 74 | @echo 75 | @echo "Build finished; now you can process the JSON files." 76 | 77 | htmlhelp: 78 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 79 | @echo 80 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 81 | ".hhp project file in $(BUILDDIR)/htmlhelp." 82 | 83 | qthelp: 84 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 85 | @echo 86 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 87 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 88 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/barrista.qhcp" 89 | @echo "To view the help file:" 90 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/barrista.qhc" 91 | 92 | devhelp: 93 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 94 | @echo 95 | @echo "Build finished." 96 | @echo "To view the help file:" 97 | @echo "# mkdir -p $$HOME/.local/share/devhelp/barrista" 98 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/barrista" 99 | @echo "# devhelp" 100 | 101 | epub: 102 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 103 | @echo 104 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 105 | 106 | latex: 107 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 108 | @echo 109 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 110 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 111 | "(use \`make latexpdf' here to do that automatically)." 112 | 113 | latexpdf: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo "Running LaTeX files through pdflatex..." 116 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 117 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 118 | 119 | latexpdfja: 120 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 121 | @echo "Running LaTeX files through platex and dvipdfmx..." 122 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 123 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 124 | 125 | text: 126 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 127 | @echo 128 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 129 | 130 | man: 131 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 132 | @echo 133 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 134 | 135 | texinfo: 136 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 137 | @echo 138 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 139 | @echo "Run \`make' in that directory to run these through makeinfo" \ 140 | "(use \`make info' here to do that automatically)." 141 | 142 | info: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo "Running Texinfo files through makeinfo..." 145 | make -C $(BUILDDIR)/texinfo info 146 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 147 | 148 | gettext: 149 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 150 | @echo 151 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 152 | 153 | changes: 154 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 155 | @echo 156 | @echo "The overview file is in $(BUILDDIR)/changes." 157 | 158 | linkcheck: 159 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 160 | @echo 161 | @echo "Link check complete; look for any errors in the above output " \ 162 | "or in $(BUILDDIR)/linkcheck/output.txt." 163 | 164 | doctest: 165 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 166 | @echo "Testing of doctests in the sources finished, look at the " \ 167 | "results in $(BUILDDIR)/doctest/output.txt." 168 | 169 | xml: 170 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 171 | @echo 172 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 173 | 174 | pseudoxml: 175 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 176 | @echo 177 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 178 | -------------------------------------------------------------------------------- /documentation/_static/barrista.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/classner/barrista/230ad0ecfdac22aa95b38e5aeedc73fcd625a94a/documentation/_static/barrista.jpg -------------------------------------------------------------------------------- /documentation/barrista.rst: -------------------------------------------------------------------------------- 1 | API documentation 2 | ================= 3 | 4 | 5 | barrista.config module 6 | ---------------------- 7 | 8 | .. automodule:: barrista.config 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | barrista.design module 14 | ---------------------- 15 | 16 | .. automodule:: barrista.design 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | 21 | barrista.initialization module 22 | ------------------------------ 23 | 24 | .. automodule:: barrista.initialization 25 | :members: 26 | :undoc-members: 27 | :show-inheritance: 28 | 29 | barrista.monitoring module 30 | ------------------------------ 31 | 32 | .. automodule:: barrista.monitoring 33 | :members: 34 | :undoc-members: 35 | :show-inheritance: 36 | 37 | barrista.net module 38 | ------------------- 39 | 40 | .. automodule:: barrista.net 41 | :members: 42 | :undoc-members: 43 | :show-inheritance: 44 | 45 | barrista.solver module 46 | ---------------------- 47 | 48 | .. automodule:: barrista.solver 49 | :members: 50 | :undoc-members: 51 | :show-inheritance: 52 | 53 | barrista.tools module 54 | --------------------- 55 | 56 | .. automodule:: barrista.tools 57 | :members: 58 | :undoc-members: 59 | :show-inheritance: 60 | 61 | 62 | Module contents 63 | --------------- 64 | 65 | .. automodule:: barrista 66 | :members: 67 | :undoc-members: 68 | :show-inheritance: 69 | -------------------------------------------------------------------------------- /documentation/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # barrista documentation build configuration file, created by 4 | # sphinx-quickstart on Tue Jun 23 17:47:50 2015. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import sys 16 | import os 17 | import sphinx_rtd_theme 18 | 19 | # If extensions (or modules to document with autodoc) are in another directory, 20 | # add these directories to sys.path here. If the directory is relative to the 21 | # documentation root, use os.path.abspath to make it absolute, like shown here. 22 | sys.path.insert(0, os.path.abspath('..')) 23 | 24 | # -- General configuration ------------------------------------------------ 25 | 26 | # If your documentation needs a minimal Sphinx version, state it here. 27 | #needs_sphinx = '1.0' 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | extensions = [ 33 | 'sphinx.ext.autodoc', 34 | 'sphinx.ext.viewcode', 35 | 'sphinx.ext.mathjax', 36 | 'sphinx.ext.graphviz' 37 | ] 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ['_templates'] 41 | 42 | # The suffix of source filenames. 43 | source_suffix = '.rst' 44 | 45 | # The encoding of source files. 46 | #source_encoding = 'utf-8-sig' 47 | 48 | # The master toctree document. 49 | master_doc = 'index' 50 | 51 | # General information about the project. 52 | project = u'barrista' 53 | copyright = u'2016, University of Tuebingen' 54 | 55 | # The version info for the project you're documenting, acts as replacement for 56 | # |version| and |release|, also used in various other places throughout the 57 | # built documents. 58 | # 59 | # The short X.Y version. 60 | version = '' 61 | # The full version, including alpha/beta/rc tags. 62 | release = '' 63 | 64 | # The language for content autogenerated by Sphinx. Refer to documentation 65 | # for a list of supported languages. 66 | #language = None 67 | 68 | # There are two options for replacing |today|: either, you set today to some 69 | # non-false value, then it is used: 70 | #today = '' 71 | # Else, today_fmt is used as the format for a strftime call. 72 | #today_fmt = '%B %d, %Y' 73 | 74 | # List of patterns, relative to source directory, that match files and 75 | # directories to ignore when looking for source files. 76 | exclude_patterns = ['_build'] 77 | 78 | # The reST default role (used for this markup: `text`) to use for all 79 | # documents. 80 | #default_role = None 81 | 82 | # If true, '()' will be appended to :func: etc. cross-reference text. 83 | #add_function_parentheses = True 84 | 85 | # If true, the current module name will be prepended to all description 86 | # unit titles (such as .. function::). 87 | #add_module_names = True 88 | 89 | # If true, sectionauthor and moduleauthor directives will be shown in the 90 | # output. They are ignored by default. 91 | #show_authors = False 92 | 93 | # The name of the Pygments (syntax highlighting) style to use. 94 | pygments_style = 'sphinx' 95 | 96 | # A list of ignored prefixes for module index sorting. 97 | #modindex_common_prefix = [] 98 | 99 | # If true, keep warnings as "system message" paragraphs in the built documents. 100 | #keep_warnings = False 101 | 102 | 103 | # -- Options for HTML output ---------------------------------------------- 104 | 105 | # The theme to use for HTML and HTML Help pages. See the documentation for 106 | # a list of builtin themes. 107 | html_theme = 'sphinx_rtd_theme' 108 | 109 | # Theme options are theme-specific and customize the look and feel of a theme 110 | # further. For a list of options available for each theme, see the 111 | # documentation. 112 | #html_theme_options = {} 113 | 114 | # Add any paths that contain custom themes here, relative to this directory. 115 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 116 | 117 | # The name for this set of Sphinx documents. If None, it defaults to 118 | # " v documentation". 119 | #html_title = None 120 | 121 | # A shorter title for the navigation bar. Default is the same as html_title. 122 | #html_short_title = None 123 | 124 | # The name of an image file (relative to this directory) to place at the top 125 | # of the sidebar. 126 | #html_logo = None 127 | 128 | # The name of an image file (within the static path) to use as favicon of the 129 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 130 | # pixels large. 131 | #html_favicon = None 132 | 133 | # Add any paths that contain custom static files (such as style sheets) here, 134 | # relative to this directory. They are copied after the builtin static files, 135 | # so a file named "default.css" will overwrite the builtin "default.css". 136 | html_static_path = ['_static'] 137 | 138 | # Add any extra paths that contain custom files (such as robots.txt or 139 | # .htaccess) here, relative to this directory. These files are copied 140 | # directly to the root of the documentation. 141 | #html_extra_path = [] 142 | 143 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 144 | # using the given strftime format. 145 | #html_last_updated_fmt = '%b %d, %Y' 146 | 147 | # If true, SmartyPants will be used to convert quotes and dashes to 148 | # typographically correct entities. 149 | #html_use_smartypants = True 150 | 151 | # Custom sidebar templates, maps document names to template names. 152 | #html_sidebars = {} 153 | 154 | # Additional templates that should be rendered to pages, maps page names to 155 | # template names. 156 | #html_additional_pages = {} 157 | 158 | # If false, no module index is generated. 159 | #html_domain_indices = True 160 | 161 | # If false, no index is generated. 162 | #html_use_index = True 163 | 164 | # If true, the index is split into individual pages for each letter. 165 | #html_split_index = False 166 | 167 | # If true, links to the reST sources are added to the pages. 168 | html_show_sourcelink = False 169 | 170 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 171 | #html_show_sphinx = True 172 | 173 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 174 | #html_show_copyright = True 175 | 176 | # If true, an OpenSearch description file will be output, and all pages will 177 | # contain a tag referring to it. The value of this option must be the 178 | # base URL from which the finished HTML is served. 179 | #html_use_opensearch = '' 180 | 181 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 182 | #html_file_suffix = None 183 | 184 | # Output file base name for HTML help builder. 185 | htmlhelp_basename = 'barristadoc' 186 | 187 | 188 | # -- Options for LaTeX output --------------------------------------------- 189 | 190 | latex_elements = { 191 | # The paper size ('letterpaper' or 'a4paper'). 192 | #'papersize': 'letterpaper', 193 | 194 | # The font size ('10pt', '11pt' or '12pt'). 195 | #'pointsize': '10pt', 196 | 197 | # Additional stuff for the LaTeX preamble. 198 | #'preamble': '', 199 | } 200 | 201 | # Grouping the document tree into LaTeX files. List of tuples 202 | # (source start file, target name, title, 203 | # author, documentclass [howto, manual, or own class]). 204 | latex_documents = [ 205 | ('index', 'barrista.tex', u'barrista Documentation', 206 | u'Author', 'manual'), 207 | ] 208 | 209 | # The name of an image file (relative to this directory) to place at the top of 210 | # the title page. 211 | #latex_logo = None 212 | 213 | # For "manual" documents, if this is true, then toplevel headings are parts, 214 | # not chapters. 215 | #latex_use_parts = False 216 | 217 | # If true, show page references after internal links. 218 | #latex_show_pagerefs = False 219 | 220 | # If true, show URL addresses after external links. 221 | #latex_show_urls = False 222 | 223 | # Documents to append as an appendix to all manuals. 224 | #latex_appendices = [] 225 | 226 | # If false, no module index is generated. 227 | #latex_domain_indices = True 228 | 229 | 230 | # -- Options for manual page output --------------------------------------- 231 | 232 | # One entry per manual page. List of tuples 233 | # (source start file, name, description, authors, manual section). 234 | man_pages = [ 235 | ('index', 'barrista', u'barrista Documentation', 236 | [u'Author'], 1) 237 | ] 238 | 239 | # If true, show URL addresses after external links. 240 | #man_show_urls = False 241 | 242 | 243 | # -- Options for Texinfo output ------------------------------------------- 244 | 245 | # Grouping the document tree into Texinfo files. List of tuples 246 | # (source start file, target name, title, author, 247 | # dir menu entry, description, category) 248 | texinfo_documents = [ 249 | ('index', 'barrista', u'barrista Documentation', 250 | u'Author', 'barrista', 'One line description of project.', 251 | 'Miscellaneous'), 252 | ] 253 | 254 | # Documents to append as an appendix to all manuals. 255 | #texinfo_appendices = [] 256 | 257 | # If false, no module index is generated. 258 | #texinfo_domain_indices = True 259 | 260 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 261 | #texinfo_show_urls = 'footnote' 262 | 263 | # If true, do not generate a @detailmenu in the "Top" node's menu. 264 | #texinfo_no_detailmenu = False 265 | 266 | 267 | # -- Options for Epub output ---------------------------------------------- 268 | 269 | # Bibliographic Dublin Core info. 270 | epub_title = u'barrista' 271 | epub_author = u'Author' 272 | epub_publisher = u'Author' 273 | epub_copyright = u'2015, Author' 274 | 275 | # The basename for the epub file. It defaults to the project name. 276 | #epub_basename = u'barrista' 277 | 278 | # The HTML theme for the epub output. Since the default themes are not optimized 279 | # for small screen space, using the same theme for HTML and epub output is 280 | # usually not wise. This defaults to 'epub', a theme designed to save visual 281 | # space. 282 | #epub_theme = 'epub' 283 | 284 | # The language of the text. It defaults to the language option 285 | # or en if the language is not set. 286 | #epub_language = '' 287 | 288 | # The scheme of the identifier. Typical schemes are ISBN or URL. 289 | #epub_scheme = '' 290 | 291 | # The unique identifier of the text. This can be a ISBN number 292 | # or the project homepage. 293 | #epub_identifier = '' 294 | 295 | # A unique identification for the text. 296 | #epub_uid = '' 297 | 298 | # A tuple containing the cover image and cover page html template filenames. 299 | #epub_cover = () 300 | 301 | # A sequence of (type, uri, title) tuples for the guide element of content.opf. 302 | #epub_guide = () 303 | 304 | # HTML files that should be inserted before the pages created by sphinx. 305 | # The format is a list of tuples containing the path and title. 306 | #epub_pre_files = [] 307 | 308 | # HTML files shat should be inserted after the pages created by sphinx. 309 | # The format is a list of tuples containing the path and title. 310 | #epub_post_files = [] 311 | 312 | # A list of files that should not be packed into the epub file. 313 | epub_exclude_files = ['search.html'] 314 | 315 | # The depth of the table of contents in toc.ncx. 316 | #epub_tocdepth = 3 317 | 318 | # Allow duplicate toc entries. 319 | #epub_tocdup = True 320 | 321 | # Choose between 'default' and 'includehidden'. 322 | #epub_tocscope = 'default' 323 | 324 | # Fix unsupported image types using the PIL. 325 | #epub_fix_images = False 326 | 327 | # Scale large images. 328 | #epub_max_image_width = 0 329 | 330 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 331 | #epub_show_urls = 'inline' 332 | 333 | # If false, no index is generated. 334 | #epub_use_index = True 335 | -------------------------------------------------------------------------------- /documentation/index.rst: -------------------------------------------------------------------------------- 1 | .. barrista documentation master file, created by 2 | sphinx-quickstart on Tue Jun 23 17:47:50 2015. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | 7 | .. figure:: _static/barrista.jpg 8 | :scale: 35% 9 | :alt: Gwilym Davies composing his signature drink at the 2009 World Barista Championship in Atlanta, Georgia. 10 | :align: right 11 | 12 | `Photo by Liz Clayton `_, 13 | `CC `_. 14 | 15 | 16 | Welcome to barrista's documentation! 17 | ==================================== 18 | 19 | Barrista will serve your `caffe `_ 20 | right! It is a Python library that offers a full-featured interface to `caffe`, 21 | similar to `keras `_ and 22 | `Theano `_. 23 | 24 | Why barrista? 25 | ============= 26 | 27 | `barrista` gives you full, pythonic control over the entire `caffe` framework. 28 | It is different from the plain `caffe` Python interface in the way, that it 29 | exposes the entire caffe functionality (from net design over training and 30 | prediction) in a principled way to Python. 31 | 32 | * Design your nets with the full power of `caffe` within Python. Creating 33 | a network is as easy as:: 34 | 35 | import barrista.design as ds 36 | netspec = ds.NetSpecification([[10, 3, 51, 51], [10]], 37 | # batchsize 10, 3 dim. of 51x51 signal, 10 labels 38 | inputs=['data', 'annotations']) 39 | netspec.layers.append(ds.ConvolutionLayer(Convolution_kernel_size=3, 40 | Convolution_pad=1, 41 | Convolution_num_output=1)) 42 | # The layers are wired together automatically, unless you specify something else: 43 | netspec.layers.append(ds.InnerProductLayer(tops=['net_out'], 44 | InnerProduct_num_output=10)) 45 | netspec.layers.append(ds.SoftmaxWithLossLayer(bottoms=['net_out', 46 | 'annotations'])) 47 | net = netspec.instantiate() 48 | 49 | * `barrista` naturally understands and writes every .prototxt 50 | file your `caffe` version does! Load `.prototxt` files (also from the 51 | `model zoo `_!) 52 | and use or modify the networks with barrista:: 53 | 54 | netspec.to_prototxt(output_filename='net.prototxt') 55 | net.save('net.caffemodel') # Save the weights. 56 | new_netspec = ds.NetSpecification.from_prototxt(filename='net.prototxt') 57 | new_network = new_netspec.instantiate() 58 | new_network.load_blobs_from('net.caffemodel') # Load the weights. 59 | 60 | * Use your networks in a principled way from Python. You get transparent support 61 | for repetitive tasks like batching or padding, with a clear separation of 62 | preprocessing:: 63 | 64 | import barrista.solver 65 | net.fit(1000, 66 | barrista.solver.SGDSolver(base_lr=0.01), 67 | X={'data': np.ones((21, 3, 51, 51)), # Automatically batched. 68 | 'annotations': np.zeros((21,))}) 69 | net.predict({'data': np.zeros((5, 3, 51, 51)), [...]}) 70 | 71 | * Use callbacks (functions that are called before and after processing a batch) 72 | to monitor training and prediction, or to dynamically modify 73 | the data used in the batches. `barrista` comes with a standard set of 74 | frequently used callbacks and it is very easy to add your own:: 75 | 76 | import barrista.monitoring 77 | net.fit(# ... as before 78 | train_callbacks=[ 79 | # Write the network weights every 100 iterations to disk. 80 | barrista.monitoring.Checkpointer('/tmp', 100), 81 | # Get a progress bar with ETA. 82 | barrista.monitoring.ProgressIndicator()]) 83 | 84 | * `barrista` is always fully consistent with `caffe`. We internally inspect the 85 | protobuf module generated by `caffe` to infer the interface. Adding your 86 | own layers is as easy as:: 87 | 88 | import barrista.config 89 | barrista.config.LAYER_TYPES['Amazing'] = ['AmazingParameter'] 90 | 91 | * It runs on every platform `caffe` runs on, and can be used with Python 2 and 3. 92 | 93 | What's the license? 94 | =================== 95 | 96 | You can use barrista under the MIT License, which means you may use it freely 97 | in any projects. The full license can be found in the main folder of the 98 | `barrista` repository. 99 | 100 | Get the source & documentation 101 | ============================== 102 | 103 | The source is hosted on `github `_ and the 104 | documentation/homepage is available on github pages at 105 | `http://classner.github.io/barrista `_. 106 | 107 | Documentation 108 | ============= 109 | 110 | .. toctree:: 111 | :maxdepth: 2 112 | 113 | setup 114 | usage 115 | barrista 116 | 117 | Indices and tables 118 | ================== 119 | 120 | * :ref:`genindex` 121 | * :ref:`modindex` 122 | 123 | About 124 | ===== 125 | 126 | This software was created at the `Bernstein Center for Computational 127 | Neuroscience `_ at the `University of Tuebingen 128 | `_ and 129 | the `Max Planck Institute for Intelligent Systems `_ 130 | Tuebingen. The main contributors are: 131 | 132 | * `Christoph Lassner `_, 133 | * `Daniel Kappler `_, 134 | * `Martin Kiefel `_ and 135 | * `Peter Gehler `_. 136 | 137 | We thank `Yangqing Jia `_ and the 138 | `BVLC vision group `_ for creating the great 139 | `caffe` package! 140 | -------------------------------------------------------------------------------- /documentation/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | goto end 41 | ) 42 | 43 | if "%1" == "clean" ( 44 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 45 | del /q /s %BUILDDIR%\* 46 | goto end 47 | ) 48 | 49 | 50 | %SPHINXBUILD% 2> nul 51 | if errorlevel 9009 ( 52 | echo. 53 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 54 | echo.installed, then set the SPHINXBUILD environment variable to point 55 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 56 | echo.may add the Sphinx directory to PATH. 57 | echo. 58 | echo.If you don't have Sphinx installed, grab it from 59 | echo.http://sphinx-doc.org/ 60 | exit /b 1 61 | ) 62 | 63 | if "%1" == "html" ( 64 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 65 | if errorlevel 1 exit /b 1 66 | echo. 67 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 68 | goto end 69 | ) 70 | 71 | if "%1" == "dirhtml" ( 72 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 73 | if errorlevel 1 exit /b 1 74 | echo. 75 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 76 | goto end 77 | ) 78 | 79 | if "%1" == "singlehtml" ( 80 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 81 | if errorlevel 1 exit /b 1 82 | echo. 83 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 84 | goto end 85 | ) 86 | 87 | if "%1" == "pickle" ( 88 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 89 | if errorlevel 1 exit /b 1 90 | echo. 91 | echo.Build finished; now you can process the pickle files. 92 | goto end 93 | ) 94 | 95 | if "%1" == "json" ( 96 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 97 | if errorlevel 1 exit /b 1 98 | echo. 99 | echo.Build finished; now you can process the JSON files. 100 | goto end 101 | ) 102 | 103 | if "%1" == "htmlhelp" ( 104 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 105 | if errorlevel 1 exit /b 1 106 | echo. 107 | echo.Build finished; now you can run HTML Help Workshop with the ^ 108 | .hhp project file in %BUILDDIR%/htmlhelp. 109 | goto end 110 | ) 111 | 112 | if "%1" == "qthelp" ( 113 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 114 | if errorlevel 1 exit /b 1 115 | echo. 116 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 117 | .qhcp project file in %BUILDDIR%/qthelp, like this: 118 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\barrista.qhcp 119 | echo.To view the help file: 120 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\barrista.ghc 121 | goto end 122 | ) 123 | 124 | if "%1" == "devhelp" ( 125 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished. 129 | goto end 130 | ) 131 | 132 | if "%1" == "epub" ( 133 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 134 | if errorlevel 1 exit /b 1 135 | echo. 136 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 137 | goto end 138 | ) 139 | 140 | if "%1" == "latex" ( 141 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 142 | if errorlevel 1 exit /b 1 143 | echo. 144 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 145 | goto end 146 | ) 147 | 148 | if "%1" == "latexpdf" ( 149 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 150 | cd %BUILDDIR%/latex 151 | make all-pdf 152 | cd %BUILDDIR%/.. 153 | echo. 154 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 155 | goto end 156 | ) 157 | 158 | if "%1" == "latexpdfja" ( 159 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 160 | cd %BUILDDIR%/latex 161 | make all-pdf-ja 162 | cd %BUILDDIR%/.. 163 | echo. 164 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 165 | goto end 166 | ) 167 | 168 | if "%1" == "text" ( 169 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 170 | if errorlevel 1 exit /b 1 171 | echo. 172 | echo.Build finished. The text files are in %BUILDDIR%/text. 173 | goto end 174 | ) 175 | 176 | if "%1" == "man" ( 177 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 178 | if errorlevel 1 exit /b 1 179 | echo. 180 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 181 | goto end 182 | ) 183 | 184 | if "%1" == "texinfo" ( 185 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 186 | if errorlevel 1 exit /b 1 187 | echo. 188 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 189 | goto end 190 | ) 191 | 192 | if "%1" == "gettext" ( 193 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 194 | if errorlevel 1 exit /b 1 195 | echo. 196 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 197 | goto end 198 | ) 199 | 200 | if "%1" == "changes" ( 201 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 202 | if errorlevel 1 exit /b 1 203 | echo. 204 | echo.The overview file is in %BUILDDIR%/changes. 205 | goto end 206 | ) 207 | 208 | if "%1" == "linkcheck" ( 209 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 210 | if errorlevel 1 exit /b 1 211 | echo. 212 | echo.Link check complete; look for any errors in the above output ^ 213 | or in %BUILDDIR%/linkcheck/output.txt. 214 | goto end 215 | ) 216 | 217 | if "%1" == "doctest" ( 218 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 219 | if errorlevel 1 exit /b 1 220 | echo. 221 | echo.Testing of doctests in the sources finished, look at the ^ 222 | results in %BUILDDIR%/doctest/output.txt. 223 | goto end 224 | ) 225 | 226 | if "%1" == "xml" ( 227 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 228 | if errorlevel 1 exit /b 1 229 | echo. 230 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 231 | goto end 232 | ) 233 | 234 | if "%1" == "pseudoxml" ( 235 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 236 | if errorlevel 1 exit /b 1 237 | echo. 238 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 239 | goto end 240 | ) 241 | 242 | :end 243 | -------------------------------------------------------------------------------- /documentation/setup.rst: -------------------------------------------------------------------------------- 1 | Setup 2 | ===== 3 | 4 | ============ 5 | Requirements 6 | ============ 7 | 8 | This package has the following requirements: 9 | 10 | * caffe, rc2 or newer, built with the Python interface, 11 | * OpenCV (for automatic rescaling), 12 | * setuptools (for installation and building the documentation), 13 | * sphinx (for building the documentation). 14 | 15 | =================== 16 | Caffe modifications 17 | =================== 18 | 19 | While barrista 'serves' caffe right, some beans must be added for the 20 | perfect flavor. This translates to just a few lines of C++ code that must 21 | be changed in the core library. 22 | 23 | We offer a .patch file for various versions since caffe rc2. But even 24 | if you are not using one of these version, don't worry, you will be able to 25 | easily add them by hand. Alternatively, you can run:: 26 | 27 | git submodule update --init 28 | 29 | in the barrista folder, and an already patched caffe version is loaded to the 30 | `caffe` subfolder of barrista. We keep this in sync with upstream caffe in 31 | regular intervals, and you can build it as described by the 32 | `BVLC team `_. 33 | 34 | The patch files are located in the barrista folder `patches` and can be applied 35 | (for example) by navigating to your caffe root folder and executing:: 36 | 37 | git apply ../path/to/barrista/patches/barrista-patch-[X].patch 38 | 39 | If there is no patch available for your very caffe version, you should still 40 | be able to quickly find out what lines to change by having a look at the 41 | `.patch` files. 42 | 43 | ============= 44 | Configuration 45 | ============= 46 | 47 | Now that your `caffe` is ready, we will setup `barrista` so that it knows 48 | what caffe to work with! 49 | 50 | There are three ways to do this: 51 | 52 | #. edit the source files, 53 | #. set environment variables, 54 | #. do it on-the-fly within your code. 55 | 56 | The first possibility is not as `clean` as the others (since, e.g., a 57 | barrista update might break your config), but in case you want to do it: 58 | the file ``barrista/config.py`` is all yours! It contains all relevant 59 | information (the meaning of the variable ``LAYER_TYPES`` is discussed in 60 | detail in the section :ref:`registering-layers`). 61 | 62 | A clean way to let barrista know about the location of the `caffe` to use is, 63 | to use the two environment variables :py:data:`CAFFE_PYTHON_FOLDER` and 64 | :py:data:`CAFFE_BIN_FOLDER`. Whereas the first one points to the folder where 65 | caffe's Python module is located (usually ``caffe/python``), the second one points 66 | to the folder where all caffe executables are stored after having run 67 | the installation (usually ``caffe/build/install/bin``). This folder must contain 68 | the executable ``upgrade_net_proto_text``. 69 | 70 | If you are working with many different `caffe` builds or want to swap them 71 | on the fly, you can also do this easily! In this case:: 72 | 73 | # Import the barrista's config module before any of the other modules! 74 | import barrista.config 75 | # Change the two variable values to your liking: 76 | barrista.config.CAFFE_PYTHON_FOLDER = 'a/crazy/folder' 77 | barrista.config.CAFFE_BIN_FOLDER = 'another/funny/one' 78 | # Now use your configured barrista: 79 | import barrista.design 80 | # ... 81 | 82 | A side note on the google logging facilities of caffe: the output is very useful 83 | for debugging, but may clutter your terminal unnecessary otherwise. To change 84 | the log-level, simply:: 85 | 86 | export GLOG_minloglevel=2 87 | 88 | and add this to your favorite place, e.g., `.bashrc`. The levels are 0 (debug, 89 | default), 1 (info), 2 (warnings), 3 (errors). 90 | 91 | ======= 92 | Testing 93 | ======= 94 | 95 | Before installing barrista, you can (and should) run the tests to verify 96 | proper configuration and compatibility by executing:: 97 | 98 | python setup.py test 99 | 100 | ============ 101 | Installation 102 | ============ 103 | 104 | You can install the project to your python installation by running:: 105 | 106 | python setup.py install 107 | 108 | or, to be able to modify the software in its current folder while using it, 109 | run:: 110 | 111 | python setup.py develop 112 | 113 | If you want to build the documentation, run:: 114 | 115 | python setup.py build_sphinx 116 | 117 | Should this command fail with an error message along the lines of 118 | '`build_sphinx` is an invalid command', just upgrade your distutils 119 | installation by running ``pip install --upgrade distutils`` and it 120 | should work. 121 | 122 | .. _registering-layers: 123 | 124 | ============================= 125 | Registering additional layers 126 | ============================= 127 | 128 | The layer types and their parameters can, unfortunately, not be inferred 129 | from the caffe protobuf protocols in a fully automatic manner. 130 | Your standard barrista knows all the popularly used layer types that come with 131 | the default caffe installation, but if you extend caffe, it is necessary 132 | to register your new layers by hand. There are again two possibilities: 133 | 134 | #. edit the file ``config.py`` by hand, 135 | #. register the layer types on-the-fly during use. 136 | 137 | The responsible object is :py:data:`barrista.config.LAYER_TYPES`. It is a dictionary 138 | with the layer name strings as keys, and a list of names of protobuf objects 139 | that are the layer's parameters, e.g.,:: 140 | 141 | 'Convolution': ['ConvolutionParameter'] 142 | 143 | is responsibe for registering the convolution layer. 144 | You can simply edit ``config.py`` and add your own layers, or add them on the 145 | fly as follows:: 146 | 147 | import barrista.config 148 | # This must be done before importing any other submodule! 149 | barrista.config.LAYER_TYPES['Convolution'] = ['ConvolutionParameter'] 150 | import barrista.design 151 | ... 152 | -------------------------------------------------------------------------------- /documentation/usage.rst: -------------------------------------------------------------------------------- 1 | Using `barrista` 2 | ================ 3 | 4 | This file gives a quite comprehensive walkthrough through nearly all 5 | features offered by barrista. If you want to get your hands dirty right away, 6 | there is a comprehensive example of a VGG-like net being trained and applied 7 | in the file ``example.py`` in the root folder of the barrista package. 8 | 9 | ================================== 10 | Importing and configuring barrista 11 | ================================== 12 | 13 | If you have `caffe` on your path, you can use barrista right away and 14 | include and use any of its submodules. Otherwise, you can configure it 15 | to use a specific `caffe` version on the fly as follows:: 16 | 17 | import barrista.config 18 | # This must be done before importing any other submodule. 19 | barrista.config.CAFFE_PYTHON_FOLDER = 'your/path' 20 | barrista.config.CAFFE_BIN_FOLDER = 'your/bin/path' 21 | import barrista.design 22 | ... 23 | 24 | For an exact description of the two parameters, see 25 | :py:data:`barrista.config.CAFFE_PYTHON_FOLDER` and 26 | :py:data:`barrista.config.CAFFE_BIN_FOLDER`. 27 | 28 | ================================ 29 | Creating a network specification 30 | ================================ 31 | 32 | The module :py:mod:`barrista.design` contains methods and classes to 33 | design `caffe` models. We will use it in the following example to create 34 | a simple, `VGG`-like model:: 35 | 36 | import barrista.design as design 37 | from barrista.design import (ConvolutionLayer, ReLULayer, PoolingLayer, 38 | DropoutLayer, InnerProductLayer, 39 | SoftmaxLayer, SoftmaxWithLossLayer, 40 | AccuracyLayer) 41 | 42 | # The only required parameter is a list of lists with the input shape 43 | # specification for the network. In this case, we also specify names 44 | # for the inputs layers. 45 | netspec = design.NetSpecification([[10, 3, 51, 51], [10]], 46 | inputs=['data', 'annotations']) 47 | 48 | layers = [] 49 | conv_params = {'Convolution_kernel_size': 3, 50 | 'Convolution_num_output': 32, 51 | 'Convolution_pad': 1} 52 | 53 | # If not specified, the first top blob for each layer is automatically 54 | # wired with the first bottom of the preceeding layer. If your are using 55 | # multi-in/out layers, you have to manually specify tops and bottoms. 56 | 57 | layers.append(ConvolutionLayer(**conv_params)) 58 | layers.append(ReLULayer()) 59 | layers.append(ConvolutionLayer(**conv_params)) 60 | layers.append(ReLULayer()) 61 | layers.append(PoolingLayer(Pooling_kernel_size=2)) 62 | layers.append(DropoutLayer(Dropout_dropout_ratio=0.25)) 63 | 64 | conv_params['Convolution_num_output'] = 64 65 | layers.append(ConvolutionLayer(**conv_params)) 66 | layers.append(ReLULayer()) 67 | layers.append(ConvolutionLayer(**conv_params)) 68 | layers.append(ReLULayer()) 69 | layers.append(PoolingLayer(Pooling_kernel_size=2)) 70 | layers.append(DropoutLayer(Dropout_dropout_ratio=0.25)) 71 | 72 | layers.append(InnerProductLayer(InnerProduct_num_output=256)) 73 | layers.append(ReLULayer()) 74 | layers.append(DropoutLayer(Dropout_dropout_ratio=0.25)) 75 | 76 | layers.append(InnerProductLayer(InnerProduct_num_output=10)) 77 | layers.append(SoftmaxLayer()) 78 | 79 | netspec.layers.extend(layers) 80 | 81 | The layer names are exactly the same as in the prototxt format. All direct 82 | parameters for a layer can be set by using it's constructor or later be set 83 | as it's object property. If you have to use sub-objects (or rather messages, 84 | in prototxt-speak), they are all available from the object 85 | :py:data:`barrista.design.PROTODETAIL`. 86 | 87 | You can now inspect the specification and programatically change its parameters. 88 | To get the prototxt representation, use the method 89 | :py:func:`barrista.design.NetSpecification.to_prototxt`:: 90 | 91 | print(netspec.to_prototxt()) 92 | 93 | The method has an additional parameter ``output_filename`` that can be used to 94 | directly create prototxt files:: 95 | 96 | netspec.to_prototxt(output_filename='test.prototxt') 97 | 98 | ===================== 99 | Visualizing a network 100 | ===================== 101 | 102 | It is possible to visualize a network specification or an instantiated 103 | network by calling its :py:func:`barrista.design.NetSpecification.visualize` 104 | or :py:func:`barrista.net.Net.visualize` function. It is possible to directly 105 | display it or write it to a file:: 106 | 107 | # Create the visualization and display it. 108 | viz = netspec.visualize(display=True) 109 | # Write it to a file. 110 | import cv2 111 | cv2.imwrite('/tmp/test.png', viz) 112 | 113 | ================================= 114 | Importing a network specification 115 | ================================= 116 | 117 | You can work with all your already prepared prototxt files as well! Use the 118 | method :py:func:`barrista.design.NetSpecification.from_prototxt` to load 119 | any valid caffe model (of any version!) and inspect and modify it in this 120 | framework:: 121 | 122 | netspec_reloaded = design.NetSpecification.from_prototxt(filename='test.prototxt') 123 | 124 | =============== 125 | Using a network 126 | =============== 127 | 128 | However, apart from diagnostic or logging 129 | purposes, it is not necessary to work with prototxt specifications any more. 130 | Simply run:: 131 | 132 | net = netspec.instantiate() 133 | 134 | to get a fully working network object. It is subclassed from the 135 | ``caffe.Net``, so it comes with all the methods you are familiar with. But 136 | be prepared for some more convenience! You can set cpu or gpu mode by 137 | using :py:func:`barrista.net.set_mode_cpu` and 138 | :py:func:`barrista.net.set_mode_gpu`. 139 | 140 | Loading parameters 141 | ~~~~~~~~~~~~~~~~~~ 142 | 143 | With this, the blobs can be loaded as:: 144 | 145 | net.load_blobs_from('your/path/to/blobs.caffemodel') 146 | 147 | and to restore a solver, use:: 148 | 149 | solver.restore('your/path/to/xyz.solverstate', net) 150 | 151 | **CAUTION**: The blobs are stored in the ``.caffemodel``s by name. Blobs will be 152 | matched to network layers with the same name. If a name does not match, the 153 | blob is simply ignored! This gives a powerful mechanic for partially loading 154 | blobs, but be careful when remaining your layers! 155 | 156 | Training a network 157 | ~~~~~~~~~~~~~~~~~~ 158 | 159 | To train a network, you can use the `scikit-learn` like method 160 | :py:func:`barrista.net.Net.fit`. It is very powerful and can be used in many 161 | different ways! While maintaining nearly all configurability of the caffe 162 | solvers, it adds callback functionality and is a lot easier to use. 163 | 164 | The only required method parameter is the number of iterations that you want 165 | to train your network with. If you configured it with data-layers that are 166 | loading data from external sources, you just have to decide about the kind 167 | of solver to use and probably specify its learning rate. For this example, 168 | we use in-memory data from Python for the training, and some monitors to 169 | generate outputs:: 170 | 171 | from barrista import solver 172 | from barrista.monitoring import ProgressIndicator, Checkpointer 173 | 174 | X = np.zeros((11, 3, 51, 51), dtype='float32') 175 | Y = np.ones((11, 1), dtype='float32') 176 | 177 | # Configure our monitors. 178 | progress = ProgressIndicator() 179 | checkptr = Checkpointer('test_net_', 50) 180 | # Run the training. 181 | net.fit(100, 182 | solver.SGDSolver(base_lr=0.01, snapshot_prefix='test_net_'), 183 | {'data': X, # 'data' and 'annotations' are the input layer names. 184 | 'annotations': Y}, # optional (if you have, e.g., a DataLayer) 185 | test_interval=50, # optional 186 | X_val={'data': X, # optional 187 | 'annotations': Y}, 188 | after_batch_callbacks=[progress, checkptr], # optional 189 | after_test_callbacks=[progress]) # optional 190 | 191 | The parameters ``test_interval``, ``X_val`` and ``Y_val`` are optional. If they 192 | are specified, there is a test performed on the validation set in 193 | regular intervals. 194 | 195 | Note that all iteration parameters are speaking of 'true' iterations, i.e., 196 | not batch iterations but sample iterations. This is, why they must be a 197 | multiple of the batch size (e.g., for a network with a batch size of 10, 198 | you have to do at least 10 training iterations, and one batch will be 199 | used for the training). 200 | 201 | The :py:class:`barrista.monitoring.Checkpointer` is used to write the network 202 | blobs to a file, which can be loaded later using the function 203 | :py:func:`barrista.net.Net.load_blobs_from` as well as the respective 204 | solverstate. The ``snapshot_prefix`` provided to the solver and the 205 | checkpointer prefix must match for this to work correctly. 206 | 207 | Getting predictions 208 | ~~~~~~~~~~~~~~~~~~~ 209 | 210 | In the spirit of the `scikit-learn` library, we added the method 211 | :py:func:`barrista.net.Net.predict` to get predictions for you, while 212 | maintaining a clear separation of data preprocessing: 213 | 214 | * It is YOUR responsibility to prepare the data in an iterable object 215 | of numpy arrays with the correctly matching first dimension (i.e., 216 | the number of channels). 217 | * The method will match the data to the input size of the network and 218 | forward propagate it in batches. 219 | 220 | By default, it rescales the examples using 221 | bicubic interpolation to the full input field size of the network, but if you 222 | set ``pad_instead_of_rescale``, they will be instead padded to be centered in 223 | the input field. If you choose padding and ``return_unprocessed_outputs`` is 224 | set to ``False``, the data will automatically be reduced to the relevant 225 | area. 226 | 227 | You may 228 | optionally set callback functions in between the batches to, e.g., 229 | update progress indicators:: 230 | 231 | from barrista.monitoring import ProgressIndicator 232 | # Only the number of channels (3) must match. 233 | inputs = np.zeros((20, 3, 10, 10)) 234 | results = net.predict(inputs, 235 | after_batch_callbacks=[ProgressIndicator()]) 236 | # This works for single-input networks. If you have multiple inputs, just 237 | # provide a dicitonary of layer-names with arrays, as for the fit-method. 238 | # Similarly, in case of a single-output network, this method returns a 239 | # single list of predictions, or, in case of a multi-output network, 240 | # a dictionary of output layer names with their respective output lists. 241 | print(results) 242 | 243 | ======================================================== 244 | Using different architectures to ``fit`` and ``predict`` 245 | ======================================================== 246 | 247 | You have many possibilities to condition the network layout for the very same 248 | network depending on it's state. It has 249 | :py:attr:`barrista.design.NetSpecification.phase`, 250 | :py:attr:`barrista.design.NetSpecification.level` and 251 | :py:attr:`barrista.design.NetSpecification.stages`. The ``phase`` is used 252 | to configure the net during the 'fit' progress to alternate between training 253 | and validation sets. We offer a simple way of using the ``stages`` to switch 254 | between different architectures for 'fit' and 'predict'. 255 | 256 | When designing a network, you can specify the optional parameters 257 | ``predict_inputs`` and ``predict_input_shapes``. If you do so, when 258 | instantiating the net, a second version of the net with the stages set only 259 | to ``predict`` is created (with shared weights with the main network) and 260 | automatically used when calling the :py:func:`barrista.net.Net.predict` 261 | method (for an illustration of this behavior, see also the documentation for 262 | :py:class:`barrista.design.NetSpecification`). 263 | This is a very convenient way of using your networks comfortably and 264 | just as expected, while maintaining a high level of convenience:: 265 | 266 | netspec = design.NetSpecification([[10, 3, 51, 51], [10]], 267 | inputs=['data', 'annotations'], 268 | predict_inputs=['data'], 269 | predict_input_shapes=[[10, 3, 51, 51]]) 270 | # ... add layers as usual. 271 | # This is the last regular one. Use `tops` to give its outputs a 272 | # simple-to-remember name. 273 | layers.append(InnerProductLayer(tops=['net_out'], InnerProduct_num_output=10)) 274 | # Add a layer for being used by the `predict` method: 275 | layers.append(SoftmaxLayer(bottoms=['net_out'], 276 | tops=['out'], 277 | include_stages=['predict'])) 278 | # Add layers for being used by the `fit` method: 279 | layers.append(SoftmaxWithLossLayer(bottoms=['net_out', 'annotations'], 280 | include_stages=['fit'])) 281 | layers.append(AccuracyLayer(name='accuracy', 282 | bottoms=['net_out', 'annotations'], 283 | include_stages=['fit'])) 284 | 285 | Remember that you can additionally use any other conditional criteria such as 286 | ``phase`` and ``level`` to further customize the net. 287 | 288 | Once instantiated, this net will output loss and accuracy when it's 289 | :py:func:`barrista.net.Net.fit` 290 | method is called, and output softmaxed values when it's 291 | :py:func:`barrista.net.Net.predict` method is called. You can find an example 292 | for this in the file ``barrista/examples/showcase.py``. 293 | -------------------------------------------------------------------------------- /examples/MNIST/README.txt: -------------------------------------------------------------------------------- 1 | MNIST example 2 | ============= 3 | 4 | This folder contains a full-featured example for an MNIST model. It does 5 | not reach a specific performance, but is rather meant to be instructive 6 | with basic infrastructure that can be used for new data. 7 | 8 | All files are executable and encapsulate one aspect of the model. They can 9 | all be run with `--help` to get more information. 10 | 11 | To run the training, simply run 12 | 13 | ./train.py testrun --model_name=basic 14 | 15 | This will run the training and store the results in the folder results/testrun. 16 | The model is exchangeable, and must be a Python module in 'models' that has a 17 | `MODEL` property. 18 | 19 | Happy training! 20 | -------------------------------------------------------------------------------- /examples/MNIST/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Preparing the data.""" 3 | # pylint: disable=invalid-name, no-member 4 | from __future__ import print_function 5 | 6 | import os as _os 7 | import logging as _logging 8 | import cv2 as _cv2 9 | import numpy as _np 10 | 11 | import click as _click 12 | import progressbar as _progressbar 13 | from sklearn.datasets import fetch_mldata as _fetch_mldata 14 | 15 | 16 | _LOGGER = _logging.getLogger(__name__) 17 | _DATA_FOLDER = _os.path.join(_os.path.dirname(__file__), 18 | 'data') 19 | if not _os.path.exists(_DATA_FOLDER): 20 | _LOGGER.info("Data folder not found. Creating...") 21 | _os.mkdir(_DATA_FOLDER) 22 | 23 | 24 | def training_data(): 25 | """Get the `MNIST original` training data.""" 26 | _np.random.seed(1) 27 | permutation = _np.random.permutation(range(60000)) 28 | mnist = _fetch_mldata('MNIST original', 29 | data_home=_os.path.join(_DATA_FOLDER, 30 | 'MNIST_original')) 31 | return (mnist.data[:60000, :][permutation, :].reshape((60000, 1, 28, 28)).astype('float32'), 32 | mnist.target[:60000][permutation].reshape((60000, 1)).astype('float32')) 33 | 34 | 35 | def test_data(): 36 | """Get the `MNIST original` test data.""" 37 | mnist = _fetch_mldata('MNIST original', 38 | data_home=_os.path.join(_DATA_FOLDER, 39 | 'MNIST_original')) 40 | return (mnist.data[60000:, :].reshape((10000, 1, 28, 28)).astype('float32'), 41 | mnist.target[60000:].reshape((10000, 1)).astype('float32')) 42 | 43 | 44 | @_click.group() 45 | def _cli(): 46 | """Handle the experiment data.""" 47 | pass 48 | 49 | @_cli.command() 50 | def validate_storage(): 51 | """Validate the data.""" 52 | _LOGGER.info("Validating storage...") 53 | val_folder = _os.path.join(_DATA_FOLDER, 'images') 54 | _LOGGER.info("Writing images to %s.", 55 | val_folder) 56 | if not _os.path.exists(val_folder): 57 | _os.mkdir(val_folder) 58 | _LOGGER.info("Train...") 59 | tr_folder = _os.path.join(val_folder, 'train') 60 | if not _os.path.exists(tr_folder): 61 | _os.mkdir(tr_folder) 62 | tr_data, tr_labels = training_data() 63 | pbar = _progressbar.ProgressBar(maxval=60000 - 1, 64 | widgets=[_progressbar.Percentage(), 65 | _progressbar.Bar(), 66 | _progressbar.ETA()]) 67 | pbar.start() 68 | for idx in range(60000): 69 | _cv2.imwrite(_os.path.join(tr_folder, '%05d_%d.jpg' % (idx, 70 | int(tr_labels[idx, 0]))), 71 | tr_data[idx, 0]) 72 | pbar.update(idx) 73 | pbar.finish() 74 | _LOGGER.info("Test...") 75 | te_folder = _os.path.join(val_folder, 'test') 76 | if not _os.path.exists(te_folder): 77 | _os.mkdir(te_folder) 78 | te_data, te_labels = test_data() 79 | pbar = _progressbar.ProgressBar(maxval=10000 - 1, 80 | widgets=[_progressbar.Percentage(), 81 | _progressbar.Bar(), 82 | _progressbar.ETA()]) 83 | pbar.start() 84 | for idx in range(10000): 85 | _cv2.imwrite(_os.path.join(te_folder, '%05d_%d.jpg' % (idx, 86 | int(te_labels[idx, 0]))), 87 | te_data[idx, 0]) 88 | pbar.update(idx) 89 | pbar.finish() 90 | 91 | if __name__ == '__main__': 92 | _logging.basicConfig(level=_logging.INFO) 93 | _cli() 94 | -------------------------------------------------------------------------------- /examples/MNIST/models/basic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """A simple network to run on MNIST.""" 3 | # pylint: disable=wrong-import-position, invalid-name, no-member 4 | 5 | import logging as _logging 6 | import cv2 as _cv2 7 | import numpy as _np 8 | 9 | import barrista.design as _ds 10 | 11 | _LOGGER = _logging.getLogger() 12 | 13 | _netspec = _ds.NetSpecification([[100, 1, 28, 28], [100,]], 14 | inputs=['data', 'labels'], 15 | predict_inputs=['data'], 16 | predict_input_shapes=[[100, 1, 28, 28]]) 17 | 18 | _layers = [] 19 | 20 | # Build the network. 21 | _layers.append(_ds.ConvolutionLayer( 22 | name='conv1', 23 | bottoms=['data'], 24 | Convolution_num_output=32, 25 | Convolution_kernel_size=(3, 3), 26 | Convolution_weight_filler=_ds.PROTODETAIL.FillerParameter( 27 | type='uniform', 28 | min=-_np.sqrt(1./(3.*3.*1.)), 29 | max=_np.sqrt(1./(3.*3.*1.))))) 30 | _layers.append(_ds.ReLULayer()) 31 | _layers.append(_ds.PoolingLayer( 32 | Pooling_kernel_size=2, 33 | Pooling_stride=2)) 34 | _layers.append(_ds.ConvolutionLayer( 35 | name='conv2', 36 | Convolution_num_output=32, 37 | Convolution_kernel_size=(3, 3), 38 | Convolution_weight_filler=_ds.PROTODETAIL.FillerParameter( 39 | type='uniform', 40 | min=-_np.sqrt(1./(3.*3.*32.)), 41 | max=_np.sqrt(1./(3.*3.*32.))))) 42 | _layers.append(_ds.ReLULayer()) 43 | _layers.append(_ds.PoolingLayer( 44 | Pooling_kernel_size=2, 45 | Pooling_stride=2)) 46 | _layers.append(_ds.InnerProductLayer( 47 | name='out_ip1', 48 | InnerProduct_num_output=256, 49 | InnerProduct_weight_filler=_ds.PROTODETAIL.FillerParameter( 50 | type='uniform', 51 | min=-_np.sqrt(1./1152.), 52 | max=_np.sqrt(1./1152.)))) 53 | _layers.append(_ds.ReLULayer()) 54 | _layers.append(_ds.InnerProductLayer( 55 | InnerProduct_num_output=10, 56 | name='net_out', 57 | tops=['net_out'], 58 | InnerProduct_weight_filler=_ds.PROTODETAIL.FillerParameter( 59 | type='uniform', 60 | min=-_np.sqrt(1./256.), 61 | max=_np.sqrt(1./256.)))) 62 | 63 | _layers.append(_ds.SoftmaxLayer( 64 | name='score', 65 | bottoms=['net_out'], 66 | include_stages=['predict'])) 67 | _layers.append(_ds.SoftmaxWithLossLayer( 68 | name='loss', 69 | bottoms=['net_out', 'labels'], 70 | include_stages=['fit'])) 71 | _layers.append(_ds.AccuracyLayer( 72 | name='accuracy', 73 | bottoms=['net_out', 'labels'], 74 | include_stages=['fit'])) 75 | 76 | 77 | _netspec.layers = _layers 78 | 79 | MODEL = _netspec.instantiate() 80 | 81 | 82 | if __name__ == '__main__': 83 | _logging.basicConfig(level=_logging.INFO) 84 | _LOGGER = _logging.getLogger(__name__) 85 | 86 | name = __file__ + '_vis.png' 87 | _LOGGER.info("Rendering model to %s.", 88 | name) 89 | 90 | vis = MODEL.visualize() 91 | _cv2.imwrite(name, vis) 92 | 93 | _LOGGER.info("Done.") 94 | -------------------------------------------------------------------------------- /examples/MNIST/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Create visualizations.""" 3 | # pylint: disable=no-member, invalid-name 4 | from __future__ import print_function 5 | 6 | import logging 7 | import click 8 | import numpy as np 9 | from sklearn.metrics import accuracy_score 10 | 11 | import barrista.monitoring as mnt 12 | import barrista.net as bnet 13 | from train import _model 14 | from data import training_data, test_data 15 | 16 | _LOGGER = logging.getLogger(__name__) 17 | 18 | 19 | @click.group() 20 | def cli(): 21 | """Test a model.""" 22 | pass 23 | 24 | 25 | @cli.command() 26 | @click.argument('result_folder', type=click.STRING) 27 | @click.option('--epoch', type=click.INT, default=None, 28 | help="The epoch of the model to use.") 29 | @click.option('--image_idx', type=click.INT, default=0, 30 | help="The image to visualize.") 31 | @click.option("--use_cpu", type=click.BOOL, default=False, is_flag=True, 32 | help='Use the CPU. If not set, use the GPU.') 33 | # pylint: disable=too-many-locals 34 | def test_image( 35 | result_folder, 36 | epoch=None, 37 | image_idx=0, 38 | use_cpu=False): 39 | """Test a network on one test image.""" 40 | if use_cpu: 41 | bnet.set_mode_cpu() 42 | else: 43 | bnet.set_mode_gpu() 44 | _LOGGER.info("Loading data...") 45 | tr_data, _ = training_data() 46 | te_data, _ = test_data() 47 | _LOGGER.info("Loading network...") 48 | # Load the model for training. 49 | model, _, _, _ = _model(result_folder, 50 | tr_data.shape[0], 51 | epoch=epoch) 52 | _LOGGER.info("Predicting...") 53 | results = model.predict(te_data, 54 | test_callbacks=[mnt.ProgressIndicator()]) 55 | _LOGGER.info("Prediction for image %d: %s.", 56 | image_idx, str(results[image_idx])) 57 | 58 | 59 | @cli.command() 60 | @click.argument('result_folder', type=click.STRING) 61 | @click.option('--epoch', type=click.INT, default=None, 62 | help="The epoch of the model to use.") 63 | @click.option("--use_cpu", type=click.BOOL, default=False, is_flag=True, 64 | help='Use the CPU. If not set, use the GPU.') 65 | # pylint: disable=too-many-locals 66 | def score( 67 | result_folder, 68 | epoch=None, 69 | use_cpu=False): 70 | """Test a network on the dataset.""" 71 | if use_cpu: 72 | bnet.set_mode_cpu() 73 | else: 74 | bnet.set_mode_gpu() 75 | _LOGGER.info("Loading data...") 76 | tr_data, _ = training_data() 77 | te_data, te_labels = test_data() 78 | _LOGGER.info("Loading network...") 79 | # Load the model. 80 | model, _, _, _ = _model(result_folder, 81 | tr_data.shape[0], 82 | epoch=epoch, 83 | no_solver=True) 84 | _LOGGER.info("Predicting...") 85 | results = model.predict(te_data, 86 | test_callbacks=[mnt.ProgressIndicator()]) 87 | _LOGGER.info("Accuracy: %f.", 88 | accuracy_score(te_labels, 89 | np.argmax(np.array(results), axis=1))) 90 | 91 | 92 | if __name__ == '__main__': 93 | logging.basicConfig(level=logging.INFO) 94 | cli() 95 | -------------------------------------------------------------------------------- /examples/MNIST/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Training of the network.""" 3 | # pylint: disable=wrong-import-position 4 | from __future__ import print_function 5 | import os 6 | import sys 7 | import imp 8 | import shutil 9 | import glob 10 | import logging 11 | 12 | import click 13 | from natsort import natsorted 14 | import barrista.solver as sv 15 | import barrista.net as bnet 16 | import barrista.monitoring as mnt 17 | 18 | from data import training_data, test_data 19 | 20 | 21 | _LOGGER = logging.getLogger(__name__) 22 | LOGFORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' 23 | RESULT_FOLDER = os.path.join(os.path.dirname(__file__), 24 | 'results') 25 | if not os.path.exists(RESULT_FOLDER): 26 | os.mkdir(RESULT_FOLDER) 27 | 28 | 29 | # pylint: disable=too-many-arguments, too-many-locals, too-many-branches 30 | # pylint: disable=too-many-statements 31 | def _model(result_folder, 32 | epoch_size, 33 | model_name=None, 34 | epoch=None, 35 | write_every=10, 36 | optimizer_name='sgd', 37 | lr_param=0.01, 38 | lr_decay_sched=None, 39 | lr_decay_ratio=0.1, 40 | mom_param=0.9, 41 | wd_param=1E-4, 42 | no_solver=False, 43 | allow_overwrite=False): 44 | """Get a model and optimizer either loaded or created.""" 45 | if epoch is not None: 46 | write_every = min(write_every, epoch) 47 | optimizer_name = str(optimizer_name) 48 | out_folder = os.path.join('results', result_folder) 49 | if optimizer_name == 'sgd': 50 | if lr_decay_sched is not None and lr_decay_sched != '': 51 | lr_policy = 'multistep' 52 | # Each value must be multiplied with the epoch size (possibly 53 | # rounded). This is done later once the batch size is known. 54 | lr_decay_sched = [int(val) for val in lr_decay_sched.split(',')] 55 | else: 56 | lr_policy = 'fixed' 57 | optimizer = sv.SGDSolver(base_lr=lr_param, 58 | momentum=mom_param, 59 | weight_decay=wd_param, 60 | lr_policy=lr_policy, 61 | gamma=lr_decay_ratio, 62 | stepvalue=lr_decay_sched, 63 | snapshot_prefix=os.path.join( 64 | str(out_folder), 'model')) 65 | else: 66 | assert lr_decay_sched is not None, ( 67 | "LR decay schedule only supported for SGD!") 68 | optimizer = sv.AdamSolver(base_lr=lr_param, # pylint: disable=redefined-variable-type 69 | weight_decay=wd_param, 70 | snapshot_prefix=os.path.join( 71 | str(out_folder), 'model')) 72 | if os.path.exists(os.path.join('results', result_folder)) and ( 73 | not allow_overwrite or (allow_overwrite and model_name is None)): 74 | assert model_name is None, ( 75 | "This result path already exists! " 76 | "If you still want to use it, add the flag `--allow_overwrite`.") 77 | logging.basicConfig( 78 | level=logging.INFO, 79 | format=LOGFORMAT, 80 | filename=os.path.join('results', result_folder, 'train.log'), 81 | filemode='a') 82 | _LOGGER.info("Provided arguments: %s.", str(sys.argv)) 83 | # Load the data from there. 84 | modelmod = imp.load_source( 85 | '_modelmod', 86 | os.path.join('results', result_folder, 'model.py')) 87 | model = modelmod.MODEL 88 | batch_size = model.blobs['data'].shape[0] 89 | checkpoint_step = round_to_mbsize(epoch_size * write_every, batch_size) / batch_size 90 | if epoch is None: 91 | # Use the last one. 92 | modelfiles = glob.glob(os.path.join('results', 93 | result_folder, 94 | 'model_iter_*.caffemodel')) 95 | if len(modelfiles) == 0: 96 | raise Exception("No model found to resume from!") 97 | lastm = natsorted(modelfiles)[-1] 98 | batch_iters = int(os.path.basename(lastm).split('.')[0][11:]) 99 | base_iter = batch_iters * batch_size 100 | cmfilename = lastm 101 | ssfilename = cmfilename[:-10] + 'solverstate' 102 | else: 103 | assert epoch % write_every == 0, ( 104 | "Writing every %d epochs. Please use a multiple of it!") 105 | cmfilename = os.path.join('results', 106 | result_folder, 107 | 'model_iter_%d.caffemodel' % ( 108 | epoch / write_every * checkpoint_step)) 109 | ssfilename = os.path.join('results', 110 | result_folder, 111 | 'model_iter_%d.solverstate' % ( 112 | epoch / write_every * checkpoint_step)) 113 | base_iter = epoch * epoch_size 114 | assert os.path.exists(cmfilename), ( 115 | "Could not find model parameter file at %s!" % (cmfilename)) 116 | assert os.path.exists(ssfilename), ( 117 | "Could not find solverstate file at %s!" % (ssfilename)) 118 | 119 | _LOGGER.info("Loading model from %s...", cmfilename) 120 | model.load_blobs_from(str(cmfilename)) 121 | if not no_solver: 122 | _LOGGER.info("Loading solverstate from %s...", ssfilename) 123 | if lr_decay_sched is not None: 124 | # pylint: disable=protected-access 125 | optimizer._parameter_dict['stepvalue'] = [ 126 | round_to_mbsize(val * epoch_size, batch_size) 127 | for val in lr_decay_sched] 128 | optimizer.restore(str(ssfilename), model) 129 | else: 130 | # Create the result folder. 131 | assert model_name is not None, ( 132 | "If a new result_folder is specified, a model name must be given!") 133 | out_folder = os.path.join(RESULT_FOLDER, result_folder) 134 | if os.path.exists(out_folder): 135 | # Reset, because an overwrite was requested. 136 | shutil.rmtree(out_folder) 137 | os.mkdir(out_folder) 138 | os.mkdir(os.path.join(out_folder, 'visualizations')) 139 | logging.basicConfig( 140 | level=logging.INFO, 141 | format=LOGFORMAT, 142 | filename=os.path.join(out_folder, 'train.log'), 143 | filemode='w') 144 | _LOGGER.info("Provided arguments: %s.", str(sys.argv)) 145 | _LOGGER.info("Result folder created: %s.", out_folder) 146 | _LOGGER.info("Freezing experimental setup...") 147 | # Copy the contents over. 148 | shutil.copy2(os.path.join('models', model_name + '.py'), 149 | os.path.join(out_folder, 'model.py')) 150 | for pyfile in glob.glob(os.path.join(os.path.dirname(__file__), 151 | '*.py')): 152 | shutil.copy2(pyfile, 153 | os.path.join(out_folder, os.path.basename(pyfile))) 154 | _LOGGER.info("Creating model...") 155 | # Get the model. 156 | modelmod = imp.load_source('_modelmod', 157 | os.path.join(out_folder, 'model.py')) 158 | model = modelmod.MODEL 159 | if not no_solver and lr_decay_sched is not None: 160 | batch_size = model.blobs['data'].shape[0] 161 | # pylint: disable=protected-access 162 | optimizer._parameter_dict['stepvalue'] = [ 163 | round_to_mbsize(val * epoch_size, batch_size) 164 | for val in lr_decay_sched] 165 | base_iter = 0 166 | if no_solver: 167 | return model, None, out_folder, base_iter 168 | else: 169 | return model, optimizer, out_folder, base_iter 170 | 171 | 172 | def round_to_mbsize(value, batch_size): 173 | """Round value to multiple of batch size, if required.""" 174 | if value % batch_size == 0: 175 | return value 176 | else: 177 | return value + batch_size - value % batch_size 178 | 179 | @click.command() 180 | @click.argument("result_folder", type=click.STRING) # pylint: disable=no-member 181 | @click.option("--model_name", type=click.STRING, 182 | help='Model name to use, if a new trial should be created.') 183 | @click.option("--epoch", type=click.INT, default=None, 184 | help='Epoch to start from, if training is resumed.') 185 | @click.option("--num_epoch", type=click.INT, default=3, 186 | help='Final number of epochs to reach. Default: 3.') 187 | @click.option("--optimizer_name", type=click.Choice(['adam', 'sgd']), 188 | default='sgd', 189 | help='Optimizer to use. Default: sgd.') 190 | @click.option("--lr_param", type=click.FLOAT, default=0.001, 191 | help='The base learning rate to use. Default: 0.001.') 192 | @click.option("--lr_decay_sched", type=click.STRING, default='90,135', 193 | help='Scheduled learning rate changes.') 194 | @click.option("--lr_decay_ratio", type=float, default=0.1, 195 | help='Ratio for the change.') 196 | @click.option("--mom_param", type=click.FLOAT, default=0.9, 197 | help='The momentum to use if SGD is the optimizer. Default: 0.9.') 198 | @click.option("--wd_param", type=click.FLOAT, default=0.0001, 199 | help='The weight decay to use. Default: 0.0001.') 200 | @click.option("--monitor", type=click.BOOL, default=False, is_flag=True, 201 | help='Use extended monitoring (slows down training).') 202 | @click.option("--allow_overwrite", type=click.BOOL, default=False, is_flag=True, 203 | help='Allow reuse of an existing result directory.') 204 | @click.option("--use_cpu", type=click.BOOL, default=False, is_flag=True, 205 | help='Use the CPU. If not set, use the GPU.') 206 | # pylint: disable=too-many-arguments, unused-argument 207 | def cli(result_folder, 208 | model_name=None, 209 | epoch=None, 210 | num_epoch=3, 211 | optimizer_name='sgd', 212 | lr_param=0.001, 213 | lr_decay_sched='90,135', 214 | lr_decay_ratio=0.1, 215 | mom_param=0.9, 216 | wd_param=0.0001, 217 | monitor=False, 218 | allow_overwrite=False, 219 | use_cpu=False): 220 | """Train a model.""" 221 | print("Parameters: ", sys.argv) 222 | if use_cpu: 223 | bnet.set_mode_cpu() 224 | else: 225 | bnet.set_mode_gpu() 226 | # Load the data. 227 | tr_data, tr_labels = training_data() 228 | te_data, te_labels = test_data() 229 | # Setup the output folder, including logging. 230 | model, optimizer, out_folder, base_iter = _model( 231 | result_folder, 232 | tr_data.shape[0], 233 | model_name, 234 | epoch, 235 | 1, 236 | optimizer_name, 237 | lr_param, 238 | lr_decay_sched, 239 | lr_decay_ratio, 240 | mom_param, 241 | wd_param, 242 | False, 243 | allow_overwrite) 244 | batch_size = model.blobs['data'].shape[0] 245 | logger = mnt.JSONLogger(str(out_folder), 246 | 'model', 247 | {'train': ['train_loss', 'train_accuracy'], 248 | 'test': ['test_loss', 'test_accuracy']}, 249 | base_iter=base_iter, 250 | write_every=round_to_mbsize(50000, batch_size), 251 | create_plot=monitor) 252 | progr_ind = mnt.ProgressIndicator() 253 | 254 | if monitor: 255 | extra_monitors = [ 256 | mnt.ActivationMonitor(round_to_mbsize(10000, batch_size), 257 | os.path.join(str(out_folder), 258 | 'visualizations' + os.sep), 259 | sample={'data': tr_data[0]}), 260 | mnt.FilterMonitor(round_to_mbsize(10000, batch_size), 261 | os.path.join(str(out_folder), 262 | 'visualizations' + os.sep)), 263 | mnt.GradientMonitor(round_to_mbsize(10000, batch_size), 264 | os.path.join(str(out_folder), 265 | 'visualizations' + os.sep), 266 | relative=True), 267 | ] 268 | else: 269 | extra_monitors = [] 270 | model.fit(round_to_mbsize(num_epoch * tr_data.shape[0], batch_size), 271 | optimizer, 272 | X={'data': tr_data, 'labels': tr_labels}, 273 | X_val={'data': te_data, 'labels': te_labels}, 274 | test_interval=round_to_mbsize(tr_data.shape[0], batch_size), 275 | train_callbacks=[ 276 | progr_ind, 277 | logger, 278 | mnt.Checkpointer(os.path.join(str(out_folder), 279 | 'model'), 280 | round_to_mbsize(tr_data.shape[0], batch_size), 281 | base_iterations=base_iter), 282 | ] + extra_monitors, 283 | test_callbacks=[ 284 | progr_ind, 285 | logger]) 286 | 287 | 288 | if __name__ == '__main__': 289 | cli() # pylint: disable=no-value-for-parameter 290 | -------------------------------------------------------------------------------- /examples/MNIST/visualize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Create visualizations.""" 3 | # pylint: disable=no-member, invalid-name, wrong-import-position 4 | from __future__ import print_function 5 | 6 | import os 7 | import json 8 | import logging 9 | import click 10 | 11 | import numpy as np 12 | _LOGGER = logging.getLogger(__name__) 13 | try: 14 | import matplotlib.pyplot as plt 15 | MPL_AVAILABLE = True 16 | except ImportError: 17 | print("Matplotlib could not be imported!") 18 | MPL_AVAILABLE = False 19 | 20 | 21 | def _sorted_ar(inf, key): 22 | iters = [] 23 | vals = [] 24 | for values in inf: 25 | if values.has_key(key): 26 | iters.append(int(values['NumIters'])) 27 | vals.append(float(values[key])) 28 | sortperm = np.argsort(iters) 29 | arr = np.array([iters, vals]).T 30 | return arr[sortperm, :] 31 | 32 | 33 | def _get_information(logfile): 34 | _LOGGER.info("Getting log information from %s...", logfile) 35 | with open(logfile, 'r') as infile: 36 | perfdict = json.load(infile) 37 | 38 | train_ce = _sorted_ar(perfdict['train'], 'train_loss') 39 | train_ac = _sorted_ar(perfdict['train'], 'train_accuracy') 40 | test_ce = _sorted_ar(perfdict['test'], 'test_loss') 41 | test_ac = _sorted_ar(perfdict['test'], 'test_accuracy') 42 | return train_ce, train_ac, test_ce, test_ac 43 | 44 | 45 | @click.group() 46 | def cli(): 47 | """Create visualizations for model results.""" 48 | pass 49 | 50 | @cli.command() 51 | @click.argument('model_name', type=click.STRING) 52 | @click.option('--display', is_flag=True, default=False, 53 | help='Do not write the output, but display the plot.') 54 | def performance(model_name, display=False): 55 | """Create performance plots.""" 56 | _LOGGER.info('Creating performance plot for model `%s`.', 57 | model_name) 58 | if display: 59 | outfile = None 60 | else: 61 | outfile = os.path.join('results', model_name, 'performance.png') 62 | draw_perfplots(os.path.join('results', model_name, 'barrista_model.json'), 63 | outfile) 64 | _LOGGER.info("Done.") 65 | 66 | 67 | def draw_perfplots(logfile, outfile=None): 68 | """Draw the performance plots.""" 69 | train_ce, train_ac, test_ce, test_ac =\ 70 | _get_information(logfile) 71 | 72 | if not MPL_AVAILABLE: 73 | raise Exception("This method requires Matplotlib!") 74 | _, (ax1, ax2) = plt.subplots(nrows=2, sharex=True) 75 | # Loss. 76 | ax1.set_title("Loss") 77 | ax1.plot(train_ce[:, 0], train_ce[:, 1], 78 | label='Training', c='b', alpha=0.7) 79 | ax1.plot(test_ce[:, 0], test_ce[:, 1], 80 | label='Test', c='g', alpha=0.7) 81 | ax1.scatter(test_ce[:, 0], test_ce[:, 1], 82 | c='g', s=50) 83 | 84 | ax1.set_ylabel('Cross-Entropy-Loss') 85 | ax1.grid() 86 | # Accuracy. 87 | ax2.set_title("Accuracy") 88 | ax2.plot(train_ac[:, 0], train_ac[:, 1], 89 | label='Training', c='b', alpha=0.7) 90 | ax2.plot(test_ac[:, 0], test_ac[:, 1], 91 | label='Test', c='g', alpha=0.7) 92 | ax2.scatter(test_ac[:, 0], test_ac[:, 1], 93 | c='g', s=50) 94 | ax2.set_ylabel('Accuracy') 95 | ax2.grid() 96 | 97 | ax1.legend() 98 | if outfile is not None: 99 | plt.savefig(outfile, bbox_inches='tight') 100 | else: 101 | plt.show() 102 | 103 | 104 | if __name__ == '__main__': 105 | logging.basicConfig(level=logging.INFO) 106 | cli() 107 | -------------------------------------------------------------------------------- /examples/residual-nets/README.txt: -------------------------------------------------------------------------------- 1 | Residual Network example 2 | ======================== 3 | 4 | This folder contains a full-featured example for deep residual networks. 5 | It reaches comparable performance to the network described in 6 | "Deep Residual Learning for Image Recognition", He et al., 2015. 7 | 8 | All files are executable and encapsulate one aspect of the model. They can 9 | all be run with `--help` to get more information. 10 | 11 | To run the training, simply run 12 | 13 | ./train.py testrun --model_name=msra3 14 | 15 | This will run the training and store the results in the folder results/testrun. 16 | The model is exchangeable, and must be a Python module in 'models' that has a 17 | `MODEL` property. 18 | 19 | The `msra3` creates a residual network with 3 residual blocks per network 20 | part with the same image size. The proposed network has 3 such parts. In 21 | total, this corresponds to the 20 layer network from the original paper. 22 | The constructing method simply takes this number of blocks as parameter. 23 | `msra9` thus constructs the 50 layer network, and you can easily play 24 | around with a lot deeper architectures. 25 | 26 | Happy training! 27 | -------------------------------------------------------------------------------- /examples/residual-nets/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Preparing the data.""" 3 | # pylint: disable=invalid-name, no-member 4 | from __future__ import print_function 5 | 6 | import os as _os 7 | import logging as _logging 8 | import cv2 as _cv2 9 | import numpy as _np 10 | 11 | import click as _click 12 | import progressbar as _progressbar 13 | import skdata.cifar10 as _skdc10 14 | 15 | 16 | _LOGGER = _logging.getLogger(__name__) 17 | _DATA_FOLDER = _os.path.join(_os.path.dirname(__file__), 18 | 'data') 19 | _MEAN = None 20 | if not _os.path.exists(_DATA_FOLDER): 21 | _LOGGER.info("Data folder not found. Creating...") 22 | _os.mkdir(_DATA_FOLDER) 23 | 24 | 25 | def training_data(): 26 | """Get the `CIFAR-10` training data.""" 27 | global _MEAN # pylint: disable=global-statement 28 | _np.random.seed(1) 29 | view = _skdc10.view.OfficialImageClassificationTask() 30 | permutation = _np.random.permutation(range(50000)) 31 | if _MEAN is None: 32 | _MEAN = view.train.x.reshape((50000 * 32 * 32, 3)).mean(axis=0) 33 | return ((view.train.x[:50000, :][permutation, :] - _MEAN). 34 | transpose((0, 3, 1, 2)).astype('float32'), 35 | view.train.y[:50000][permutation].reshape((50000, 1)).astype('float32')) 36 | 37 | 38 | def test_data(): 39 | """Get the `CIFAR-10` test data.""" 40 | global _MEAN # pylint: disable=global-statement 41 | _np.random.seed(1) 42 | view = _skdc10.view.OfficialImageClassificationTask() 43 | permutation = _np.random.permutation(range(10000)) 44 | if _MEAN is None: 45 | _MEAN = view.train.x.reshape((50000 * 32 * 32, 3)).mean(axis=0) 46 | return ((view.test.x[:10000, :][permutation, :] - _MEAN). 47 | transpose((0, 3, 1, 2)).astype('float32'), 48 | view.test.y[:10000][permutation].reshape((10000, 1)).astype('float32')) 49 | 50 | 51 | @_click.group() 52 | def _cli(): 53 | """Handle the experiment data.""" 54 | pass 55 | 56 | @_cli.command() 57 | def validate_storage(): 58 | """Validate the data.""" 59 | _LOGGER.info("Validating storage...") 60 | val_folder = _os.path.join(_DATA_FOLDER, 'images') 61 | _LOGGER.info("Writing images to %s.", 62 | val_folder) 63 | if not _os.path.exists(val_folder): 64 | _os.mkdir(val_folder) 65 | _LOGGER.info("Train...") 66 | tr_folder = _os.path.join(val_folder, 'train') 67 | if not _os.path.exists(tr_folder): 68 | _os.mkdir(tr_folder) 69 | tr_data, tr_labels = training_data() 70 | _LOGGER.info("Mean determined as: %s.", str(_MEAN)) 71 | pbar = _progressbar.ProgressBar(maxval=50000 - 1, 72 | widgets=[_progressbar.Percentage(), 73 | _progressbar.Bar(), 74 | _progressbar.ETA()]) 75 | pbar.start() 76 | for idx in range(50000): 77 | _cv2.imwrite(_os.path.join(tr_folder, '%05d_%d.jpg' % (idx, 78 | int(tr_labels[idx, 0]))), 79 | (tr_data[idx,].transpose((1, 2, 0)) + _MEAN).astype('uint8')) 80 | pbar.update(idx) 81 | pbar.finish() 82 | _LOGGER.info("Test...") 83 | te_folder = _os.path.join(val_folder, 'test') 84 | if not _os.path.exists(te_folder): 85 | _os.mkdir(te_folder) 86 | te_data, te_labels = test_data() 87 | pbar = _progressbar.ProgressBar(maxval=10000 - 1, 88 | widgets=[_progressbar.Percentage(), 89 | _progressbar.Bar(), 90 | _progressbar.ETA()]) 91 | pbar.start() 92 | for idx in range(10000): 93 | _cv2.imwrite(_os.path.join(te_folder, '%05d_%d.jpg' % (idx, 94 | int(te_labels[idx, 0]))), 95 | (te_data[idx,].transpose((1, 2, 0)) + _MEAN).astype('uint8')) 96 | pbar.update(idx) 97 | pbar.finish() 98 | 99 | if __name__ == '__main__': 100 | _logging.basicConfig(level=_logging.INFO) 101 | _cli() 102 | -------------------------------------------------------------------------------- /examples/residual-nets/models/msra3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """A reimplementation of the 20-layer MSRA residual net.""" 3 | # pylint: disable=wrong-import-position, invalid-name, no-member 4 | 5 | import logging as _logging 6 | import cv2 as _cv2 7 | import numpy as _np 8 | 9 | import barrista.design as _ds 10 | 11 | _LOGGER = _logging.getLogger() 12 | 13 | _netspec = _ds.NetSpecification([[128, 3, 32, 32], [128,]], 14 | inputs=['data', 'labels'], 15 | predict_inputs=['data'], 16 | predict_input_shapes=[[128, 3, 32, 32]]) 17 | 18 | _layers = [] 19 | _l_idx = 0 20 | _USE_GLOBAL_STATS = None 21 | 22 | def ResBlock(n_in, in_name, n_out, stride=1): 23 | """Create a residual block.""" 24 | global _l_idx # pylint: disable=global-statement 25 | layers = [] 26 | layers.append(_ds.ConvolutionLayer( 27 | name='resblock{}_conv1'.format(_l_idx), 28 | bottoms=[in_name], 29 | Convolution_num_output=n_out, 30 | Convolution_kernel_size=(3, 3), 31 | Convolution_stride=(stride, stride), 32 | Convolution_pad=(1, 1), 33 | Convolution_weight_filler=_ds.PROTODETAIL.FillerParameter(type='msra') 34 | )) 35 | layers.append(_ds.BatchNormLayer( 36 | name='resblock{}_bn1'.format(_l_idx), 37 | BatchNorm_use_global_stats=_USE_GLOBAL_STATS, 38 | BatchNorm_moving_average_fraction=0.9 39 | )) 40 | layers.append(_ds.ScaleLayer( 41 | name='resblock{}_scale1'.format(_l_idx), 42 | Scale_bias_term=True 43 | )) 44 | layers.append(_ds.ReLULayer(name='resblock{}_relu1'.format(_l_idx))) 45 | layers.append(_ds.ConvolutionLayer( 46 | Convolution_num_output=n_out, 47 | Convolution_kernel_size=(3, 3), 48 | Convolution_stride=(1, 1), 49 | Convolution_pad=(1, 1), 50 | Convolution_weight_filler=_ds.PROTODETAIL.FillerParameter(type='msra') 51 | )) 52 | layers.append(_ds.BatchNormLayer( 53 | name='resblock{}_bn2'.format(_l_idx), 54 | BatchNorm_use_global_stats=_USE_GLOBAL_STATS, 55 | BatchNorm_moving_average_fraction=0.9 56 | )) 57 | layers.append(_ds.ScaleLayer( 58 | name='resblock{}_scale2'.format(_l_idx), 59 | Scale_bias_term=True 60 | )) 61 | sum_in = [in_name, 'resblock{}_scale2'.format(_l_idx)] 62 | if n_in != n_out: 63 | layers.append(_ds.ConvolutionLayer( 64 | name='resblock{}_sidepath'.format(_l_idx), 65 | bottoms=[in_name], 66 | Convolution_num_output=n_out, 67 | Convolution_kernel_size=(1, 1), 68 | Convolution_stride=(stride, stride), 69 | Convolution_pad=(0, 0) 70 | )) 71 | sum_in[0] = 'resblock{}_sidepath'.format(_l_idx) 72 | layers.append(_ds.EltwiseLayer( 73 | name='resblock{}_sum'.format(_l_idx), 74 | bottoms=sum_in, 75 | Eltwise_operation=_ds.PROTODETAIL.EltwiseParameter.SUM)) 76 | layers.append(_ds.ReLULayer(name='resblock{}_out'.format(_l_idx))) 77 | _l_idx += 1 78 | return layers, 'resblock{}_out'.format(_l_idx - 1) 79 | 80 | 81 | def _construct_resnet(blocks_per_part): 82 | _layers.append(_ds.ConvolutionLayer( 83 | name='conv_initial', 84 | bottoms=['data'], 85 | Convolution_num_output=16, 86 | Convolution_kernel_size=(3, 3), 87 | Convolution_stride=(1, 1), 88 | Convolution_pad=(1, 1), 89 | Convolution_weight_filler=_ds.PROTODETAIL.FillerParameter( 90 | type='msra'))) 91 | _layers.append(_ds.BatchNormLayer( 92 | name='bn_initial', 93 | BatchNorm_use_global_stats=_USE_GLOBAL_STATS, 94 | BatchNorm_moving_average_fraction=0.9 95 | )) 96 | _layers.append(_ds.ScaleLayer( 97 | name='scale_initial', 98 | Scale_bias_term=True 99 | )) 100 | _layers.append(_ds.ReLULayer(name='relu_initial')) 101 | last_out = 'relu_initial' 102 | for i in range(blocks_per_part): 103 | layers, last_out = ResBlock(16, last_out, 16) 104 | _layers.extend(layers) 105 | for i in range(blocks_per_part): 106 | layers, last_out = ResBlock(32 if i > 0 else 16, 107 | last_out, 108 | 32, 109 | 1 if i > 0 else 2) 110 | _layers.extend(layers) 111 | for i in range(blocks_per_part): 112 | layers, last_out = ResBlock(64 if i > 0 else 32, 113 | last_out, 114 | 64, 115 | 1 if i > 0 else 2) 116 | _layers.extend(layers) 117 | _layers.append(_ds.PoolingLayer( 118 | name='avpool', 119 | Pooling_kernel_size=8, 120 | Pooling_stride=1, 121 | Pooling_pad=0, 122 | Pooling_pool=_ds.PROTODETAIL.PoolingParameter.AVE)) 123 | _layers.append(_ds.InnerProductLayer( 124 | InnerProduct_num_output=10, 125 | InnerProduct_weight_filler=_ds.PROTODETAIL.FillerParameter( 126 | type='uniform', 127 | min=-_np.sqrt(2./64.), 128 | max=_np.sqrt(2./64.)), 129 | name='net_out')) 130 | _layers.append(_ds.BatchNormLayer( 131 | name='net_out_bn', 132 | BatchNorm_use_global_stats=_USE_GLOBAL_STATS, 133 | BatchNorm_moving_average_fraction=0.9 134 | )) 135 | _layers.append(_ds.ScaleLayer( 136 | name='net_out_bnscale')) 137 | _layers.append(_ds.SoftmaxLayer( 138 | name='score', 139 | bottoms=['net_out_bnscale'], 140 | include_stages=['predict'])) 141 | _layers.append(_ds.SoftmaxWithLossLayer( 142 | name='loss', 143 | bottoms=['net_out_bnscale', 'labels'], 144 | include_stages=['fit'])) 145 | _layers.append(_ds.AccuracyLayer( 146 | name='accuracy', 147 | bottoms=['net_out_bnscale', 'labels'], 148 | include_stages=['fit'])) 149 | 150 | _construct_resnet(3) 151 | _netspec.layers = _layers 152 | MODEL = _netspec.instantiate() 153 | for pname, pval in MODEL.params.items(): 154 | if 'sidepath' in pname: 155 | w_ary = _np.zeros((pval[0].data.shape[1], pval[0].data.shape[0]), 156 | dtype='float32') 157 | w_ary[:, :pval[0].data.shape[1]] = _np.eye(pval[0].data.shape[1]) 158 | w_ary = w_ary.T 159 | pval[0].data[:] = w_ary.reshape(pval[0].data.shape) 160 | 161 | 162 | if __name__ == '__main__': 163 | _logging.basicConfig(level=_logging.INFO) 164 | _LOGGER = _logging.getLogger(__name__) 165 | name = __file__ + '_vis.png' 166 | _LOGGER.info("Rendering model to %s.", 167 | name) 168 | vis = MODEL.visualize() 169 | _cv2.imwrite(name, vis) 170 | _LOGGER.info("Done.") 171 | -------------------------------------------------------------------------------- /examples/residual-nets/models/msra9.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """A reimplementation of the 50-layer MSRA residual net.""" 3 | # pylint: disable=wrong-import-position, invalid-name, no-member 4 | 5 | import logging as _logging 6 | import cv2 as _cv2 7 | import numpy as _np 8 | 9 | import barrista.design as _ds 10 | 11 | _LOGGER = _logging.getLogger() 12 | 13 | _netspec = _ds.NetSpecification([[128, 3, 32, 32], [128,]], 14 | inputs=['data', 'labels'], 15 | predict_inputs=['data'], 16 | predict_input_shapes=[[128, 3, 32, 32]]) 17 | 18 | _layers = [] 19 | _l_idx = 0 20 | _USE_GLOBAL_STATS = None 21 | 22 | def ResBlock(n_in, in_name, n_out, stride=1): 23 | """Create a residual block.""" 24 | global _l_idx # pylint: disable=global-statement 25 | layers = [] 26 | layers.append(_ds.ConvolutionLayer( 27 | name='resblock{}_conv1'.format(_l_idx), 28 | bottoms=[in_name], 29 | Convolution_num_output=n_out, 30 | Convolution_kernel_size=(3, 3), 31 | Convolution_stride=(stride, stride), 32 | Convolution_pad=(1, 1), 33 | Convolution_weight_filler=_ds.PROTODETAIL.FillerParameter(type='msra') 34 | )) 35 | layers.append(_ds.BatchNormLayer( 36 | name='resblock{}_bn1'.format(_l_idx), 37 | BatchNorm_use_global_stats=_USE_GLOBAL_STATS, 38 | BatchNorm_moving_average_fraction=0.9 39 | )) 40 | layers.append(_ds.ScaleLayer( 41 | name='resblock{}_scale1'.format(_l_idx), 42 | Scale_bias_term=True 43 | )) 44 | layers.append(_ds.ReLULayer(name='resblock{}_relu1'.format(_l_idx))) 45 | layers.append(_ds.ConvolutionLayer( 46 | Convolution_num_output=n_out, 47 | Convolution_kernel_size=(3, 3), 48 | Convolution_stride=(1, 1), 49 | Convolution_pad=(1, 1), 50 | Convolution_weight_filler=_ds.PROTODETAIL.FillerParameter(type='msra') 51 | )) 52 | layers.append(_ds.BatchNormLayer( 53 | name='resblock{}_bn2'.format(_l_idx), 54 | BatchNorm_use_global_stats=_USE_GLOBAL_STATS, 55 | BatchNorm_moving_average_fraction=0.9 56 | )) 57 | layers.append(_ds.ScaleLayer( 58 | name='resblock{}_scale2'.format(_l_idx), 59 | Scale_bias_term=True 60 | )) 61 | sum_in = [in_name, 'resblock{}_scale2'.format(_l_idx)] 62 | if n_in != n_out: 63 | layers.append(_ds.ConvolutionLayer( 64 | name='resblock{}_sidepath'.format(_l_idx), 65 | bottoms=[in_name], 66 | Convolution_num_output=n_out, 67 | Convolution_kernel_size=(1, 1), 68 | Convolution_stride=(stride, stride), 69 | Convolution_pad=(0, 0) 70 | )) 71 | sum_in[0] = 'resblock{}_sidepath'.format(_l_idx) 72 | layers.append(_ds.EltwiseLayer( 73 | name='resblock{}_sum'.format(_l_idx), 74 | bottoms=sum_in, 75 | Eltwise_operation=_ds.PROTODETAIL.EltwiseParameter.SUM)) 76 | layers.append(_ds.ReLULayer(name='resblock{}_out'.format(_l_idx))) 77 | _l_idx += 1 78 | return layers, 'resblock{}_out'.format(_l_idx - 1) 79 | 80 | 81 | def _construct_resnet(blocks_per_part): 82 | _layers.append(_ds.ConvolutionLayer( 83 | name='conv_initial', 84 | bottoms=['data'], 85 | Convolution_num_output=16, 86 | Convolution_kernel_size=(3, 3), 87 | Convolution_stride=(1, 1), 88 | Convolution_pad=(1, 1), 89 | Convolution_weight_filler=_ds.PROTODETAIL.FillerParameter( 90 | type='msra'))) 91 | _layers.append(_ds.BatchNormLayer( 92 | name='bn_initial', 93 | BatchNorm_use_global_stats=_USE_GLOBAL_STATS, 94 | BatchNorm_moving_average_fraction=0.9 95 | )) 96 | _layers.append(_ds.ScaleLayer( 97 | name='scale_initial', 98 | Scale_bias_term=True 99 | )) 100 | _layers.append(_ds.ReLULayer(name='relu_initial')) 101 | last_out = 'relu_initial' 102 | for i in range(blocks_per_part): 103 | layers, last_out = ResBlock(16, last_out, 16) 104 | _layers.extend(layers) 105 | for i in range(blocks_per_part): 106 | layers, last_out = ResBlock(32 if i > 0 else 16, 107 | last_out, 108 | 32, 109 | 1 if i > 0 else 2) 110 | _layers.extend(layers) 111 | for i in range(blocks_per_part): 112 | layers, last_out = ResBlock(64 if i > 0 else 32, 113 | last_out, 114 | 64, 115 | 1 if i > 0 else 2) 116 | _layers.extend(layers) 117 | _layers.append(_ds.PoolingLayer( 118 | name='avpool', 119 | Pooling_kernel_size=8, 120 | Pooling_stride=1, 121 | Pooling_pad=0, 122 | Pooling_pool=_ds.PROTODETAIL.PoolingParameter.AVE)) 123 | _layers.append(_ds.InnerProductLayer( 124 | InnerProduct_num_output=10, 125 | InnerProduct_weight_filler=_ds.PROTODETAIL.FillerParameter( 126 | type='uniform', 127 | min=-_np.sqrt(2./64.), 128 | max=_np.sqrt(2./64.)), 129 | name='net_out')) 130 | _layers.append(_ds.BatchNormLayer( 131 | name='net_out_bn', 132 | BatchNorm_use_global_stats=_USE_GLOBAL_STATS, 133 | BatchNorm_moving_average_fraction=0.9 134 | )) 135 | _layers.append(_ds.ScaleLayer( 136 | name='net_out_bnscale')) 137 | _layers.append(_ds.SoftmaxLayer( 138 | name='score', 139 | bottoms=['net_out_bnscale'], 140 | include_stages=['predict'])) 141 | _layers.append(_ds.SoftmaxWithLossLayer( 142 | name='loss', 143 | bottoms=['net_out_bnscale', 'labels'], 144 | include_stages=['fit'])) 145 | _layers.append(_ds.AccuracyLayer( 146 | name='accuracy', 147 | bottoms=['net_out_bnscale', 'labels'], 148 | include_stages=['fit'])) 149 | 150 | _construct_resnet(9) 151 | _netspec.layers = _layers 152 | MODEL = _netspec.instantiate() 153 | for pname, pval in MODEL.params.items(): 154 | if 'sidepath' in pname: 155 | w_ary = _np.zeros((pval[0].data.shape[1], pval[0].data.shape[0]), 156 | dtype='float32') 157 | w_ary[:, :pval[0].data.shape[1]] = _np.eye(pval[0].data.shape[1]) 158 | w_ary = w_ary.T 159 | pval[0].data[:] = w_ary.reshape(pval[0].data.shape) 160 | 161 | 162 | if __name__ == '__main__': 163 | _logging.basicConfig(level=_logging.INFO) 164 | _LOGGER = _logging.getLogger(__name__) 165 | name = __file__ + '_vis.png' 166 | _LOGGER.info("Rendering model to %s.", 167 | name) 168 | vis = MODEL.visualize() 169 | _cv2.imwrite(name, vis) 170 | _LOGGER.info("Done.") 171 | -------------------------------------------------------------------------------- /examples/residual-nets/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Create visualizations.""" 3 | # pylint: disable=no-member, invalid-name 4 | from __future__ import print_function 5 | 6 | import logging 7 | import click 8 | import numpy as np 9 | from sklearn.metrics import accuracy_score 10 | 11 | import barrista.monitoring as mnt 12 | import barrista.net as bnet 13 | from train import _model, RandCropMonitor 14 | from data import training_data, test_data 15 | 16 | _LOGGER = logging.getLogger(__name__) 17 | 18 | 19 | @click.group() 20 | def cli(): 21 | """Test a model.""" 22 | pass 23 | 24 | 25 | @cli.command() 26 | @click.argument('result_folder', type=click.STRING) 27 | @click.option('--epoch', type=click.INT, default=None, 28 | help="The epoch of the model to use.") 29 | @click.option('--image_idx', type=click.INT, default=0, 30 | help="The image to visualize.") 31 | @click.option("--use_cpu", type=click.BOOL, default=False, is_flag=True, 32 | help='Use the CPU. If not set, use the GPU.') 33 | # pylint: disable=too-many-locals 34 | def test_image( 35 | result_folder, 36 | epoch=None, 37 | image_idx=0, 38 | use_cpu=False): 39 | """Test a network on one test image.""" 40 | if use_cpu: 41 | bnet.set_mode_cpu() 42 | else: 43 | bnet.set_mode_gpu() 44 | _LOGGER.info("Loading data...") 45 | tr_data, _ = training_data() 46 | te_data, _ = test_data() 47 | from data import _MEAN 48 | _LOGGER.info("Loading network...") 49 | # Load the model for training. 50 | model, _, _, _ = _model(result_folder, 51 | tr_data.shape[0], 52 | epoch=epoch) 53 | _LOGGER.info("Predicting...") 54 | results = model.predict(te_data[:image_idx + 1], 55 | test_callbacks=[ 56 | RandCropMonitor('data', _MEAN), 57 | mnt.ProgressIndicator() 58 | ], 59 | out_blob_names=['score']) 60 | _LOGGER.info("Prediction for image %d: %s.", 61 | image_idx, str(results[image_idx])) 62 | 63 | 64 | @cli.command() 65 | @click.argument('result_folder', type=click.STRING) 66 | @click.option('--epoch', type=click.INT, default=None, 67 | help="The epoch of the model to use.") 68 | @click.option("--use_cpu", type=click.BOOL, default=False, is_flag=True, 69 | help='Use the CPU. If not set, use the GPU.') 70 | # pylint: disable=too-many-locals 71 | def score( 72 | result_folder, 73 | epoch=None, 74 | use_cpu=False): 75 | """Test a network on the dataset.""" 76 | if use_cpu: 77 | bnet.set_mode_cpu() 78 | else: 79 | bnet.set_mode_gpu() 80 | _LOGGER.info("Loading data...") 81 | tr_data, _ = training_data() 82 | te_data, te_labels = test_data() 83 | from data import _MEAN 84 | _LOGGER.info("Loading network...") 85 | # Load the model. 86 | model, _, _, _ = _model(result_folder, 87 | tr_data.shape[0], 88 | epoch=epoch, 89 | no_solver=True) 90 | _LOGGER.info("Predicting...") 91 | results = model.predict(te_data, 92 | test_callbacks=[ 93 | RandCropMonitor('data', _MEAN), 94 | mnt.ProgressIndicator() 95 | ], 96 | out_blob_names=['score']) 97 | _LOGGER.info("Accuracy: %f.", 98 | accuracy_score(te_labels, 99 | np.argmax(np.array(results), axis=1))) 100 | 101 | 102 | if __name__ == '__main__': 103 | logging.basicConfig(level=logging.INFO) 104 | cli() 105 | -------------------------------------------------------------------------------- /examples/residual-nets/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Training of the network.""" 3 | # pylint: disable=wrong-import-position 4 | from __future__ import print_function 5 | import os 6 | import sys 7 | import imp 8 | import shutil 9 | import glob 10 | import logging 11 | import numpy as np 12 | from scipy.misc import imresize 13 | 14 | import click 15 | from natsort import natsorted 16 | import barrista.solver as sv 17 | import barrista.net as bnet 18 | import barrista.monitoring as mnt 19 | 20 | from data import training_data, test_data 21 | 22 | 23 | _LOGGER = logging.getLogger(__name__) 24 | LOGFORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' 25 | RESULT_FOLDER = os.path.join(os.path.dirname(__file__), 26 | 'results') 27 | if not os.path.exists(RESULT_FOLDER): 28 | os.mkdir(RESULT_FOLDER) 29 | 30 | 31 | # pylint: disable=too-many-arguments, too-many-locals, too-many-branches 32 | # pylint: disable=too-many-statements 33 | def _model(result_folder, 34 | epoch_size, 35 | model_name=None, 36 | epoch=None, 37 | write_every=10, 38 | optimizer_name='sgd', 39 | lr_param=0.01, 40 | lr_decay_sched=None, 41 | lr_decay_ratio=0.1, 42 | mom_param=0.9, 43 | wd_param=1E-4, 44 | no_solver=False, 45 | allow_overwrite=False): 46 | """Get a model and optimizer either loaded or created.""" 47 | if epoch is not None: 48 | write_every = min(write_every, epoch) 49 | optimizer_name = str(optimizer_name) 50 | out_folder = os.path.join('results', result_folder) 51 | if optimizer_name == 'sgd': 52 | if lr_decay_sched is not None and lr_decay_sched != '': 53 | lr_policy = 'multistep' 54 | # Each value must be multiplied with the epoch size (possibly 55 | # rounded). This is done later once the batch size is known. 56 | lr_decay_sched = [int(val) for val in lr_decay_sched.split(',')] 57 | else: 58 | lr_policy = 'fixed' 59 | optimizer = sv.SGDSolver(base_lr=lr_param, 60 | momentum=mom_param, 61 | weight_decay=wd_param, 62 | lr_policy=lr_policy, 63 | gamma=lr_decay_ratio, 64 | stepvalue=lr_decay_sched, 65 | snapshot_prefix=os.path.join( 66 | str(out_folder), 'model')) 67 | else: 68 | assert lr_decay_sched is not None, ( 69 | "LR decay schedule only supported for SGD!") 70 | optimizer = sv.AdamSolver(base_lr=lr_param, # pylint: disable=redefined-variable-type 71 | weight_decay=wd_param, 72 | snapshot_prefix=os.path.join( 73 | str(out_folder), 'model')) 74 | if os.path.exists(os.path.join('results', result_folder)) and ( 75 | not allow_overwrite or (allow_overwrite and model_name is None)): 76 | assert model_name is None, ( 77 | "This result path already exists! " 78 | "If you still want to use it, add the flag `--allow_overwrite`.") 79 | logging.basicConfig( 80 | level=logging.INFO, 81 | format=LOGFORMAT, 82 | filename=os.path.join('results', result_folder, 'train.log'), 83 | filemode='a') 84 | _LOGGER.info("Provided arguments: %s.", str(sys.argv)) 85 | # Load the data from there. 86 | modelmod = imp.load_source( 87 | '_modelmod', 88 | os.path.join('results', result_folder, 'model.py')) 89 | model = modelmod.MODEL 90 | batch_size = model.blobs['data'].shape[0] 91 | checkpoint_step = round_to_mbsize(epoch_size * write_every, batch_size) / batch_size 92 | if epoch is None: 93 | # Use the last one. 94 | modelfiles = glob.glob(os.path.join('results', 95 | result_folder, 96 | 'model_iter_*.caffemodel')) 97 | if len(modelfiles) == 0: 98 | raise Exception("No model found to resume from!") 99 | lastm = natsorted(modelfiles)[-1] 100 | batch_iters = int(os.path.basename(lastm).split('.')[0][11:]) 101 | base_iter = batch_iters * batch_size 102 | cmfilename = lastm 103 | ssfilename = cmfilename[:-10] + 'solverstate' 104 | else: 105 | assert epoch % write_every == 0, ( 106 | "Writing every %d epochs. Please use a multiple of it!") 107 | cmfilename = os.path.join('results', 108 | result_folder, 109 | 'model_iter_%d.caffemodel' % ( 110 | epoch / write_every * checkpoint_step)) 111 | ssfilename = os.path.join('results', 112 | result_folder, 113 | 'model_iter_%d.solverstate' % ( 114 | epoch / write_every * checkpoint_step)) 115 | base_iter = epoch * epoch_size 116 | assert os.path.exists(cmfilename), ( 117 | "Could not find model parameter file at %s!" % (cmfilename)) 118 | assert os.path.exists(ssfilename), ( 119 | "Could not find solverstate file at %s!" % (ssfilename)) 120 | 121 | _LOGGER.info("Loading model from %s...", cmfilename) 122 | model.load_blobs_from(str(cmfilename)) 123 | if not no_solver: 124 | _LOGGER.info("Loading solverstate from %s...", ssfilename) 125 | if lr_decay_sched is not None: 126 | # pylint: disable=protected-access 127 | optimizer._parameter_dict['stepvalue'] = [ 128 | round_to_mbsize(val * epoch_size, batch_size) 129 | for val in lr_decay_sched] 130 | optimizer.restore(str(ssfilename), model) 131 | else: 132 | # Create the result folder. 133 | assert model_name is not None, ( 134 | "If a new result_folder is specified, a model name must be given!") 135 | out_folder = os.path.join(RESULT_FOLDER, result_folder) 136 | if os.path.exists(out_folder): 137 | # Reset, because an overwrite was requested. 138 | shutil.rmtree(out_folder) 139 | os.mkdir(out_folder) 140 | os.mkdir(os.path.join(out_folder, 'visualizations')) 141 | logging.basicConfig( 142 | level=logging.INFO, 143 | format=LOGFORMAT, 144 | filename=os.path.join(out_folder, 'train.log'), 145 | filemode='w') 146 | _LOGGER.info("Provided arguments: %s.", str(sys.argv)) 147 | _LOGGER.info("Result folder created: %s.", out_folder) 148 | _LOGGER.info("Freezing experimental setup...") 149 | # Copy the contents over. 150 | shutil.copy2(os.path.join('models', model_name + '.py'), 151 | os.path.join(out_folder, 'model.py')) 152 | for pyfile in glob.glob(os.path.join(os.path.dirname(__file__), 153 | '*.py')): 154 | shutil.copy2(pyfile, 155 | os.path.join(out_folder, os.path.basename(pyfile))) 156 | _LOGGER.info("Creating model...") 157 | # Get the model. 158 | modelmod = imp.load_source('_modelmod', 159 | os.path.join(out_folder, 'model.py')) 160 | model = modelmod.MODEL 161 | if not no_solver and lr_decay_sched is not None: 162 | batch_size = model.blobs['data'].shape[0] 163 | # pylint: disable=protected-access 164 | optimizer._parameter_dict['stepvalue'] = [ 165 | round_to_mbsize(val * epoch_size, batch_size) 166 | for val in lr_decay_sched] 167 | base_iter = 0 168 | if no_solver: 169 | return model, None, out_folder, base_iter 170 | else: 171 | return model, optimizer, out_folder, base_iter 172 | 173 | 174 | class RandCropMonitor(mnt.ParallelMonitor): 175 | 176 | """Creates random crops.""" 177 | 178 | def __init__(self, layer_name, mean, stretch_up_to=1.1, scaling_size=40): 179 | self._layer_name = layer_name 180 | self._mean = mean 181 | self._stretch_up_to = stretch_up_to 182 | self._scaling_size = scaling_size 183 | 184 | def get_parallel_blob_names(self): 185 | return [self._layer_name] 186 | 187 | def _pre_train_batch(self, kwargs): 188 | net = kwargs['net'] 189 | for sample_idx in range(len(net.blobs[self._layer_name].data)): 190 | sample = net.blobs[self._layer_name].data[sample_idx] 191 | if sample.min() == 0. and sample.max() == 0.: 192 | raise Exception("invalid data") 193 | stretch_factor = np.random.uniform(low=1., 194 | high=self._stretch_up_to, 195 | size=(2,)) 196 | im = (imresize(sample.transpose((1, 2, 0)) + self._mean, # pylint: disable=invalid-name 197 | (int(self._scaling_size * stretch_factor[0]), 198 | int(self._scaling_size * stretch_factor[1])), 199 | 'bilinear') - self._mean).transpose((2, 0, 1)) 200 | retimx = np.random.randint(low=0, 201 | high=(im.shape[2] - 202 | net.blobs[self._layer_name].data.shape[3] + 1)) 203 | retimy = np.random.randint(low=0, 204 | high=(im.shape[1] - 205 | net.blobs[self._layer_name].data.shape[2] + 1)) 206 | retim = im[:, 207 | retimy:retimy+net.blobs[self._layer_name].data.shape[2], 208 | retimx:retimx+net.blobs[self._layer_name].data.shape[3]] 209 | sample[...] = retim 210 | 211 | def _pre_test_batch(self, kwargs): 212 | net = kwargs['testnet'] 213 | for sample_idx in range(len(net.blobs[self._layer_name].data)): 214 | sample = net.blobs[self._layer_name].data[sample_idx] 215 | stretch_factor = (1., 1.) 216 | im = (imresize(sample.transpose((1, 2, 0)) + self._mean, # pylint: disable=invalid-name 217 | (int(self._scaling_size * stretch_factor[0]), 218 | int(self._scaling_size * stretch_factor[1])), 219 | 'bilinear') - self._mean).transpose((2, 0, 1)) 220 | retimx = (im.shape[2] - net.blobs[self._layer_name].data.shape[3]) // 2 221 | retimy = (im.shape[1] - net.blobs[self._layer_name].data.shape[2]) // 2 222 | retim = im[:, 223 | retimy:retimy+net.blobs[self._layer_name].data.shape[2], 224 | retimx:retimx+net.blobs[self._layer_name].data.shape[3]] 225 | sample[...] = retim 226 | 227 | 228 | def round_to_mbsize(value, batch_size): 229 | """Round value to multiple of batch size, if required.""" 230 | if value % batch_size == 0: 231 | return value 232 | else: 233 | return value + batch_size - value % batch_size 234 | 235 | @click.command() 236 | @click.argument("result_folder", type=click.STRING) # pylint: disable=no-member 237 | @click.option("--model_name", type=click.STRING, 238 | help='Model name to use, if a new trial should be created.') 239 | @click.option("--epoch", type=click.INT, default=None, 240 | help='Epoch to start from, if training is resumed.') 241 | @click.option("--num_epoch", type=click.INT, default=150, 242 | help='Final number of epochs to reach. Default: 150.') 243 | @click.option("--optimizer_name", type=click.Choice(['adam', 'sgd']), 244 | default='sgd', 245 | help='Optimizer to use. Default: sgd.') 246 | @click.option("--lr_param", type=click.FLOAT, default=0.1, 247 | help='The base learning rate to use. Default: 0.1.') 248 | @click.option("--lr_decay_sched", type=click.STRING, default='90,135', 249 | help='Scheduled learning rate changes.') 250 | @click.option("--lr_decay_ratio", type=float, default=0.1, 251 | help='Ratio for the change.') 252 | @click.option("--mom_param", type=click.FLOAT, default=0.9, 253 | help='The momentum to use if SGD is the optimizer. Default: 0.9.') 254 | @click.option("--wd_param", type=click.FLOAT, default=0.0001, 255 | help='The weight decay to use. Default: 0.0001.') 256 | @click.option("--monitor", type=click.BOOL, default=False, is_flag=True, 257 | help='Use extended monitoring (slows down training).') 258 | @click.option("--allow_overwrite", type=click.BOOL, default=False, is_flag=True, 259 | help='Allow reuse of an existing result directory.') 260 | @click.option("--use_cpu", type=click.BOOL, default=False, is_flag=True, 261 | help='Use the CPU. If not set, use the GPU.') 262 | # pylint: disable=too-many-arguments, unused-argument 263 | def cli(result_folder, 264 | model_name=None, 265 | epoch=None, 266 | num_epoch=150, 267 | optimizer_name='sgd', 268 | lr_param=0.1, 269 | lr_decay_sched='90,135', 270 | lr_decay_ratio=0.1, 271 | mom_param=0.9, 272 | wd_param=0.0001, 273 | monitor=False, 274 | allow_overwrite=False, 275 | use_cpu=False): 276 | """Train a model.""" 277 | print("Parameters: ", sys.argv) 278 | if use_cpu: 279 | bnet.set_mode_cpu() 280 | else: 281 | bnet.set_mode_gpu() 282 | # Load the data. 283 | tr_data, tr_labels = training_data() 284 | te_data, te_labels = test_data() 285 | from data import _MEAN 286 | # Setup the output folder, including logging. 287 | model, optimizer, out_folder, base_iter = _model( 288 | result_folder, 289 | tr_data.shape[0], 290 | model_name, 291 | epoch, 292 | 10, 293 | optimizer_name, 294 | lr_param, 295 | lr_decay_sched, 296 | lr_decay_ratio, 297 | mom_param, 298 | wd_param, 299 | False, 300 | allow_overwrite) 301 | batch_size = model.blobs['data'].shape[0] 302 | logger = mnt.JSONLogger(str(out_folder), 303 | 'model', 304 | {'train': ['train_loss', 'train_accuracy'], 305 | 'test': ['test_loss', 'test_accuracy']}, 306 | base_iter=base_iter, 307 | write_every=round_to_mbsize(10000, batch_size), 308 | create_plot=monitor) 309 | progr_ind = mnt.ProgressIndicator() 310 | cropper = RandCropMonitor('data', _MEAN) 311 | if monitor: 312 | extra_monitors = [ 313 | mnt.ActivationMonitor(round_to_mbsize(10000, batch_size), 314 | os.path.join(str(out_folder), 315 | 'visualizations' + os.sep), 316 | selected_blobs=['resblock3_out', 'avpool'], 317 | sample={'data': tr_data[0]}), 318 | mnt.FilterMonitor(round_to_mbsize(10000, batch_size), 319 | os.path.join(str(out_folder), 320 | 'visualizations' + os.sep), 321 | selected_parameters={'resblock1_conv1': [0], 322 | 'resblock3_conv1': [0], 323 | 'resblock7_conv1': [0]}), 324 | mnt.GradientMonitor(round_to_mbsize(10000, batch_size), 325 | os.path.join(str(out_folder), 326 | 'visualizations' + os.sep), 327 | relative=True, 328 | selected_parameters={'resblock1_conv1': [0, 1], 329 | 'resblock3_conv1': [0, 1], 330 | 'resblock7_conv1': [0, 1]}), 331 | ] 332 | else: 333 | extra_monitors = [] 334 | model.fit(round_to_mbsize(num_epoch * 50000, batch_size), 335 | optimizer, 336 | X={'data': tr_data, 'labels': tr_labels}, 337 | X_val={'data': te_data, 'labels': te_labels}, 338 | test_interval=round_to_mbsize(50000, batch_size), 339 | train_callbacks=[ 340 | progr_ind, 341 | logger, 342 | mnt.RotatingMirroringMonitor({'data': 0}, 0, 0.5), 343 | cropper, 344 | mnt.Checkpointer(os.path.join(str(out_folder), 345 | 'model'), 346 | round_to_mbsize(50000 * 10, batch_size), 347 | base_iterations=base_iter), 348 | ] + extra_monitors, 349 | test_callbacks=[ 350 | progr_ind, 351 | cropper, 352 | logger], 353 | shuffle=True 354 | ) 355 | 356 | 357 | if __name__ == '__main__': 358 | cli() # pylint: disable=no-value-for-parameter 359 | -------------------------------------------------------------------------------- /examples/residual-nets/visualize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Create visualizations.""" 3 | # pylint: disable=no-member, invalid-name, wrong-import-position 4 | from __future__ import print_function 5 | 6 | import os 7 | import json 8 | import logging 9 | import click 10 | 11 | import numpy as np 12 | _LOGGER = logging.getLogger(__name__) 13 | try: 14 | import matplotlib.pyplot as plt 15 | MPL_AVAILABLE = True 16 | except ImportError: 17 | print("Matplotlib could not be imported!") 18 | MPL_AVAILABLE = False 19 | 20 | 21 | def _sorted_ar(inf, key): 22 | iters = [] 23 | vals = [] 24 | for values in inf: 25 | if values.has_key(key): 26 | iters.append(int(values['NumIters'])) 27 | vals.append(float(values[key])) 28 | sortperm = np.argsort(iters) 29 | arr = np.array([iters, vals]).T 30 | return arr[sortperm, :] 31 | 32 | 33 | def _get_information(logfile): 34 | _LOGGER.info("Getting log information from %s...", logfile) 35 | with open(logfile, 'r') as infile: 36 | perfdict = json.load(infile) 37 | 38 | train_ce = _sorted_ar(perfdict['train'], 'train_loss') 39 | train_ac = _sorted_ar(perfdict['train'], 'train_accuracy') 40 | test_ce = _sorted_ar(perfdict['test'], 'test_loss') 41 | test_ac = _sorted_ar(perfdict['test'], 'test_accuracy') 42 | return train_ce, train_ac, test_ce, test_ac 43 | 44 | 45 | @click.group() 46 | def cli(): 47 | """Create visualizations for model results.""" 48 | pass 49 | 50 | @cli.command() 51 | @click.argument('model_name', type=click.STRING) 52 | @click.option('--display', is_flag=True, default=False, 53 | help='Do not write the output, but display the plot.') 54 | def performance(model_name, display=False): 55 | """Create performance plots.""" 56 | _LOGGER.info('Creating performance plot for model `%s`.', 57 | model_name) 58 | if display: 59 | outfile = None 60 | else: 61 | outfile = os.path.join('results', model_name, 'performance.png') 62 | draw_perfplots(os.path.join('results', model_name, 'barrista_model.json'), 63 | outfile) 64 | _LOGGER.info("Done.") 65 | 66 | 67 | def draw_perfplots(logfile, outfile=None): 68 | """Draw the performance plots.""" 69 | train_ce, train_ac, test_ce, test_ac =\ 70 | _get_information(logfile) 71 | 72 | if not MPL_AVAILABLE: 73 | raise Exception("This method requires Matplotlib!") 74 | _, (ax1, ax2) = plt.subplots(nrows=2, sharex=True) 75 | # Loss. 76 | ax1.set_title("Loss") 77 | ax1.plot(train_ce[:, 0], train_ce[:, 1], 78 | label='Training', c='b', alpha=0.7) 79 | ax1.plot(test_ce[:, 0], test_ce[:, 1], 80 | label='Test', c='g', alpha=0.7) 81 | ax1.scatter(test_ce[:, 0], test_ce[:, 1], 82 | c='g', s=50) 83 | 84 | ax1.set_ylabel('Cross-Entropy-Loss') 85 | ax1.grid() 86 | # Accuracy. 87 | ax2.set_title("Accuracy") 88 | ax2.plot(train_ac[:, 0], train_ac[:, 1], 89 | label='Training', c='b', alpha=0.7) 90 | ax2.plot(test_ac[:, 0], test_ac[:, 1], 91 | label='Test', c='g', alpha=0.7) 92 | ax2.scatter(test_ac[:, 0], test_ac[:, 1], 93 | c='g', s=50) 94 | ax2.set_ylabel('Accuracy') 95 | ax2.grid() 96 | 97 | ax1.legend() 98 | if outfile is not None: 99 | plt.savefig(outfile, bbox_inches='tight') 100 | else: 101 | plt.show() 102 | 103 | 104 | if __name__ == '__main__': 105 | logging.basicConfig(level=logging.INFO) 106 | cli() 107 | -------------------------------------------------------------------------------- /examples/showcase.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """An illustrative example of the usage of `barrista`.""" 3 | # pylint: disable=F0401, C0103, E1101, W0611, no-name-in-module 4 | import os 5 | import sys 6 | import logging 7 | 8 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(__file__)), 9 | '..')) 10 | import numpy as np # noqa 11 | 12 | # This provides us with tools to design a network. 13 | import barrista.design as design # noqa 14 | from barrista.design import (ConvolutionLayer, ReLULayer, PoolingLayer, 15 | DropoutLayer, InnerProductLayer, 16 | SoftmaxLayer, SoftmaxWithLossLayer, 17 | AccuracyLayer) # noqa 18 | from barrista.tools import TemporaryDirectory # noqa 19 | # The monitoring module comes with helpful tools to monitor progress and 20 | # performance. 21 | from barrista.monitoring import (ProgressIndicator, Checkpointer, 22 | JSONLogger) 23 | 24 | from barrista import solver as _solver 25 | logging.basicConfig(level=logging.INFO) 26 | 27 | 28 | # When `predict_inputs` and `predict_Input_shapes` are not specified, this 29 | # is used as a straightforward network specification. It they are supplied, 30 | # a virtual second network with stage `predict` is used at prediction time. 31 | netspec = design.NetSpecification([[10, 3, 51, 51], [10]], 32 | inputs=['data', 'annotations'], 33 | predict_inputs=['data'], 34 | predict_input_shapes=[[10, 3, 51, 51]]) 35 | 36 | # This is a VGG like convolutional network. This could now even be created 37 | # procedural! 38 | layers = [] 39 | conv_params = {'Convolution_kernel_size': 3, 40 | 'Convolution_num_output': 32, 41 | 'Convolution_pad': 1} 42 | 43 | layers.append(ConvolutionLayer(**conv_params)) 44 | layers.append(ReLULayer()) 45 | layers.append(ConvolutionLayer(**conv_params)) 46 | layers.append(ReLULayer()) 47 | layers.append(PoolingLayer(Pooling_kernel_size=2, Pooling_stride=2)) 48 | layers.append(DropoutLayer(Dropout_dropout_ratio=0.25)) 49 | 50 | conv_params['Convolution_num_output'] = 64 51 | layers.append(ConvolutionLayer(**conv_params)) 52 | layers.append(ReLULayer()) 53 | layers.append(ConvolutionLayer(**conv_params)) 54 | layers.append(ReLULayer()) 55 | layers.append(PoolingLayer(Pooling_kernel_size=2, Pooling_stride=2)) 56 | layers.append(DropoutLayer(Dropout_dropout_ratio=0.25)) 57 | 58 | layers.append(InnerProductLayer(InnerProduct_num_output=256)) 59 | layers.append(ReLULayer()) 60 | layers.append(DropoutLayer(Dropout_dropout_ratio=0.25)) 61 | 62 | layers.append(InnerProductLayer(tops=['net_out'], InnerProduct_num_output=10)) 63 | 64 | # Output layer for stage `predict`. 65 | layers.append(SoftmaxLayer(tops=['out'], include_stages=['predict'])) 66 | 67 | # Output layers for stage `fit`. 68 | layers.append(SoftmaxWithLossLayer(name='loss', 69 | bottoms=['net_out', 'annotations'], 70 | include_stages=['fit'])) 71 | layers.append(AccuracyLayer(name='accuracy', 72 | bottoms=['net_out', 'annotations'], 73 | include_stages=['fit'])) 74 | 75 | netspec.layers.extend(layers) 76 | # Create the network. Notice how all layers are automatically wired! If you 77 | # selectively name layers or blobs, this is taken into account. 78 | net = netspec.instantiate() 79 | 80 | # Let's do some training (the data does absolutely make no sense and this is 81 | # done solely for illustrative purposes). Note that the amount of inputs may 82 | # be arbitrary, and batching, etc. is automatically taken care of! 83 | X = {'data': np.zeros((11, 3, 51, 51), dtype='float32'), 84 | 'annotations': np.ones((11, 1), dtype='float32')} 85 | 86 | with TemporaryDirectory() as tmpdir: 87 | # Configure our monitors 88 | # . 89 | progress = ProgressIndicator() 90 | perforce = JSONLogger(tmpdir, 91 | 'test', 92 | {'test': ['test_loss', 93 | 'test_accuracy'], 94 | 'train': ['train_loss', 95 | 'train_accuracy']}) 96 | # This is only commented out to let this example run on the travis worker. 97 | # When writing to disk with the program, the process gets killed... 98 | # checkptr = Checkpointer(os.path.join(tmpdir, 'test_net_'), 50) 99 | # Run the training. 100 | net.fit(100, 101 | _solver.SGDSolver(base_lr=0.01), 102 | X, 103 | test_interval=50, # optional 104 | X_val=X, # optional 105 | train_callbacks=[progress, 106 | # checkptr, 107 | perforce], 108 | test_callbacks=[progress, 109 | perforce] 110 | ) 111 | print("Fitting done!") 112 | # Note the flexibility you have with the monitors: they may be used for any 113 | # task! By using a different JSON logger for batch- and test-callbacks, you 114 | # can collect the performance in different logs. 115 | 116 | # Predict some new data. Note, that this is automatically using the weights 117 | # of the trained net, but in the `predict` layout. 118 | results = net.predict(np.zeros((100, 3, 51, 51), dtype='float32'), 119 | test_callbacks=[ProgressIndicator()]) 120 | print("Predicting done!") 121 | # Reloading a model. 122 | # net.load_blobs_from(os.path.join(tmpdir, 'test_net_50.caffemodel')) 123 | 124 | # Visualizing a model. You can add the parameter `display=True` to directly 125 | # show it. 126 | # pylint: disable=W0212 127 | if design._draw is not None: 128 | viz = netspec.visualize() 129 | import cv2 # pylint: disable=wrong-import-order, wrong-import-position 130 | cv2.imwrite(os.path.join(tmpdir, 'test.png'), viz) 131 | 132 | # Going back to medieval age: 133 | netspec.to_prototxt(output_filename=os.path.join(tmpdir, 134 | 'test.prototxt')) 135 | netspec_rel = design.NetSpecification.from_prototxt( 136 | filename=os.path.join(tmpdir, 'test.prototxt')) 137 | -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 University of Tuebingen. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /patches/barrista-patch-caffe-6eae122a8eb84f8371dde815986cd7524fc4cbaa.patch: -------------------------------------------------------------------------------- 1 | From e4304745ef9b42914d669884b7a55766ec804b48 Mon Sep 17 00:00:00 2001 2 | Message-Id: 3 | From: Christoph Lassner 4 | Date: Mon, 5 Oct 2015 06:01:44 +0200 5 | Subject: [PATCH] Barrista compatibility patch. 6 | 7 | --- 8 | include/caffe/solver.hpp | 23 +++++++++++++++++++++++ 9 | python/caffe/_caffe.cpp | 30 ++++++++++++++++++++++++------ 10 | src/caffe/solver.cpp | 28 ++++++++++++++++++++++++++++ 11 | 3 files changed, 75 insertions(+), 6 deletions(-) 12 | 13 | diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp 14 | index 2ecf539..796b761 100644 15 | --- a/include/caffe/solver.hpp 16 | +++ b/include/caffe/solver.hpp 17 | @@ -42,7 +42,11 @@ class Solver { 18 | explicit Solver(const SolverParameter& param, 19 | const Solver* root_solver = NULL); 20 | explicit Solver(const string& param_file, const Solver* root_solver = NULL); 21 | + Solver(const string& param_file, 22 | + shared_ptr > &net); 23 | void Init(const SolverParameter& param); 24 | + void InitForNet(const SolverParameter& param, 25 | + shared_ptr > &net); 26 | void InitTrainNet(); 27 | void InitTestNets(); 28 | 29 | @@ -159,6 +163,9 @@ class SGDSolver : public Solver { 30 | : Solver(param) { PreSolve(); } 31 | explicit SGDSolver(const string& param_file) 32 | : Solver(param_file) { PreSolve(); } 33 | + SGDSolver(const string ¶m_file, 34 | + shared_ptr > net) 35 | + : Solver(param_file, net) { PreSolve(); } 36 | 37 | const vector > >& history() { return history_; } 38 | 39 | @@ -191,6 +198,9 @@ class NesterovSolver : public SGDSolver { 40 | : SGDSolver(param) {} 41 | explicit NesterovSolver(const string& param_file) 42 | : SGDSolver(param_file) {} 43 | + NesterovSolver(const string ¶m_file, 44 | + shared_ptr > net) 45 | + : SGDSolver(param_file, net) {} 46 | 47 | protected: 48 | virtual void ComputeUpdateValue(int param_id, Dtype rate); 49 | @@ -205,6 +215,9 @@ class AdaGradSolver : public SGDSolver { 50 | : SGDSolver(param) { constructor_sanity_check(); } 51 | explicit AdaGradSolver(const string& param_file) 52 | : SGDSolver(param_file) { constructor_sanity_check(); } 53 | + AdaGradSolver(const string ¶m_file, 54 | + shared_ptr > net) 55 | + : SGDSolver(param_file, net) { constructor_sanity_check(); } 56 | 57 | protected: 58 | virtual void ComputeUpdateValue(int param_id, Dtype rate); 59 | @@ -224,6 +237,10 @@ class RMSPropSolver : public SGDSolver { 60 | : SGDSolver(param) { constructor_sanity_check(); } 61 | explicit RMSPropSolver(const string& param_file) 62 | : SGDSolver(param_file) { constructor_sanity_check(); } 63 | + explicit RMSPropSolver(const string& param_file, 64 | + shared_ptr > net) 65 | + : SGDSolver(param_file, net) { constructor_sanity_check(); } 66 | + 67 | 68 | protected: 69 | virtual void ComputeUpdateValue(int param_id, Dtype rate); 70 | @@ -246,6 +263,9 @@ class AdaDeltaSolver : public SGDSolver { 71 | : SGDSolver(param) { AdaDeltaPreSolve(); } 72 | explicit AdaDeltaSolver(const string& param_file) 73 | : SGDSolver(param_file) { AdaDeltaPreSolve(); } 74 | + explicit AdaDeltaSolver(const string& param_file, 75 | + shared_ptr > net) 76 | + : SGDSolver(param_file, net) { AdaDeltaPreSolve(); } 77 | 78 | protected: 79 | void AdaDeltaPreSolve(); 80 | @@ -269,6 +289,9 @@ class AdamSolver : public SGDSolver { 81 | : SGDSolver(param) { AdamPreSolve();} 82 | explicit AdamSolver(const string& param_file) 83 | : SGDSolver(param_file) { AdamPreSolve(); } 84 | + explicit AdamSolver(const string& param_file, 85 | + shared_ptr > net) 86 | + : SGDSolver(param_file, net) { AdamPreSolve(); } 87 | 88 | protected: 89 | void AdamPreSolve(); 90 | diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp 91 | index ccd5776..4075c16 100644 92 | --- a/python/caffe/_caffe.cpp 93 | +++ b/python/caffe/_caffe.cpp 94 | @@ -93,6 +93,11 @@ shared_ptr > Net_Init_Load( 95 | return net; 96 | } 97 | 98 | +void Net_load_blobs_from(Net& net, string filename) { 99 | + CheckFile(filename); 100 | + net.CopyTrainedLayersFrom(filename); 101 | +} 102 | + 103 | void Net_Save(const Net& net, string filename) { 104 | NetParameter net_param; 105 | net.ToProto(&net_param, false); 106 | @@ -224,6 +229,7 @@ BOOST_PYTHON_MODULE(_caffe) { 107 | .def("_forward", &Net::ForwardFromTo) 108 | .def("_backward", &Net::BackwardFromTo) 109 | .def("reshape", &Net::Reshape) 110 | + .def("load_blobs_from", &Net_load_blobs_from) 111 | // The cast is to select a particular overload. 112 | .def("copy_from", static_cast::*)(const string)>( 113 | &Net::CopyTrainedLayersFrom)) 114 | @@ -290,22 +296,34 @@ BOOST_PYTHON_MODULE(_caffe) { 115 | 116 | bp::class_, bp::bases >, 117 | shared_ptr >, boost::noncopyable>( 118 | - "SGDSolver", bp::init()); 119 | + "SGDSolver", bp::no_init) 120 | + .def(bp::init()) 121 | + .def(bp::init > >()); 122 | bp::class_, bp::bases >, 123 | shared_ptr >, boost::noncopyable>( 124 | - "NesterovSolver", bp::init()); 125 | + "NesterovSolver", bp::no_init) 126 | + .def(bp::init()) 127 | + .def(bp::init > >()); 128 | bp::class_, bp::bases >, 129 | shared_ptr >, boost::noncopyable>( 130 | - "AdaGradSolver", bp::init()); 131 | + "AdaGradSolver", bp::no_init) 132 | + .def(bp::init()) 133 | + .def(bp::init > >()); 134 | bp::class_, bp::bases >, 135 | shared_ptr >, boost::noncopyable>( 136 | - "RMSPropSolver", bp::init()); 137 | + "RMSPropSolver", bp::no_init) 138 | + .def(bp::init()) 139 | + .def(bp::init > >()); 140 | bp::class_, bp::bases >, 141 | shared_ptr >, boost::noncopyable>( 142 | - "AdaDeltaSolver", bp::init()); 143 | + "AdaDeltaSolver", bp::no_init) 144 | + .def(bp::init()) 145 | + .def(bp::init > >()); 146 | bp::class_, bp::bases >, 147 | shared_ptr >, boost::noncopyable>( 148 | - "AdamSolver", bp::init()); 149 | + "AdamSolver", bp::no_init) 150 | + .def(bp::init()) 151 | + .def(bp::init > >()); 152 | 153 | bp::def("get_solver", &GetSolverFromFile, 154 | bp::return_value_policy()); 155 | diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp 156 | index 12c13dd..3eaf4be 100644 157 | --- a/src/caffe/solver.cpp 158 | +++ b/src/caffe/solver.cpp 159 | @@ -70,6 +70,34 @@ void Solver::Init(const SolverParameter& param) { 160 | } 161 | 162 | template 163 | +Solver::Solver(const string& param_file, 164 | + shared_ptr > &net) 165 | + : net_(), root_solver_(NULL) { 166 | + SolverParameter param; 167 | + ReadProtoFromTextFileOrDie(param_file, ¶m); 168 | + InitForNet(param, net); 169 | +} 170 | + 171 | +template 172 | +void Solver::InitForNet(const SolverParameter& param, 173 | + shared_ptr > &net) { 174 | + LOG(INFO) << "Initializing solver from parameters: " << std::endl 175 | + << param.DebugString(); 176 | + param_ = param; 177 | + CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative."; 178 | + if (param_.random_seed() >= 0) { 179 | + Caffe::set_random_seed(param_.random_seed()); 180 | + } 181 | + LOG(INFO) << "Solver scaffolding done."; 182 | + iter_ = 0; 183 | + current_step_ = 0; 184 | + // This assumes that the net is configured for training. This method 185 | + // is intended to be used with the Barrista package, and additional 186 | + // checks are performed in the barrista.Net.fit method. 187 | + net_ = net; 188 | +} 189 | + 190 | +template 191 | void Solver::InitTrainNet() { 192 | const int num_train_nets = param_.has_net() + param_.has_net_param() + 193 | param_.has_train_net() + param_.has_train_net_param(); 194 | -- 195 | 1.9.1 196 | 197 | -------------------------------------------------------------------------------- /patches/barrista-patch-caffe-dc831aa8f5d3b7d9473958f5b9e745c98755a0a6.patch: -------------------------------------------------------------------------------- 1 | diff --git a/include/caffe/sgd_solvers.hpp b/include/caffe/sgd_solvers.hpp 2 | index 1fc52d8..f883a12 100644 3 | --- a/include/caffe/sgd_solvers.hpp 4 | +++ b/include/caffe/sgd_solvers.hpp 5 | @@ -19,6 +19,9 @@ class SGDSolver : public Solver { 6 | : Solver(param) { PreSolve(); } 7 | explicit SGDSolver(const string& param_file) 8 | : Solver(param_file) { PreSolve(); } 9 | + SGDSolver(const string ¶m_file, 10 | + shared_ptr > net) 11 | + : Solver(param_file, net) { PreSolve(); } 12 | virtual inline const char* type() const { return "SGD"; } 13 | 14 | const vector > >& history() { return history_; } 15 | @@ -52,6 +55,9 @@ class NesterovSolver : public SGDSolver { 16 | : SGDSolver(param) {} 17 | explicit NesterovSolver(const string& param_file) 18 | : SGDSolver(param_file) {} 19 | + NesterovSolver(const string ¶m_file, 20 | + shared_ptr > net) 21 | + : SGDSolver(param_file, net) {} 22 | virtual inline const char* type() const { return "Nesterov"; } 23 | 24 | protected: 25 | @@ -67,6 +73,9 @@ class AdaGradSolver : public SGDSolver { 26 | : SGDSolver(param) { constructor_sanity_check(); } 27 | explicit AdaGradSolver(const string& param_file) 28 | : SGDSolver(param_file) { constructor_sanity_check(); } 29 | + AdaGradSolver(const string ¶m_file, 30 | + shared_ptr > net) 31 | + : SGDSolver(param_file, net) { constructor_sanity_check(); } 32 | virtual inline const char* type() const { return "AdaGrad"; } 33 | 34 | protected: 35 | @@ -87,6 +96,9 @@ class RMSPropSolver : public SGDSolver { 36 | : SGDSolver(param) { constructor_sanity_check(); } 37 | explicit RMSPropSolver(const string& param_file) 38 | : SGDSolver(param_file) { constructor_sanity_check(); } 39 | + explicit RMSPropSolver(const string& param_file, 40 | + shared_ptr > net) 41 | + : SGDSolver(param_file, net) { constructor_sanity_check(); } 42 | virtual inline const char* type() const { return "RMSProp"; } 43 | 44 | protected: 45 | @@ -110,6 +122,9 @@ class AdaDeltaSolver : public SGDSolver { 46 | : SGDSolver(param) { AdaDeltaPreSolve(); } 47 | explicit AdaDeltaSolver(const string& param_file) 48 | : SGDSolver(param_file) { AdaDeltaPreSolve(); } 49 | + explicit AdaDeltaSolver(const string& param_file, 50 | + shared_ptr > net) 51 | + : SGDSolver(param_file, net) { AdaDeltaPreSolve(); } 52 | virtual inline const char* type() const { return "AdaDelta"; } 53 | 54 | protected: 55 | @@ -134,6 +149,9 @@ class AdamSolver : public SGDSolver { 56 | : SGDSolver(param) { AdamPreSolve();} 57 | explicit AdamSolver(const string& param_file) 58 | : SGDSolver(param_file) { AdamPreSolve(); } 59 | + explicit AdamSolver(const string& param_file, 60 | + shared_ptr > net) 61 | + : SGDSolver(param_file, net) { AdamPreSolve(); } 62 | virtual inline const char* type() const { return "Adam"; } 63 | 64 | protected: 65 | diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp 66 | index 38259ed..f1cb326 100644 67 | --- a/include/caffe/solver.hpp 68 | +++ b/include/caffe/solver.hpp 69 | @@ -43,7 +43,11 @@ class Solver { 70 | explicit Solver(const SolverParameter& param, 71 | const Solver* root_solver = NULL); 72 | explicit Solver(const string& param_file, const Solver* root_solver = NULL); 73 | + Solver(const string& param_file, 74 | + shared_ptr > &net); 75 | void Init(const SolverParameter& param); 76 | + void InitForNet(const SolverParameter& param, 77 | + shared_ptr > &net); 78 | void InitTrainNet(); 79 | void InitTestNets(); 80 | 81 | diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp 82 | index 12a5745..649ea96 100644 83 | --- a/python/caffe/_caffe.cpp 84 | +++ b/python/caffe/_caffe.cpp 85 | @@ -300,22 +300,34 @@ BOOST_PYTHON_MODULE(_caffe) { 86 | 87 | bp::class_, bp::bases >, 88 | shared_ptr >, boost::noncopyable>( 89 | - "SGDSolver", bp::init()); 90 | + "SGDSolver", bp::no_init) 91 | + .def(bp::init()) 92 | + .def(bp::init > >()); 93 | bp::class_, bp::bases >, 94 | shared_ptr >, boost::noncopyable>( 95 | - "NesterovSolver", bp::init()); 96 | + "NesterovSolver", bp::no_init) 97 | + .def(bp::init()) 98 | + .def(bp::init > >()); 99 | bp::class_, bp::bases >, 100 | shared_ptr >, boost::noncopyable>( 101 | - "AdaGradSolver", bp::init()); 102 | + "AdaGradSolver", bp::no_init) 103 | + .def(bp::init()) 104 | + .def(bp::init > >()); 105 | bp::class_, bp::bases >, 106 | shared_ptr >, boost::noncopyable>( 107 | - "RMSPropSolver", bp::init()); 108 | + "RMSPropSolver", bp::no_init) 109 | + .def(bp::init()) 110 | + .def(bp::init > >()); 111 | bp::class_, bp::bases >, 112 | shared_ptr >, boost::noncopyable>( 113 | - "AdaDeltaSolver", bp::init()); 114 | + "AdaDeltaSolver", bp::no_init) 115 | + .def(bp::init()) 116 | + .def(bp::init > >()); 117 | bp::class_, bp::bases >, 118 | shared_ptr >, boost::noncopyable>( 119 | - "AdamSolver", bp::init()); 120 | + "AdamSolver", bp::no_init) 121 | + .def(bp::init()) 122 | + .def(bp::init > >()); 123 | 124 | bp::def("get_solver", &GetSolverFromFile, 125 | bp::return_value_policy()); 126 | diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp 127 | index a5ccf9c..074e116 100644 128 | --- a/src/caffe/solver.cpp 129 | +++ b/src/caffe/solver.cpp 130 | @@ -63,6 +63,34 @@ void Solver::Init(const SolverParameter& param) { 131 | current_step_ = 0; 132 | } 133 | 134 | + template 135 | +Solver::Solver(const string& param_file, 136 | + shared_ptr > &net) 137 | + : net_(), root_solver_(NULL) { 138 | + SolverParameter param; 139 | + ReadProtoFromTextFileOrDie(param_file, ¶m); 140 | + InitForNet(param, net); 141 | +} 142 | + 143 | +template 144 | +void Solver::InitForNet(const SolverParameter& param, 145 | + shared_ptr > &net) { 146 | + LOG(INFO) << "Initializing solver from parameters: " << std::endl 147 | + << param.DebugString(); 148 | + param_ = param; 149 | + CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative."; 150 | + if (param_.random_seed() >= 0) { 151 | + Caffe::set_random_seed(param_.random_seed()); 152 | + } 153 | + LOG(INFO) << "Solver scaffolding done."; 154 | + iter_ = 0; 155 | + current_step_ = 0; 156 | + // This assumes that the net is configured for training. This method 157 | + // is intended to be used with the Barrista package, and additional 158 | + // checks are performed in the barrista.Net.fit method. 159 | + net_ = net; 160 | +} 161 | + 162 | template 163 | void Solver::InitTrainNet() { 164 | const int num_train_nets = param_.has_net() + param_.has_net_param() + 165 | -------------------------------------------------------------------------------- /patches/barrista-patch-caffe-release-candidater-(tag-rc2).patch: -------------------------------------------------------------------------------- 1 | From 46fe85e02ef6975b2e6076fafebec6261306eb63 Mon Sep 17 00:00:00 2001 2 | Message-Id: <46fe85e02ef6975b2e6076fafebec6261306eb63.1444017656.git.mail@christophlassner.de> 3 | From: Christoph Lassner 4 | Date: Mon, 5 Oct 2015 06:00:41 +0200 5 | Subject: [PATCH] Barrista compatibility patch. 6 | 7 | --- 8 | include/caffe/solver.hpp | 13 +++++++++++++ 9 | python/caffe/__init__.py | 2 +- 10 | python/caffe/_caffe.cpp | 18 +++++++++++++++--- 11 | python/caffe/pycaffe.py | 2 +- 12 | src/caffe/solver.cpp | 28 ++++++++++++++++++++++++++++ 13 | 5 files changed, 58 insertions(+), 5 deletions(-) 14 | 15 | diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp 16 | index 2510de7..2be4be1 100644 17 | --- a/include/caffe/solver.hpp 18 | +++ b/include/caffe/solver.hpp 19 | @@ -19,7 +19,11 @@ class Solver { 20 | public: 21 | explicit Solver(const SolverParameter& param); 22 | explicit Solver(const string& param_file); 23 | + Solver(const string& param_file, 24 | + shared_ptr > &net); 25 | void Init(const SolverParameter& param); 26 | + void InitForNet(const SolverParameter& param, 27 | + shared_ptr > &net); 28 | void InitTrainNet(); 29 | void InitTestNets(); 30 | // The main entry of the solver function. In default, iter will be zero. Pass 31 | @@ -74,6 +78,9 @@ class SGDSolver : public Solver { 32 | : Solver(param) { PreSolve(); } 33 | explicit SGDSolver(const string& param_file) 34 | : Solver(param_file) { PreSolve(); } 35 | + SGDSolver(const string ¶m_file, 36 | + shared_ptr > net) 37 | + : Solver(param_file, net) { PreSolve(); } 38 | 39 | const vector > >& history() { return history_; } 40 | 41 | @@ -100,6 +107,9 @@ class NesterovSolver : public SGDSolver { 42 | : SGDSolver(param) {} 43 | explicit NesterovSolver(const string& param_file) 44 | : SGDSolver(param_file) {} 45 | + NesterovSolver(const string ¶m_file, 46 | + shared_ptr > net) 47 | + : SGDSolver(param_file, net) {} 48 | 49 | protected: 50 | virtual void ComputeUpdateValue(); 51 | @@ -114,6 +124,9 @@ class AdaGradSolver : public SGDSolver { 52 | : SGDSolver(param) { constructor_sanity_check(); } 53 | explicit AdaGradSolver(const string& param_file) 54 | : SGDSolver(param_file) { constructor_sanity_check(); } 55 | + AdaGradSolver(const string ¶m_file, 56 | + shared_ptr > net) 57 | + : SGDSolver(param_file, net) { constructor_sanity_check(); } 58 | 59 | protected: 60 | virtual void ComputeUpdateValue(); 61 | diff --git a/python/caffe/__init__.py b/python/caffe/__init__.py 62 | index 37e8956..b191e0a 100644 63 | --- a/python/caffe/__init__.py 64 | +++ b/python/caffe/__init__.py 65 | @@ -1,4 +1,4 @@ 66 | -from .pycaffe import Net, SGDSolver 67 | +from .pycaffe import Net, SGDSolver, AdaGradSolver, NesterovSolver 68 | from ._caffe import set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver 69 | from .proto.caffe_pb2 import TRAIN, TEST 70 | from .classifier import Classifier 71 | diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp 72 | index a5d0e64..1d94bb8 100644 73 | --- a/python/caffe/_caffe.cpp 74 | +++ b/python/caffe/_caffe.cpp 75 | @@ -92,6 +92,11 @@ shared_ptr > Net_Init_Load( 76 | return net; 77 | } 78 | 79 | +void Net_load_blobs_from(Net& net, string filename) { 80 | + CheckFile(filename); 81 | + net.CopyTrainedLayersFrom(filename); 82 | +} 83 | + 84 | void Net_Save(const Net& net, string filename) { 85 | NetParameter net_param; 86 | net.ToProto(&net_param, false); 87 | @@ -191,6 +196,7 @@ BOOST_PYTHON_MODULE(_caffe) { 88 | .def("_forward", &Net::ForwardFromTo) 89 | .def("_backward", &Net::BackwardFromTo) 90 | .def("reshape", &Net::Reshape) 91 | + .def("load_blobs_from", &Net_load_blobs_from) 92 | // The cast is to select a particular overload. 93 | .def("copy_from", static_cast::*)(const string)>( 94 | &Net::CopyTrainedLayersFrom)) 95 | @@ -248,13 +254,19 @@ BOOST_PYTHON_MODULE(_caffe) { 96 | 97 | bp::class_, bp::bases >, 98 | shared_ptr >, boost::noncopyable>( 99 | - "SGDSolver", bp::init()); 100 | + "SGDSolver", bp::no_init) 101 | + .def(bp::init()) 102 | + .def(bp::init > >()); 103 | bp::class_, bp::bases >, 104 | shared_ptr >, boost::noncopyable>( 105 | - "NesterovSolver", bp::init()); 106 | + "NesterovSolver", bp::no_init) 107 | + .def(bp::init()) 108 | + .def(bp::init > >()); 109 | bp::class_, bp::bases >, 110 | shared_ptr >, boost::noncopyable>( 111 | - "AdaGradSolver", bp::init()); 112 | + "AdaGradSolver", bp::no_init) 113 | + .def(bp::init()) 114 | + .def(bp::init > >()); 115 | 116 | bp::def("get_solver", &GetSolverFromFile, 117 | bp::return_value_policy()); 118 | diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py 119 | index 31c145d..59f7287 100644 120 | --- a/python/caffe/pycaffe.py 121 | +++ b/python/caffe/pycaffe.py 122 | @@ -7,7 +7,7 @@ from collections import OrderedDict 123 | from itertools import izip_longest 124 | import numpy as np 125 | 126 | -from ._caffe import Net, SGDSolver 127 | +from ._caffe import Net, SGDSolver, AdaGradSolver, NesterovSolver 128 | import caffe.io 129 | 130 | # We directly update methods from Net here (rather than using composition or 131 | diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp 132 | index 8ed8aec..5e2f656 100644 133 | --- a/src/caffe/solver.cpp 134 | +++ b/src/caffe/solver.cpp 135 | @@ -28,6 +28,34 @@ Solver::Solver(const string& param_file) 136 | } 137 | 138 | template 139 | +void Solver::InitForNet(const SolverParameter& param, 140 | + shared_ptr > &net) { 141 | + LOG(INFO) << "Initializing solver from parameters: " << std::endl 142 | + << param.DebugString(); 143 | + param_ = param; 144 | + CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative."; 145 | + if (param_.random_seed() >= 0) { 146 | + Caffe::set_random_seed(param_.random_seed()); 147 | + } 148 | + LOG(INFO) << "Solver scaffolding done."; 149 | + iter_ = 0; 150 | + current_step_ = 0; 151 | + // This assumes that the net is configured for training. This method 152 | + // is intended to be used with the Barrista package, and additional 153 | + // checks are performed in the barrista.Net.fit method. 154 | + net_ = net; 155 | +} 156 | + 157 | +template 158 | +Solver::Solver(const string& param_file, 159 | + shared_ptr > &net) 160 | + : net_() { 161 | + SolverParameter param; 162 | + ReadProtoFromTextFileOrDie(param_file, ¶m); 163 | + InitForNet(param, net); 164 | +} 165 | + 166 | +template 167 | void Solver::Init(const SolverParameter& param) { 168 | LOG(INFO) << "Initializing solver from parameters: " << std::endl 169 | << param.DebugString(); 170 | -- 171 | 1.9.1 172 | 173 | -------------------------------------------------------------------------------- /patches/build_support/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Script called by Travis to build and test Caffe. 3 | # Travis CI tests are CPU-only for lack of compatible hardware. 4 | 5 | set -e 6 | MAKE="make --jobs=$NUM_THREADS --keep-going" 7 | 8 | mkdir build 9 | cd build 10 | CPU_ONLY=" -DCPU_ONLY=ON" 11 | if ! $WITH_CUDA; then 12 | CPU_ONLY=" -DCPU_ONLY=OFF" 13 | fi 14 | PYTHON_ARGS="" 15 | if [ "$PYTHON_VERSION" = "3" ]; then 16 | PYTHON_ARGS="$PYTHON_ARGS -Dpython_version=3 -DBOOST_LIBRARYDIR=$CONDA_DIR/lib/" 17 | else 18 | PYTHON_ARGS="$PYTHON_ARGS -Dpython_version=2 -DBOOST_LIBRARYDIR=$CONDA_DIR/lib/ -DPYTHON_EXECUTABLE=$CONDA_DIR/bin/python -DPYTHON_INCLUDE_DIR=$CONDA_DIR/include/python2.7/ -DPYTHON_LIBRARY=$CONDA_DIR/lib/libpython2.7.so" 19 | fi 20 | if $WITH_IO; then 21 | IO_ARGS="-DUSE_OPENCV=ON -DUSE_LMDB=ON -DUSE_LEVELDB=ON" 22 | else 23 | IO_ARGS="-DUSE_OPENCV=OFF -DUSE_LMDB=OFF -DUSE_LEVELDB=OFF" 24 | fi 25 | cmake -DBUILD_python=ON -DCMAKE_BUILD_TYPE=Release $CPU_ONLY $PYTHON_ARGS -DCMAKE_INCLUDE_PATH="$CONDA_DIR/include/" -DCMAKE_LIBRARY_PATH="$CONDA_DIR/lib/" $IO_ARGS .. 26 | $MAKE 27 | $MAKE install 28 | cd - 29 | -------------------------------------------------------------------------------- /patches/build_support/pylint_call.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | pylint --disable=star-args,no-member,duplicate-code,wrong-import-order,wrong-import-position,too-many-nested-blocks,too-many-boolean-expressions,deprecated-method barrista && \ 3 | pylint --disable=star-args,no-member,duplicate-code,wrong-import-order,wrong-import-position,too-many-nested-blocks,too-many-boolean-expressions,deprecated-method setup.py && \ 4 | pylint --disable=star-args,no-member,duplicate-code,wrong-import-order,wrong-import-position,too-many-nested-blocks,too-many-boolean-expressions,deprecated-method tests.py 5 | -------------------------------------------------------------------------------- /patches/build_support/travis_setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script must be run with sudo. 3 | 4 | set -e 5 | 6 | MAKE="make --jobs=$NUM_THREADS" 7 | # Install apt packages where the Ubuntu 12.04 default and ppa works for Caffe 8 | 9 | # This ppa is for gflags and glog 10 | add-apt-repository -y ppa:tuleu/precise-backports 11 | apt-get -y update 12 | apt-get install \ 13 | wget git curl \ 14 | python-dev python-numpy python3-dev\ 15 | libleveldb-dev libsnappy-dev libopencv-dev \ 16 | libatlas-dev libatlas-base-dev \ 17 | libhdf5-serial-dev libgflags-dev libgoogle-glog-dev \ 18 | bc 19 | 20 | # Add a special apt-repository to install CMake 2.8.9 for CMake Caffe build, 21 | # if needed. By default, Aptitude in Ubuntu 12.04 installs CMake 2.8.7, but 22 | # Caffe requires a minimum CMake version of 2.8.8. 23 | if $WITH_CMAKE; then 24 | # cmake 3 will make sure that the python interpreter and libraries match 25 | wget --no-check-certificate http://www.cmake.org/files/v3.2/cmake-3.2.3-Linux-x86_64.sh -O cmake3.sh 26 | chmod +x cmake3.sh 27 | ./cmake3.sh --prefix=/usr/ --skip-license --exclude-subdir 28 | fi 29 | 30 | # Install LMDB 31 | LMDB_URL=https://github.com/LMDB/lmdb/archive/LMDB_0.9.14.tar.gz 32 | LMDB_FILE=/tmp/lmdb.tar.gz 33 | pushd . 34 | wget $LMDB_URL -O $LMDB_FILE 35 | tar -C /tmp -xzvf $LMDB_FILE 36 | cd /tmp/lmdb*/libraries/liblmdb/ 37 | $MAKE 38 | $MAKE install 39 | popd 40 | rm -f $LMDB_FILE 41 | 42 | # Install the Python runtime dependencies via miniconda (this is much faster 43 | # than using pip for everything). 44 | export PATH=$CONDA_DIR/bin:$PATH 45 | #if [ ! -d $CONDA_DIR ]; then 46 | rm -rf $CONDA_DIR 47 | if [ "$PYTHON_VERSION" -eq "3" ]; then 48 | wget http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh 49 | else 50 | wget http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh 51 | fi 52 | chmod +x miniconda.sh 53 | ./miniconda.sh -b -p $CONDA_DIR 54 | 55 | conda update --yes conda 56 | if [ "$PYTHON_VERSION" -eq "3" ]; then 57 | # The version of boost we're using for Python 3 depends on 3.4 for now. 58 | conda install --yes python=3.4 59 | fi 60 | conda install --yes numpy scipy matplotlib scikit-image pip 61 | # Let conda install boost (so that boost_python matches) 62 | conda install --yes -c https://conda.binstar.org/menpo boost=1.56.0 63 | #fi 64 | 65 | pushd . 66 | wget https://github.com/google/protobuf/archive/3.0.0-GA.tar.gz -O protobuf-3.tar.gz 67 | tar -C /tmp -xzvf protobuf-3.tar.gz 68 | cd /tmp/protobuf-3*/ 69 | ./autogen.sh 70 | ./configure --prefix=$CONDA_DIR 71 | $MAKE 72 | $MAKE install 73 | popd 74 | 75 | if [ "$PYTHON_VERSION" -eq "3" ]; then 76 | pip install --pre protobuf 77 | else 78 | pip install --pre protobuf 79 | pip install coveralls 80 | pip install coverage 81 | pip install tqdm 82 | fi 83 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | tqdm 4 | natsort 5 | protobuf 6 | progressbar 7 | scikit-image 8 | scikit-learn 9 | skdata 10 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [build_sphinx] 2 | source-dir = documentation/ 3 | build-dir = documentation/_build 4 | all_files = 1 5 | 6 | [upload_sphinx] 7 | upload-dir = documentation/_build/html 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Barrista setup script.""" 3 | # pylint: disable=C0103 4 | import os 5 | from setuptools import setup 6 | 7 | ############################################################################### 8 | # install requirements are defined in requirements.txt 9 | # we only want to define those once, so we re-use the file 10 | ############################################################################### 11 | # call independent path of the requirements file 12 | file_path_requirements = os.path.join( 13 | os.path.dirname(os.path.abspath(__file__)), 'requirements.txt') 14 | 15 | with open(file_path_requirements, 'r') as fi: 16 | requirements = fi.read().splitlines() 17 | 18 | setup(name='barrista', 19 | version='0.4', 20 | description='Serving your caffe right', 21 | author='Christoph Lassner', 22 | author_email='classner@tue.mpg.de', 23 | test_suite='tests', 24 | packages=['barrista'], 25 | install_requires=requirements) 26 | --------------------------------------------------------------------------------