├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _static │ └── style.css ├── _templates │ └── layout.html ├── conf.py └── index.rst ├── logo.png ├── requirements.txt ├── scripts └── dtype_map.py ├── setup.py ├── tests ├── __init__.py ├── base.py ├── core.py ├── models │ ├── __init__.py │ ├── simple.py │ └── simple2.py ├── ops.py └── perf │ ├── create_plots.py │ ├── measure_runtimes.py │ └── simple.py └── tfdeploy.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.sublime-project 2 | *.sublime-workspace 3 | *.pyc 4 | *.log 5 | *.DS_Store 6 | *.pkl 7 | dist 8 | MANIFEST 9 | docs/_build 10 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | 3 | dist: trusty 4 | 5 | language: python 6 | 7 | services: 8 | - docker 9 | 10 | matrix: 11 | include: 12 | - env: TF_TAG=1.0.1 TD_TEST_SCIPY=0 TD_TEST_GPU=0 13 | - env: TF_TAG=1.0.1 TD_TEST_SCIPY=1 TD_TEST_GPU=0 14 | - env: TF_TAG=1.0.1-py3 TD_TEST_SCIPY=0 TD_TEST_GPU=0 15 | 16 | install: 17 | - docker pull tensorflow/tensorflow:$TF_TAG 18 | 19 | script: docker run -t --rm -v `pwd`:/root/tfdeploy -w /root/tfdeploy -e TD_TEST_SCIPY=$TD_TEST_SCIPY -e TD_TEST_GPU=$TD_TEST_GPU tensorflow/tensorflow:$TF_TAG python -m unittest tests 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016-2025, Marcel Rieger 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software without 16 | specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | tfdeploy logo 2 | 3 | [![Build Status](https://travis-ci.org/riga/tfdeploy.svg?branch=master)](https://travis-ci.org/riga/tfdeploy) [![Documentation Status](https://readthedocs.org/projects/tfdeploy/badge/?version=latest)](http://tfdeploy.readthedocs.org/en/latest/?badge=latest) [![Package Status](https://badge.fury.io/py/tfdeploy.svg)](https://badge.fury.io/py/tfdeploy) 4 | 5 | Deploy [tensorflow](https://www.tensorflow.org) graphs for *fast* evaluation and export to *tensorflow-less* environments running [NumPy](http://www.numpy.org). 6 | 7 | > [!NOTE] 8 | > This project started as a personal playground to get an in-depth understanding of TensorFlow's operations and kernels. 9 | > Up to a certain version, the NumPy based operations in tfdeploy provided full feature parity, but it is obvious that such a project cannot keep up with the vast development speed driven by TensorFlow devs and the open-source community. 10 | > 11 | > Therefore, tfdeploy is **no longer actively maintained**. 12 | > However, the code base remains active as an easy-to-read reference implementation for most of the kernels that constitute the heart of todays ML landscape. 13 | 14 | 15 | ##### Evaluation usage 16 | 17 | ```python 18 | import tfdeploy as td 19 | import numpy as np 20 | 21 | model = td.Model("/path/to/model.pkl") 22 | inp, outp = model.get("input", "output") 23 | 24 | batch = np.random.rand(10000, 784) 25 | result = outp.eval({inp: batch}) 26 | ``` 27 | 28 | 29 | ##### Installation and dependencies 30 | 31 | Via [pip](https://pypi.python.org/pypi/tfdeploy) 32 | 33 | ```bash 34 | pip install tfdeploy 35 | ``` 36 | 37 | or by simply copying the file into your project. 38 | 39 | NumPy ≥ 1.10 should be installed on your system. [SciPy](http://www.scipy.org/) is optional. See [optimization](#optimization) for more info on optional packages. 40 | 41 | By design, TensorFlow is required when creating a model. 42 | 43 | 44 | ### Content 45 | 46 | - [Why?](#why) 47 | - [How?](#how) 48 | - [Convert your graph](#convert-your-graph) 49 | - [Load the model and evaluate](#load-the-model-and-evaluate) 50 | - [Write your own operation](#write-your-own-operation) 51 | - [Ensembles](#ensembles) 52 | - [Optimization](#optimization) 53 | - [Performance](#performance) 54 | - [Contributing](#contributing) 55 | - [Development](#development) 56 | - [Authors](#authors) 57 | - [License](#license) 58 | 59 | ## Why? 60 | 61 | Working with TensorFlow is awesome. Model definition and training is simple yet powerful, and the range of built-in features is just striking. 62 | 63 | Model deployment in environments that are not able to run TensorFlow, however, things can be difficult (**note** that tfdeploy was developed before TensorFlow Lite was a thing). 64 | 65 | To boil it down, tfdeploy 66 | 67 | - is lightweight. A single file with < 150 lines of core code. Just copy it to your project. 68 | - [faster](#performance) than using TensorFlow's `Tensor.eval`. 69 | - **does not need TensorFlow** during evaluation. 70 | - only depends on NumPy. 71 | - can load one or more models from a single file. 72 | - does not support GPUs (maybe [gnumpy](http://www.cs.toronto.edu/~tijmen/gnumpy.html) is worth a try here). 73 | 74 | 75 | ## How? 76 | 77 | The central class is `tfdeploy.Model`. The following two examples demonstrate how a model can be created from a TensorFlow graph, saved to and loaded from disk, and eventually evaluated. 78 | 79 | ##### Convert your graph 80 | 81 | ```python 82 | import tensorflow as tf 83 | import tfdeploy as td 84 | 85 | # setup tfdeploy (only when creating models) 86 | td.setup(tf) 87 | 88 | # build your graph 89 | sess = tf.Session() 90 | 91 | # use names for input and output layers 92 | x = tf.placeholder("float", shape=[None, 784], name="input") 93 | W = tf.Variable(tf.truncated_normal([784, 100], stddev=0.05)) 94 | b = tf.Variable(tf.zeros([100])) 95 | y = tf.nn.softmax(tf.matmul(x, W) + b, name="output") 96 | 97 | sess.run(tf.global_variables_initializer()) 98 | 99 | # ... training ... 100 | 101 | # create a tfdeploy model and save it to disk 102 | model = td.Model() 103 | model.add(y, sess) # y and all its ops and related tensors are added recursively 104 | model.save("model.pkl") 105 | ``` 106 | 107 | ##### Load the model and evaluate 108 | 109 | ```python 110 | import numpy as np 111 | import tfdeploy as td 112 | 113 | model = td.Model("model.pkl") 114 | 115 | # shorthand to x and y 116 | x, y = model.get("input", "output") 117 | 118 | # evaluate 119 | batch = np.random.rand(10000, 784) 120 | result = y.eval({x: batch}) 121 | ``` 122 | 123 | ##### Write your own `Operation` 124 | 125 | tfdeploy supports most of the `Operation`'s [implemented in tensorflow](https://www.tensorflow.org/versions/master/api_docs/python/math_ops.html). However, if you miss one (in that case, submit a PR or an issue ;) ) or if you're using custom ops, you might want to extend tfdeploy by defining a new class op that inherits from `tfdeploy.Operation`: 126 | 127 | ```python 128 | import tensorflow as tf 129 | import tfdeploy as td 130 | import numpy as np 131 | 132 | # setup tfdeploy (only when creating models) 133 | td.setup(tf) 134 | 135 | # ... write you model here ... 136 | 137 | # let's assume your final tensor "y" relies on an op of type "InvertedSoftmax" 138 | # before creating the td.Model, you should add that op to tfdeploy 139 | 140 | class InvertedSoftmax(td.Operation): 141 | @staticmethod 142 | def func(a): 143 | e = np.exp(-a) 144 | # ops should return a tuple 145 | return np.divide(e, np.sum(e, axis=-1, keepdims=True)), 146 | 147 | # this is equivalent to 148 | # @td.Operation.factory 149 | # def InvertedSoftmax(a): 150 | # e = np.exp(-a) 151 | # return np.divide(e, np.sum(e, axis=-1, keepdims=True)), 152 | 153 | # now we're good to go 154 | model = td.Model() 155 | model.add(y, sess) 156 | model.save("model.pkl") 157 | ``` 158 | 159 | When writing new ops, three things are important: 160 | 161 | - Try to avoid loops, prefer NumPy vectorization. 162 | - Return a tuple. 163 | - Don't change incoming tensors/arrays in-place, always work on and return copies. 164 | 165 | 166 | ## Ensembles 167 | 168 | tfdeploy provides a helper class to evaluate an ensemble of models: `Ensemble`. It can load multiple models, evaluate them and combine their output values using different methods. 169 | 170 | ```python 171 | # create the ensemble 172 | ensemble = td.Ensemble(["model1.pkl", "model2.pkl", ...], method=td.METHOD_MEAN) 173 | 174 | # get input and output tensors (which actually are TensorEnsemble instances) 175 | input, output = ensemble.get("input", "output") 176 | 177 | # evaluate the ensemble just like a normal model 178 | batch = ... 179 | value = output.eval({input: batch}) 180 | ``` 181 | 182 | The return value of `get()` is a `TensorEnsemble` istance. It is basically a wrapper around multiple tensors and should be used as keys in the `feed_dict` of the `eval()` call. 183 | 184 | You can choose between `METHOD_MEAN` (the default), `METHOD_MAX` and `METHOD_MIN`. If you want to use a custom ensembling method, use `METHOD_CUSTOM` and overwrite the static `func_custom()` method of the `TensorEnsemble` instance. 185 | 186 | 187 | ## Optimization 188 | 189 | Most ops are written using pure numpy. However, multiple implementations of the same op are allowed that may use additional third-party Python packages providing even faster functionality for some situations. 190 | 191 | For example, NumPy does not provide a vectorized *lgamma* function. Thus, the standard `tfdeploy.Lgamma` op uses `math.lgamma` that was previously vectorized using `numpy.vectorize`. For these situations, additional implementations of the same op are possible (the *lgamma* example is quite academic, but this definitely makes sense for more sophisticated ops like pooling). We can simply tell the op to use its SciPy implementation instead: 192 | 193 | ```python 194 | td.Lgamma.use_impl(td.IMPL_SCIPY) 195 | ``` 196 | 197 | Currently, allowed implementation types are NumPy (`IMPL_NUMPY`, the default) and SciPy (`IMPL_SCIPY`). 198 | 199 | 200 | ##### Adding additional implementations 201 | 202 | Additional implementations can be added by setting the `impl` attribute of the op factory or by using the `add_impl` decorator of existing operations. The first registered implementation will be the default one. 203 | 204 | ```python 205 | # create the default lgamma op with numpy implementation 206 | lgamma_vec = np.vectorize(math.lgamma) 207 | 208 | @td.Operation.factory 209 | # equivalent to 210 | # @td.Operation.factory(impl=td.IMPL_NUMPY) 211 | def Lgamma(a): 212 | return lgamma_vec(a), 213 | 214 | # add a scipy-based implementation 215 | @Lgamma.add_impl(td.IMPL_SCIPY) 216 | def Lgamma(a): 217 | return sp.special.gammaln(a), 218 | ``` 219 | 220 | 221 | ##### Auto-optimization 222 | 223 | If SciPy is available on your system, it is reasonable to use all ops in their SciPy implementation (if it exists, of course). This should be configured before you create any model from TensorFlow objects using the second argument of the `setup` function: 224 | 225 | ```python 226 | td.setup(tf, td.IMPL_SCIPY) 227 | ``` 228 | 229 | Ops that do not implement `IMPL_SCIPY` stick with the NumPy version (`IMPL_NUMPY`). 230 | 231 | 232 | ## Performance 233 | 234 | tfdeploy is lightweight (1 file, < 150 lines of core code) and fast. Internal evaluation calls have only very few overhead and tensor operations use NumPy vectorization. The actual performance depends on the ops in your graph. While most of the TensorFlow ops have a numpy equivalent or can be constructed from NumPy functions, a few ops require additional Python-based loops (e.g. `BatchMatMul`). But in many cases (and for small to medium graphs) it's potentially faster than using TensorFlow's `Tensor.eval`. 235 | 236 | This is a comparison for a basic graph where all ops are vectorized (basically `Add`, `MatMul` and `Softmax`): 237 | 238 | ```bash 239 | > ipython -i tests/perf/simple.py 240 | 241 | In [1]: %timeit -n 100 test_tf() 242 | 100 loops, best of 3: 109 ms per loop 243 | 244 | In [2]: %timeit -n 100 test_td() 245 | 100 loops, best of 3: 60.5 ms per loop 246 | ``` 247 | 248 | ## Contributing 249 | 250 | If you want to contribute with new ops and features, I'm happy to receive pull requests. Just make sure to add a new test case to `tests/core.py` or `tests/ops.py` and run them via: 251 | 252 | ```bash 253 | > python -m unittest tests 254 | ``` 255 | 256 | 257 | ##### Test grid 258 | 259 | In general, tests should be run for different environments: 260 | 261 | | Variation | Values | 262 | | ------------------ | ------- | 263 | | tensorflow version | `1.0.1` | 264 | | python version | 2, 3 | 265 | | `TD_TEST_SCIPY` | 0, 1 | 266 | | `TD_TEST_GPU` | 0, 1 | 267 | 268 | 269 | ##### Docker 270 | 271 | For testing purposes, it is convenient to use docker. Fortunately, the official [tensorflow images](https://hub.docker.com/r/tensorflow/tensorflow/) contain all we need: 272 | 273 | ```bash 274 | git clone https://github.com/riga/tfdeploy.git 275 | cd tfdeploy 276 | 277 | docker run --rm -v `pwd`:/root/tfdeploy -w /root/tfdeploy -e "TD_TEST_SCIPY=1" tensorflow/tensorflow:1.0.1 python -m unittest tests 278 | ``` 279 | 280 | 281 | ## Development 282 | 283 | - Source hosted at [GitHub](https://github.com/riga/tfdeploy) 284 | - Report issues, questions, feature requests on [GitHub Issues](https://github.com/riga/tfdeploy/issues) 285 | 286 | 287 | ## Authors 288 | 289 | - [Marcel R.](https://github.com/riga) 290 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " applehelp to make an Apple Help Book" 34 | @echo " devhelp to make HTML files and a Devhelp project" 35 | @echo " epub to make an epub" 36 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 37 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 38 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 39 | @echo " text to make text files" 40 | @echo " man to make manual pages" 41 | @echo " texinfo to make Texinfo files" 42 | @echo " info to make Texinfo files and run them through makeinfo" 43 | @echo " gettext to make PO message catalogs" 44 | @echo " changes to make an overview of all changed/added/deprecated items" 45 | @echo " xml to make Docutils-native XML files" 46 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 47 | @echo " linkcheck to check all external links for integrity" 48 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 49 | @echo " coverage to run coverage check of the documentation (if enabled)" 50 | 51 | clean: 52 | rm -rf $(BUILDDIR)/* 53 | 54 | html: 55 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 56 | @echo 57 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 58 | 59 | dirhtml: 60 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 61 | @echo 62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 63 | 64 | singlehtml: 65 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 66 | @echo 67 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 68 | 69 | pickle: 70 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 71 | @echo 72 | @echo "Build finished; now you can process the pickle files." 73 | 74 | json: 75 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 76 | @echo 77 | @echo "Build finished; now you can process the JSON files." 78 | 79 | htmlhelp: 80 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 81 | @echo 82 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 83 | ".hhp project file in $(BUILDDIR)/htmlhelp." 84 | 85 | qthelp: 86 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 87 | @echo 88 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 89 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 90 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/tfdeploy.qhcp" 91 | @echo "To view the help file:" 92 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/tfdeploy.qhc" 93 | 94 | applehelp: 95 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 96 | @echo 97 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 98 | @echo "N.B. You won't be able to view it unless you put it in" \ 99 | "~/Library/Documentation/Help or install it in your application" \ 100 | "bundle." 101 | 102 | devhelp: 103 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 104 | @echo 105 | @echo "Build finished." 106 | @echo "To view the help file:" 107 | @echo "# mkdir -p $$HOME/.local/share/devhelp/tfdeploy" 108 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/tfdeploy" 109 | @echo "# devhelp" 110 | 111 | epub: 112 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 113 | @echo 114 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 115 | 116 | latex: 117 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 118 | @echo 119 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 120 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 121 | "(use \`make latexpdf' here to do that automatically)." 122 | 123 | latexpdf: 124 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 125 | @echo "Running LaTeX files through pdflatex..." 126 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 127 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 128 | 129 | latexpdfja: 130 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 131 | @echo "Running LaTeX files through platex and dvipdfmx..." 132 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 133 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 134 | 135 | text: 136 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 137 | @echo 138 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 139 | 140 | man: 141 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 142 | @echo 143 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 144 | 145 | texinfo: 146 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 147 | @echo 148 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 149 | @echo "Run \`make' in that directory to run these through makeinfo" \ 150 | "(use \`make info' here to do that automatically)." 151 | 152 | info: 153 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 154 | @echo "Running Texinfo files through makeinfo..." 155 | make -C $(BUILDDIR)/texinfo info 156 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 157 | 158 | gettext: 159 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 160 | @echo 161 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 162 | 163 | changes: 164 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 165 | @echo 166 | @echo "The overview file is in $(BUILDDIR)/changes." 167 | 168 | linkcheck: 169 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 170 | @echo 171 | @echo "Link check complete; look for any errors in the above output " \ 172 | "or in $(BUILDDIR)/linkcheck/output.txt." 173 | 174 | doctest: 175 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 176 | @echo "Testing of doctests in the sources finished, look at the " \ 177 | "results in $(BUILDDIR)/doctest/output.txt." 178 | 179 | coverage: 180 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 181 | @echo "Testing of coverage in the sources finished, look at the " \ 182 | "results in $(BUILDDIR)/coverage/python.txt." 183 | 184 | xml: 185 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 186 | @echo 187 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 188 | 189 | pseudoxml: 190 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 191 | @echo 192 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 193 | -------------------------------------------------------------------------------- /docs/_static/style.css: -------------------------------------------------------------------------------- 1 | h1.logo, 2 | div.sphinxsidebarwrapper > h3 { 3 | display: none; 4 | } 5 | -------------------------------------------------------------------------------- /docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {# layout.html #} 2 | {% extends "!layout.html" %} 3 | 4 | {% set css_files = css_files + ['_static/style.css'] %} 5 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import sys 5 | import os 6 | import shlex 7 | 8 | 9 | sys.path.insert(0, os.path.abspath("..")) 10 | import tfdeploy as td 11 | 12 | 13 | project = "tfdeploy" 14 | author = td.__author__ 15 | copyright = td.__copyright__ 16 | version = td.__version__ 17 | release = td.__version__ 18 | 19 | 20 | templates_path = ["_templates"] 21 | html_static_path = ["_static"] 22 | master_doc = "index" 23 | source_suffix = ".rst" 24 | 25 | 26 | exclude_patterns = [] 27 | pygments_style = "sphinx" 28 | html_logo = "../logo.png" 29 | html_theme = "alabaster" 30 | html_sidebars = {"**": [ 31 | "about.html", 32 | "localtoc.html", 33 | "searchbox.html"] 34 | } 35 | html_theme_options = { 36 | "github_user": "riga", 37 | "github_repo": "tfdeploy", 38 | "travis_button": True 39 | } 40 | 41 | 42 | extensions = [ 43 | "sphinx.ext.autodoc" 44 | ] 45 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | tfdeploy 2 | ======== 3 | 4 | .. centered:: This page contains only API docs. For more info, visit `tfdeploy on GitHub `_. 5 | 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | 10 | 11 | .. automodule:: tfdeploy 12 | 13 | 14 | Classes 15 | ^^^^^^^ 16 | 17 | ``Model`` 18 | --------- 19 | 20 | .. autoclass:: Model 21 | :member-order: bysource 22 | :members: 23 | 24 | 25 | ``Tensor`` 26 | ---------- 27 | 28 | .. autoclass:: Tensor 29 | :member-order: bysource 30 | :members: 31 | 32 | 33 | ``Operation`` 34 | ------------- 35 | 36 | .. autoclass:: Operation 37 | :member-order: bysource 38 | :members: 39 | 40 | 41 | ``Ensemble`` 42 | ------------ 43 | 44 | .. autoclass:: Ensemble 45 | :member-order: bysource 46 | :members: 47 | 48 | 49 | ``TensorEnsemble`` 50 | ------------------ 51 | 52 | .. autoclass:: TensorEnsemble 53 | :member-order: bysource 54 | :members: 55 | 56 | 57 | Functions 58 | ^^^^^^^^^ 59 | 60 | ``setup`` 61 | --------- 62 | 63 | .. autofunction:: setup 64 | 65 | 66 | ``reset`` 67 | --------- 68 | 69 | .. autofunction:: reset 70 | 71 | 72 | ``optimize`` 73 | ------------ 74 | 75 | .. autofunction:: optimize 76 | 77 | 78 | ``print_tensor`` 79 | ---------------- 80 | 81 | .. autofunction:: print_tensor 82 | 83 | 84 | ``print_op`` 85 | ------------ 86 | 87 | .. autofunction:: print_op 88 | 89 | 90 | ``print_tf_tensor`` 91 | ------------------- 92 | 93 | .. autofunction:: print_tf_tensor 94 | 95 | 96 | ``print_tf_op`` 97 | --------------- 98 | 99 | .. autofunction:: print_tf_op 100 | 101 | 102 | Other Attributes 103 | ^^^^^^^^^^^^^^^^ 104 | 105 | .. py:attribute:: IMPL_NUMPY 106 | 107 | Implementation type for ops that use numpy (the default). 108 | 109 | .. py:attribute:: IMPL_SCIPY 110 | 111 | Implementation type for ops that use scipy. 112 | 113 | .. py:attribute:: HAS_SCIPY 114 | 115 | A flag that is *True* when scipy is available on your system. 116 | 117 | 118 | Exceptions 119 | ^^^^^^^^^^ 120 | 121 | ``UnknownOperationException`` 122 | ----------------------------- 123 | 124 | .. autoexception:: UnknownOperationException 125 | 126 | 127 | ``OperationMismatchException`` 128 | ------------------------------ 129 | 130 | .. autoexception:: OperationMismatchException 131 | 132 | 133 | ``InvalidImplementationException`` 134 | ---------------------------------- 135 | 136 | .. autoexception:: InvalidImplementationException 137 | 138 | 139 | ``UnknownImplementationException`` 140 | ---------------------------------- 141 | 142 | .. autoexception:: UnknownImplementationException 143 | 144 | 145 | ``UnknownEnsembleMethodException`` 146 | ---------------------------------- 147 | 148 | .. autoexception:: UnknownEnsembleMethodException 149 | 150 | 151 | ``EnsembleMismatchException`` 152 | ----------------------------- 153 | 154 | .. autoexception:: EnsembleMismatchException 155 | 156 | 157 | ``ScipyOperationException`` 158 | --------------------------- 159 | 160 | .. autoexception:: ScipyOperationException 161 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/riga/tfdeploy/c519c6342997afd313b4e11e8417ce70eef7d1dd/logo.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | -------------------------------------------------------------------------------- /scripts/dtype_map.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Script that prints a mapping of tensorflow dtype nums to and numpy dtypes, e.g.: 5 | 6 | > python dtype_map.py 7 | dtype_map = { 8 | 1: np.float32, 9 | 2: np.float64, 10 | 3: np.int32, 11 | 4: np.uint8, 12 | ... 13 | } 14 | """ 15 | 16 | 17 | import tensorflow as tf 18 | 19 | 20 | # create the mapping 21 | dtype_map = {} 22 | 23 | # fill it 24 | types_pb2 = tf.core.framework.types_pb2 25 | for attr in dir(types_pb2): 26 | if attr.startswith("DT_"): 27 | tf_type_enum = getattr(types_pb2, attr) 28 | try: 29 | dtype_map[tf_type_enum] = "np." + tf.as_dtype(tf_type_enum).as_numpy_dtype.__name__ 30 | except: 31 | pass 32 | 33 | # print dict-like code 34 | dtype_map = "\n".join(" %s: %s," % tpl for tpl in dtype_map.items()) 35 | print("{\n" + dtype_map[:-1] + "\n}") 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import os 5 | from subprocess import Popen, PIPE 6 | from distutils.core import setup 7 | import tfdeploy as td 8 | 9 | 10 | readme = os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md") 11 | if os.path.isfile(readme): 12 | cmd = "pandoc --from=markdown --to=rst " + readme 13 | p = Popen(cmd, stdout=PIPE, stderr=PIPE, shell=True, executable="/bin/bash") 14 | out, err = p.communicate() 15 | if p.returncode != 0: 16 | print("pandoc conversion failed: " + err) 17 | long_description = out 18 | else: 19 | long_description = "" 20 | 21 | keywords = [ 22 | "tensorflow", "deploy", "export", "dump", "numpy", "model", "predict", "evaluate", "function", 23 | "method" 24 | ] 25 | 26 | classifiers = [ 27 | "Programming Language :: Python", 28 | "Programming Language :: Python :: 2", 29 | "Programming Language :: Python :: 3", 30 | "Development Status :: 4 - Beta", 31 | "Operating System :: OS Independent", 32 | "License :: OSI Approved :: MIT License", 33 | "Intended Audience :: Developers", 34 | "Intended Audience :: Science/Research", 35 | "Intended Audience :: Information Technology", 36 | "Topic :: Scientific/Engineering :: Artificial Intelligence" 37 | ] 38 | 39 | 40 | setup( 41 | name = td.__name__, 42 | version = td.__version__, 43 | author = td.__author__, 44 | description = td.__doc__.strip(), 45 | license = td.__license__, 46 | url = td.__contact__, 47 | py_modules = [td.__name__], 48 | keywords = keywords, 49 | classifiers = classifiers, 50 | long_description = long_description or td.__doc__.strip() 51 | ) 52 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # some logs 4 | import os 5 | import sys 6 | import numpy as np 7 | import tensorflow as tf 8 | print(80 * "-") 9 | print("python : " + sys.version.split(" ")[0]) 10 | print("tensorflow: " + tf.__version__) 11 | print("numpy : " + np.version.version) 12 | try: 13 | import scipy as sp 14 | spv = sp.version.version 15 | except: 16 | spv = "NONE" 17 | print("scipy : " + spv) 18 | envkeys = [key for key in os.environ.keys() if key.startswith("TD_")] 19 | if envkeys: 20 | print("-") 21 | maxlen = max(len(key) for key in envkeys) 22 | for key in envkeys: 23 | print(key + (maxlen - len(key)) * " " + ": " + os.environ[key]) 24 | print(80 * "-") 25 | 26 | 27 | # import all tests 28 | from .core import * 29 | from .ops import * 30 | -------------------------------------------------------------------------------- /tests/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import os 5 | import sys 6 | import unittest 7 | 8 | import tensorflow as tf 9 | 10 | 11 | # adjust the path to import tfdeploy 12 | base = os.path.normpath(os.path.join(os.path.abspath(__file__), "../..")) 13 | sys.path.append(base) 14 | import tfdeploy as td 15 | 16 | # setup 17 | td.setup(tf) 18 | 19 | 20 | class TestCase(unittest.TestCase): 21 | 22 | def __init__(self, *args, **kwargs): 23 | super(TestCase, self).__init__(*args, **kwargs) 24 | 25 | self._cache = {} 26 | 27 | def get(self, model, *attrs): 28 | result = tuple() 29 | for attr in attrs: 30 | key = (model, attr) 31 | if key not in self._cache: 32 | tmp = __import__("tests.models." + model, globals(), locals(), [attr]) 33 | self._cache[key] = getattr(tmp, attr) 34 | result += (self._cache[key],) 35 | return result if len(result) > 1 else result[0] 36 | -------------------------------------------------------------------------------- /tests/core.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import numpy as np 5 | from .base import TestCase, td 6 | 7 | 8 | __all__ = ["CoreTestCase"] 9 | 10 | 11 | class CoreTestCase(TestCase): 12 | 13 | def __init__(self, *args, **kwargs): 14 | super(CoreTestCase, self).__init__(*args, **kwargs) 15 | 16 | self.simple_model = td.Model() 17 | y, sess = self.get("simple", "y", "sess") 18 | self.simple_model.add(y, tf_sess=sess) 19 | 20 | def test_tensors(self): 21 | m = self.simple_model 22 | 23 | # model has one root tensor ... 24 | self.assertEqual(len(m.roots), 1) 25 | 26 | # ... which is named "output" and can be retrieved via get 27 | outp = m.get("output") 28 | self.assertIn(outp, m.roots.values()) 29 | 30 | # the input tensor is named "input" 31 | self.assertIsNotNone(m.get("input")) 32 | 33 | def test_ops(self): 34 | m = self.simple_model 35 | op = m.get("output").op 36 | 37 | # the root tensor operator is a softmax op ... 38 | self.assertIsInstance(op, td.Softmax) 39 | 40 | # ... and has one input ... 41 | self.assertEqual(len(op.inputs), 1) 42 | 43 | # ... whose op is an add op 44 | self.assertIsInstance(op.inputs[0].op, td.Add) 45 | 46 | def test_eval(self): 47 | m = self.simple_model 48 | inp, outp, kp = m.get("input", "output", "keep_prob") 49 | 50 | # create an input batch 51 | examples = np.random.rand(1000, 10).astype("float32") 52 | 53 | # first, eval using tf 54 | x, y, keep_prob, sess = self.get("simple", "x", "y", "keep_prob", "sess") 55 | rtf = y.eval(session=sess, feed_dict={x: examples, keep_prob: 1.0}) 56 | 57 | # then, eval using td 58 | rtd = outp.eval({inp: examples, kp: 1.0}) 59 | 60 | # no element in the diff array should be larger than 1e-7 61 | maxdiff = np.max(np.abs(rtf - rtd)) 62 | self.assertLess(maxdiff, 1e-7) 63 | 64 | def test_ensemble_eval(self): 65 | simple_model2 = td.Model() 66 | y2, sess2 = self.get("simple2", "y", "sess") 67 | simple_model2.add(y2, tf_sess=sess2) 68 | 69 | simple_model2.get("input_1").name = "input:0" 70 | simple_model2.get("output_1").name = "output:0" 71 | simple_model2.get("keep_prob_1").name = "keep_prob:0" 72 | 73 | simple_ensemble = td.Ensemble() 74 | simple_ensemble.models = [self.simple_model, simple_model2] 75 | 76 | inp, outp, kp = simple_ensemble.get("input", "output", "keep_prob") 77 | 78 | # create an input batch 79 | examples = np.random.rand(1000, 10).astype("float32") 80 | 81 | # eval both models manually and build the mean 82 | x1, y1, keep_prob1 = self.simple_model.get("input", "output", "keep_prob") 83 | r1 = y1.eval({x1: examples, keep_prob1: 1.0}) 84 | x2, y2, keep_prob2 = simple_model2.get("input", "output", "keep_prob") 85 | r2 = y2.eval({x2: examples, keep_prob2: 1.0}) 86 | rm = np.add(r1, r2) / 2. 87 | 88 | # then, eval the ensemble 89 | re = outp.eval({inp: examples, kp: 1.0}) 90 | 91 | # no element in the diff array should be larger than 1e-7 92 | maxdiff = np.max(np.abs(re - rm)) 93 | self.assertLess(maxdiff, 1e-7) 94 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /tests/models/simple.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import tensorflow as tf 5 | import tfdeploy as td 6 | 7 | 8 | sess = tf.Session() 9 | 10 | x = tf.placeholder(tf.float32, shape=[None, 10], name="input") 11 | keep_prob = tf.placeholder(tf.float32, name="keep_prob") 12 | 13 | W = tf.Variable(tf.truncated_normal([10, 5], stddev=0.05)) 14 | b = tf.Variable(tf.zeros([5])) 15 | 16 | W_drop = tf.nn.dropout(W, keep_prob) 17 | 18 | y = tf.nn.softmax(tf.matmul(x, W_drop) + b, name="output") 19 | 20 | if td._tf_version[:3] < (0, 12, 0): 21 | sess.run(tf.initialize_all_variables()) 22 | else: 23 | sess.run(tf.global_variables_initializer()) 24 | -------------------------------------------------------------------------------- /tests/models/simple2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import tensorflow as tf 5 | import tfdeploy as td 6 | 7 | 8 | sess = tf.Session() 9 | 10 | x = tf.placeholder(tf.float32, shape=[None, 10], name="input") 11 | keep_prob = tf.placeholder(tf.float32, name="keep_prob") 12 | 13 | W = tf.Variable(tf.truncated_normal([10, 5], stddev=0.05)) 14 | b = tf.Variable(tf.zeros([5])) 15 | 16 | W_drop = tf.nn.dropout(W, keep_prob) 17 | 18 | y = tf.nn.softmax(tf.matmul(x, W_drop) + b, name="output") 19 | 20 | if td._tf_version[:3] < (0, 12, 0): 21 | sess.run(tf.initialize_all_variables()) 22 | else: 23 | sess.run(tf.global_variables_initializer()) 24 | -------------------------------------------------------------------------------- /tests/ops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import os 5 | import numpy as np 6 | from .base import TestCase, td 7 | import tensorflow as tf 8 | from tensorflow.python.framework import device 9 | 10 | 11 | __all__ = ["OpsTestCase"] 12 | 13 | 14 | # get device from env 15 | CPU, GPU = range(2) 16 | DEVICE = CPU 17 | if os.environ.get("TD_TEST_GPU", "").lower() in ("1", "yes", "true"): 18 | DEVICE = GPU 19 | DEVICE_ID = "/%s:0" % ["cpu", "gpu"][DEVICE] 20 | 21 | # setup td 22 | td.setup(tf) 23 | 24 | # optimize for scipy depending on env 25 | if os.environ.get("TD_TEST_SCIPY", "").lower() in ("1", "yes", "true"): 26 | td.optimize(td.IMPL_SCIPY) 27 | 28 | 29 | class OpsTestCase(TestCase): 30 | 31 | def __init__(self, *args, **kwargs): 32 | super(OpsTestCase, self).__init__(*args, **kwargs) 33 | 34 | # add the device to the "_device_function_stack" of the default graph 35 | dev = device.merge_device(DEVICE_ID) 36 | tf.get_default_graph()._device_function_stack.append(dev) 37 | 38 | # create a tf session 39 | self.sess = tf.Session() 40 | 41 | self.ndigits = 7 42 | 43 | def check(self, t, comp=None, ndigits=None, stats=False, abs=False, debug=False): 44 | if td._tf_version[:3] < (0, 12, 0): 45 | self.sess.run(tf.initialize_all_variables()) 46 | else: 47 | self.sess.run(tf.global_variables_initializer()) 48 | 49 | if not isinstance(t, tuple): 50 | t = (t,) 51 | 52 | for _t in t: 53 | rtf = _t.eval(session=self.sess) 54 | rtd = td.Tensor(_t, self.sess).eval() 55 | 56 | if debug: 57 | import pdb; pdb.set_trace() 58 | 59 | if ndigits is None: 60 | ndigits = self.ndigits 61 | 62 | if hasattr(comp, "__call__"): 63 | return comp(rtf, rtd) 64 | 65 | if isinstance(rtf, np.ndarray): 66 | self.assertEqual(rtf.dtype, rtd.dtype) 67 | if abs: 68 | rtf = np.abs(rtf) 69 | rtd = np.abs(rtd) 70 | if not stats: 71 | self.assertTrue(np.allclose(rtf, rtd, atol=0.1**ndigits)) 72 | else: 73 | self.assertEqual(round(rtf.sum(), ndigits), round(rtd.sum(), ndigits)) 74 | self.assertEqual(round(rtf.mean(), ndigits), round(rtd.mean(), ndigits)) 75 | elif isinstance(rtf, float): 76 | self.assertEqual(round(rtf, ndigits), round(rtd, ndigits)) 77 | else: 78 | self.assertEqual(rtf, rtd) 79 | 80 | def random(self, *shapes, **kwargs): 81 | if all(isinstance(i, int) for i in shapes): 82 | if kwargs.get("complex", False): 83 | return (self.random(*shapes) + 1j * self.random(*shapes)).astype(np.complex64) 84 | else: 85 | return np.random.rand(*shapes) 86 | else: 87 | return tuple(self.random(*shape) for shape in shapes) 88 | 89 | def test_ops_have_tests(self): 90 | tests = [attr for attr in dir(self) if attr.startswith("test_")] 91 | for type in td.OperationRegister.classes: 92 | self.assertIn("test_" + type, tests) 93 | 94 | 95 | # 96 | # sequences 97 | # 98 | 99 | def test_LinSpace(self): 100 | t = tf.linspace(0., 10., 15) 101 | self.check(t) 102 | 103 | def test_Range(self): 104 | t = tf.range(1, 10, 2) 105 | self.check(t) 106 | 107 | 108 | # 109 | # random tensors 110 | # 111 | 112 | def test_RandomStandardNormal(self): 113 | t = tf.random_normal((40, 30), dtype="float32") 114 | # compare only dtype 115 | def comp(rtf, rtd): 116 | self.assertEqual(rtf.dtype, rtd.dtype) 117 | self.check(t, comp=comp) 118 | 119 | def test_TruncatedNormal(self): 120 | t = tf.truncated_normal((40, 300), dtype="float32") 121 | # compare dtype and 2-sigma truncation 122 | def comp(rtf, rtd): 123 | self.assertEqual(rtf.dtype, rtd.dtype) 124 | self.assertLessEqual(np.max(np.abs(rtd)), 2) 125 | self.check(t, comp=comp) 126 | 127 | def test_RandomUniform(self): 128 | t = tf.random_uniform((50, 80), -2, 3, dtype="float32") 129 | # compare only min, max and dtype 130 | def comp(rtf, rtd): 131 | self.assertLess(np.max(rtd), 3) 132 | self.assertGreaterEqual(np.min(rtd), -2) 133 | self.assertEqual(rtd.dtype, np.float32) 134 | self.check(t, comp=comp) 135 | 136 | def test_RandomUniformInt(self): 137 | # no python interface yet, but might be something like 138 | # t = tf.random_uniform_int((50, 80), -2, 3) 139 | # # compare only min and max 140 | # def comp(rtf, rtd): 141 | # self.assertLess(np.max(rtd), 3) 142 | # self.assertGreaterEqual(np.min(rtd), -2) 143 | # self.check(t, comp=comp) 144 | pass 145 | 146 | def test_RandomShuffle(self): 147 | t = tf.random_shuffle(self.random(10, 4)) 148 | # compare only sum of first axis 149 | def comp(rtf, rtd): 150 | self.assertTrue(np.allclose(np.sum(rtf, axis=0), np.sum(rtd, axis=0))) 151 | self.check(t, comp=comp) 152 | 153 | def test_random_crop(self): 154 | t = tf.random_crop(self.random(3, 4, 8), [1, 2, 4]) 155 | # compare only shape 156 | def comp(rtf, rtd): 157 | self.assertEqual(rtf.shape, rtd.shape) 158 | self.check(t, comp=comp) 159 | 160 | 161 | # 162 | # casting 163 | # 164 | 165 | def test_Cast(self): 166 | t = tf.cast(self.random(3, 4).astype("float32"), tf.float64) 167 | self.check(t) 168 | 169 | def test_StringToNumber(self): 170 | t = tf.string_to_number(list("0123456789")) 171 | self.check(t) 172 | 173 | 174 | # 175 | # shapes and shaping 176 | # 177 | 178 | def test_Shape(self): 179 | t = tf.shape(self.random(3, 4, 5)) 180 | self.check(t) 181 | 182 | def test_Size(self): 183 | t = tf.size(self.random(3, 4)) 184 | self.check(t) 185 | 186 | def test_Rank(self): 187 | t = tf.rank(self.random(3, 3)) 188 | self.check(t) 189 | 190 | def test_Reshape(self): 191 | t = tf.reshape(self.random(3, 4, 5), (2, -1)) 192 | self.check(t) 193 | 194 | def test_Squeeze(self): 195 | t = tf.squeeze(self.random(1, 2, 1, 3, 3, 1)) 196 | self.check(t) 197 | 198 | def test_ExpandDims(self): 199 | t = tf.expand_dims(self.random(2, 3, 3, 4), -2) 200 | self.check(t) 201 | 202 | 203 | # 204 | # slicing and joining 205 | # 206 | 207 | def test_Slice(self): 208 | t = tf.slice(np.arange(3*4*8*6).reshape(3, 4, 8, 6), [1, 1, 2, 2], 4 * [2]) 209 | self.check(t) 210 | 211 | def test_Split(self): 212 | for t in tf.split(self.random(8, 50, 10, 2), 5, 2): 213 | self.check(t) 214 | 215 | def test_SplitV(self): 216 | for t in tf.split(self.random(8, 50, 10, 2), [10, 30, 5, 1, 4], 1): 217 | self.check(t) 218 | 219 | def test_Tile(self): 220 | t = tf.tile(self.random(3, 4, 5), [1, 2, 3]) 221 | self.check(t) 222 | 223 | def test_Pad(self): 224 | t = tf.pad(self.random(3, 8, 5), [[1, 2], [2, 1], [1, 0]]) 225 | self.check(t) 226 | 227 | def test_ConcatV2(self): 228 | aaa = self.random((3, 4, 5), (3, 4, 5)) 229 | t = tf.concat(list(self.random((3, 4, 5), (3, 4, 5))), 2) 230 | self.check(t) 231 | 232 | def test_Pack(self): 233 | pass 234 | 235 | def test_Unpack(self): 236 | pass 237 | 238 | def test_Stack(self): 239 | t = tf.stack(list(self.random((3, 4, 5), (3, 4, 5))), 2) 240 | self.check(t) 241 | 242 | def test_Unstack(self): 243 | for t in tf.unstack(self.random(6, 4, 5), axis=1): 244 | self.check(t) 245 | 246 | def test_ReverseSequence(self): 247 | x = self.random(3, 4, 10) 248 | t = tf.reverse_sequence(x, [5, 0, 0, 8], seq_dim=2, batch_dim=1) 249 | self.check(t) 250 | 251 | def test_ReverseV2(self): 252 | t = tf.reverse(self.random(3, 4, 10), [1, 2]) 253 | self.check(t) 254 | 255 | def test_Transpose(self): 256 | t = tf.transpose(self.random(4, 3, 5), perm=[2, 0, 1]) 257 | self.check(t) 258 | 259 | 260 | # 261 | # arithmetic math ops 262 | # 263 | 264 | def test_Add(self): 265 | t = tf.add(*self.random((3, 4), (3, 4))) 266 | self.check(t) 267 | 268 | def test_Subtract(self): 269 | t = tf.subtract(*self.random((3, 4), (3, 4))) 270 | self.check(t) 271 | 272 | test_Sub = test_Subtract 273 | 274 | def test_Multiply(self): 275 | t = tf.multiply(*self.random((3, 5), (3, 5))) 276 | self.check(t) 277 | 278 | test_Mul = test_Multiply 279 | 280 | def test_scalar_mul(self): 281 | t = tf.scalar_mul(1, tf.Variable(self.random(3, 5))) 282 | self.check(t) 283 | 284 | def test_Div(self): 285 | t = tf.div(*self.random((3, 5), (3, 5))) 286 | self.check(t) 287 | 288 | def test_RealDiv(self): 289 | t = tf.div(*self.random((3, 5), (3, 5))) 290 | self.check(t) 291 | 292 | def test_TrueDiv(self): 293 | t = tf.truediv(*self.random((3, 5), (3, 5))) 294 | self.check(t) 295 | 296 | def test_FloorDiv(self): 297 | t = tf.floordiv(*self.random((3, 5), (3, 5))) 298 | self.check(t) 299 | 300 | def test_Mod(self): 301 | t = tf.mod(*self.random((4, 3), (4, 3))) 302 | self.check(t) 303 | 304 | def test_FloorMod(self): 305 | t = tf.floormod(*self.random((4, 3), (4, 3))) 306 | self.check(t) 307 | 308 | def test_Cross(self): 309 | t = tf.cross(*self.random((4, 3), (4, 3))) 310 | self.check(t) 311 | 312 | 313 | # 314 | # basic math ops 315 | # 316 | 317 | def test_AddN(self): 318 | t = tf.add_n(self.random((4, 3), (4, 3))) 319 | self.check(t) 320 | 321 | def test_Abs(self): 322 | t = tf.abs(-self.random(4, 3)) 323 | self.check(t) 324 | 325 | def test_Negative(self): 326 | t = tf.negative(self.random(4, 3)) 327 | self.check(t) 328 | 329 | test_Neg = test_Negative 330 | 331 | def test_Sign(self): 332 | t = tf.sign(self.random(4, 3) - 0.5) 333 | self.check(t) 334 | 335 | def test_Inv(self): 336 | if td._tf_version[:2] <= (0, 11): 337 | t = tf.inv(self.random(4, 3)) 338 | self.check(t) 339 | 340 | def test_Square(self): 341 | t = tf.square(self.random(4, 3)) 342 | self.check(t) 343 | 344 | def test_Round(self): 345 | t = tf.round(self.random(4, 3) - 0.5) 346 | self.check(t) 347 | 348 | def test_Sqrt(self): 349 | t = tf.sqrt(self.random(4, 3)) 350 | self.check(t) 351 | 352 | def test_Rsqrt(self): 353 | t = tf.rsqrt(self.random(4, 3)) 354 | self.check(t) 355 | 356 | def test_Pow(self): 357 | t = tf.pow(*self.random((4, 3), (4, 3))) 358 | self.check(t) 359 | 360 | def test_Exp(self): 361 | t = tf.exp(self.random(4, 3)) 362 | self.check(t) 363 | 364 | def test_Log(self): 365 | t = tf.log(self.random(4, 3)) 366 | self.check(t) 367 | 368 | def test_Ceil(self): 369 | t = tf.ceil(self.random(4, 3) - 0.5) 370 | self.check(t) 371 | 372 | def test_Floor(self): 373 | t = tf.floor(self.random(4, 3) - 0.5) 374 | self.check(t) 375 | 376 | def test_Maximum(self): 377 | t = tf.maximum(*self.random((4, 3), (4, 3))) 378 | self.check(t) 379 | 380 | def test_Minimum(self): 381 | t = tf.minimum(*self.random((4, 3), (4, 3))) 382 | self.check(t) 383 | 384 | def test_Cos(self): 385 | t = tf.cos(self.random(4, 3)) 386 | self.check(t) 387 | 388 | def test_Sin(self): 389 | t = tf.sin(self.random(4, 3)) 390 | self.check(t) 391 | 392 | def test_lbeta(self): 393 | t = tf.lbeta(self.random(4, 3)) 394 | self.check(t) 395 | 396 | def test_Tan(self): 397 | t = tf.tan(self.random(4, 3)) 398 | self.check(t) 399 | 400 | def test_Acos(self): 401 | t = tf.acos(self.random(4, 3)) 402 | self.check(t) 403 | 404 | def test_Asin(self): 405 | t = tf.asin(self.random(4, 3)) 406 | self.check(t) 407 | 408 | def test_Atan(self): 409 | t = tf.atan(self.random(4, 3)) 410 | self.check(t) 411 | 412 | def test_Lgamma(self): 413 | t = tf.lgamma(self.random(4, 3)) 414 | self.check(t) 415 | 416 | def test_Digamma(self): 417 | t = tf.digamma(self.random(4, 3)) 418 | self.check(t) 419 | 420 | def test_Erf(self): 421 | t = tf.erf(self.random(4, 3)) 422 | self.check(t) 423 | 424 | def test_Erfc(self): 425 | t = tf.erfc(self.random(4, 3)) 426 | self.check(t) 427 | 428 | def test_SquaredDifference(self): 429 | t = tf.squared_difference(*self.random((3, 4, 4), (3, 4, 4))) 430 | self.check(t) 431 | 432 | def test_Igamma(self): 433 | t = tf.igamma(*self.random((3, 3), (3, 3))) 434 | self.check(t) 435 | 436 | def test_Igammac(self): 437 | t = tf.igammac(*self.random((3, 3), (3, 3))) 438 | self.check(t) 439 | 440 | def test_Zeta(self): 441 | t = tf.zeta(self.random(3, 3) + 2, self.random(3, 3)) 442 | self.check(t) 443 | 444 | def test_Polygamma(self): 445 | t = tf.polygamma(np.array([1, 2, 3]).astype("float32"), np.array([4, 5, 6]).astype("float32")) 446 | self.check(t) 447 | 448 | def test_Betainc(self): 449 | t = tf.betainc(*self.random((3, 3), (3, 3), (3, 3))) 450 | self.check(t) 451 | 452 | 453 | # 454 | # matrix math ops 455 | # 456 | 457 | def test_Diag(self): 458 | t = tf.diag(self.random(3, 3)) 459 | self.check(t) 460 | 461 | def test_DiagPart(self): 462 | t = tf.diag_part(self.random(3, 3)) 463 | self.check(t) 464 | 465 | def test_MatrixDiagPart(self): 466 | if td._tf_version[:2] >= (0, 12): 467 | t = tf.matrix_diag_part(self.random(3, 4, 4, 5)) 468 | self.check(t) 469 | 470 | def test_trace(self): 471 | t = tf.trace(self.random(3, 3)) 472 | self.check(t) 473 | 474 | def test_MatMul(self): 475 | t = tf.matmul(*self.random((4, 3), (3, 5)), transpose_a=False, transpose_b=False) 476 | self.check(t) 477 | t = tf.matmul(*self.random((3, 4), (3, 5)), transpose_a=True, transpose_b=False) 478 | self.check(t) 479 | t = tf.matmul(*self.random((4, 3), (5, 3)), transpose_a=False, transpose_b=True) 480 | self.check(t) 481 | t = tf.matmul(*self.random((3, 4), (5, 3)), transpose_a=True, transpose_b=True) 482 | self.check(t) 483 | 484 | # def test_BatchMatMul(self): 485 | # t = tf.batch_matmul(*self.random((2, 4, 4, 3), (2, 4, 3, 5)), adj_x=False, adj_y=False) 486 | # self.check(t) 487 | # t = tf.batch_matmul(*self.random((2, 4, 3, 4), (2, 4, 3, 5)), adj_x=True, adj_y=False) 488 | # self.check(t) 489 | # t = tf.batch_matmul(*self.random((2, 4, 4, 3), (2, 4, 5, 3)), adj_x=False, adj_y=True) 490 | # self.check(t) 491 | # t = tf.batch_matmul(*self.random((2, 4, 3, 4), (2, 4, 5, 3)), adj_x=True, adj_y=True) 492 | # self.check(t) 493 | 494 | def test_MatrixDeterminant(self): 495 | t = tf.matrix_determinant(self.random(2, 3, 4, 3, 3)) 496 | self.check(t) 497 | 498 | def test_MatrixInverse(self): 499 | t = tf.matrix_inverse(self.random(2, 3, 4, 3, 3), adjoint=False) 500 | self.check(t) 501 | t = tf.matrix_inverse(self.random(2, 3, 4, 3, 3), adjoint=True) 502 | self.check(t) 503 | 504 | def test_Cholesky(self): 505 | t = tf.cholesky(np.array(3 * [8, 3, 3, 8]).reshape(3, 2, 2).astype("float32")) 506 | self.check(t) 507 | 508 | def test_MatrixSolve(self): 509 | t = tf.matrix_solve(*self.random((2, 3, 3, 3), (2, 3, 3, 1)), adjoint=False) 510 | self.check(t) 511 | t = tf.matrix_solve(*self.random((2, 3, 3, 3), (2, 3, 3, 1)), adjoint=True) 512 | self.check(t) 513 | 514 | def test_MatrixTriangularSolve(self): 515 | t = tf.matrix_triangular_solve(*self.random((2, 3, 3, 3), (2, 3, 3, 1)), adjoint=False, lower=False) 516 | self.check(t) 517 | t = tf.matrix_triangular_solve(*self.random((2, 3, 3, 3), (2, 3, 3, 1)), adjoint=True, lower=False) 518 | self.check(t) 519 | t = tf.matrix_triangular_solve(*self.random((2, 3, 3, 3), (2, 3, 3, 1)), adjoint=False, lower=True) 520 | self.check(t) 521 | 522 | def test_MatrixSolveLs(self): 523 | t = tf.matrix_solve_ls(*self.random((2, 3, 3, 3), (2, 3, 3, 1))) 524 | self.check(t) 525 | 526 | def test_SelfAdjointEig(self): 527 | # legacy support 528 | pass 529 | 530 | def test_SelfAdjointEigV2(self): 531 | t = tf.self_adjoint_eig(np.array(3 * [3, 2, 2, 1]).reshape(3, 2, 2).astype("float32")) 532 | # the order of eigen vectors and values may differ between tf and np, so only compare sum 533 | # and mean 534 | # also, different numerical algorithms are used, so account for difference in precision by 535 | # comparing numbers with 4 digits 536 | self.check(t, ndigits=4, stats=True, abs=True) 537 | 538 | def test_Svd(self): 539 | t = tf.svd(self.random(4, 5, 3, 2).astype("float32")) 540 | self.check(t, ndigits=4, abs=True) 541 | 542 | 543 | # 544 | # complex number ops 545 | # 546 | 547 | def test_Complex(self): 548 | t = tf.complex(*self.random((3, 4), (3, 4))) 549 | self.check(t) 550 | 551 | def test_Conj(self): 552 | t = tf.conj(self.random(3, 4, complex=True)) 553 | self.check(t) 554 | 555 | def test_Imag(self): 556 | t = tf.imag(tf.Variable(self.random(3, 4, complex=True))) 557 | self.check(t) 558 | 559 | def test_Real(self): 560 | t = tf.real(tf.Variable(self.random(3, 4, complex=True))) 561 | self.check(t) 562 | 563 | 564 | # 565 | # Fourier transform ops 566 | # 567 | 568 | def test_FFT2D(self): 569 | # only defined for gpu 570 | if DEVICE == GPU: 571 | t = tf.fft2d(self.random(3, 4, complex=True)) 572 | self.check(t) 573 | 574 | def test_IFFT2D(self): 575 | # only defined for gpu 576 | if DEVICE == GPU: 577 | t = tf.ifft2d(self.random(3, 4, complex=True)) 578 | self.check(t) 579 | 580 | def test_FFT3D(self): 581 | # only defined for gpu 582 | if DEVICE == GPU: 583 | t = tf.fft3d(self.random(3, 4, 5, complex=True)) 584 | self.check(t) 585 | 586 | def test_IFFT3D(self): 587 | # only defined for gpu 588 | if DEVICE == GPU: 589 | t = tf.ifft3d(self.random(3, 4, 5, complex=True)) 590 | self.check(t) 591 | 592 | 593 | # 594 | # reduction 595 | # 596 | 597 | def test_Sum(self): 598 | t = tf.reduce_sum(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True) 599 | self.check(t) 600 | t = tf.reduce_sum(self.random(3, 4, 5), reduction_indices=(0, 1), keep_dims=True) 601 | self.check(t) 602 | t = tf.reduce_sum(self.random(3, 4, 5), reduction_indices=0, keep_dims=True) 603 | self.check(t) 604 | if td._tf_version[:3] >= (0, 12, 0): 605 | t = tf.reduce_sum(self.random(3, 4, 5), axis=[0, 1], keep_dims=True) 606 | self.check(t) 607 | t = tf.reduce_sum(self.random(3, 4, 5), axis=(0, 1), keep_dims=True) 608 | self.check(t) 609 | t = tf.reduce_sum(self.random(3, 4, 5), axis=0, keep_dims=True) 610 | self.check(t) 611 | 612 | def test_Prod(self): 613 | t = tf.reduce_prod(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True) 614 | self.check(t) 615 | if td._tf_version[:3] >= (0, 12, 0): 616 | t = tf.reduce_prod(self.random(3, 4, 5), axis=[0, 1], keep_dims=True) 617 | self.check(t) 618 | 619 | def test_Min(self): 620 | t = tf.reduce_min(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True) 621 | self.check(t) 622 | if td._tf_version[:3] >= (0, 12, 0): 623 | t = tf.reduce_min(self.random(3, 4, 5), axis=[0, 1], keep_dims=True) 624 | self.check(t) 625 | 626 | def test_Max(self): 627 | t = tf.reduce_max(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True) 628 | self.check(t) 629 | if td._tf_version[:3] >= (0, 12, 0): 630 | t = tf.reduce_max(self.random(3, 4, 5), axis=[0, 1], keep_dims=True) 631 | self.check(t) 632 | 633 | def test_Mean(self): 634 | t = tf.reduce_mean(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True) 635 | self.check(t) 636 | if td._tf_version[:3] >= (0, 12, 0): 637 | t = tf.reduce_mean(self.random(3, 4, 5), axis=[0, 1], keep_dims=True) 638 | self.check(t) 639 | 640 | def test_All(self): 641 | t = tf.reduce_all(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True) 642 | self.check(t) 643 | if td._tf_version[:3] >= (0, 12, 0): 644 | t = tf.reduce_all(self.random(3, 4, 5), axis=[0, 1], keep_dims=True) 645 | self.check(t) 646 | 647 | def test_Any(self): 648 | t = tf.reduce_any(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True) 649 | self.check(t) 650 | if td._tf_version[:3] >= (0, 12, 0): 651 | t = tf.reduce_any(self.random(3, 4, 5), axis=[0, 1], keep_dims=True) 652 | self.check(t) 653 | 654 | 655 | # 656 | # segmentation 657 | # 658 | 659 | def test_SegmentSum(self): 660 | t = tf.segment_sum(self.random(4, 2, 3), np.array([0, 1, 1, 2])) 661 | self.check(t) 662 | 663 | def test_SegmentProd(self): 664 | t = tf.segment_prod(self.random(4, 2, 3), np.array([0, 1, 1, 2])) 665 | self.check(t) 666 | 667 | def test_SegmentMin(self): 668 | t = tf.segment_min(self.random(4, 2, 3), np.array([0, 1, 1, 2])) 669 | self.check(t) 670 | 671 | def test_SegmentMax(self): 672 | t = tf.segment_max(self.random(4, 2, 3), np.array([0, 1, 1, 2])) 673 | self.check(t) 674 | 675 | def test_SegmentMean(self): 676 | t = tf.segment_mean(self.random(4, 2, 3), np.array([0, 1, 1, 2])) 677 | self.check(t) 678 | 679 | def test_UnsortedSegmentSum(self): 680 | t = tf.unsorted_segment_sum(self.random(4, 2, 3), np.array([0, 2, 2, 1]), 3) 681 | self.check(t) 682 | 683 | def test_SparseSegmentSum(self): 684 | t = tf.sparse_segment_sum(self.random(4, 3, 2), [0, 2, 3], [0, 1, 1]) 685 | self.check(t) 686 | 687 | def test_SparseSegmentMean(self): 688 | t = tf.sparse_segment_mean(self.random(4, 3, 2), [0, 2, 3], [0, 1, 1]) 689 | self.check(t) 690 | 691 | def test_SparseSegmentSqrtN(self): 692 | t = tf.sparse_segment_sqrt_n(self.random(4, 3, 2), [0, 2, 3], [0, 1, 1]) 693 | self.check(t) 694 | 695 | 696 | # 697 | # sequence comparison and indexing 698 | # 699 | 700 | def test_ArgMin(self): 701 | t = tf.argmin(self.random(3, 4, 2), 1) 702 | self.check(t) 703 | 704 | def test_ArgMax(self): 705 | t = tf.argmax(self.random(3, 4, 2), 1) 706 | self.check(t) 707 | 708 | def test_ListDiff(self): 709 | if td._tf_version[:2] <= (0, 11): 710 | l = np.random.randint(0, 5, 100) 711 | t1, t2 = tf.listdiff(l, l[::-2]) 712 | self.check(t1) 713 | self.check(t2) 714 | 715 | def test_Where(self): 716 | t = tf.where([[True, False], [False, False], [True, False]]) 717 | self.check(t) 718 | 719 | def test_Unique(self): 720 | t = tf.unique([9, 3, 5, 7, 3, 9, 9], out_idx=tf.int32) 721 | self.check(t) 722 | 723 | def test_InvertPermutation(self): 724 | t = tf.invert_permutation(np.random.permutation(10)) 725 | self.check(t) 726 | 727 | 728 | # 729 | # control flow ops 730 | # 731 | 732 | def test_Identity(self): 733 | t = tf.identity(self.random(3, 4)) 734 | self.check(t) 735 | 736 | 737 | # 738 | # NN activation ops 739 | # 740 | 741 | def test_Relu(self): 742 | t = tf.nn.relu(self.random(100) - 0.5) 743 | self.check(t) 744 | 745 | def test_Relu6(self): 746 | t = tf.nn.relu6((self.random(100) - 0.5) * 20) 747 | self.check(t) 748 | 749 | def test_Elu(self): 750 | t = tf.nn.elu(self.random(100) - 0.5) 751 | self.check(t) 752 | 753 | def test_Softplus(self): 754 | t = tf.nn.softplus(self.random(100) - 0.5) 755 | self.check(t) 756 | 757 | def test_Softsign(self): 758 | t = tf.nn.softsign(self.random(100) - 0.5) 759 | self.check(t) 760 | 761 | def test_BiasAdd(self): 762 | t = tf.nn.bias_add(*self.random((4, 5), (5,))) 763 | self.check(t) 764 | 765 | def test_Sigmoid(self): 766 | t = tf.nn.sigmoid(self.random(3, 4)) 767 | self.check(t) 768 | 769 | def test_Tanh(self): 770 | t = tf.nn.tanh(self.random(3, 4)) 771 | self.check(t) 772 | 773 | def test_Softmax(self): 774 | t = tf.nn.softmax(self.random(10, 5)) 775 | self.check(t) 776 | 777 | 778 | # 779 | # NN convolution ops 780 | # 781 | 782 | def test_Conv1D(self): 783 | t = tf.nn.conv1d(np.arange(8000).reshape(1000, 2, 4).astype("float32"), 784 | np.ones(80).reshape(2, 4, 10).astype("float32"), 785 | 1, "SAME") 786 | self.check(t) 787 | t = tf.nn.conv1d(np.arange(8000).reshape(1000, 2, 4).astype("float32"), 788 | np.ones(80).reshape(2, 4, 10).astype("float32"), 789 | 2, "VALID") 790 | self.check(t) 791 | 792 | def test_Conv2D(self): 793 | t = tf.nn.conv2d(np.arange(24000).reshape(1000, 2, 3, 4).astype("float32"), 794 | np.ones(160).reshape(2, 2, 4, 10).astype("float32"), 795 | [1, 2, 3, 1], "SAME") 796 | self.check(t) 797 | t = tf.nn.conv2d(np.arange(24000).reshape(1000, 2, 3, 4).astype("float32"), 798 | np.ones(160).reshape(2, 2, 4, 10).astype("float32"), 799 | [1, 2, 5, 1], "VALID") 800 | self.check(t) 801 | 802 | def test_Conv3D(self): 803 | t = tf.nn.conv3d(np.arange(72000).reshape(1000, 2, 3, 3, 4).astype("float32"), 804 | np.ones(320).reshape(2, 2, 2, 4, 10).astype("float32"), 805 | [1, 1, 1, 1, 1], "SAME") 806 | self.check(t) 807 | t = tf.nn.conv3d(np.arange(72000).reshape(1000, 2, 3, 3, 4).astype("float32"), 808 | np.ones(320).reshape(2, 2, 2, 4, 10).astype("float32"), 809 | [1, 1, 1, 1, 1], "VALID") 810 | self.check(t) 811 | 812 | 813 | # 814 | # pooling ops 815 | # 816 | 817 | def test_AvgPool(self): 818 | t = tf.nn.avg_pool(np.arange(16).reshape(1, 4, 4, 1).astype("float32"), 819 | [1, 2, 2, 1], [1, 1, 1, 1], "SAME") 820 | self.check(t) 821 | t = tf.nn.avg_pool(np.arange(16).reshape(1, 4, 4, 1).astype("float32"), 822 | [1, 2, 2, 1], [1, 2, 2, 1], "VALID") 823 | self.check(t) 824 | 825 | def test_MaxPool(self): 826 | t = tf.nn.max_pool(np.arange(16).reshape(1, 4, 4, 1).astype("float32"), 827 | [1, 2, 2, 1], [1, 1, 1, 1], "SAME") 828 | self.check(t) 829 | t = tf.nn.max_pool(np.arange(16).reshape(1, 4, 4, 1).astype("float32"), 830 | [1, 2, 2, 1], [1, 2, 2, 1], "VALID") 831 | self.check(t) 832 | t = tf.nn.max_pool(np.arange(64).reshape(2, 4, 8, 1).astype("float32"), 833 | [1, 2, 2, 1], [1, 3, 2, 1], "VALID") 834 | self.check(t) 835 | 836 | def test_AvgPool3D(self): 837 | t = tf.nn.avg_pool3d(np.arange(64).reshape(1, 4, 4, 4, 1).astype("float32"), 838 | [1, 2, 2, 2, 1], [1, 1, 1, 1, 1], "SAME") 839 | self.check(t) 840 | t = tf.nn.avg_pool3d(np.arange(48).reshape(1, 4, 4, 3, 1).astype("float32"), 841 | [1, 2, 2, 1, 1], [1, 2, 2, 1, 1], "VALID") 842 | self.check(t) 843 | 844 | def test_MaxPool3D(self): 845 | t = tf.nn.max_pool3d(np.arange(64).reshape(1, 4, 4, 4, 1).astype("float32"), 846 | [1, 2, 2, 2, 1], [1, 1, 1, 1, 1], "SAME") 847 | self.check(t) 848 | t = tf.nn.max_pool3d(np.arange(48).reshape(1, 4, 4, 3, 1).astype("float32"), 849 | [1, 2, 2, 1, 1], [1, 2, 2, 1, 1], "VALID") 850 | self.check(t) 851 | -------------------------------------------------------------------------------- /tests/perf/create_plots.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Plotting script that plots all kinds of distributions using the json file created by 5 | measure_runtimes.py. 6 | """ 7 | 8 | import os 9 | import json 10 | import itertools 11 | import collections 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | from measure_runtimes import specs, output_file 16 | 17 | 18 | # 19 | # constants 20 | # 21 | 22 | plotdir = "data" 23 | 24 | 25 | # 26 | # plot function 27 | # 28 | 29 | def plot(xkey, coords, data): 30 | # determine the plotfile path 31 | labels = sorted([(xkey, 0)] + coords.items(), key=lambda tpl: specs.keys().index(tpl[0])) 32 | title = ", ".join("%s: %s" % tpl for tpl in labels if tpl[0] != "examples" and tpl[0] != xkey) 33 | labels = [(key[0], value) for key, value in labels] 34 | plotfile = "runtime_" + xkey + "_" + "_".join("%s%s" % tpl for tpl in labels) + ".png" 35 | plotfile = os.path.join(plotdir, plotfile) 36 | 37 | # get x and y data to plot 38 | x = [d[xkey] for d in data] 39 | ykeys = ["tf_cpu", "tf_gpu", "td"] 40 | y = {ykey: [d["times"][ykey]["mean"] for d in data] for ykey in ykeys} 41 | markers = dict(zip(ykeys, ("s", "o", "D"))) 42 | 43 | # do the plot 44 | fig, axes = plt.subplots() 45 | for ykey, _y in y.items(): 46 | axes.plot(x, _y, "-" + markers[ykey], label=ykey) 47 | 48 | axes.set_xlabel(xkey) 49 | axes.set_ylabel("time per batch [s]") 50 | axes.set_title(title) 51 | if xkey != "units": 52 | axes.set_xscale("log") 53 | axes.legend(loc="best") 54 | 55 | fig.savefig(plotfile) 56 | fig.clf() 57 | 58 | 59 | # 60 | # data filter function 61 | # 62 | 63 | def filter_data(data, **kwargs): 64 | return [d for d in data if all(d[key] == value for key, value in kwargs.items())] 65 | 66 | 67 | # 68 | # main and entry hook 69 | # 70 | 71 | def main(): 72 | # check if the output file exists 73 | if not os.path.exists(output_file): 74 | IOError("output file '%s' does not exist, run run_tests.py to generate it" % output_file) 75 | 76 | # read the data 77 | with open(output_file, "r") as f: 78 | data = json.load(f) 79 | 80 | # prepare the plot dir 81 | if not os.path.exists(plotdir): 82 | os.mkdir(plotdir) 83 | 84 | # loop through all keys in specs, each of them will be used for the x-axis 85 | for xkey in specs: 86 | # do not plot examples on the x-axis 87 | if xkey == "examples": 88 | continue 89 | 90 | # determine the remaining keys, do the combinatorics 91 | keys = [key for key in specs if key != xkey] 92 | for combi in itertools.product(*[specs[key] for key in keys]): 93 | coord = collections.OrderedDict(zip(keys, combi)) 94 | plot_data = filter_data(data, **coord) 95 | plot_data.sort(key=lambda d: d[key]) 96 | plot(xkey, coord, plot_data) 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /tests/perf/measure_runtimes.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Runs a performance test for tfdeploy, tensorflow@CPU and tensorflow@GPU with different data and 5 | network dimensions. The results are written to a json file "times.json" in the current directory. 6 | """ 7 | 8 | import os 9 | import time 10 | import json 11 | import uuid 12 | import itertools 13 | import collections 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | import tfdeploy as td 18 | 19 | 20 | # 21 | # test specs and constants 22 | # 23 | 24 | output_file = "times.json" 25 | 26 | specs = collections.OrderedDict() 27 | specs["features"] = [10, 20, 50, 100, 200, 500, 1000] 28 | specs["examples"] = [100000] 29 | specs["batchsize"] = [1000, 100, 10] 30 | specs["layers"] = [1, 2, 5, 10] 31 | specs["units"] = [10, 20, 50, 100, 200, 500] 32 | 33 | 34 | # 35 | # Data class that contains cached random numbers and handles batch iteration 36 | # 37 | 38 | # static instance 39 | data = None 40 | 41 | # class definition 42 | class Data(object): 43 | 44 | def __init__(self, max_features, max_examples): 45 | super(Data, self).__init__() 46 | 47 | self._max_features = max_features 48 | self._max_examples = max_examples 49 | 50 | self._data = np.random.rand(max_examples, max_features).astype(np.float32) 51 | 52 | self._features = None 53 | self._examples = None 54 | self._batchsize = None 55 | 56 | def prepare(self, features, examples, batchsize): 57 | if features > self._max_features: 58 | raise ValueError("features must be lower than max_features (%d)" % self._max_features) 59 | 60 | if examples > self._max_examples: 61 | raise ValueError("examples must be lower than max_examples (%d)" % self._max_examples) 62 | 63 | if examples % batchsize != 0: 64 | raise ValueError("batchsize must be a divider of examples") 65 | 66 | self._features = features 67 | self._examples = examples 68 | self._batchsize = batchsize 69 | 70 | def __iter__(self): 71 | for i in range(self._examples / self._batchsize): 72 | yield self._data[(i*self._batchsize):((i+1)*self._batchsize), :self._features] 73 | 74 | 75 | # 76 | # model generation helpers 77 | # 78 | 79 | def create_tf_model(features, layers, units, device, input_name, output_name): 80 | with tf.device(device): 81 | x = tf.placeholder(tf.float32, shape=[None, features], name=input_name) 82 | y = x 83 | for i in range(layers): 84 | W = tf.Variable(tf.random_normal([features if y == x else units, units])) 85 | b = tf.Variable(tf.zeros([units])) 86 | y = tf.tanh(tf.matmul(y, W) + b, name=output_name if i == layers - 1 else None) 87 | return x, y 88 | 89 | 90 | def create_models(features, layers, units): 91 | postfix = str(uuid.uuid4())[:8] 92 | input_name = "input_" + postfix 93 | output_name = "output_" + postfix 94 | 95 | tf_cpu_x, tf_cpu_y = create_tf_model(features, layers, units, "/cpu:0", input_name, output_name) 96 | tf_gpu_x, tf_gpu_y = create_tf_model(features, layers, units, "/gpu:0", input_name, output_name) 97 | 98 | tf_sess = tf.Session() 99 | tf_sess.run(tf.initialize_all_variables()) 100 | 101 | td_model = td.Model() 102 | td_model.add(tf_cpu_y, tf_sess) 103 | td_x, td_y = td_model.get(input_name, output_name) 104 | 105 | tf_cpu_fn = lambda batch: tf_sess.run(tf_cpu_y, feed_dict={tf_cpu_x: batch}) 106 | tf_gpu_fn = lambda batch: tf_sess.run(tf_gpu_y, feed_dict={tf_gpu_x: batch}) 107 | td_fn = lambda batch: td_y.eval({td_x: batch}) 108 | 109 | return collections.OrderedDict([ 110 | ("tf_cpu", tf_cpu_fn), 111 | ("tf_gpu", tf_gpu_fn), 112 | ("td", td_fn) 113 | ]) 114 | 115 | 116 | # 117 | # actual test function 118 | # 119 | 120 | def test(features, examples, batchsize, layers, units): 121 | # prepare the data for the given input dimensions 122 | data.prepare(features, examples, batchsize) 123 | 124 | # create models / evaluation functions 125 | models = create_models(features, layers, units) 126 | 127 | # storage for measured runtimes 128 | times = collections.OrderedDict((name, []) for name in models.keys()) 129 | 130 | # loop through batches and evaluation functions 131 | for batch in data: 132 | for name, fn in models.items(): 133 | t1 = time.time() 134 | fn(batch) 135 | times[name].append(time.time() - t1) 136 | 137 | for name, l in times.items(): 138 | a = np.array(l) 139 | times[name] = { 140 | "total" : np.sum(a), 141 | "mean" : np.mean(a), 142 | "variance": np.var(a) 143 | } 144 | 145 | return times 146 | 147 | 148 | # 149 | # main and entry hook 150 | # 151 | 152 | def main(): 153 | combis = list(itertools.product(*specs.values())) 154 | 155 | # create data with maximum shape 156 | global data 157 | print("create data") 158 | data = Data(max(specs["features"]), max(specs["examples"])) 159 | print("done") 160 | 161 | # run actual tests 162 | results = [] 163 | for i, combi in enumerate(combis): 164 | print("running test %d/%d" % (i + 1, len(combis))) 165 | d = collections.OrderedDict(zip(specs.keys(), combi)) 166 | d["times"] = test(**d) 167 | results.append(d) 168 | if i == 3: break 169 | 170 | with open(output_file, "w") as f: 171 | json.dump(results, f, indent=4) 172 | 173 | if __name__ == "__main__": 174 | main() 175 | -------------------------------------------------------------------------------- /tests/perf/simple.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import tensorflow as tf 4 | import tfdeploy as td 5 | import numpy as np 6 | 7 | 8 | # setup tf graph 9 | sess = tf.Session() 10 | x = tf.placeholder("float", shape=[None, 784], name="input") 11 | W = tf.Variable(tf.truncated_normal([784, 100], stddev=0.05)) 12 | b = tf.Variable(tf.zeros([100])) 13 | y = tf.nn.softmax(tf.matmul(x, W) + b, name="output") 14 | 15 | if td._tf_version[:3] < (0, 12, 0) and 0: 16 | sess.run(tf.initialize_all_variables()) 17 | else: 18 | sess.run(tf.global_variables_initializer()) 19 | 20 | # setup td model 21 | model = td.Model() 22 | model.add(y, sess) 23 | inp, outp = model.get("input", "output") 24 | 25 | # testing code 26 | batch = np.random.rand(10000, 784) 27 | 28 | def test_tf(): 29 | return y.eval(session=sess, feed_dict={x: batch}) 30 | 31 | def test_td(): 32 | return outp.eval({inp: batch}) 33 | -------------------------------------------------------------------------------- /tfdeploy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Deploy tensorflow graphs for fast evaluation and export to tensorflow-less environments running 5 | numpy. 6 | """ 7 | 8 | 9 | __author__ = "Marcel Rieger" 10 | __copyright__ = "Copyright 2016-2025, Marcel Rieger" 11 | __credits__ = ["Marcel Rieger"] 12 | __contact__ = "https://github.com/riga/tfdeploy" 13 | __license__ = "BSD-3-Clause" 14 | __status__ = "Development" 15 | __version__ = "0.4.2" 16 | 17 | __all__ = ["Model", "Tensor", "Operation", "Ensemble", 18 | "UnknownOperationException", "OperationMismatchException", 19 | "InvalidImplementationException", "UnknownImplementationException", 20 | "EnsembleMismatchException", "ScipyOperationException", 21 | "reset", "optimize", "print_tensor", "print_op", "print_tf_tensor", "print_tf_op", 22 | "IMPL_NUMPY", "IMPL_SCIPY", "IMPLS", 23 | "METHOD_MEAN", "METHOD_MAX", "METHOD_MIN", "METHOD_CUSTOM", "METHODS", 24 | "HAS_SCIPY"] 25 | 26 | 27 | # imports for core code 28 | import os 29 | import re 30 | from uuid import uuid4 31 | from functools import reduce 32 | 33 | try: 34 | # python 2 35 | import cPickle as pickle 36 | except ImportError: 37 | # python 3 38 | import pickle 39 | 40 | # third-party imports 41 | import numpy as np 42 | 43 | 44 | # metaclass decorator from six package, credits to Benjamin Peterson 45 | def add_metaclass(metaclass): 46 | def wrapper(cls): 47 | orig_vars = cls.__dict__.copy() 48 | slots = orig_vars.get("__slots__") 49 | if slots is not None: 50 | if isinstance(slots, str): 51 | slots = [slots] 52 | for slots_var in slots: 53 | orig_vars.pop(slots_var) 54 | orig_vars.pop("__dict__", None) 55 | orig_vars.pop("__weakref__", None) 56 | return metaclass(cls.__name__, cls.__bases__, orig_vars) 57 | return wrapper 58 | 59 | 60 | class Model(object): 61 | """ 62 | A trained model that contains one or more converted tensorflow graphs. When *path* is set, a 63 | previously saved model is loaded from that path. Usage: 64 | 65 | .. code-block:: python 66 | 67 | import tensorflow as tf 68 | import tfdeploy as td 69 | 70 | # build your graph, use names for input and output tensors 71 | sess = tf.Session() 72 | x = tf.placeholder("float", shape=[None, 784], name="input") 73 | W = tf.Variable(tf.truncated_normal([784, 100], stddev=0.05)) 74 | b = tf.Variable(tf.zeros([100])) 75 | y = tf.nn.softmax(tf.matmul(x, W) + b, name="output") 76 | sess.run(tf.initialize_all_variables()) 77 | 78 | # ... training ... 79 | 80 | # create a model and save it to disk 81 | model = td.Model() 82 | model.add(y, sess) 83 | model.save("/path/to/model.pkl") 84 | 85 | And then in an other file: 86 | 87 | .. code-block:: python 88 | 89 | import tfdeploy as td 90 | import numpy as np 91 | 92 | model = td.Model("/path/to/model.pkl") 93 | inp, outp = model.get("input", "output") 94 | 95 | batch = np.random.rand(10000, 784) 96 | result = outp.eval({inp: batch}) 97 | 98 | .. py:attribute:: roots 99 | 100 | Contained root tensors in a dict mapped to a key. 101 | """ 102 | 103 | value_index_cre = re.compile("\:\d+$") 104 | default_value_index = 0 105 | 106 | def __init__(self, path=None): 107 | super(Model, self).__init__() 108 | 109 | self.roots = {} 110 | 111 | # load when desired 112 | if path is not None: 113 | self.load(path) 114 | 115 | def get(self, *names, **kwargs): 116 | """ get(*names, key=None) 117 | Returns one or more :py:class:`Tensor` instances given by *names* using a deep lookup within 118 | the model. If *key* is not *None*, only the root tensor with that *key* is traversed. *None* 119 | is returned when no tensor was found. In case a tensor is passed, it's name is used for the 120 | lookup. 121 | """ 122 | tensors = tuple(self._get(name, **kwargs) for name in names) 123 | return tensors[0] if len(names) == 1 else tensors 124 | 125 | def _get(self, name, key=None): 126 | if isinstance(name, Tensor): 127 | name = name.name 128 | 129 | # append the default value_index if there's none 130 | if not self.value_index_cre.search(name): 131 | name += ":%d" % self.default_value_index 132 | 133 | # return the first occurance of a tensor with that name 134 | if key is not None: 135 | return self.roots[key].get(name) 136 | else: 137 | return reduce(lambda t1, t2: t1 or t2.get(name), self.roots.values(), None) 138 | 139 | def __getitem__(self, name): 140 | return self.get(name) 141 | 142 | def __contains__(self, name): 143 | return self.get(name) is not None 144 | 145 | def add(self, tensor, tf_sess=None, key=None, **kwargs): 146 | """ 147 | Adds a new root *tensor* for a *key* which, if *None*, defaults to a consecutive number. 148 | When *tensor* is not an instance of :py:class:`Tensor` but an instance of 149 | ``tensorflow.Tensor``, it is converted first. In that case, *tf_sess* should be a valid 150 | tensorflow session and *kwargs* are forwarded to the :py:class:`Tensor` constructor. 151 | """ 152 | if not isinstance(tensor, Tensor): 153 | tensor = Tensor(tensor, tf_sess, **kwargs) 154 | 155 | if key is None: 156 | if len(self.roots) == 0: 157 | key = 0 158 | else: 159 | key = max(self.roots.keys()) + 1 160 | 161 | self.roots[key] = tensor 162 | 163 | def load(self, path): 164 | """ 165 | Loads all tensors from a file defined by *path* and adds them to the root set. 166 | """ 167 | path = os.path.expandvars(os.path.expanduser(path)) 168 | with open(path, "rb") as f: 169 | roots = pickle.load(f) 170 | 171 | for key, tensor in roots.items(): 172 | self.add(tensor, key=key) 173 | 174 | def save(self, path): 175 | """ 176 | Saves all tensors of the root set to a file defined by *path*. 177 | """ 178 | path = os.path.expandvars(os.path.expanduser(path)) 179 | with open(path, "wb") as f: 180 | pickle.dump(self.roots, f) 181 | 182 | 183 | class TensorRegister(type): 184 | """ 185 | Meta class of :py:class:`Tensor` that performs instance caching indexed by tensorflow tensor 186 | instances. 187 | """ 188 | 189 | instances = {} 190 | 191 | def __call__(cls, tf_tensor, *args, **kwargs): 192 | # simple caching 193 | if tf_tensor not in cls.instances: 194 | inst = super(TensorRegister, cls).__call__(tf_tensor, *args, **kwargs) 195 | cls.instances[tf_tensor] = inst 196 | return cls.instances[tf_tensor] 197 | 198 | 199 | @add_metaclass(TensorRegister) 200 | class Tensor(object): 201 | """ 202 | Building block of a model. In *graph* terms, tensors represent connections between nodes (ops) 203 | of a graph. It contains information on the op it results from. The conversion uses the 204 | (tensorflow) instances *tf_tensor* and *tf_sess*, *tf_feed_dict* can be set to evaluate the 205 | tensor's current value. 206 | 207 | .. py:attribute:: name 208 | 209 | The name of the tensor. 210 | 211 | .. py:attribute:: value_index 212 | 213 | The integer value index of this tensor, i.e., the position in the op's output list. 214 | 215 | .. py:attribute:: op 216 | 217 | The op instance that defines the value of this tensor. When created from a 218 | ``tensorflow.Placeholder`` or a ``tensorflow.Variable/V2``, op will be *None*. 219 | 220 | .. py:attribute:: value 221 | 222 | The value of this tensor. When created from a ``tensorflow.Variable/V2``, this will be the 223 | value of that variable, or *None* otherwise until it is evaluated the first time. 224 | """ 225 | 226 | def __init__(self, tf_tensor, tf_sess, tf_feed_dict=None): 227 | super(Tensor, self).__init__() 228 | 229 | if not tf_sess: 230 | raise ValueError("bad tensorflow session: %s" % tf_sess) 231 | 232 | self.name = tf_tensor.name 233 | self.value_index = tf_tensor.value_index 234 | self.op = None 235 | self.value = None 236 | self.last_uuid = None 237 | 238 | # guess the value 239 | # explicitly evaluate variables and constants, use feed_dict for placeholders 240 | if tf_tensor.op.type in ("Variable", "VariableV2", "Const"): 241 | self.value = tf_tensor.eval(session=tf_sess, feed_dict=tf_feed_dict) 242 | elif tf_tensor.op.type == "Placeholder": 243 | if tf_feed_dict is not None and tf_tensor in tf_feed_dict: 244 | self.value = tf_feed_dict[tf_tensor] 245 | 246 | # create the op 247 | # no op for variables, placeholders and constants 248 | if tf_tensor.op.type not in ("Variable", "VariableV2", "Const", "Placeholder"): 249 | self.op = Operation.new(tf_tensor.op, tf_sess, tf_feed_dict=tf_feed_dict) 250 | 251 | def get(self, *names): 252 | """ 253 | Returns one or more tensors given by *names* using a deep lookup within the inputs of the 254 | op. Note that *this* tensor is returned when the name matches. *None* is returned when no 255 | tensor was found. 256 | """ 257 | tensors = tuple(self._get(name) for name in names) 258 | return tensors[0] if len(names) == 1 else tensors 259 | 260 | def _get(self, name): 261 | if self.name == name: 262 | return self 263 | elif self.op is None: 264 | return None 265 | else: 266 | return self.op.get(name) 267 | 268 | def eval(self, feed_dict=None, _uuid=None): 269 | """ eval(feed_dict=None) 270 | Returns the value of this tensor based on the evaluation of all dependent ops and tensors. 271 | You can overwrite values of dependent tensors using *feed_dict*, a mapping of tensors to 272 | numpy arrays, which is passed down the evaluation chain. 273 | """ 274 | # set a cache uuid for this eval call 275 | if _uuid is None: 276 | _uuid = uuid4() 277 | 278 | # already cached? this is important for tensors that are used multiple time within the graph 279 | if _uuid == self.last_uuid: 280 | return self.value 281 | else: 282 | self.last_uuid = _uuid 283 | 284 | if feed_dict is None: 285 | feed_dict = {} 286 | 287 | # when _this_ tensor is in the feed_dict, return the fed value 288 | # otherwise, eval the op 289 | if self in feed_dict: 290 | self.value = feed_dict[self] 291 | elif self.op is not None: 292 | self.value = self.op.eval(feed_dict=feed_dict, _uuid=_uuid)[self.value_index] 293 | 294 | return self.value 295 | 296 | def __call__(self, *args, **kwargs): 297 | return self.eval(*args, **kwargs) 298 | 299 | 300 | class OperationRegister(type): 301 | """ 302 | Meta class of :py:class:`Operation` that performs instance caching indexed by tensorflow op 303 | instances. Additionaly, all derived classes are registered in a mapping using their type's for 304 | faster op class lookup. 305 | """ 306 | 307 | classes = {} 308 | instances = {} 309 | 310 | def __new__(metacls, classname, bases, classdict): 311 | # when not set explicitly in that class, set type to the class name 312 | classdict.setdefault("types", (classname,)) 313 | cls = super(OperationRegister, metacls).__new__(metacls, classname, bases, classdict) 314 | # register the class for each of its types 315 | for type in cls.types: 316 | metacls.classes[type] = cls 317 | return cls 318 | 319 | def __call__(cls, tf_op, *args, **kwargs): 320 | # simple caching 321 | if tf_op not in cls.instances: 322 | inst = super(OperationRegister, cls).__call__(tf_op, *args, **kwargs) 323 | cls.instances[tf_op] = inst 324 | return cls.instances[tf_op] 325 | 326 | 327 | # implementation types 328 | IMPLS = IMPL_NUMPY, IMPL_SCIPY = range(2) 329 | IMPL_NAMES = ["numpy", "scipy"] 330 | 331 | 332 | @add_metaclass(OperationRegister) 333 | class Operation(object): 334 | """ 335 | Building block of a model. In *graph* terms, operations (ops) represent nodes that are connected 336 | via tensors. It contains information on its input tensors. The conversion uses the 337 | (tensorflow) instance *tf_op*, all *args* and *kwargs* are forwarded to the :py:class:`Tensor` 338 | constructor for this op's input tensors. Op instances can have multiple implementations, i.e., 339 | different methods that lead to equivalent results but might use additional third-party software 340 | such as *scipy*. To select a specific implementation, invoke :py:func:`use_impl`: 341 | 342 | .. code-block:: python 343 | 344 | # tell SomeOp to use the scipy implementation of its op logic 345 | SomeOp.use_impl(IMPL_SCIPY) 346 | 347 | See :py:func:`add_impl` for more info about adding new implementations. 348 | 349 | .. py:attribute:: types 350 | classmember 351 | 352 | A tuple containing the types of tensorflow ops that this op can represent. 353 | 354 | .. py:attribute:: unpack 355 | classmember 356 | 357 | If *True* (default), the values of evaluated input tensors are forwarded to *func* as single 358 | arguments, or, otherwise, as a list. 359 | 360 | .. py:attribute:: attrs 361 | classmember 362 | 363 | Names of the configuration attributes of the original tensorflow op in a tuple. 364 | 365 | .. py:attribute:: name 366 | 367 | The name of the op. 368 | 369 | .. py:attribute:: inputs 370 | 371 | Tuple of tensors that are input to this op. Their order is important as they are forwarded to 372 | *func* for evaluation. 373 | 374 | .. py:attribute:: kwargs 375 | 376 | Keyword arguments containing configuration values that will be passed to *func*. 377 | """ 378 | 379 | impl = None 380 | impls = [] 381 | 382 | types = () 383 | unpack = True 384 | attrs = () 385 | output_dtypes = False 386 | 387 | def __init__(self, tf_op, *args, **kwargs): 388 | super(Operation, self).__init__() 389 | 390 | # compare types as a cross check 391 | if tf_op.type not in self.types: 392 | raise OperationMismatchException("operation types do not match: %s, %s" \ 393 | % (self.types, tf_op.type)) 394 | 395 | self.name = tf_op.name 396 | self.inputs = tuple(Tensor(tf_tensor, *args, **kwargs) for tf_tensor in tf_op.inputs) 397 | 398 | self.value = None 399 | self.last_uuid = None 400 | 401 | # store attributes as kwargs for calls to eval 402 | self.kwargs = [] 403 | for attr in self.attrs: 404 | try: 405 | value = tf_op.get_attr(attr) 406 | except ValueError: 407 | value = None 408 | self.kwargs.append(value) 409 | 410 | # store output dtypes for calls to eval when x is True 411 | self.output_dtypes = [dtype_map[dtype] for dtype in tf_op._output_types] 412 | 413 | @classmethod 414 | def new(cls, tf_op, *args, **kwargs): 415 | """ 416 | Factory function that takes a tensorflow op *tf_op* and returns an instance of the 417 | appropriate op class. *args* and *kwargs* are forwarded to the op constructor. Raises an 418 | exception of type :py:exc:`UnknownOperationException` in case the requested op type is not 419 | known. 420 | """ 421 | if tf_op.type not in cls.classes: 422 | raise UnknownOperationException("unknown operation: %s" % tf_op.type) 423 | 424 | return cls.classes[tf_op.type](tf_op, *args, **kwargs) 425 | 426 | def set_attr(self, attr, value): 427 | """ 428 | Overwrites the value of an attribute *attr* with a new *value*. 429 | """ 430 | if attr not in self.attrs: 431 | raise AttributeError("no attribute '%s' in op '%s'" % (attr, self.name)) 432 | 433 | self.kwargs[self.attrs.index(attr)] = value 434 | 435 | def get(self, *names): 436 | """ 437 | Returns one or more tensors given by *names* using a deep lookup within this op. *None* is 438 | returned when no tensor was found. 439 | """ 440 | tensors = tuple(self._get(name) for name in names) 441 | return tensors[0] if len(names) == 1 else tensors 442 | 443 | def _get(self, name): 444 | return reduce(lambda t1,t2: t1 or t2.get(name), self.inputs, None) 445 | 446 | def eval(self, feed_dict=None, _uuid=None): 447 | """ eval(feed_dict=None) 448 | Returns the value of all output tensors in a tuple. See :py:meth:`Tensor.eval` for more 449 | info. 450 | """ 451 | # set a cache uuid for this eval call 452 | if _uuid is None: 453 | _uuid = uuid4() 454 | 455 | # already cached? 456 | if _uuid == self.last_uuid: 457 | return self.value 458 | else: 459 | self.last_uuid = _uuid 460 | 461 | args = [t.eval(feed_dict=feed_dict, _uuid=_uuid) for t in self.inputs] 462 | if self.unpack: 463 | args.extend(self.kwargs) 464 | else: 465 | args = [args] + self.kwargs 466 | if self.__class__.output_dtypes: 467 | args.append(self.output_dtypes) 468 | 469 | self.value = self.func(*args) 470 | 471 | return self.value 472 | 473 | @classmethod 474 | def func(cls, *args): 475 | """ 476 | The actual op logic. By default, the method call is forwareded to the 477 | implementation-specific version which is determined using *impl*. Overwrite this method in 478 | inheriting classes to disable this feature. Must return a tuple. 479 | """ 480 | if cls.impl == IMPL_NUMPY: 481 | return cls.func_numpy(*args) 482 | elif cls.impl == IMPL_SCIPY: 483 | return cls.func_scipy(*args) 484 | else: 485 | raise InvalidImplementationException(cls.impl) 486 | 487 | @staticmethod 488 | def func_numpy(*args): 489 | """ 490 | Numpy implementation of the op logic. Returns a tuple. 491 | """ 492 | raise NotImplementedError 493 | 494 | @staticmethod 495 | def func_scipy(*args): 496 | """ 497 | Scipy implementation of the op logic. Returns a tuple. 498 | """ 499 | raise NotImplementedError 500 | 501 | @classmethod 502 | def factory(cls, func=None, impl=IMPL_NUMPY, **kwargs): 503 | """ factory(func=None, impl=IMPL_NUMPY, **kwargs) 504 | Returns a new op class whose static function will be set to *func*. The name of *func* will 505 | also be the op class name. *impl* is the default implementation type of the op. *kwargs* are 506 | used to update the class dict of the newly created op class. 507 | """ 508 | if impl not in IMPLS: 509 | raise InvalidImplementationException(impl) 510 | 511 | def wrapper(func): 512 | classdict = {"impls": [], "func_" + IMPL_NAMES[impl]: staticmethod(func)} 513 | classdict.update(kwargs) 514 | 515 | cls = Operation.__class__(func.__name__, (Operation,), classdict) 516 | cls.__doc__ = func.__doc__ 517 | cls.impls.append(impl) 518 | cls.use_impl(impl) 519 | 520 | return cls 521 | 522 | return wrapper if func is None else wrapper(func) 523 | 524 | @classmethod 525 | def use_impl(cls, impl): 526 | """ 527 | Switches the implementation type to *impl*. Returns the previous type. 528 | """ 529 | if impl not in cls.impls: 530 | raise UnknownImplementationException(impl) 531 | 532 | prev = cls.impl 533 | cls.impl = impl 534 | return prev 535 | 536 | @classmethod 537 | def add_impl(cls, impl): 538 | """ 539 | Decorator to add an additional implementation to this op. Example: 540 | 541 | .. code-block:: python 542 | 543 | # initial implementation using factory, defaults to numpy 544 | @Operation.factory 545 | def MyOp(a, b): 546 | # use numpy only 547 | return ... 548 | 549 | # also add a scipy implementation 550 | @MyOp.add_impl(IMPL_SCIPY) 551 | def MyOp(a, b): 552 | # also use scipy 553 | return ... 554 | """ 555 | if impl not in IMPLS: 556 | raise InvalidImplementationException(impl) 557 | 558 | def wrapper(func): 559 | setattr(cls, "func_" + IMPL_NAMES[impl], staticmethod(func)) 560 | if impl not in cls.impls: 561 | cls.impls.append(impl) 562 | return cls 563 | 564 | return wrapper 565 | 566 | 567 | # ensemble method types 568 | METHODS = METHOD_MEAN, METHOD_MAX, METHOD_MIN, METHOD_CUSTOM = range(4) 569 | METHOD_NAMES = ["mean", "max", "min", "custom"] 570 | 571 | 572 | class Ensemble(object): 573 | """ 574 | An ensemble is a wrapper around multiple models to compute ensemble values. It can initialized 575 | with a list of model paths and an ensembling method that decides how to compute the merged 576 | value. 577 | 578 | .. code-block:: python 579 | 580 | # create the ensemble 581 | ensemble = Ensemble(["model1.pkl", "model2.pkl", ...], METHOD_MEAN) 582 | 583 | # get input and output tensors (which actually are TensorEnsemble instances) 584 | input, output = ensemble.get("input", "output") 585 | 586 | # evaluate the ensemble just like a normal model 587 | batch = ... 588 | value = output.eval({input: batch}) 589 | 590 | If you want to use another method than ``METHOD_MEAN``, ``METHOD_MAX`` or ``METHOD_MAX``, use 591 | ``METHOD_CUSTOM`` and overwrite the ``func_custom`` method of the :py:class:`TensorEnsemble` 592 | instance. 593 | 594 | .. py:attribute:: models 595 | 596 | A list that contains all read models. 597 | 598 | .. py:attribute:: method 599 | 600 | The ensembling method. 601 | """ 602 | 603 | def __init__(self, paths=None, method=METHOD_MEAN): 604 | """ __init__(paths=None, method=METHOD_MEAN) 605 | """ 606 | super(Ensemble, self).__init__() 607 | 608 | # check method 609 | if method not in METHODS: 610 | raise UnknownEnsembleMethodException(method) 611 | self.method = method 612 | 613 | # loaded models 614 | self.models = [] 615 | 616 | # load when desired 617 | if paths is not None: 618 | self.load(paths) 619 | 620 | def get(self, *names, **kwargs): 621 | """ get(*names, key=None) 622 | Returns one or more :py:class:`TensorEnsemble` instances given by *names* using a deep 623 | lookup within all read models. Each returned tensor ensemble will have ``len(models)`` 624 | tensors. If a model does not contain a specific tensor defined by a specific *name*, the 625 | associated ensemble tensor will contain a *None* for that model in its tensors. If *key* is 626 | not *None*, only the root tensors with that *key* are traversed. 627 | """ 628 | # create empty tensor ensembles with our method 629 | tensor_ensembles = [TensorEnsemble([], self.method) for name in names] 630 | 631 | # loop over models, collect and add tensors 632 | for model in self.models: 633 | tensors = model.get(*names, **kwargs) 634 | if not isinstance(tensors, tuple): 635 | tensors = (tensors,) 636 | for i, t in enumerate(tensors if isinstance(tensors, tuple) else (tensors,)): 637 | tensor_ensembles[i].tensors.append(t) 638 | 639 | return tensor_ensembles[0] if len(names) == 1 else tuple(tensor_ensembles) 640 | 641 | def load(self, paths): 642 | """ 643 | Loads models from a list of *paths*. 644 | """ 645 | for path in paths: 646 | self.models.append(Model(path)) 647 | 648 | 649 | class TensorEnsemble(object): 650 | """ 651 | A tensor ensemble basically contains a list of tensors that correspond to models of an 652 | :py:class:`Ensemble` instance. 653 | 654 | .. py:attribute: tensors 655 | 656 | The list of contained tensors. Tensor *i* corresponds to model *i*. 657 | 658 | .. py:attribute: method 659 | 660 | The ensembling method. 661 | """ 662 | 663 | def __init__(self, tensors, method=METHOD_MEAN): 664 | super(TensorEnsemble, self).__init__() 665 | 666 | # check method 667 | if method not in METHODS: 668 | raise UnknownEnsembleMethodException(method) 669 | self.method = method 670 | 671 | self.tensors = list(tensors) 672 | 673 | def eval(self, feed_dict=None): 674 | """ 675 | Evaluates all contained tensors using a *feed_dict* and returns the ensemble value. The keys 676 | of *feed_dict* must be tensor ensembles. Its values can be batches, i.e., numpy arrays, or 677 | lists or tuples of batches. In the latter case, these lists or tuples must have the same 678 | length as the list of stored tensors as they will be mapped. 679 | """ 680 | # first, check that the length of all feed_dict keys match our own length 681 | for tensor_ensemble in feed_dict: 682 | if len(tensor_ensemble.tensors) != len(self.tensors): 683 | raise EnsembleMismatchException("incompatible lengths of tensors: %d, %d" \ 684 | % (len(self.tensors), len(tensor_ensemble.tensors))) 685 | 686 | # create a joined uuid 687 | _uuid = uuid4() 688 | 689 | # prepare feed_dicts 690 | feed_dicts = [{} for _ in range(len(self.tensors))] 691 | for tensor_ensemble, value in feed_dict.items(): 692 | for i, tensor in enumerate(tensor_ensemble.tensors): 693 | if tensor is not None: 694 | feed_dicts[i][tensor] = value[i] if isinstance(value, (list, tuple)) else value 695 | 696 | # eval all tensors 697 | values = [t.eval(feed_dict=d, _uuid=_uuid) for t, d in zip(self.tensors, feed_dicts)] 698 | 699 | # return the computed ensemble value 700 | return self.func(values) 701 | 702 | def __call__(self, *args, **kwargs): 703 | return self.eval(*args, **kwargs) 704 | 705 | def func(self, values): 706 | """ 707 | The actual ensembling logic that combines multiple *values*. The method call is forwareded 708 | tothe ensemble method-specific variant which is determined using *method*. 709 | """ 710 | if self.method == METHOD_MEAN: 711 | return self.func_mean(values) 712 | elif self.method == METHOD_MAX: 713 | return self.func_max(values) 714 | elif self.method == METHOD_MIN: 715 | return self.func_min(values) 716 | elif self.method == METHOD_CUSTOM: 717 | return self.func_custom(values) 718 | else: 719 | raise UnknownEnsembleMethodException(self.method) 720 | 721 | @staticmethod 722 | def func_mean(values): 723 | return np.mean(np.stack(values), axis=0) 724 | 725 | @staticmethod 726 | def func_max(values): 727 | return np.amax(np.stack(values), axis=0) 728 | 729 | @staticmethod 730 | def func_min(values): 731 | return np.amin(np.stack(values), axis=0) 732 | 733 | @staticmethod 734 | def func_custom(values): 735 | raise NotImplementedError 736 | 737 | 738 | class UnknownOperationException(Exception): 739 | """ 740 | An exception which is raised when trying to convert an unknown tensorflow. 741 | """ 742 | 743 | 744 | class OperationMismatchException(Exception): 745 | """ 746 | An exception which is raised during instantiation of an op whose type does not match the 747 | underlying tensorflow op. 748 | """ 749 | 750 | 751 | class InvalidImplementationException(Exception): 752 | """ 753 | An exception which is raised when an implementation of an unknown type is registered for an 754 | :py:class:`Operation` class. 755 | """ 756 | 757 | 758 | class UnknownImplementationException(Exception): 759 | """ 760 | An exception which is raised when an :py:class:`Operation` instance is requested to use an 761 | implementation type that was not yet added. 762 | """ 763 | 764 | 765 | class UnknownEnsembleMethodException(Exception): 766 | """ 767 | An exception which is raised when an :py:class:`Ensemble` instance is initialised with an 768 | unknown ensemle method. 769 | """ 770 | 771 | 772 | class EnsembleMismatchException(Exception): 773 | """ 774 | An exception which is raised when a :py:class:`TensorEnsemble` instance is evaluated with a 775 | *feed_dict* whose keys, i.e. also :py:class:`TensorEnsemble` instances, do not match the tensor 776 | to evaluate. An example would be that a tensor ensemble with *n* tensors is evaluated with a 777 | tensor ensemble it its *feed_dict* that contains *m* tensors. 778 | """ 779 | 780 | 781 | class ScipyOperationException(Exception): 782 | """ 783 | An exception which is raised when trying to evaluate an op that uses scipy internally and scipy 784 | is not available. 785 | """ 786 | def __init__(self, attr): 787 | msg = "trying to access 'scipy.%s', but scipy is not installed on your system, " \ 788 | "install scipy to use this operation or use an other implementation" % attr 789 | super(ScipyOperationException, self).__init__(msg) 790 | 791 | 792 | # parses the tf version and returns a tuple, e.g. "0.12.0-rc1" => (0, 12, 0, "rc1") 793 | def _parse_tf_version(v): 794 | parts = v.split(".", 2) 795 | if "-" in parts[2]: 796 | parts.extend(parts.pop().split("-", 1)) 797 | return tuple([int(p) for p in parts[:3]] + parts[3:]) 798 | 799 | 800 | # default (last) tf version 801 | _tf_version_string = "0.12.0-rc1" 802 | _tf_version = _parse_tf_version(_tf_version_string) 803 | 804 | 805 | def setup(tf, order=None): 806 | """ 807 | Sets up global variables (currently only the tensorflow version) to adapt to peculiarities of 808 | different tensorflow versions. This function should only be called before :py:class:`Model` 809 | creation, not for evaluation. Therefore, the tensorflow module *tf* must be passed: 810 | 811 | .. code-block:: python 812 | 813 | import tensorflow as tf 814 | import tfdeploy as td 815 | 816 | td.setup(tf) 817 | 818 | # ... 819 | 820 | Also, when *order* is not *None*, it is forwarded to :py:func:`optimize` for convenience. 821 | """ 822 | global _tf_version_string, _tf_version 823 | _tf_version_string = tf.__version__ 824 | _tf_version = _parse_tf_version(_tf_version_string) 825 | 826 | if order is not None: 827 | optimize(order) 828 | 829 | 830 | def reset(): 831 | """ 832 | Resets the instance caches of :py:class:`TensorRegister` and :py:class:`OperationRegister`. 833 | """ 834 | TensorRegister.instances.clear() 835 | OperationRegister.instances.clear() 836 | 837 | 838 | def optimize(order): 839 | """ optimize(impl) 840 | Tries to set the implementation type of all registered :py:class:`Operation` classes to *impl*. 841 | This has no effect when an op does not implement that type. 842 | 843 | The behavior is equivalent to: 844 | 845 | .. code-block:: python 846 | 847 | for op in Operation.__subclasses__(): 848 | if impl in op.impls: 849 | op.use_impl(impl) 850 | 851 | *impl* can also be a list or tuple of valid implementation types representing a preferred order. 852 | """ 853 | if not isinstance(order, (list, tuple)): 854 | order = [order] 855 | 856 | for op in Operation.__subclasses__(): 857 | for impl in order: 858 | if impl in op.impls: 859 | op.use_impl(impl) 860 | break 861 | 862 | 863 | def print_tensor(td_tensor, indent="| ", max_depth=-1, depth=0): 864 | """ print_tensor(td_tensor, indent=" ", max_depth=-1) 865 | Prints the dependency graph of a :py:class:`Tensor` *td_tensor*, where each new level is 866 | indented by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where 867 | each tensor and each op count as a level. 868 | """ 869 | offset = depth * indent 870 | line = "td tensor: %s" % td_tensor.name 871 | if td_tensor.value is not None: 872 | line += " (%s)" % (",".join(str(i) for i in td_tensor.value.shape),) 873 | 874 | print(offset + line) 875 | 876 | if td_tensor.op and (max_depth < 0 or max_depth > depth): 877 | print_op(td_tensor.op, indent=indent, max_depth=max_depth, depth=depth+1) 878 | 879 | 880 | def print_op(td_op, indent="| ", max_depth=-1, depth=0): 881 | """ print_op(td_op, indent=" ", max_depth=-1) 882 | Prints the dependency graph of a :py:class:`Operation` *td_op*, where each new level is indented 883 | by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where each 884 | tensor and each op count as a level. 885 | """ 886 | offset = depth * indent 887 | line = "td op: %s (%s)" % (td_op.name, ",".join(td_op.types)) 888 | 889 | print(offset + line) 890 | 891 | if max_depth < 0 or max_depth > depth: 892 | for td_tensor in td_op.inputs: 893 | print_tensor(td_tensor, indent=indent, max_depth=max_depth, depth=depth+1) 894 | 895 | 896 | def print_tf_tensor(tf_tensor, indent="| ", max_depth=-1, depth=0): 897 | """ print_tf_tensor(tf_tensor, indent=" ", max_depth=-1) 898 | Prints the dependency graph of a tensorflow tensor *tf_tensor*, where each new level is indented 899 | by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where each 900 | tensor and each op count as a level. 901 | """ 902 | offset = depth * indent 903 | shape = tuple(int(i) for i in tf_tensor.get_shape()) 904 | line = "tf tensor: %s (%s)" % (tf_tensor.name, ",".join(str(i) for i in shape)) 905 | 906 | print(offset + line) 907 | 908 | if tf_tensor.op and (max_depth < 0 or max_depth > depth): 909 | print_tf_op(tf_tensor.op, indent=indent, max_depth=max_depth, depth=depth+1) 910 | 911 | 912 | def print_tf_op(tf_op, indent="| ", max_depth=-1, depth=0): 913 | """ print_tf_op(tf_tensor, indent=" ", max_depth=-1) 914 | Prints the dependency graph of a tensorflow operation *tf_op*, where each new level is indented 915 | by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where each 916 | tensor and each op count as a level. 917 | """ 918 | offset = depth * indent 919 | line = "tf op: %s (%s)" % (tf_op.name, tf_op.type) 920 | 921 | print(offset + line) 922 | 923 | if max_depth < 0 or max_depth > depth: 924 | for tf_tensor in tf_op.inputs: 925 | print_tf_tensor(tf_tensor, indent=indent, max_depth=max_depth, depth=depth+1) 926 | 927 | 928 | # imports exclusively for ops 929 | from operator import mul 930 | from itertools import product 931 | from collections import defaultdict 932 | 933 | # optional import of scipy 934 | try: 935 | if os.environ.get("TD_REFUSE_SCIPY", "").lower() in ("1", "true", "yes"): 936 | raise ImportError 937 | 938 | import scipy as sp 939 | import scipy.special 940 | HAS_SCIPY = True 941 | except ImportError: 942 | class ScipyDummy(object): 943 | def __getattr__(self, attr): 944 | raise ScipyOperationException(attr) 945 | sp = ScipyDummy() 946 | HAS_SCIPY = False 947 | 948 | 949 | # mapping of tf dtypes to np dtypes 950 | dtype_map = { 951 | 1: np.float32, 952 | 2: np.float64, 953 | 3: np.int32, 954 | 4: np.uint8, 955 | 5: np.int16, 956 | 6: np.int8, 957 | 7: np.object, 958 | 8: np.complex64, 959 | 9: np.int64, 960 | 10: np.bool, 961 | 14: np.uint16, 962 | 17: np.uint16, 963 | 18: np.complex128, 964 | 19: np.float16, 965 | 101: np.float32, 966 | 102: np.float64, 967 | 103: np.int32, 968 | 104: np.uint8, 969 | 105: np.int16, 970 | 106: np.int8, 971 | 107: np.object, 972 | 108: np.complex64, 973 | 109: np.int64, 974 | 110: np.bool, 975 | 114: np.uint16, 976 | 117: np.uint16, 977 | 118: np.complex128, 978 | 119: np.float16 979 | } 980 | 981 | 982 | lgamma_vec = np.vectorize(np.math.lgamma) 983 | erf_vec = np.vectorize(np.math.erf) 984 | erfc_vec = np.vectorize(np.math.erfc) 985 | 986 | def _transpose(a, dim=2): 987 | if dim <= 0: 988 | axes = None 989 | else: 990 | axes = list(range(a.ndim)) 991 | axes.append(axes.pop(-1 * dim)) 992 | return np.transpose(a, axes=axes) 993 | 994 | def _adjoint(a, dim=2): 995 | return np.conj(_transpose(a, dim=dim)) 996 | 997 | 998 | # 999 | # sequences 1000 | # 1001 | 1002 | @Operation.factory 1003 | def LinSpace(start, stop, num): 1004 | """ 1005 | Linspace op. 1006 | """ 1007 | return np.linspace(start, stop, num=num, dtype=np.float32), 1008 | 1009 | 1010 | @Operation.factory 1011 | def Range(start, limit, delta): 1012 | """ 1013 | Range op. 1014 | """ 1015 | return np.arange(start, limit, delta, dtype=np.int32), 1016 | 1017 | 1018 | # 1019 | # random tensors 1020 | # 1021 | 1022 | @Operation.factory(attrs=("dtype", "seed")) 1023 | def RandomStandardNormal(shape, dtype, seed): 1024 | """ 1025 | Standard (mu=0, sigma=1) gaussian op. 1026 | """ 1027 | if seed: 1028 | np.random.seed(seed) 1029 | return np.random.normal(size=reduce(mul, shape)).reshape(shape).astype(dtype_map[dtype]), 1030 | 1031 | 1032 | @Operation.factory(attrs=("dtype", "seed")) 1033 | def TruncatedNormal(shape, dtype, seed): 1034 | """ 1035 | Standard (mu=0, sigma=1) gaussian op with truncation above 2 sigma. 1036 | """ 1037 | if seed: 1038 | np.random.seed(seed) 1039 | n = reduce(mul, shape) 1040 | r = np.empty(n, dtype=dtype_map[dtype]) 1041 | idxs = np.ones(n, dtype=np.bool) 1042 | while n: 1043 | r[idxs] = np.random.normal(size=n) 1044 | idxs = np.abs(r) > 2 1045 | n = np.sum(idxs) 1046 | return r.reshape(shape), 1047 | 1048 | 1049 | @Operation.factory(attrs=("dtype", "seed")) 1050 | def RandomUniform(shape, dtype, seed): 1051 | """ 1052 | Random uniform op. 1053 | """ 1054 | if seed: 1055 | np.random.seed(seed) 1056 | return np.random.uniform(size=shape).astype(dtype_map[dtype]), 1057 | 1058 | 1059 | @Operation.factory(attrs=("seed",)) 1060 | def RandomUniformInt(shape, minval, maxval, seed): 1061 | """ 1062 | Random uniform int op. 1063 | """ 1064 | if seed: 1065 | np.random.seed(seed) 1066 | return np.random.randint(minval, maxval, size=shape), 1067 | 1068 | 1069 | @Operation.factory(attrs=("seed",)) 1070 | def RandomShuffle(a, seed): 1071 | """ 1072 | Random uniform op. 1073 | """ 1074 | if seed: 1075 | np.random.seed(seed) 1076 | r = a.copy() 1077 | np.random.shuffle(r) 1078 | return r, 1079 | 1080 | 1081 | # 1082 | # casting 1083 | # 1084 | 1085 | @Operation.factory(types=("Cast", "StringToNumber"), output_dtypes=True) 1086 | def Cast(a, output_dtypes): 1087 | """ 1088 | Cast op. 1089 | """ 1090 | return np.copy(a).astype(output_dtypes[0]), 1091 | 1092 | 1093 | # 1094 | # shapes and shaping 1095 | # 1096 | 1097 | @Operation.factory 1098 | def Shape(a): 1099 | """ 1100 | Shape op. 1101 | """ 1102 | return np.array(a.shape, dtype=np.int32), 1103 | 1104 | 1105 | @Operation.factory 1106 | def Size(a): 1107 | """ 1108 | Size op. 1109 | """ 1110 | return np.array([a.size], dtype=np.int32), 1111 | 1112 | 1113 | @Operation.factory 1114 | def Rank(a): 1115 | """ 1116 | Rank op. 1117 | """ 1118 | return np.array([len(a.shape)], dtype=np.int32), 1119 | 1120 | 1121 | @Operation.factory 1122 | def Reshape(a, shape): 1123 | """ 1124 | Reshape op. 1125 | """ 1126 | return np.copy(a).reshape(shape), 1127 | 1128 | 1129 | @Operation.factory(attrs=("squeeze_dims",)) 1130 | def Squeeze(a, squeeze_dims): 1131 | """ 1132 | Squeeze op, i.e. removes singular axes. 1133 | """ 1134 | if not squeeze_dims: 1135 | squeeze_dims = list(range(len(a.shape))) 1136 | slices = [(0 if (dim == 1 and i in squeeze_dims) else slice(None)) \ 1137 | for i, dim in enumerate(a.shape)] 1138 | return np.copy(a)[slices], 1139 | 1140 | 1141 | @Operation.factory 1142 | def ExpandDims(a, dim): 1143 | """ 1144 | Expand dim op, i.e. add singular axis at dim. 1145 | """ 1146 | shape = list(a.shape) 1147 | if dim >= 0: 1148 | shape.insert(dim, 1) 1149 | else: 1150 | shape.insert(len(shape) + dim + 1, 1) 1151 | return np.copy(a).reshape(*shape), 1152 | 1153 | 1154 | # 1155 | # slicing and joining 1156 | # 1157 | 1158 | @Operation.factory 1159 | def Slice(a, begin, size): 1160 | """ 1161 | Slicing op. 1162 | """ 1163 | return np.copy(a)[[slice(*tpl) for tpl in zip(begin, begin+size)]], 1164 | 1165 | 1166 | @Operation.factory(attrs=("num_split",)) 1167 | def Split(axis, a, n): 1168 | """ 1169 | Split op with n splits. 1170 | """ 1171 | return tuple(np.split(np.copy(a), n, axis=axis)) 1172 | 1173 | 1174 | @Operation.factory 1175 | def SplitV(a, splits, axis): 1176 | """ 1177 | Split op with multiple split sizes. 1178 | """ 1179 | return tuple(np.split(np.copy(a), np.cumsum(splits), axis=axis)) 1180 | 1181 | 1182 | @Operation.factory 1183 | def Tile(a, n): 1184 | """ 1185 | Tile op. 1186 | """ 1187 | return np.tile(a, n), 1188 | 1189 | 1190 | @Operation.factory 1191 | def Pad(a, paddings): 1192 | """ 1193 | Zero padping op. 1194 | """ 1195 | return np.pad(a, paddings, mode="constant", constant_values=0), 1196 | 1197 | 1198 | @Operation.factory(unpack=False) 1199 | def ConcatV2(inputs): 1200 | """ 1201 | Concat op. 1202 | """ 1203 | axis = inputs.pop() 1204 | return np.concatenate(inputs, axis=axis), 1205 | 1206 | 1207 | @Operation.factory(attrs=("axis",), unpack=False) 1208 | def Pack(inputs, axis): 1209 | """ 1210 | Pack op. 1211 | """ 1212 | return np.stack(inputs, axis=axis), 1213 | 1214 | 1215 | @Operation.factory(attrs=("num", "axis")) 1216 | def Unpack(a, num, axis): 1217 | """ 1218 | Unpack op. 1219 | """ 1220 | return tuple(np.squeeze(b, axis=axis) for b in np.split(a, num, axis=axis)) 1221 | 1222 | 1223 | @Operation.factory(attrs=("seq_dim", "batch_dim")) 1224 | def ReverseSequence(a, seq_lengths, seq_dim, batch_dim): 1225 | """ 1226 | Sequential reverse op. 1227 | """ 1228 | r = np.copy(a) 1229 | invidxs = (len(r.shape) - 1) * [slice(None)] 1230 | if seq_dim < batch_dim: 1231 | invidxs[seq_dim] = slice(None, None, -1) 1232 | else: 1233 | invidxs[seq_dim - 1] = slice(None, None, -1) 1234 | _invidxs = tuple(invidxs) 1235 | selidxs = len(r.shape) * [slice(None)] 1236 | for i, l in enumerate(seq_lengths): 1237 | if not l: 1238 | continue 1239 | selidxs[batch_dim] = i 1240 | selidxs[seq_dim] = slice(0, l) 1241 | _selidxs = tuple(selidxs) 1242 | r[_selidxs] = a[_selidxs][_invidxs] 1243 | return r, 1244 | 1245 | 1246 | @Operation.factory 1247 | def ReverseV2(a, axes): 1248 | """ 1249 | Reverse op. 1250 | """ 1251 | idxs = tuple(slice(None, None, 2 * int(i not in axes) - 1) for i in range(len(a.shape))) 1252 | return np.copy(a[idxs]), 1253 | 1254 | 1255 | @Operation.factory 1256 | def Transpose(a, perm=None): 1257 | """ 1258 | Transpose op. 1259 | """ 1260 | return np.transpose(a, axes=perm), 1261 | 1262 | 1263 | # 1264 | # arithmetic math ops 1265 | # 1266 | 1267 | @Operation.factory(types=("Add", "BiasAdd")) 1268 | def Add(a, b): 1269 | """ 1270 | Addition op. 1271 | """ 1272 | return np.add(a, b), 1273 | 1274 | 1275 | @Operation.factory(types=("Subtract", "Sub")) 1276 | def Subtract(a, b): 1277 | """ 1278 | Subtraction op. 1279 | """ 1280 | return np.subtract(a, b), 1281 | 1282 | 1283 | @Operation.factory(types=("Multiply", "Mul")) 1284 | def Multiply(a, b): 1285 | """ 1286 | Multiplication op. 1287 | """ 1288 | return np.multiply(a, b), 1289 | 1290 | 1291 | @Operation.factory(types=("Div", "RealDiv")) 1292 | def Div(a, b): 1293 | """ 1294 | Division op. 1295 | """ 1296 | return np.divide(a, b), 1297 | 1298 | 1299 | @Operation.factory 1300 | def FloorDiv(a, b): 1301 | """ 1302 | Floor division op, i.e., a // b. 1303 | """ 1304 | return np.floor_divide(a, b), 1305 | 1306 | 1307 | @Operation.factory(types=("Mod", "FloorMod")) 1308 | def Mod(a, b): 1309 | """ 1310 | Modulo op. 1311 | """ 1312 | return np.mod(a, b), 1313 | 1314 | 1315 | @Operation.factory 1316 | def Cross(a, b): 1317 | """ 1318 | Cross product op. 1319 | """ 1320 | return np.cross(a, b), 1321 | 1322 | 1323 | # 1324 | # basic math ops 1325 | # 1326 | 1327 | @Operation.factory(unpack=False) 1328 | def AddN(inputs): 1329 | """ 1330 | Multi add op. 1331 | """ 1332 | return reduce(np.add, inputs), 1333 | 1334 | 1335 | @Operation.factory 1336 | def Abs(a): 1337 | """ 1338 | Abs op. 1339 | """ 1340 | return np.abs(a), 1341 | 1342 | 1343 | @Operation.factory(types=("Negative", "Neg")) 1344 | def Negative(a): 1345 | """ 1346 | Negative op. 1347 | """ 1348 | return np.negative(a), 1349 | 1350 | 1351 | @Operation.factory 1352 | def Sign(a): 1353 | """ 1354 | Sign op. 1355 | """ 1356 | return np.sign(a), 1357 | 1358 | 1359 | @Operation.factory 1360 | def Inv(a): 1361 | """ 1362 | Reciprocal op. 1363 | """ 1364 | return np.reciprocal(a), 1365 | 1366 | 1367 | @Operation.factory 1368 | def Square(a): 1369 | """ 1370 | Square op. 1371 | """ 1372 | return np.square(a), 1373 | 1374 | 1375 | @Operation.factory 1376 | def Round(a): 1377 | """ 1378 | Round op. 1379 | """ 1380 | return np.round(a), 1381 | 1382 | 1383 | @Operation.factory 1384 | def Sqrt(a): 1385 | """ 1386 | Square root op. 1387 | """ 1388 | return np.sqrt(a), 1389 | 1390 | 1391 | @Operation.factory 1392 | def Rsqrt(a): 1393 | """ 1394 | Reciprocal square root op. 1395 | """ 1396 | return np.reciprocal(np.sqrt(a)), 1397 | 1398 | 1399 | @Operation.factory 1400 | def Pow(a, b): 1401 | """ 1402 | Power op. 1403 | """ 1404 | return np.power(a, b), 1405 | 1406 | 1407 | @Operation.factory 1408 | def Exp(a): 1409 | """ 1410 | Exponential op. 1411 | """ 1412 | return np.exp(a), 1413 | 1414 | 1415 | @Operation.factory 1416 | def Log(a): 1417 | """ 1418 | Logarithm op. 1419 | """ 1420 | return np.log(a), 1421 | 1422 | 1423 | @Operation.factory 1424 | def Ceil(a): 1425 | """ 1426 | Ceil round op. 1427 | """ 1428 | return np.ceil(a), 1429 | 1430 | 1431 | @Operation.factory 1432 | def Floor(a): 1433 | """ 1434 | Floor round op. 1435 | """ 1436 | return np.floor(a), 1437 | 1438 | 1439 | @Operation.factory 1440 | def Maximum(a, b): 1441 | """ 1442 | Maximum op. 1443 | """ 1444 | return np.maximum(a, b), 1445 | 1446 | 1447 | @Operation.factory 1448 | def Minimum(a, b): 1449 | """ 1450 | Minimum op. 1451 | """ 1452 | return np.minimum(a, b), 1453 | 1454 | 1455 | @Operation.factory 1456 | def Cos(a): 1457 | """ 1458 | Cos op. 1459 | """ 1460 | return np.cos(a), 1461 | 1462 | 1463 | @Operation.factory 1464 | def Sin(a): 1465 | """ 1466 | Sin op. 1467 | """ 1468 | return np.sin(a), 1469 | 1470 | 1471 | @Operation.factory 1472 | def Tan(a): 1473 | """ 1474 | Tan op. 1475 | """ 1476 | return np.tan(a), 1477 | 1478 | 1479 | @Operation.factory 1480 | def Acos(a): 1481 | """ 1482 | Acos op. 1483 | """ 1484 | return np.arccos(a), 1485 | 1486 | 1487 | @Operation.factory 1488 | def Asin(a): 1489 | """ 1490 | Asin op. 1491 | """ 1492 | return np.arcsin(a), 1493 | 1494 | 1495 | @Operation.factory 1496 | def Atan(a): 1497 | """ 1498 | Atan op. 1499 | """ 1500 | return np.arctan(a), 1501 | 1502 | 1503 | @Operation.factory 1504 | def Lgamma(a): 1505 | """ 1506 | lgamma op. 1507 | """ 1508 | return lgamma_vec(a), 1509 | 1510 | @Lgamma.add_impl(IMPL_SCIPY) 1511 | def Lgamma(a): 1512 | return sp.special.gammaln(a), 1513 | 1514 | 1515 | @Operation.factory(impl=IMPL_SCIPY) 1516 | def Digamma(a): 1517 | """ 1518 | Digamma op. 1519 | """ 1520 | return sp.special.digamma(a), 1521 | 1522 | 1523 | @Operation.factory 1524 | def Erf(a): 1525 | """ 1526 | Gaussian error function op. 1527 | """ 1528 | return erf_vec(a), 1529 | 1530 | @Erf.add_impl(IMPL_SCIPY) 1531 | def Erf(a): 1532 | return sp.special.erf(a), 1533 | 1534 | 1535 | @Operation.factory 1536 | def Erfc(a): 1537 | """ 1538 | Complementary gaussian error function op. 1539 | """ 1540 | return erfc_vec(a), 1541 | 1542 | @Erfc.add_impl(IMPL_SCIPY) 1543 | def Erfc(a): 1544 | return sp.special.erfc(a), 1545 | 1546 | 1547 | @Operation.factory 1548 | def SquaredDifference(a, b): 1549 | """ 1550 | Squared diff op, i.e. (a-b)**2 1551 | """ 1552 | return (a - b)**2, 1553 | 1554 | 1555 | @Operation.factory(impl=IMPL_SCIPY) 1556 | def Igamma(a, b): 1557 | """ 1558 | Incomplete gamma op. 1559 | """ 1560 | return sp.special.gammainc(a, b), 1561 | 1562 | 1563 | @Operation.factory(impl=IMPL_SCIPY) 1564 | def Igammac(a, b): 1565 | """ 1566 | Complemented, incomplete gamma op. 1567 | """ 1568 | return sp.special.gammaincc(a, b), 1569 | 1570 | 1571 | @Operation.factory(impl=IMPL_SCIPY) 1572 | def Zeta(a, b): 1573 | """ 1574 | Zeta op. 1575 | """ 1576 | return sp.special.zeta(a, b), 1577 | 1578 | 1579 | @Operation.factory(impl=IMPL_SCIPY) 1580 | def Polygamma(a, b): 1581 | """ 1582 | Polygamma op. 1583 | """ 1584 | return sp.special.polygamma(a, b), 1585 | 1586 | 1587 | @Operation.factory(impl=IMPL_SCIPY) 1588 | def Betainc(a, b, x): 1589 | """ 1590 | Complemented, incomplete gamma op. 1591 | """ 1592 | return sp.special.betainc(a, b, x), 1593 | 1594 | 1595 | # 1596 | # matrix math ops 1597 | # 1598 | 1599 | @Operation.factory 1600 | def Diag(a): 1601 | """ 1602 | Diag op. 1603 | """ 1604 | r = np.zeros(2 * a.shape, dtype=a.dtype) 1605 | for idx, v in np.ndenumerate(a): 1606 | r[2 * idx] = v 1607 | return r, 1608 | 1609 | 1610 | @Operation.factory 1611 | def DiagPart(a): 1612 | """ 1613 | Diag op that returns only the diagonal elements. 1614 | """ 1615 | return np.diagonal(a), 1616 | 1617 | 1618 | @Operation.factory 1619 | def MatrixDiagPart(a): 1620 | """ 1621 | Batched diag op that returns only the diagonal elements. 1622 | """ 1623 | r = np.zeros(a.shape[:-2] + (min(a.shape[-2:]),)) 1624 | for coord in np.ndindex(a.shape[:-2]): 1625 | pos = coord + (Ellipsis,) 1626 | r[pos] = np.diagonal(a[pos]) 1627 | return r, 1628 | 1629 | 1630 | @Operation.factory(attrs=("transpose_a", "transpose_b")) 1631 | def MatMul(a, b, transpose_a, transpose_b): 1632 | """ 1633 | Matrix multiplication op. 1634 | """ 1635 | return np.dot(a if not transpose_a else np.transpose(a), 1636 | b if not transpose_b else np.transpose(b)), 1637 | 1638 | 1639 | @Operation.factory 1640 | def MatrixDeterminant(a): 1641 | """ 1642 | Matrix det op. 1643 | """ 1644 | return np.linalg.det(a), 1645 | 1646 | 1647 | @Operation.factory(attrs=("adjoint",)) 1648 | def MatrixInverse(a, adj): 1649 | """ 1650 | Matrix inversion op. 1651 | """ 1652 | return np.linalg.inv(a if not adj else _adjoint(a)), 1653 | 1654 | 1655 | @Operation.factory 1656 | def Cholesky(a): 1657 | """ 1658 | Cholesky decomposition op. 1659 | """ 1660 | return np.linalg.cholesky(a), 1661 | 1662 | 1663 | @Operation.factory(attrs=("adjoint",)) 1664 | def MatrixSolve(a, rhs, adj): 1665 | """ 1666 | Matrix solve op. 1667 | """ 1668 | return np.linalg.solve(a if not adj else _adjoint(a), rhs), 1669 | 1670 | 1671 | @Operation.factory(attrs=("lower", "adjoint"), impl=IMPL_SCIPY) 1672 | def MatrixTriangularSolve(a, rhs, lower, adj): 1673 | """ 1674 | Matrix triangular solve op. 1675 | """ 1676 | trans = 0 if not adj else 2 1677 | 1678 | r = np.empty(rhs.shape).astype(a.dtype) 1679 | for coord in np.ndindex(a.shape[:-2]): 1680 | pos = coord + (Ellipsis,) 1681 | r[pos] = sp.linalg.solve_triangular(a[pos] if not adj else np.conj(a[pos]), rhs[pos], 1682 | trans=trans, lower=lower) 1683 | 1684 | return r, 1685 | 1686 | 1687 | @Operation.factory 1688 | def MatrixSolveLs(a, rhs, l2_reg): 1689 | """ 1690 | Matrix least-squares solve op. 1691 | """ 1692 | r = np.empty(rhs.shape).astype(a.dtype) 1693 | for coord in np.ndindex(a.shape[:-2]): 1694 | pos = coord + (Ellipsis,) 1695 | r[pos] = np.linalg.lstsq(a[pos], rhs[pos])[0] 1696 | 1697 | return r, 1698 | 1699 | 1700 | @Operation.factory 1701 | def SelfAdjointEig(a): 1702 | """ 1703 | Eigen decomp op. 1704 | """ 1705 | shape = list(a.shape) 1706 | shape[-2] += 1 1707 | return np.append(*np.linalg.eig(a)).reshape(*shape), 1708 | 1709 | 1710 | @Operation.factory 1711 | def SelfAdjointEigV2(a): 1712 | """ 1713 | Eigen decomp op. 1714 | """ 1715 | return np.linalg.eig(a) 1716 | 1717 | 1718 | @Operation.factory(attrs=("compute_uv", "full_matrices")) 1719 | def Svd(a, uv, full): 1720 | """ 1721 | Single value decomp op. 1722 | """ 1723 | u, s, v = np.linalg.svd(a, full_matrices=full, compute_uv=uv) 1724 | return s, u, v 1725 | 1726 | 1727 | # 1728 | # complex number ops 1729 | # 1730 | 1731 | @Operation.factory 1732 | def Complex(a, b): 1733 | """ 1734 | Complex number op. 1735 | """ 1736 | return np.add(a, np.multiply(b, 1j)), 1737 | 1738 | 1739 | @Operation.factory 1740 | def Conj(a): 1741 | """ 1742 | Complex conjugate op. 1743 | """ 1744 | return np.conj(a), 1745 | 1746 | 1747 | @Operation.factory 1748 | def Imag(a): 1749 | """ 1750 | Complex imag op. 1751 | """ 1752 | return np.imag(a), 1753 | 1754 | 1755 | @Operation.factory 1756 | def Real(a): 1757 | """ 1758 | Complex real op. 1759 | """ 1760 | return np.real(a), 1761 | 1762 | 1763 | # 1764 | # Fourier transform ops 1765 | # 1766 | 1767 | @Operation.factory 1768 | def FFT2D(a): 1769 | """ 1770 | Discrete 2D FT op. 1771 | """ 1772 | return np.fft.fft2(a), 1773 | 1774 | 1775 | @Operation.factory 1776 | def IFFT2D(a): 1777 | """ 1778 | Discrete inverse 2D FT op. 1779 | """ 1780 | return np.fft.ifft2(a), 1781 | 1782 | 1783 | @Operation.factory 1784 | def FFT3D(a): 1785 | """ 1786 | Discrete 3D FT op. 1787 | """ 1788 | return np.fft.fftn(a), 1789 | 1790 | 1791 | @Operation.factory 1792 | def IFFT3D(a): 1793 | """ 1794 | Discrete inverse 3D FT op. 1795 | """ 1796 | return np.fft.ifftn(a), 1797 | 1798 | 1799 | # 1800 | # reduction 1801 | # 1802 | 1803 | @Operation.factory(attrs=("keep_dims",)) 1804 | def Sum(a, axis, keep_dims): 1805 | """ 1806 | Sum reduction op. 1807 | """ 1808 | return np.sum(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis), 1809 | keepdims=keep_dims), 1810 | 1811 | 1812 | @Operation.factory(attrs=("keep_dims",)) 1813 | def Prod(a, axis, keep_dims): 1814 | """ 1815 | Prod reduction op. 1816 | """ 1817 | return np.prod(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis), 1818 | keepdims=keep_dims), 1819 | 1820 | 1821 | @Operation.factory(attrs=("keep_dims",)) 1822 | def Min(a, axis, keep_dims): 1823 | """ 1824 | Min reduction op. 1825 | """ 1826 | return np.amin(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis), 1827 | keepdims=keep_dims), 1828 | 1829 | 1830 | @Operation.factory(attrs=("keep_dims",)) 1831 | def Max(a, axis, keep_dims): 1832 | """ 1833 | Max reduction op. 1834 | """ 1835 | return np.amax(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis), 1836 | keepdims=keep_dims), 1837 | 1838 | 1839 | @Operation.factory(attrs=("keep_dims",)) 1840 | def Mean(a, axis, keep_dims): 1841 | """ 1842 | Mean reduction op. 1843 | """ 1844 | return np.mean(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis), 1845 | keepdims=keep_dims), 1846 | 1847 | 1848 | @Operation.factory(attrs=("keep_dims",)) 1849 | def All(a, axis, keep_dims): 1850 | """ 1851 | All reduction op. 1852 | """ 1853 | return np.all(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis), 1854 | keepdims=keep_dims), 1855 | 1856 | 1857 | @Operation.factory(attrs=("keep_dims",)) 1858 | def Any(a, axis, keep_dims): 1859 | """ 1860 | Any reduction op. 1861 | """ 1862 | return np.any(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis), 1863 | keepdims=keep_dims), 1864 | 1865 | 1866 | # 1867 | # segmentation 1868 | # 1869 | 1870 | def seg_map(func, a, ids): 1871 | m = defaultdict(list) 1872 | for i, e in enumerate(ids): 1873 | m[e].append(i) 1874 | r = np.empty((len(m),) + a.shape[1:], dtype=a.dtype) 1875 | for i, idxs in m.items(): 1876 | r[i] = func(idxs) 1877 | return r 1878 | 1879 | 1880 | @Operation.factory(types=("SegmentSum", "UnsortedSegmentSum")) 1881 | def SegmentSum(a, ids, *args): 1882 | """ 1883 | Segmented sum op. 1884 | """ 1885 | func = lambda idxs: reduce(np.add, a[idxs]) 1886 | return seg_map(func, a, ids), 1887 | 1888 | 1889 | @Operation.factory 1890 | def SegmentProd(a, ids): 1891 | """ 1892 | Segmented prod op. 1893 | """ 1894 | func = lambda idxs: reduce(np.multiply, a[idxs]) 1895 | return seg_map(func, a, ids), 1896 | 1897 | 1898 | @Operation.factory 1899 | def SegmentMin(a, ids): 1900 | """ 1901 | Segmented min op. 1902 | """ 1903 | func = lambda idxs: np.amin(a[idxs], axis=0) 1904 | return seg_map(func, a, ids), 1905 | 1906 | 1907 | @Operation.factory 1908 | def SegmentMax(a, ids): 1909 | """ 1910 | Segmented max op. 1911 | """ 1912 | func = lambda idxs: np.amax(a[idxs], axis=0) 1913 | return seg_map(func, a, ids), 1914 | 1915 | 1916 | @Operation.factory 1917 | def SegmentMean(a, ids): 1918 | """ 1919 | Segmented mean op. 1920 | """ 1921 | func = lambda idxs: np.mean(a[idxs], axis=0) 1922 | return seg_map(func, a, ids), 1923 | 1924 | 1925 | @Operation.factory 1926 | def SparseSegmentSum(a, idxs, ids): 1927 | """ 1928 | Sparse segmented sum op. 1929 | """ 1930 | return SegmentSum.func(a[idxs], ids) 1931 | 1932 | 1933 | @Operation.factory 1934 | def SparseSegmentMean(a, idxs, ids): 1935 | """ 1936 | Sparse segmented mean op. 1937 | """ 1938 | return SegmentMean.func(a[idxs], ids) 1939 | 1940 | 1941 | @Operation.factory 1942 | def SparseSegmentSqrtN(a, idxs, ids): 1943 | """ 1944 | Sparse segmented sum / sqrt(n=len(idxs)) op. 1945 | """ 1946 | func = lambda _idxs: np.divide(reduce(np.add, a[idxs][_idxs]), np.math.sqrt(len(_idxs))) 1947 | return seg_map(func, a, ids), 1948 | 1949 | 1950 | # 1951 | # sequence comparison and indexing 1952 | # 1953 | 1954 | @Operation.factory 1955 | def ArgMin(a, dim): 1956 | """ 1957 | Argmin op. 1958 | """ 1959 | return np.argmin(a, axis=dim), 1960 | 1961 | 1962 | @Operation.factory 1963 | def ArgMax(a, dim): 1964 | """ 1965 | Argmax op. 1966 | """ 1967 | return np.argmax(a, axis=dim), 1968 | 1969 | 1970 | @Operation.factory 1971 | def ListDiff(a, b): 1972 | """ 1973 | List diff op. 1974 | """ 1975 | d = np.setdiff1d(a, b) 1976 | return d, np.searchsorted(a, d).astype(np.int32) 1977 | 1978 | 1979 | @Operation.factory 1980 | def Where(a): 1981 | """ 1982 | Boolean where op. 1983 | """ 1984 | return np.argwhere(a), 1985 | 1986 | 1987 | @Operation.factory(attrs=("out_idx",)) 1988 | def Unique(a, t): 1989 | """ 1990 | Unique op. 1991 | """ 1992 | _, idxs, inv = np.unique(a, return_index=True, return_inverse=True) 1993 | return np.copy(a)[np.sort(idxs)], idxs[inv].astype(dtype_map[t]) 1994 | 1995 | 1996 | @Operation.factory 1997 | def InvertPermutation(a): 1998 | """ 1999 | Invert perm op. 2000 | """ 2001 | return np.argsort(a).astype(np.int32), 2002 | 2003 | 2004 | # 2005 | # control flow ops 2006 | # 2007 | 2008 | @Operation.factory 2009 | def Identity(a): 2010 | """ 2011 | Identity op. 2012 | """ 2013 | return np.copy(a), 2014 | 2015 | 2016 | # 2017 | # NN activation ops 2018 | # 2019 | 2020 | @Operation.factory 2021 | def Relu(a): 2022 | """ 2023 | Relu op. 2024 | """ 2025 | return np.maximum(a, 0), 2026 | 2027 | 2028 | @Operation.factory 2029 | def Relu6(a): 2030 | """ 2031 | Relu6 op. 2032 | """ 2033 | return np.clip(a, 0, 6), 2034 | 2035 | 2036 | @Operation.factory 2037 | def Elu(a): 2038 | """ 2039 | Elu op. 2040 | """ 2041 | return np.where(a < 0, np.subtract(np.exp(a), 1), a), 2042 | 2043 | 2044 | @Operation.factory 2045 | def Softplus(a): 2046 | """ 2047 | Softplus op. 2048 | """ 2049 | return np.log(np.add(np.exp(a), 1)), 2050 | 2051 | 2052 | @Operation.factory 2053 | def Softsign(a): 2054 | """ 2055 | Softsign op. 2056 | """ 2057 | return np.divide(a, np.add(np.abs(a), 1)), 2058 | 2059 | 2060 | @Operation.factory 2061 | def Sigmoid(a): 2062 | """ 2063 | Sogmoid (logistic) op. 2064 | """ 2065 | return np.reciprocal(np.add(1, np.exp(-a))), 2066 | 2067 | 2068 | @Operation.factory 2069 | def Tanh(a): 2070 | """ 2071 | Tanh op. 2072 | """ 2073 | return np.tanh(a), 2074 | 2075 | 2076 | @Operation.factory 2077 | def Softmax(a): 2078 | """ 2079 | Softmax op. 2080 | """ 2081 | e = np.exp(a) 2082 | return np.divide(e, np.sum(e, axis=-1, keepdims=True)), 2083 | 2084 | 2085 | # 2086 | # NN convolution ops 2087 | # 2088 | 2089 | def _prepare_patches(a, f, strides, padding, padmode): 2090 | v = np.array((0,) + (a.ndim - 2) * (1,) + (0,)) 2091 | w = np.array((0,) + f.shape[:-2] + (0,)) 2092 | 2093 | src = a 2094 | if padding == "SAME": 2095 | out_shape = np.ceil(np.array(a.shape).astype(np.float) / strides).astype(np.int) 2096 | pad = ((out_shape - v) * strides + w - a.shape).clip(min=0) 2097 | pad_start = pad // 2 2098 | if np.any(pad): 2099 | src = np.pad(a, list(zip(pad_start, pad - pad_start)), padmode) 2100 | else: # VALID 2101 | out_shape = np.ceil((np.array(a.shape).astype(np.float) - w + v) \ 2102 | / strides).astype(np.int) 2103 | pad = np.zeros(len(a.shape)) 2104 | 2105 | return out_shape, src 2106 | 2107 | 2108 | def _conv_patches(a, f, strides, padding): 2109 | out_shape, src = _prepare_patches(a, f, strides, padding, "constant") 2110 | 2111 | patches = np.empty(tuple(out_shape)[:-1] + f.shape).astype(a.dtype) 2112 | 2113 | s = (slice(None),) 2114 | e = (Ellipsis,) 2115 | en = (Ellipsis, np.newaxis) 2116 | for coord in np.ndindex(*out_shape[1:-1]): 2117 | pos = np.array(strides[1:-1]) * coord 2118 | patches[s + coord + e] = \ 2119 | src[s + tuple(slice(*tpl) for tpl in zip(pos, pos + f.shape[:-2]))][en] * f 2120 | 2121 | return patches 2122 | 2123 | 2124 | @Operation.factory(attrs=("strides", "padding", "data_format")) 2125 | def Conv1D(a, f, strides, padding, data_format): 2126 | """ 2127 | 1D conv op. 2128 | """ 2129 | if data_format.decode("ascii") == "NCHW": 2130 | a = np.rollaxis(a, 1, -1), 2131 | 2132 | patches = _conv_patches(a, f, 3 * [strides], padding.decode("ascii")) 2133 | conv = np.sum(patches, axis=tuple(range(-f.ndim, -1))) 2134 | 2135 | if data_format.decode("ascii") == "NCHW": 2136 | conv = np.rollaxis(conv, -1, 1) 2137 | 2138 | return conv, 2139 | 2140 | 2141 | @Operation.factory(attrs=("strides", "padding", "data_format")) 2142 | def Conv2D(a, f, strides, padding, data_format): 2143 | """ 2144 | 2D conv op. 2145 | """ 2146 | if data_format.decode("ascii") == "NCHW": 2147 | a = np.rollaxis(a, 1, -1), 2148 | 2149 | patches = _conv_patches(a, f, strides, padding.decode("ascii")) 2150 | conv = np.sum(patches, axis=tuple(range(-f.ndim, -1))) 2151 | 2152 | if data_format.decode("ascii") == "NCHW": 2153 | conv = np.rollaxis(conv, -1, 1) 2154 | 2155 | return conv, 2156 | 2157 | 2158 | @Operation.factory(attrs=("strides", "padding")) 2159 | def Conv3D(a, f, strides, padding): 2160 | """ 2161 | 3D conv op. 2162 | """ 2163 | patches = _conv_patches(a, f, strides, padding.decode("ascii")) 2164 | return np.sum(patches, axis=tuple(range(-f.ndim, -1))), 2165 | 2166 | 2167 | # 2168 | # NN pooling ops 2169 | # 2170 | 2171 | def _pool_patches(a, k, strides, padding): 2172 | f = np.ones(k[1:] + [a.shape[-1]]) 2173 | 2174 | out_shape, src = _prepare_patches(a, f, strides, padding, "edge") 2175 | 2176 | patches = np.empty(tuple(out_shape) + f.shape).astype(a.dtype) 2177 | 2178 | s = (slice(None),) 2179 | e = (Ellipsis,) 2180 | en = (Ellipsis, np.newaxis) 2181 | for coord in np.ndindex(*out_shape[1:]): 2182 | pos = np.array(strides[1:]) * coord 2183 | patches[s + coord + e] = \ 2184 | src[s + tuple(slice(*tpl) for tpl in zip(pos, pos + f.shape[:-1]))][en] * f 2185 | 2186 | return patches 2187 | 2188 | 2189 | @Operation.factory(attrs=("ksize", "strides", "padding", "data_format")) 2190 | def AvgPool(a, k, strides, padding, data_format): 2191 | """ 2192 | Average pooling op. 2193 | """ 2194 | if data_format.decode("ascii") == "NCHW": 2195 | a = np.rollaxis(a, 1, -1), 2196 | 2197 | patches = _pool_patches(a, k, strides, padding.decode("ascii")) 2198 | pool = np.average(patches, axis=tuple(range(-len(k), 0))) 2199 | 2200 | if data_format.decode("ascii") == "NCHW": 2201 | pool = np.rollaxis(pool, -1, 1) 2202 | 2203 | return pool, 2204 | 2205 | 2206 | @Operation.factory(attrs=("ksize", "strides", "padding", "data_format")) 2207 | def MaxPool(a, k, strides, padding, data_format): 2208 | """ 2209 | Maximum pooling op. 2210 | """ 2211 | if data_format.decode("ascii") == "NCHW": 2212 | a = np.rollaxis(a, 1, -1), 2213 | 2214 | patches = _pool_patches(a, k, strides, padding.decode("ascii")) 2215 | pool = np.amax(patches, axis=tuple(range(-len(k), 0))) 2216 | 2217 | if data_format.decode("ascii") == "NCHW": 2218 | pool = np.rollaxis(pool, -1, 1) 2219 | 2220 | return pool, 2221 | 2222 | 2223 | @Operation.factory(attrs=("ksize", "strides", "padding")) 2224 | def AvgPool3D(a, k, strides, padding): 2225 | """ 2226 | Average 3D pooling op. 2227 | """ 2228 | patches = _pool_patches(a, k, strides, padding.decode("ascii")) 2229 | return np.average(patches, axis=tuple(range(-len(k), 0))), 2230 | 2231 | 2232 | @Operation.factory(attrs=("ksize", "strides", "padding")) 2233 | def MaxPool3D(a, k, strides, padding): 2234 | """ 2235 | Maximum 3D pooling op. 2236 | """ 2237 | patches = _pool_patches(a, k, strides, padding.decode("ascii")) 2238 | return np.amax(patches, axis=tuple(range(-len(k), 0))), 2239 | --------------------------------------------------------------------------------