├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── docs
├── Makefile
├── _static
│ └── style.css
├── _templates
│ └── layout.html
├── conf.py
└── index.rst
├── logo.png
├── requirements.txt
├── scripts
└── dtype_map.py
├── setup.py
├── tests
├── __init__.py
├── base.py
├── core.py
├── models
│ ├── __init__.py
│ ├── simple.py
│ └── simple2.py
├── ops.py
└── perf
│ ├── create_plots.py
│ ├── measure_runtimes.py
│ └── simple.py
└── tfdeploy.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.sublime-project
2 | *.sublime-workspace
3 | *.pyc
4 | *.log
5 | *.DS_Store
6 | *.pkl
7 | dist
8 | MANIFEST
9 | docs/_build
10 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | sudo: required
2 |
3 | dist: trusty
4 |
5 | language: python
6 |
7 | services:
8 | - docker
9 |
10 | matrix:
11 | include:
12 | - env: TF_TAG=1.0.1 TD_TEST_SCIPY=0 TD_TEST_GPU=0
13 | - env: TF_TAG=1.0.1 TD_TEST_SCIPY=1 TD_TEST_GPU=0
14 | - env: TF_TAG=1.0.1-py3 TD_TEST_SCIPY=0 TD_TEST_GPU=0
15 |
16 | install:
17 | - docker pull tensorflow/tensorflow:$TF_TAG
18 |
19 | script: docker run -t --rm -v `pwd`:/root/tfdeploy -w /root/tfdeploy -e TD_TEST_SCIPY=$TD_TEST_SCIPY -e TD_TEST_GPU=$TD_TEST_GPU tensorflow/tensorflow:$TF_TAG python -m unittest tests
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2016-2025, Marcel Rieger
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | * Neither the name of the copyright holder nor the names of its contributors
15 | may be used to endorse or promote products derived from this software without
16 | specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | [](https://travis-ci.org/riga/tfdeploy) [](http://tfdeploy.readthedocs.org/en/latest/?badge=latest) [](https://badge.fury.io/py/tfdeploy)
4 |
5 | Deploy [tensorflow](https://www.tensorflow.org) graphs for *fast* evaluation and export to *tensorflow-less* environments running [NumPy](http://www.numpy.org).
6 |
7 | > [!NOTE]
8 | > This project started as a personal playground to get an in-depth understanding of TensorFlow's operations and kernels.
9 | > Up to a certain version, the NumPy based operations in tfdeploy provided full feature parity, but it is obvious that such a project cannot keep up with the vast development speed driven by TensorFlow devs and the open-source community.
10 | >
11 | > Therefore, tfdeploy is **no longer actively maintained**.
12 | > However, the code base remains active as an easy-to-read reference implementation for most of the kernels that constitute the heart of todays ML landscape.
13 |
14 |
15 | ##### Evaluation usage
16 |
17 | ```python
18 | import tfdeploy as td
19 | import numpy as np
20 |
21 | model = td.Model("/path/to/model.pkl")
22 | inp, outp = model.get("input", "output")
23 |
24 | batch = np.random.rand(10000, 784)
25 | result = outp.eval({inp: batch})
26 | ```
27 |
28 |
29 | ##### Installation and dependencies
30 |
31 | Via [pip](https://pypi.python.org/pypi/tfdeploy)
32 |
33 | ```bash
34 | pip install tfdeploy
35 | ```
36 |
37 | or by simply copying the file into your project.
38 |
39 | NumPy ≥ 1.10 should be installed on your system. [SciPy](http://www.scipy.org/) is optional. See [optimization](#optimization) for more info on optional packages.
40 |
41 | By design, TensorFlow is required when creating a model.
42 |
43 |
44 | ### Content
45 |
46 | - [Why?](#why)
47 | - [How?](#how)
48 | - [Convert your graph](#convert-your-graph)
49 | - [Load the model and evaluate](#load-the-model-and-evaluate)
50 | - [Write your own operation](#write-your-own-operation)
51 | - [Ensembles](#ensembles)
52 | - [Optimization](#optimization)
53 | - [Performance](#performance)
54 | - [Contributing](#contributing)
55 | - [Development](#development)
56 | - [Authors](#authors)
57 | - [License](#license)
58 |
59 | ## Why?
60 |
61 | Working with TensorFlow is awesome. Model definition and training is simple yet powerful, and the range of built-in features is just striking.
62 |
63 | Model deployment in environments that are not able to run TensorFlow, however, things can be difficult (**note** that tfdeploy was developed before TensorFlow Lite was a thing).
64 |
65 | To boil it down, tfdeploy
66 |
67 | - is lightweight. A single file with < 150 lines of core code. Just copy it to your project.
68 | - [faster](#performance) than using TensorFlow's `Tensor.eval`.
69 | - **does not need TensorFlow** during evaluation.
70 | - only depends on NumPy.
71 | - can load one or more models from a single file.
72 | - does not support GPUs (maybe [gnumpy](http://www.cs.toronto.edu/~tijmen/gnumpy.html) is worth a try here).
73 |
74 |
75 | ## How?
76 |
77 | The central class is `tfdeploy.Model`. The following two examples demonstrate how a model can be created from a TensorFlow graph, saved to and loaded from disk, and eventually evaluated.
78 |
79 | ##### Convert your graph
80 |
81 | ```python
82 | import tensorflow as tf
83 | import tfdeploy as td
84 |
85 | # setup tfdeploy (only when creating models)
86 | td.setup(tf)
87 |
88 | # build your graph
89 | sess = tf.Session()
90 |
91 | # use names for input and output layers
92 | x = tf.placeholder("float", shape=[None, 784], name="input")
93 | W = tf.Variable(tf.truncated_normal([784, 100], stddev=0.05))
94 | b = tf.Variable(tf.zeros([100]))
95 | y = tf.nn.softmax(tf.matmul(x, W) + b, name="output")
96 |
97 | sess.run(tf.global_variables_initializer())
98 |
99 | # ... training ...
100 |
101 | # create a tfdeploy model and save it to disk
102 | model = td.Model()
103 | model.add(y, sess) # y and all its ops and related tensors are added recursively
104 | model.save("model.pkl")
105 | ```
106 |
107 | ##### Load the model and evaluate
108 |
109 | ```python
110 | import numpy as np
111 | import tfdeploy as td
112 |
113 | model = td.Model("model.pkl")
114 |
115 | # shorthand to x and y
116 | x, y = model.get("input", "output")
117 |
118 | # evaluate
119 | batch = np.random.rand(10000, 784)
120 | result = y.eval({x: batch})
121 | ```
122 |
123 | ##### Write your own `Operation`
124 |
125 | tfdeploy supports most of the `Operation`'s [implemented in tensorflow](https://www.tensorflow.org/versions/master/api_docs/python/math_ops.html). However, if you miss one (in that case, submit a PR or an issue ;) ) or if you're using custom ops, you might want to extend tfdeploy by defining a new class op that inherits from `tfdeploy.Operation`:
126 |
127 | ```python
128 | import tensorflow as tf
129 | import tfdeploy as td
130 | import numpy as np
131 |
132 | # setup tfdeploy (only when creating models)
133 | td.setup(tf)
134 |
135 | # ... write you model here ...
136 |
137 | # let's assume your final tensor "y" relies on an op of type "InvertedSoftmax"
138 | # before creating the td.Model, you should add that op to tfdeploy
139 |
140 | class InvertedSoftmax(td.Operation):
141 | @staticmethod
142 | def func(a):
143 | e = np.exp(-a)
144 | # ops should return a tuple
145 | return np.divide(e, np.sum(e, axis=-1, keepdims=True)),
146 |
147 | # this is equivalent to
148 | # @td.Operation.factory
149 | # def InvertedSoftmax(a):
150 | # e = np.exp(-a)
151 | # return np.divide(e, np.sum(e, axis=-1, keepdims=True)),
152 |
153 | # now we're good to go
154 | model = td.Model()
155 | model.add(y, sess)
156 | model.save("model.pkl")
157 | ```
158 |
159 | When writing new ops, three things are important:
160 |
161 | - Try to avoid loops, prefer NumPy vectorization.
162 | - Return a tuple.
163 | - Don't change incoming tensors/arrays in-place, always work on and return copies.
164 |
165 |
166 | ## Ensembles
167 |
168 | tfdeploy provides a helper class to evaluate an ensemble of models: `Ensemble`. It can load multiple models, evaluate them and combine their output values using different methods.
169 |
170 | ```python
171 | # create the ensemble
172 | ensemble = td.Ensemble(["model1.pkl", "model2.pkl", ...], method=td.METHOD_MEAN)
173 |
174 | # get input and output tensors (which actually are TensorEnsemble instances)
175 | input, output = ensemble.get("input", "output")
176 |
177 | # evaluate the ensemble just like a normal model
178 | batch = ...
179 | value = output.eval({input: batch})
180 | ```
181 |
182 | The return value of `get()` is a `TensorEnsemble` istance. It is basically a wrapper around multiple tensors and should be used as keys in the `feed_dict` of the `eval()` call.
183 |
184 | You can choose between `METHOD_MEAN` (the default), `METHOD_MAX` and `METHOD_MIN`. If you want to use a custom ensembling method, use `METHOD_CUSTOM` and overwrite the static `func_custom()` method of the `TensorEnsemble` instance.
185 |
186 |
187 | ## Optimization
188 |
189 | Most ops are written using pure numpy. However, multiple implementations of the same op are allowed that may use additional third-party Python packages providing even faster functionality for some situations.
190 |
191 | For example, NumPy does not provide a vectorized *lgamma* function. Thus, the standard `tfdeploy.Lgamma` op uses `math.lgamma` that was previously vectorized using `numpy.vectorize`. For these situations, additional implementations of the same op are possible (the *lgamma* example is quite academic, but this definitely makes sense for more sophisticated ops like pooling). We can simply tell the op to use its SciPy implementation instead:
192 |
193 | ```python
194 | td.Lgamma.use_impl(td.IMPL_SCIPY)
195 | ```
196 |
197 | Currently, allowed implementation types are NumPy (`IMPL_NUMPY`, the default) and SciPy (`IMPL_SCIPY`).
198 |
199 |
200 | ##### Adding additional implementations
201 |
202 | Additional implementations can be added by setting the `impl` attribute of the op factory or by using the `add_impl` decorator of existing operations. The first registered implementation will be the default one.
203 |
204 | ```python
205 | # create the default lgamma op with numpy implementation
206 | lgamma_vec = np.vectorize(math.lgamma)
207 |
208 | @td.Operation.factory
209 | # equivalent to
210 | # @td.Operation.factory(impl=td.IMPL_NUMPY)
211 | def Lgamma(a):
212 | return lgamma_vec(a),
213 |
214 | # add a scipy-based implementation
215 | @Lgamma.add_impl(td.IMPL_SCIPY)
216 | def Lgamma(a):
217 | return sp.special.gammaln(a),
218 | ```
219 |
220 |
221 | ##### Auto-optimization
222 |
223 | If SciPy is available on your system, it is reasonable to use all ops in their SciPy implementation (if it exists, of course). This should be configured before you create any model from TensorFlow objects using the second argument of the `setup` function:
224 |
225 | ```python
226 | td.setup(tf, td.IMPL_SCIPY)
227 | ```
228 |
229 | Ops that do not implement `IMPL_SCIPY` stick with the NumPy version (`IMPL_NUMPY`).
230 |
231 |
232 | ## Performance
233 |
234 | tfdeploy is lightweight (1 file, < 150 lines of core code) and fast. Internal evaluation calls have only very few overhead and tensor operations use NumPy vectorization. The actual performance depends on the ops in your graph. While most of the TensorFlow ops have a numpy equivalent or can be constructed from NumPy functions, a few ops require additional Python-based loops (e.g. `BatchMatMul`). But in many cases (and for small to medium graphs) it's potentially faster than using TensorFlow's `Tensor.eval`.
235 |
236 | This is a comparison for a basic graph where all ops are vectorized (basically `Add`, `MatMul` and `Softmax`):
237 |
238 | ```bash
239 | > ipython -i tests/perf/simple.py
240 |
241 | In [1]: %timeit -n 100 test_tf()
242 | 100 loops, best of 3: 109 ms per loop
243 |
244 | In [2]: %timeit -n 100 test_td()
245 | 100 loops, best of 3: 60.5 ms per loop
246 | ```
247 |
248 | ## Contributing
249 |
250 | If you want to contribute with new ops and features, I'm happy to receive pull requests. Just make sure to add a new test case to `tests/core.py` or `tests/ops.py` and run them via:
251 |
252 | ```bash
253 | > python -m unittest tests
254 | ```
255 |
256 |
257 | ##### Test grid
258 |
259 | In general, tests should be run for different environments:
260 |
261 | | Variation | Values |
262 | | ------------------ | ------- |
263 | | tensorflow version | `1.0.1` |
264 | | python version | 2, 3 |
265 | | `TD_TEST_SCIPY` | 0, 1 |
266 | | `TD_TEST_GPU` | 0, 1 |
267 |
268 |
269 | ##### Docker
270 |
271 | For testing purposes, it is convenient to use docker. Fortunately, the official [tensorflow images](https://hub.docker.com/r/tensorflow/tensorflow/) contain all we need:
272 |
273 | ```bash
274 | git clone https://github.com/riga/tfdeploy.git
275 | cd tfdeploy
276 |
277 | docker run --rm -v `pwd`:/root/tfdeploy -w /root/tfdeploy -e "TD_TEST_SCIPY=1" tensorflow/tensorflow:1.0.1 python -m unittest tests
278 | ```
279 |
280 |
281 | ## Development
282 |
283 | - Source hosted at [GitHub](https://github.com/riga/tfdeploy)
284 | - Report issues, questions, feature requests on [GitHub Issues](https://github.com/riga/tfdeploy/issues)
285 |
286 |
287 | ## Authors
288 |
289 | - [Marcel R.](https://github.com/riga)
290 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | PAPER =
8 | BUILDDIR = _build
9 |
10 | # User-friendly check for sphinx-build
11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
13 | endif
14 |
15 | # Internal variables.
16 | PAPEROPT_a4 = -D latex_paper_size=a4
17 | PAPEROPT_letter = -D latex_paper_size=letter
18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
19 | # the i18n builder cannot share the environment and doctrees with the others
20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
21 |
22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext
23 |
24 | help:
25 | @echo "Please use \`make ' where is one of"
26 | @echo " html to make standalone HTML files"
27 | @echo " dirhtml to make HTML files named index.html in directories"
28 | @echo " singlehtml to make a single large HTML file"
29 | @echo " pickle to make pickle files"
30 | @echo " json to make JSON files"
31 | @echo " htmlhelp to make HTML files and a HTML help project"
32 | @echo " qthelp to make HTML files and a qthelp project"
33 | @echo " applehelp to make an Apple Help Book"
34 | @echo " devhelp to make HTML files and a Devhelp project"
35 | @echo " epub to make an epub"
36 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
37 | @echo " latexpdf to make LaTeX files and run them through pdflatex"
38 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
39 | @echo " text to make text files"
40 | @echo " man to make manual pages"
41 | @echo " texinfo to make Texinfo files"
42 | @echo " info to make Texinfo files and run them through makeinfo"
43 | @echo " gettext to make PO message catalogs"
44 | @echo " changes to make an overview of all changed/added/deprecated items"
45 | @echo " xml to make Docutils-native XML files"
46 | @echo " pseudoxml to make pseudoxml-XML files for display purposes"
47 | @echo " linkcheck to check all external links for integrity"
48 | @echo " doctest to run all doctests embedded in the documentation (if enabled)"
49 | @echo " coverage to run coverage check of the documentation (if enabled)"
50 |
51 | clean:
52 | rm -rf $(BUILDDIR)/*
53 |
54 | html:
55 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
56 | @echo
57 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
58 |
59 | dirhtml:
60 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
61 | @echo
62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
63 |
64 | singlehtml:
65 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
66 | @echo
67 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
68 |
69 | pickle:
70 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
71 | @echo
72 | @echo "Build finished; now you can process the pickle files."
73 |
74 | json:
75 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
76 | @echo
77 | @echo "Build finished; now you can process the JSON files."
78 |
79 | htmlhelp:
80 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
81 | @echo
82 | @echo "Build finished; now you can run HTML Help Workshop with the" \
83 | ".hhp project file in $(BUILDDIR)/htmlhelp."
84 |
85 | qthelp:
86 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
87 | @echo
88 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \
89 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:"
90 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/tfdeploy.qhcp"
91 | @echo "To view the help file:"
92 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/tfdeploy.qhc"
93 |
94 | applehelp:
95 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp
96 | @echo
97 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp."
98 | @echo "N.B. You won't be able to view it unless you put it in" \
99 | "~/Library/Documentation/Help or install it in your application" \
100 | "bundle."
101 |
102 | devhelp:
103 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
104 | @echo
105 | @echo "Build finished."
106 | @echo "To view the help file:"
107 | @echo "# mkdir -p $$HOME/.local/share/devhelp/tfdeploy"
108 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/tfdeploy"
109 | @echo "# devhelp"
110 |
111 | epub:
112 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
113 | @echo
114 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub."
115 |
116 | latex:
117 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
118 | @echo
119 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
120 | @echo "Run \`make' in that directory to run these through (pdf)latex" \
121 | "(use \`make latexpdf' here to do that automatically)."
122 |
123 | latexpdf:
124 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
125 | @echo "Running LaTeX files through pdflatex..."
126 | $(MAKE) -C $(BUILDDIR)/latex all-pdf
127 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
128 |
129 | latexpdfja:
130 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
131 | @echo "Running LaTeX files through platex and dvipdfmx..."
132 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
133 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
134 |
135 | text:
136 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
137 | @echo
138 | @echo "Build finished. The text files are in $(BUILDDIR)/text."
139 |
140 | man:
141 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
142 | @echo
143 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man."
144 |
145 | texinfo:
146 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
147 | @echo
148 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
149 | @echo "Run \`make' in that directory to run these through makeinfo" \
150 | "(use \`make info' here to do that automatically)."
151 |
152 | info:
153 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
154 | @echo "Running Texinfo files through makeinfo..."
155 | make -C $(BUILDDIR)/texinfo info
156 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
157 |
158 | gettext:
159 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
160 | @echo
161 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
162 |
163 | changes:
164 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
165 | @echo
166 | @echo "The overview file is in $(BUILDDIR)/changes."
167 |
168 | linkcheck:
169 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
170 | @echo
171 | @echo "Link check complete; look for any errors in the above output " \
172 | "or in $(BUILDDIR)/linkcheck/output.txt."
173 |
174 | doctest:
175 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
176 | @echo "Testing of doctests in the sources finished, look at the " \
177 | "results in $(BUILDDIR)/doctest/output.txt."
178 |
179 | coverage:
180 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage
181 | @echo "Testing of coverage in the sources finished, look at the " \
182 | "results in $(BUILDDIR)/coverage/python.txt."
183 |
184 | xml:
185 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
186 | @echo
187 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml."
188 |
189 | pseudoxml:
190 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
191 | @echo
192 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
193 |
--------------------------------------------------------------------------------
/docs/_static/style.css:
--------------------------------------------------------------------------------
1 | h1.logo,
2 | div.sphinxsidebarwrapper > h3 {
3 | display: none;
4 | }
5 |
--------------------------------------------------------------------------------
/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {# layout.html #}
2 | {% extends "!layout.html" %}
3 |
4 | {% set css_files = css_files + ['_static/style.css'] %}
5 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import sys
5 | import os
6 | import shlex
7 |
8 |
9 | sys.path.insert(0, os.path.abspath(".."))
10 | import tfdeploy as td
11 |
12 |
13 | project = "tfdeploy"
14 | author = td.__author__
15 | copyright = td.__copyright__
16 | version = td.__version__
17 | release = td.__version__
18 |
19 |
20 | templates_path = ["_templates"]
21 | html_static_path = ["_static"]
22 | master_doc = "index"
23 | source_suffix = ".rst"
24 |
25 |
26 | exclude_patterns = []
27 | pygments_style = "sphinx"
28 | html_logo = "../logo.png"
29 | html_theme = "alabaster"
30 | html_sidebars = {"**": [
31 | "about.html",
32 | "localtoc.html",
33 | "searchbox.html"]
34 | }
35 | html_theme_options = {
36 | "github_user": "riga",
37 | "github_repo": "tfdeploy",
38 | "travis_button": True
39 | }
40 |
41 |
42 | extensions = [
43 | "sphinx.ext.autodoc"
44 | ]
45 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | tfdeploy
2 | ========
3 |
4 | .. centered:: This page contains only API docs. For more info, visit `tfdeploy on GitHub `_.
5 |
6 |
7 | .. toctree::
8 | :maxdepth: 2
9 |
10 |
11 | .. automodule:: tfdeploy
12 |
13 |
14 | Classes
15 | ^^^^^^^
16 |
17 | ``Model``
18 | ---------
19 |
20 | .. autoclass:: Model
21 | :member-order: bysource
22 | :members:
23 |
24 |
25 | ``Tensor``
26 | ----------
27 |
28 | .. autoclass:: Tensor
29 | :member-order: bysource
30 | :members:
31 |
32 |
33 | ``Operation``
34 | -------------
35 |
36 | .. autoclass:: Operation
37 | :member-order: bysource
38 | :members:
39 |
40 |
41 | ``Ensemble``
42 | ------------
43 |
44 | .. autoclass:: Ensemble
45 | :member-order: bysource
46 | :members:
47 |
48 |
49 | ``TensorEnsemble``
50 | ------------------
51 |
52 | .. autoclass:: TensorEnsemble
53 | :member-order: bysource
54 | :members:
55 |
56 |
57 | Functions
58 | ^^^^^^^^^
59 |
60 | ``setup``
61 | ---------
62 |
63 | .. autofunction:: setup
64 |
65 |
66 | ``reset``
67 | ---------
68 |
69 | .. autofunction:: reset
70 |
71 |
72 | ``optimize``
73 | ------------
74 |
75 | .. autofunction:: optimize
76 |
77 |
78 | ``print_tensor``
79 | ----------------
80 |
81 | .. autofunction:: print_tensor
82 |
83 |
84 | ``print_op``
85 | ------------
86 |
87 | .. autofunction:: print_op
88 |
89 |
90 | ``print_tf_tensor``
91 | -------------------
92 |
93 | .. autofunction:: print_tf_tensor
94 |
95 |
96 | ``print_tf_op``
97 | ---------------
98 |
99 | .. autofunction:: print_tf_op
100 |
101 |
102 | Other Attributes
103 | ^^^^^^^^^^^^^^^^
104 |
105 | .. py:attribute:: IMPL_NUMPY
106 |
107 | Implementation type for ops that use numpy (the default).
108 |
109 | .. py:attribute:: IMPL_SCIPY
110 |
111 | Implementation type for ops that use scipy.
112 |
113 | .. py:attribute:: HAS_SCIPY
114 |
115 | A flag that is *True* when scipy is available on your system.
116 |
117 |
118 | Exceptions
119 | ^^^^^^^^^^
120 |
121 | ``UnknownOperationException``
122 | -----------------------------
123 |
124 | .. autoexception:: UnknownOperationException
125 |
126 |
127 | ``OperationMismatchException``
128 | ------------------------------
129 |
130 | .. autoexception:: OperationMismatchException
131 |
132 |
133 | ``InvalidImplementationException``
134 | ----------------------------------
135 |
136 | .. autoexception:: InvalidImplementationException
137 |
138 |
139 | ``UnknownImplementationException``
140 | ----------------------------------
141 |
142 | .. autoexception:: UnknownImplementationException
143 |
144 |
145 | ``UnknownEnsembleMethodException``
146 | ----------------------------------
147 |
148 | .. autoexception:: UnknownEnsembleMethodException
149 |
150 |
151 | ``EnsembleMismatchException``
152 | -----------------------------
153 |
154 | .. autoexception:: EnsembleMismatchException
155 |
156 |
157 | ``ScipyOperationException``
158 | ---------------------------
159 |
160 | .. autoexception:: ScipyOperationException
161 |
--------------------------------------------------------------------------------
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/riga/tfdeploy/c519c6342997afd313b4e11e8417ce70eef7d1dd/logo.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 |
--------------------------------------------------------------------------------
/scripts/dtype_map.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Script that prints a mapping of tensorflow dtype nums to and numpy dtypes, e.g.:
5 |
6 | > python dtype_map.py
7 | dtype_map = {
8 | 1: np.float32,
9 | 2: np.float64,
10 | 3: np.int32,
11 | 4: np.uint8,
12 | ...
13 | }
14 | """
15 |
16 |
17 | import tensorflow as tf
18 |
19 |
20 | # create the mapping
21 | dtype_map = {}
22 |
23 | # fill it
24 | types_pb2 = tf.core.framework.types_pb2
25 | for attr in dir(types_pb2):
26 | if attr.startswith("DT_"):
27 | tf_type_enum = getattr(types_pb2, attr)
28 | try:
29 | dtype_map[tf_type_enum] = "np." + tf.as_dtype(tf_type_enum).as_numpy_dtype.__name__
30 | except:
31 | pass
32 |
33 | # print dict-like code
34 | dtype_map = "\n".join(" %s: %s," % tpl for tpl in dtype_map.items())
35 | print("{\n" + dtype_map[:-1] + "\n}")
36 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import os
5 | from subprocess import Popen, PIPE
6 | from distutils.core import setup
7 | import tfdeploy as td
8 |
9 |
10 | readme = os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md")
11 | if os.path.isfile(readme):
12 | cmd = "pandoc --from=markdown --to=rst " + readme
13 | p = Popen(cmd, stdout=PIPE, stderr=PIPE, shell=True, executable="/bin/bash")
14 | out, err = p.communicate()
15 | if p.returncode != 0:
16 | print("pandoc conversion failed: " + err)
17 | long_description = out
18 | else:
19 | long_description = ""
20 |
21 | keywords = [
22 | "tensorflow", "deploy", "export", "dump", "numpy", "model", "predict", "evaluate", "function",
23 | "method"
24 | ]
25 |
26 | classifiers = [
27 | "Programming Language :: Python",
28 | "Programming Language :: Python :: 2",
29 | "Programming Language :: Python :: 3",
30 | "Development Status :: 4 - Beta",
31 | "Operating System :: OS Independent",
32 | "License :: OSI Approved :: MIT License",
33 | "Intended Audience :: Developers",
34 | "Intended Audience :: Science/Research",
35 | "Intended Audience :: Information Technology",
36 | "Topic :: Scientific/Engineering :: Artificial Intelligence"
37 | ]
38 |
39 |
40 | setup(
41 | name = td.__name__,
42 | version = td.__version__,
43 | author = td.__author__,
44 | description = td.__doc__.strip(),
45 | license = td.__license__,
46 | url = td.__contact__,
47 | py_modules = [td.__name__],
48 | keywords = keywords,
49 | classifiers = classifiers,
50 | long_description = long_description or td.__doc__.strip()
51 | )
52 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # some logs
4 | import os
5 | import sys
6 | import numpy as np
7 | import tensorflow as tf
8 | print(80 * "-")
9 | print("python : " + sys.version.split(" ")[0])
10 | print("tensorflow: " + tf.__version__)
11 | print("numpy : " + np.version.version)
12 | try:
13 | import scipy as sp
14 | spv = sp.version.version
15 | except:
16 | spv = "NONE"
17 | print("scipy : " + spv)
18 | envkeys = [key for key in os.environ.keys() if key.startswith("TD_")]
19 | if envkeys:
20 | print("-")
21 | maxlen = max(len(key) for key in envkeys)
22 | for key in envkeys:
23 | print(key + (maxlen - len(key)) * " " + ": " + os.environ[key])
24 | print(80 * "-")
25 |
26 |
27 | # import all tests
28 | from .core import *
29 | from .ops import *
30 |
--------------------------------------------------------------------------------
/tests/base.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import os
5 | import sys
6 | import unittest
7 |
8 | import tensorflow as tf
9 |
10 |
11 | # adjust the path to import tfdeploy
12 | base = os.path.normpath(os.path.join(os.path.abspath(__file__), "../.."))
13 | sys.path.append(base)
14 | import tfdeploy as td
15 |
16 | # setup
17 | td.setup(tf)
18 |
19 |
20 | class TestCase(unittest.TestCase):
21 |
22 | def __init__(self, *args, **kwargs):
23 | super(TestCase, self).__init__(*args, **kwargs)
24 |
25 | self._cache = {}
26 |
27 | def get(self, model, *attrs):
28 | result = tuple()
29 | for attr in attrs:
30 | key = (model, attr)
31 | if key not in self._cache:
32 | tmp = __import__("tests.models." + model, globals(), locals(), [attr])
33 | self._cache[key] = getattr(tmp, attr)
34 | result += (self._cache[key],)
35 | return result if len(result) > 1 else result[0]
36 |
--------------------------------------------------------------------------------
/tests/core.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import numpy as np
5 | from .base import TestCase, td
6 |
7 |
8 | __all__ = ["CoreTestCase"]
9 |
10 |
11 | class CoreTestCase(TestCase):
12 |
13 | def __init__(self, *args, **kwargs):
14 | super(CoreTestCase, self).__init__(*args, **kwargs)
15 |
16 | self.simple_model = td.Model()
17 | y, sess = self.get("simple", "y", "sess")
18 | self.simple_model.add(y, tf_sess=sess)
19 |
20 | def test_tensors(self):
21 | m = self.simple_model
22 |
23 | # model has one root tensor ...
24 | self.assertEqual(len(m.roots), 1)
25 |
26 | # ... which is named "output" and can be retrieved via get
27 | outp = m.get("output")
28 | self.assertIn(outp, m.roots.values())
29 |
30 | # the input tensor is named "input"
31 | self.assertIsNotNone(m.get("input"))
32 |
33 | def test_ops(self):
34 | m = self.simple_model
35 | op = m.get("output").op
36 |
37 | # the root tensor operator is a softmax op ...
38 | self.assertIsInstance(op, td.Softmax)
39 |
40 | # ... and has one input ...
41 | self.assertEqual(len(op.inputs), 1)
42 |
43 | # ... whose op is an add op
44 | self.assertIsInstance(op.inputs[0].op, td.Add)
45 |
46 | def test_eval(self):
47 | m = self.simple_model
48 | inp, outp, kp = m.get("input", "output", "keep_prob")
49 |
50 | # create an input batch
51 | examples = np.random.rand(1000, 10).astype("float32")
52 |
53 | # first, eval using tf
54 | x, y, keep_prob, sess = self.get("simple", "x", "y", "keep_prob", "sess")
55 | rtf = y.eval(session=sess, feed_dict={x: examples, keep_prob: 1.0})
56 |
57 | # then, eval using td
58 | rtd = outp.eval({inp: examples, kp: 1.0})
59 |
60 | # no element in the diff array should be larger than 1e-7
61 | maxdiff = np.max(np.abs(rtf - rtd))
62 | self.assertLess(maxdiff, 1e-7)
63 |
64 | def test_ensemble_eval(self):
65 | simple_model2 = td.Model()
66 | y2, sess2 = self.get("simple2", "y", "sess")
67 | simple_model2.add(y2, tf_sess=sess2)
68 |
69 | simple_model2.get("input_1").name = "input:0"
70 | simple_model2.get("output_1").name = "output:0"
71 | simple_model2.get("keep_prob_1").name = "keep_prob:0"
72 |
73 | simple_ensemble = td.Ensemble()
74 | simple_ensemble.models = [self.simple_model, simple_model2]
75 |
76 | inp, outp, kp = simple_ensemble.get("input", "output", "keep_prob")
77 |
78 | # create an input batch
79 | examples = np.random.rand(1000, 10).astype("float32")
80 |
81 | # eval both models manually and build the mean
82 | x1, y1, keep_prob1 = self.simple_model.get("input", "output", "keep_prob")
83 | r1 = y1.eval({x1: examples, keep_prob1: 1.0})
84 | x2, y2, keep_prob2 = simple_model2.get("input", "output", "keep_prob")
85 | r2 = y2.eval({x2: examples, keep_prob2: 1.0})
86 | rm = np.add(r1, r2) / 2.
87 |
88 | # then, eval the ensemble
89 | re = outp.eval({inp: examples, kp: 1.0})
90 |
91 | # no element in the diff array should be larger than 1e-7
92 | maxdiff = np.max(np.abs(re - rm))
93 | self.assertLess(maxdiff, 1e-7)
94 |
--------------------------------------------------------------------------------
/tests/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/tests/models/simple.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import tensorflow as tf
5 | import tfdeploy as td
6 |
7 |
8 | sess = tf.Session()
9 |
10 | x = tf.placeholder(tf.float32, shape=[None, 10], name="input")
11 | keep_prob = tf.placeholder(tf.float32, name="keep_prob")
12 |
13 | W = tf.Variable(tf.truncated_normal([10, 5], stddev=0.05))
14 | b = tf.Variable(tf.zeros([5]))
15 |
16 | W_drop = tf.nn.dropout(W, keep_prob)
17 |
18 | y = tf.nn.softmax(tf.matmul(x, W_drop) + b, name="output")
19 |
20 | if td._tf_version[:3] < (0, 12, 0):
21 | sess.run(tf.initialize_all_variables())
22 | else:
23 | sess.run(tf.global_variables_initializer())
24 |
--------------------------------------------------------------------------------
/tests/models/simple2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import tensorflow as tf
5 | import tfdeploy as td
6 |
7 |
8 | sess = tf.Session()
9 |
10 | x = tf.placeholder(tf.float32, shape=[None, 10], name="input")
11 | keep_prob = tf.placeholder(tf.float32, name="keep_prob")
12 |
13 | W = tf.Variable(tf.truncated_normal([10, 5], stddev=0.05))
14 | b = tf.Variable(tf.zeros([5]))
15 |
16 | W_drop = tf.nn.dropout(W, keep_prob)
17 |
18 | y = tf.nn.softmax(tf.matmul(x, W_drop) + b, name="output")
19 |
20 | if td._tf_version[:3] < (0, 12, 0):
21 | sess.run(tf.initialize_all_variables())
22 | else:
23 | sess.run(tf.global_variables_initializer())
24 |
--------------------------------------------------------------------------------
/tests/ops.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import os
5 | import numpy as np
6 | from .base import TestCase, td
7 | import tensorflow as tf
8 | from tensorflow.python.framework import device
9 |
10 |
11 | __all__ = ["OpsTestCase"]
12 |
13 |
14 | # get device from env
15 | CPU, GPU = range(2)
16 | DEVICE = CPU
17 | if os.environ.get("TD_TEST_GPU", "").lower() in ("1", "yes", "true"):
18 | DEVICE = GPU
19 | DEVICE_ID = "/%s:0" % ["cpu", "gpu"][DEVICE]
20 |
21 | # setup td
22 | td.setup(tf)
23 |
24 | # optimize for scipy depending on env
25 | if os.environ.get("TD_TEST_SCIPY", "").lower() in ("1", "yes", "true"):
26 | td.optimize(td.IMPL_SCIPY)
27 |
28 |
29 | class OpsTestCase(TestCase):
30 |
31 | def __init__(self, *args, **kwargs):
32 | super(OpsTestCase, self).__init__(*args, **kwargs)
33 |
34 | # add the device to the "_device_function_stack" of the default graph
35 | dev = device.merge_device(DEVICE_ID)
36 | tf.get_default_graph()._device_function_stack.append(dev)
37 |
38 | # create a tf session
39 | self.sess = tf.Session()
40 |
41 | self.ndigits = 7
42 |
43 | def check(self, t, comp=None, ndigits=None, stats=False, abs=False, debug=False):
44 | if td._tf_version[:3] < (0, 12, 0):
45 | self.sess.run(tf.initialize_all_variables())
46 | else:
47 | self.sess.run(tf.global_variables_initializer())
48 |
49 | if not isinstance(t, tuple):
50 | t = (t,)
51 |
52 | for _t in t:
53 | rtf = _t.eval(session=self.sess)
54 | rtd = td.Tensor(_t, self.sess).eval()
55 |
56 | if debug:
57 | import pdb; pdb.set_trace()
58 |
59 | if ndigits is None:
60 | ndigits = self.ndigits
61 |
62 | if hasattr(comp, "__call__"):
63 | return comp(rtf, rtd)
64 |
65 | if isinstance(rtf, np.ndarray):
66 | self.assertEqual(rtf.dtype, rtd.dtype)
67 | if abs:
68 | rtf = np.abs(rtf)
69 | rtd = np.abs(rtd)
70 | if not stats:
71 | self.assertTrue(np.allclose(rtf, rtd, atol=0.1**ndigits))
72 | else:
73 | self.assertEqual(round(rtf.sum(), ndigits), round(rtd.sum(), ndigits))
74 | self.assertEqual(round(rtf.mean(), ndigits), round(rtd.mean(), ndigits))
75 | elif isinstance(rtf, float):
76 | self.assertEqual(round(rtf, ndigits), round(rtd, ndigits))
77 | else:
78 | self.assertEqual(rtf, rtd)
79 |
80 | def random(self, *shapes, **kwargs):
81 | if all(isinstance(i, int) for i in shapes):
82 | if kwargs.get("complex", False):
83 | return (self.random(*shapes) + 1j * self.random(*shapes)).astype(np.complex64)
84 | else:
85 | return np.random.rand(*shapes)
86 | else:
87 | return tuple(self.random(*shape) for shape in shapes)
88 |
89 | def test_ops_have_tests(self):
90 | tests = [attr for attr in dir(self) if attr.startswith("test_")]
91 | for type in td.OperationRegister.classes:
92 | self.assertIn("test_" + type, tests)
93 |
94 |
95 | #
96 | # sequences
97 | #
98 |
99 | def test_LinSpace(self):
100 | t = tf.linspace(0., 10., 15)
101 | self.check(t)
102 |
103 | def test_Range(self):
104 | t = tf.range(1, 10, 2)
105 | self.check(t)
106 |
107 |
108 | #
109 | # random tensors
110 | #
111 |
112 | def test_RandomStandardNormal(self):
113 | t = tf.random_normal((40, 30), dtype="float32")
114 | # compare only dtype
115 | def comp(rtf, rtd):
116 | self.assertEqual(rtf.dtype, rtd.dtype)
117 | self.check(t, comp=comp)
118 |
119 | def test_TruncatedNormal(self):
120 | t = tf.truncated_normal((40, 300), dtype="float32")
121 | # compare dtype and 2-sigma truncation
122 | def comp(rtf, rtd):
123 | self.assertEqual(rtf.dtype, rtd.dtype)
124 | self.assertLessEqual(np.max(np.abs(rtd)), 2)
125 | self.check(t, comp=comp)
126 |
127 | def test_RandomUniform(self):
128 | t = tf.random_uniform((50, 80), -2, 3, dtype="float32")
129 | # compare only min, max and dtype
130 | def comp(rtf, rtd):
131 | self.assertLess(np.max(rtd), 3)
132 | self.assertGreaterEqual(np.min(rtd), -2)
133 | self.assertEqual(rtd.dtype, np.float32)
134 | self.check(t, comp=comp)
135 |
136 | def test_RandomUniformInt(self):
137 | # no python interface yet, but might be something like
138 | # t = tf.random_uniform_int((50, 80), -2, 3)
139 | # # compare only min and max
140 | # def comp(rtf, rtd):
141 | # self.assertLess(np.max(rtd), 3)
142 | # self.assertGreaterEqual(np.min(rtd), -2)
143 | # self.check(t, comp=comp)
144 | pass
145 |
146 | def test_RandomShuffle(self):
147 | t = tf.random_shuffle(self.random(10, 4))
148 | # compare only sum of first axis
149 | def comp(rtf, rtd):
150 | self.assertTrue(np.allclose(np.sum(rtf, axis=0), np.sum(rtd, axis=0)))
151 | self.check(t, comp=comp)
152 |
153 | def test_random_crop(self):
154 | t = tf.random_crop(self.random(3, 4, 8), [1, 2, 4])
155 | # compare only shape
156 | def comp(rtf, rtd):
157 | self.assertEqual(rtf.shape, rtd.shape)
158 | self.check(t, comp=comp)
159 |
160 |
161 | #
162 | # casting
163 | #
164 |
165 | def test_Cast(self):
166 | t = tf.cast(self.random(3, 4).astype("float32"), tf.float64)
167 | self.check(t)
168 |
169 | def test_StringToNumber(self):
170 | t = tf.string_to_number(list("0123456789"))
171 | self.check(t)
172 |
173 |
174 | #
175 | # shapes and shaping
176 | #
177 |
178 | def test_Shape(self):
179 | t = tf.shape(self.random(3, 4, 5))
180 | self.check(t)
181 |
182 | def test_Size(self):
183 | t = tf.size(self.random(3, 4))
184 | self.check(t)
185 |
186 | def test_Rank(self):
187 | t = tf.rank(self.random(3, 3))
188 | self.check(t)
189 |
190 | def test_Reshape(self):
191 | t = tf.reshape(self.random(3, 4, 5), (2, -1))
192 | self.check(t)
193 |
194 | def test_Squeeze(self):
195 | t = tf.squeeze(self.random(1, 2, 1, 3, 3, 1))
196 | self.check(t)
197 |
198 | def test_ExpandDims(self):
199 | t = tf.expand_dims(self.random(2, 3, 3, 4), -2)
200 | self.check(t)
201 |
202 |
203 | #
204 | # slicing and joining
205 | #
206 |
207 | def test_Slice(self):
208 | t = tf.slice(np.arange(3*4*8*6).reshape(3, 4, 8, 6), [1, 1, 2, 2], 4 * [2])
209 | self.check(t)
210 |
211 | def test_Split(self):
212 | for t in tf.split(self.random(8, 50, 10, 2), 5, 2):
213 | self.check(t)
214 |
215 | def test_SplitV(self):
216 | for t in tf.split(self.random(8, 50, 10, 2), [10, 30, 5, 1, 4], 1):
217 | self.check(t)
218 |
219 | def test_Tile(self):
220 | t = tf.tile(self.random(3, 4, 5), [1, 2, 3])
221 | self.check(t)
222 |
223 | def test_Pad(self):
224 | t = tf.pad(self.random(3, 8, 5), [[1, 2], [2, 1], [1, 0]])
225 | self.check(t)
226 |
227 | def test_ConcatV2(self):
228 | aaa = self.random((3, 4, 5), (3, 4, 5))
229 | t = tf.concat(list(self.random((3, 4, 5), (3, 4, 5))), 2)
230 | self.check(t)
231 |
232 | def test_Pack(self):
233 | pass
234 |
235 | def test_Unpack(self):
236 | pass
237 |
238 | def test_Stack(self):
239 | t = tf.stack(list(self.random((3, 4, 5), (3, 4, 5))), 2)
240 | self.check(t)
241 |
242 | def test_Unstack(self):
243 | for t in tf.unstack(self.random(6, 4, 5), axis=1):
244 | self.check(t)
245 |
246 | def test_ReverseSequence(self):
247 | x = self.random(3, 4, 10)
248 | t = tf.reverse_sequence(x, [5, 0, 0, 8], seq_dim=2, batch_dim=1)
249 | self.check(t)
250 |
251 | def test_ReverseV2(self):
252 | t = tf.reverse(self.random(3, 4, 10), [1, 2])
253 | self.check(t)
254 |
255 | def test_Transpose(self):
256 | t = tf.transpose(self.random(4, 3, 5), perm=[2, 0, 1])
257 | self.check(t)
258 |
259 |
260 | #
261 | # arithmetic math ops
262 | #
263 |
264 | def test_Add(self):
265 | t = tf.add(*self.random((3, 4), (3, 4)))
266 | self.check(t)
267 |
268 | def test_Subtract(self):
269 | t = tf.subtract(*self.random((3, 4), (3, 4)))
270 | self.check(t)
271 |
272 | test_Sub = test_Subtract
273 |
274 | def test_Multiply(self):
275 | t = tf.multiply(*self.random((3, 5), (3, 5)))
276 | self.check(t)
277 |
278 | test_Mul = test_Multiply
279 |
280 | def test_scalar_mul(self):
281 | t = tf.scalar_mul(1, tf.Variable(self.random(3, 5)))
282 | self.check(t)
283 |
284 | def test_Div(self):
285 | t = tf.div(*self.random((3, 5), (3, 5)))
286 | self.check(t)
287 |
288 | def test_RealDiv(self):
289 | t = tf.div(*self.random((3, 5), (3, 5)))
290 | self.check(t)
291 |
292 | def test_TrueDiv(self):
293 | t = tf.truediv(*self.random((3, 5), (3, 5)))
294 | self.check(t)
295 |
296 | def test_FloorDiv(self):
297 | t = tf.floordiv(*self.random((3, 5), (3, 5)))
298 | self.check(t)
299 |
300 | def test_Mod(self):
301 | t = tf.mod(*self.random((4, 3), (4, 3)))
302 | self.check(t)
303 |
304 | def test_FloorMod(self):
305 | t = tf.floormod(*self.random((4, 3), (4, 3)))
306 | self.check(t)
307 |
308 | def test_Cross(self):
309 | t = tf.cross(*self.random((4, 3), (4, 3)))
310 | self.check(t)
311 |
312 |
313 | #
314 | # basic math ops
315 | #
316 |
317 | def test_AddN(self):
318 | t = tf.add_n(self.random((4, 3), (4, 3)))
319 | self.check(t)
320 |
321 | def test_Abs(self):
322 | t = tf.abs(-self.random(4, 3))
323 | self.check(t)
324 |
325 | def test_Negative(self):
326 | t = tf.negative(self.random(4, 3))
327 | self.check(t)
328 |
329 | test_Neg = test_Negative
330 |
331 | def test_Sign(self):
332 | t = tf.sign(self.random(4, 3) - 0.5)
333 | self.check(t)
334 |
335 | def test_Inv(self):
336 | if td._tf_version[:2] <= (0, 11):
337 | t = tf.inv(self.random(4, 3))
338 | self.check(t)
339 |
340 | def test_Square(self):
341 | t = tf.square(self.random(4, 3))
342 | self.check(t)
343 |
344 | def test_Round(self):
345 | t = tf.round(self.random(4, 3) - 0.5)
346 | self.check(t)
347 |
348 | def test_Sqrt(self):
349 | t = tf.sqrt(self.random(4, 3))
350 | self.check(t)
351 |
352 | def test_Rsqrt(self):
353 | t = tf.rsqrt(self.random(4, 3))
354 | self.check(t)
355 |
356 | def test_Pow(self):
357 | t = tf.pow(*self.random((4, 3), (4, 3)))
358 | self.check(t)
359 |
360 | def test_Exp(self):
361 | t = tf.exp(self.random(4, 3))
362 | self.check(t)
363 |
364 | def test_Log(self):
365 | t = tf.log(self.random(4, 3))
366 | self.check(t)
367 |
368 | def test_Ceil(self):
369 | t = tf.ceil(self.random(4, 3) - 0.5)
370 | self.check(t)
371 |
372 | def test_Floor(self):
373 | t = tf.floor(self.random(4, 3) - 0.5)
374 | self.check(t)
375 |
376 | def test_Maximum(self):
377 | t = tf.maximum(*self.random((4, 3), (4, 3)))
378 | self.check(t)
379 |
380 | def test_Minimum(self):
381 | t = tf.minimum(*self.random((4, 3), (4, 3)))
382 | self.check(t)
383 |
384 | def test_Cos(self):
385 | t = tf.cos(self.random(4, 3))
386 | self.check(t)
387 |
388 | def test_Sin(self):
389 | t = tf.sin(self.random(4, 3))
390 | self.check(t)
391 |
392 | def test_lbeta(self):
393 | t = tf.lbeta(self.random(4, 3))
394 | self.check(t)
395 |
396 | def test_Tan(self):
397 | t = tf.tan(self.random(4, 3))
398 | self.check(t)
399 |
400 | def test_Acos(self):
401 | t = tf.acos(self.random(4, 3))
402 | self.check(t)
403 |
404 | def test_Asin(self):
405 | t = tf.asin(self.random(4, 3))
406 | self.check(t)
407 |
408 | def test_Atan(self):
409 | t = tf.atan(self.random(4, 3))
410 | self.check(t)
411 |
412 | def test_Lgamma(self):
413 | t = tf.lgamma(self.random(4, 3))
414 | self.check(t)
415 |
416 | def test_Digamma(self):
417 | t = tf.digamma(self.random(4, 3))
418 | self.check(t)
419 |
420 | def test_Erf(self):
421 | t = tf.erf(self.random(4, 3))
422 | self.check(t)
423 |
424 | def test_Erfc(self):
425 | t = tf.erfc(self.random(4, 3))
426 | self.check(t)
427 |
428 | def test_SquaredDifference(self):
429 | t = tf.squared_difference(*self.random((3, 4, 4), (3, 4, 4)))
430 | self.check(t)
431 |
432 | def test_Igamma(self):
433 | t = tf.igamma(*self.random((3, 3), (3, 3)))
434 | self.check(t)
435 |
436 | def test_Igammac(self):
437 | t = tf.igammac(*self.random((3, 3), (3, 3)))
438 | self.check(t)
439 |
440 | def test_Zeta(self):
441 | t = tf.zeta(self.random(3, 3) + 2, self.random(3, 3))
442 | self.check(t)
443 |
444 | def test_Polygamma(self):
445 | t = tf.polygamma(np.array([1, 2, 3]).astype("float32"), np.array([4, 5, 6]).astype("float32"))
446 | self.check(t)
447 |
448 | def test_Betainc(self):
449 | t = tf.betainc(*self.random((3, 3), (3, 3), (3, 3)))
450 | self.check(t)
451 |
452 |
453 | #
454 | # matrix math ops
455 | #
456 |
457 | def test_Diag(self):
458 | t = tf.diag(self.random(3, 3))
459 | self.check(t)
460 |
461 | def test_DiagPart(self):
462 | t = tf.diag_part(self.random(3, 3))
463 | self.check(t)
464 |
465 | def test_MatrixDiagPart(self):
466 | if td._tf_version[:2] >= (0, 12):
467 | t = tf.matrix_diag_part(self.random(3, 4, 4, 5))
468 | self.check(t)
469 |
470 | def test_trace(self):
471 | t = tf.trace(self.random(3, 3))
472 | self.check(t)
473 |
474 | def test_MatMul(self):
475 | t = tf.matmul(*self.random((4, 3), (3, 5)), transpose_a=False, transpose_b=False)
476 | self.check(t)
477 | t = tf.matmul(*self.random((3, 4), (3, 5)), transpose_a=True, transpose_b=False)
478 | self.check(t)
479 | t = tf.matmul(*self.random((4, 3), (5, 3)), transpose_a=False, transpose_b=True)
480 | self.check(t)
481 | t = tf.matmul(*self.random((3, 4), (5, 3)), transpose_a=True, transpose_b=True)
482 | self.check(t)
483 |
484 | # def test_BatchMatMul(self):
485 | # t = tf.batch_matmul(*self.random((2, 4, 4, 3), (2, 4, 3, 5)), adj_x=False, adj_y=False)
486 | # self.check(t)
487 | # t = tf.batch_matmul(*self.random((2, 4, 3, 4), (2, 4, 3, 5)), adj_x=True, adj_y=False)
488 | # self.check(t)
489 | # t = tf.batch_matmul(*self.random((2, 4, 4, 3), (2, 4, 5, 3)), adj_x=False, adj_y=True)
490 | # self.check(t)
491 | # t = tf.batch_matmul(*self.random((2, 4, 3, 4), (2, 4, 5, 3)), adj_x=True, adj_y=True)
492 | # self.check(t)
493 |
494 | def test_MatrixDeterminant(self):
495 | t = tf.matrix_determinant(self.random(2, 3, 4, 3, 3))
496 | self.check(t)
497 |
498 | def test_MatrixInverse(self):
499 | t = tf.matrix_inverse(self.random(2, 3, 4, 3, 3), adjoint=False)
500 | self.check(t)
501 | t = tf.matrix_inverse(self.random(2, 3, 4, 3, 3), adjoint=True)
502 | self.check(t)
503 |
504 | def test_Cholesky(self):
505 | t = tf.cholesky(np.array(3 * [8, 3, 3, 8]).reshape(3, 2, 2).astype("float32"))
506 | self.check(t)
507 |
508 | def test_MatrixSolve(self):
509 | t = tf.matrix_solve(*self.random((2, 3, 3, 3), (2, 3, 3, 1)), adjoint=False)
510 | self.check(t)
511 | t = tf.matrix_solve(*self.random((2, 3, 3, 3), (2, 3, 3, 1)), adjoint=True)
512 | self.check(t)
513 |
514 | def test_MatrixTriangularSolve(self):
515 | t = tf.matrix_triangular_solve(*self.random((2, 3, 3, 3), (2, 3, 3, 1)), adjoint=False, lower=False)
516 | self.check(t)
517 | t = tf.matrix_triangular_solve(*self.random((2, 3, 3, 3), (2, 3, 3, 1)), adjoint=True, lower=False)
518 | self.check(t)
519 | t = tf.matrix_triangular_solve(*self.random((2, 3, 3, 3), (2, 3, 3, 1)), adjoint=False, lower=True)
520 | self.check(t)
521 |
522 | def test_MatrixSolveLs(self):
523 | t = tf.matrix_solve_ls(*self.random((2, 3, 3, 3), (2, 3, 3, 1)))
524 | self.check(t)
525 |
526 | def test_SelfAdjointEig(self):
527 | # legacy support
528 | pass
529 |
530 | def test_SelfAdjointEigV2(self):
531 | t = tf.self_adjoint_eig(np.array(3 * [3, 2, 2, 1]).reshape(3, 2, 2).astype("float32"))
532 | # the order of eigen vectors and values may differ between tf and np, so only compare sum
533 | # and mean
534 | # also, different numerical algorithms are used, so account for difference in precision by
535 | # comparing numbers with 4 digits
536 | self.check(t, ndigits=4, stats=True, abs=True)
537 |
538 | def test_Svd(self):
539 | t = tf.svd(self.random(4, 5, 3, 2).astype("float32"))
540 | self.check(t, ndigits=4, abs=True)
541 |
542 |
543 | #
544 | # complex number ops
545 | #
546 |
547 | def test_Complex(self):
548 | t = tf.complex(*self.random((3, 4), (3, 4)))
549 | self.check(t)
550 |
551 | def test_Conj(self):
552 | t = tf.conj(self.random(3, 4, complex=True))
553 | self.check(t)
554 |
555 | def test_Imag(self):
556 | t = tf.imag(tf.Variable(self.random(3, 4, complex=True)))
557 | self.check(t)
558 |
559 | def test_Real(self):
560 | t = tf.real(tf.Variable(self.random(3, 4, complex=True)))
561 | self.check(t)
562 |
563 |
564 | #
565 | # Fourier transform ops
566 | #
567 |
568 | def test_FFT2D(self):
569 | # only defined for gpu
570 | if DEVICE == GPU:
571 | t = tf.fft2d(self.random(3, 4, complex=True))
572 | self.check(t)
573 |
574 | def test_IFFT2D(self):
575 | # only defined for gpu
576 | if DEVICE == GPU:
577 | t = tf.ifft2d(self.random(3, 4, complex=True))
578 | self.check(t)
579 |
580 | def test_FFT3D(self):
581 | # only defined for gpu
582 | if DEVICE == GPU:
583 | t = tf.fft3d(self.random(3, 4, 5, complex=True))
584 | self.check(t)
585 |
586 | def test_IFFT3D(self):
587 | # only defined for gpu
588 | if DEVICE == GPU:
589 | t = tf.ifft3d(self.random(3, 4, 5, complex=True))
590 | self.check(t)
591 |
592 |
593 | #
594 | # reduction
595 | #
596 |
597 | def test_Sum(self):
598 | t = tf.reduce_sum(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True)
599 | self.check(t)
600 | t = tf.reduce_sum(self.random(3, 4, 5), reduction_indices=(0, 1), keep_dims=True)
601 | self.check(t)
602 | t = tf.reduce_sum(self.random(3, 4, 5), reduction_indices=0, keep_dims=True)
603 | self.check(t)
604 | if td._tf_version[:3] >= (0, 12, 0):
605 | t = tf.reduce_sum(self.random(3, 4, 5), axis=[0, 1], keep_dims=True)
606 | self.check(t)
607 | t = tf.reduce_sum(self.random(3, 4, 5), axis=(0, 1), keep_dims=True)
608 | self.check(t)
609 | t = tf.reduce_sum(self.random(3, 4, 5), axis=0, keep_dims=True)
610 | self.check(t)
611 |
612 | def test_Prod(self):
613 | t = tf.reduce_prod(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True)
614 | self.check(t)
615 | if td._tf_version[:3] >= (0, 12, 0):
616 | t = tf.reduce_prod(self.random(3, 4, 5), axis=[0, 1], keep_dims=True)
617 | self.check(t)
618 |
619 | def test_Min(self):
620 | t = tf.reduce_min(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True)
621 | self.check(t)
622 | if td._tf_version[:3] >= (0, 12, 0):
623 | t = tf.reduce_min(self.random(3, 4, 5), axis=[0, 1], keep_dims=True)
624 | self.check(t)
625 |
626 | def test_Max(self):
627 | t = tf.reduce_max(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True)
628 | self.check(t)
629 | if td._tf_version[:3] >= (0, 12, 0):
630 | t = tf.reduce_max(self.random(3, 4, 5), axis=[0, 1], keep_dims=True)
631 | self.check(t)
632 |
633 | def test_Mean(self):
634 | t = tf.reduce_mean(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True)
635 | self.check(t)
636 | if td._tf_version[:3] >= (0, 12, 0):
637 | t = tf.reduce_mean(self.random(3, 4, 5), axis=[0, 1], keep_dims=True)
638 | self.check(t)
639 |
640 | def test_All(self):
641 | t = tf.reduce_all(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True)
642 | self.check(t)
643 | if td._tf_version[:3] >= (0, 12, 0):
644 | t = tf.reduce_all(self.random(3, 4, 5), axis=[0, 1], keep_dims=True)
645 | self.check(t)
646 |
647 | def test_Any(self):
648 | t = tf.reduce_any(self.random(3, 4, 5), reduction_indices=[0, 1], keep_dims=True)
649 | self.check(t)
650 | if td._tf_version[:3] >= (0, 12, 0):
651 | t = tf.reduce_any(self.random(3, 4, 5), axis=[0, 1], keep_dims=True)
652 | self.check(t)
653 |
654 |
655 | #
656 | # segmentation
657 | #
658 |
659 | def test_SegmentSum(self):
660 | t = tf.segment_sum(self.random(4, 2, 3), np.array([0, 1, 1, 2]))
661 | self.check(t)
662 |
663 | def test_SegmentProd(self):
664 | t = tf.segment_prod(self.random(4, 2, 3), np.array([0, 1, 1, 2]))
665 | self.check(t)
666 |
667 | def test_SegmentMin(self):
668 | t = tf.segment_min(self.random(4, 2, 3), np.array([0, 1, 1, 2]))
669 | self.check(t)
670 |
671 | def test_SegmentMax(self):
672 | t = tf.segment_max(self.random(4, 2, 3), np.array([0, 1, 1, 2]))
673 | self.check(t)
674 |
675 | def test_SegmentMean(self):
676 | t = tf.segment_mean(self.random(4, 2, 3), np.array([0, 1, 1, 2]))
677 | self.check(t)
678 |
679 | def test_UnsortedSegmentSum(self):
680 | t = tf.unsorted_segment_sum(self.random(4, 2, 3), np.array([0, 2, 2, 1]), 3)
681 | self.check(t)
682 |
683 | def test_SparseSegmentSum(self):
684 | t = tf.sparse_segment_sum(self.random(4, 3, 2), [0, 2, 3], [0, 1, 1])
685 | self.check(t)
686 |
687 | def test_SparseSegmentMean(self):
688 | t = tf.sparse_segment_mean(self.random(4, 3, 2), [0, 2, 3], [0, 1, 1])
689 | self.check(t)
690 |
691 | def test_SparseSegmentSqrtN(self):
692 | t = tf.sparse_segment_sqrt_n(self.random(4, 3, 2), [0, 2, 3], [0, 1, 1])
693 | self.check(t)
694 |
695 |
696 | #
697 | # sequence comparison and indexing
698 | #
699 |
700 | def test_ArgMin(self):
701 | t = tf.argmin(self.random(3, 4, 2), 1)
702 | self.check(t)
703 |
704 | def test_ArgMax(self):
705 | t = tf.argmax(self.random(3, 4, 2), 1)
706 | self.check(t)
707 |
708 | def test_ListDiff(self):
709 | if td._tf_version[:2] <= (0, 11):
710 | l = np.random.randint(0, 5, 100)
711 | t1, t2 = tf.listdiff(l, l[::-2])
712 | self.check(t1)
713 | self.check(t2)
714 |
715 | def test_Where(self):
716 | t = tf.where([[True, False], [False, False], [True, False]])
717 | self.check(t)
718 |
719 | def test_Unique(self):
720 | t = tf.unique([9, 3, 5, 7, 3, 9, 9], out_idx=tf.int32)
721 | self.check(t)
722 |
723 | def test_InvertPermutation(self):
724 | t = tf.invert_permutation(np.random.permutation(10))
725 | self.check(t)
726 |
727 |
728 | #
729 | # control flow ops
730 | #
731 |
732 | def test_Identity(self):
733 | t = tf.identity(self.random(3, 4))
734 | self.check(t)
735 |
736 |
737 | #
738 | # NN activation ops
739 | #
740 |
741 | def test_Relu(self):
742 | t = tf.nn.relu(self.random(100) - 0.5)
743 | self.check(t)
744 |
745 | def test_Relu6(self):
746 | t = tf.nn.relu6((self.random(100) - 0.5) * 20)
747 | self.check(t)
748 |
749 | def test_Elu(self):
750 | t = tf.nn.elu(self.random(100) - 0.5)
751 | self.check(t)
752 |
753 | def test_Softplus(self):
754 | t = tf.nn.softplus(self.random(100) - 0.5)
755 | self.check(t)
756 |
757 | def test_Softsign(self):
758 | t = tf.nn.softsign(self.random(100) - 0.5)
759 | self.check(t)
760 |
761 | def test_BiasAdd(self):
762 | t = tf.nn.bias_add(*self.random((4, 5), (5,)))
763 | self.check(t)
764 |
765 | def test_Sigmoid(self):
766 | t = tf.nn.sigmoid(self.random(3, 4))
767 | self.check(t)
768 |
769 | def test_Tanh(self):
770 | t = tf.nn.tanh(self.random(3, 4))
771 | self.check(t)
772 |
773 | def test_Softmax(self):
774 | t = tf.nn.softmax(self.random(10, 5))
775 | self.check(t)
776 |
777 |
778 | #
779 | # NN convolution ops
780 | #
781 |
782 | def test_Conv1D(self):
783 | t = tf.nn.conv1d(np.arange(8000).reshape(1000, 2, 4).astype("float32"),
784 | np.ones(80).reshape(2, 4, 10).astype("float32"),
785 | 1, "SAME")
786 | self.check(t)
787 | t = tf.nn.conv1d(np.arange(8000).reshape(1000, 2, 4).astype("float32"),
788 | np.ones(80).reshape(2, 4, 10).astype("float32"),
789 | 2, "VALID")
790 | self.check(t)
791 |
792 | def test_Conv2D(self):
793 | t = tf.nn.conv2d(np.arange(24000).reshape(1000, 2, 3, 4).astype("float32"),
794 | np.ones(160).reshape(2, 2, 4, 10).astype("float32"),
795 | [1, 2, 3, 1], "SAME")
796 | self.check(t)
797 | t = tf.nn.conv2d(np.arange(24000).reshape(1000, 2, 3, 4).astype("float32"),
798 | np.ones(160).reshape(2, 2, 4, 10).astype("float32"),
799 | [1, 2, 5, 1], "VALID")
800 | self.check(t)
801 |
802 | def test_Conv3D(self):
803 | t = tf.nn.conv3d(np.arange(72000).reshape(1000, 2, 3, 3, 4).astype("float32"),
804 | np.ones(320).reshape(2, 2, 2, 4, 10).astype("float32"),
805 | [1, 1, 1, 1, 1], "SAME")
806 | self.check(t)
807 | t = tf.nn.conv3d(np.arange(72000).reshape(1000, 2, 3, 3, 4).astype("float32"),
808 | np.ones(320).reshape(2, 2, 2, 4, 10).astype("float32"),
809 | [1, 1, 1, 1, 1], "VALID")
810 | self.check(t)
811 |
812 |
813 | #
814 | # pooling ops
815 | #
816 |
817 | def test_AvgPool(self):
818 | t = tf.nn.avg_pool(np.arange(16).reshape(1, 4, 4, 1).astype("float32"),
819 | [1, 2, 2, 1], [1, 1, 1, 1], "SAME")
820 | self.check(t)
821 | t = tf.nn.avg_pool(np.arange(16).reshape(1, 4, 4, 1).astype("float32"),
822 | [1, 2, 2, 1], [1, 2, 2, 1], "VALID")
823 | self.check(t)
824 |
825 | def test_MaxPool(self):
826 | t = tf.nn.max_pool(np.arange(16).reshape(1, 4, 4, 1).astype("float32"),
827 | [1, 2, 2, 1], [1, 1, 1, 1], "SAME")
828 | self.check(t)
829 | t = tf.nn.max_pool(np.arange(16).reshape(1, 4, 4, 1).astype("float32"),
830 | [1, 2, 2, 1], [1, 2, 2, 1], "VALID")
831 | self.check(t)
832 | t = tf.nn.max_pool(np.arange(64).reshape(2, 4, 8, 1).astype("float32"),
833 | [1, 2, 2, 1], [1, 3, 2, 1], "VALID")
834 | self.check(t)
835 |
836 | def test_AvgPool3D(self):
837 | t = tf.nn.avg_pool3d(np.arange(64).reshape(1, 4, 4, 4, 1).astype("float32"),
838 | [1, 2, 2, 2, 1], [1, 1, 1, 1, 1], "SAME")
839 | self.check(t)
840 | t = tf.nn.avg_pool3d(np.arange(48).reshape(1, 4, 4, 3, 1).astype("float32"),
841 | [1, 2, 2, 1, 1], [1, 2, 2, 1, 1], "VALID")
842 | self.check(t)
843 |
844 | def test_MaxPool3D(self):
845 | t = tf.nn.max_pool3d(np.arange(64).reshape(1, 4, 4, 4, 1).astype("float32"),
846 | [1, 2, 2, 2, 1], [1, 1, 1, 1, 1], "SAME")
847 | self.check(t)
848 | t = tf.nn.max_pool3d(np.arange(48).reshape(1, 4, 4, 3, 1).astype("float32"),
849 | [1, 2, 2, 1, 1], [1, 2, 2, 1, 1], "VALID")
850 | self.check(t)
851 |
--------------------------------------------------------------------------------
/tests/perf/create_plots.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Plotting script that plots all kinds of distributions using the json file created by
5 | measure_runtimes.py.
6 | """
7 |
8 | import os
9 | import json
10 | import itertools
11 | import collections
12 |
13 | import matplotlib.pyplot as plt
14 |
15 | from measure_runtimes import specs, output_file
16 |
17 |
18 | #
19 | # constants
20 | #
21 |
22 | plotdir = "data"
23 |
24 |
25 | #
26 | # plot function
27 | #
28 |
29 | def plot(xkey, coords, data):
30 | # determine the plotfile path
31 | labels = sorted([(xkey, 0)] + coords.items(), key=lambda tpl: specs.keys().index(tpl[0]))
32 | title = ", ".join("%s: %s" % tpl for tpl in labels if tpl[0] != "examples" and tpl[0] != xkey)
33 | labels = [(key[0], value) for key, value in labels]
34 | plotfile = "runtime_" + xkey + "_" + "_".join("%s%s" % tpl for tpl in labels) + ".png"
35 | plotfile = os.path.join(plotdir, plotfile)
36 |
37 | # get x and y data to plot
38 | x = [d[xkey] for d in data]
39 | ykeys = ["tf_cpu", "tf_gpu", "td"]
40 | y = {ykey: [d["times"][ykey]["mean"] for d in data] for ykey in ykeys}
41 | markers = dict(zip(ykeys, ("s", "o", "D")))
42 |
43 | # do the plot
44 | fig, axes = plt.subplots()
45 | for ykey, _y in y.items():
46 | axes.plot(x, _y, "-" + markers[ykey], label=ykey)
47 |
48 | axes.set_xlabel(xkey)
49 | axes.set_ylabel("time per batch [s]")
50 | axes.set_title(title)
51 | if xkey != "units":
52 | axes.set_xscale("log")
53 | axes.legend(loc="best")
54 |
55 | fig.savefig(plotfile)
56 | fig.clf()
57 |
58 |
59 | #
60 | # data filter function
61 | #
62 |
63 | def filter_data(data, **kwargs):
64 | return [d for d in data if all(d[key] == value for key, value in kwargs.items())]
65 |
66 |
67 | #
68 | # main and entry hook
69 | #
70 |
71 | def main():
72 | # check if the output file exists
73 | if not os.path.exists(output_file):
74 | IOError("output file '%s' does not exist, run run_tests.py to generate it" % output_file)
75 |
76 | # read the data
77 | with open(output_file, "r") as f:
78 | data = json.load(f)
79 |
80 | # prepare the plot dir
81 | if not os.path.exists(plotdir):
82 | os.mkdir(plotdir)
83 |
84 | # loop through all keys in specs, each of them will be used for the x-axis
85 | for xkey in specs:
86 | # do not plot examples on the x-axis
87 | if xkey == "examples":
88 | continue
89 |
90 | # determine the remaining keys, do the combinatorics
91 | keys = [key for key in specs if key != xkey]
92 | for combi in itertools.product(*[specs[key] for key in keys]):
93 | coord = collections.OrderedDict(zip(keys, combi))
94 | plot_data = filter_data(data, **coord)
95 | plot_data.sort(key=lambda d: d[key])
96 | plot(xkey, coord, plot_data)
97 |
98 | if __name__ == "__main__":
99 | main()
100 |
--------------------------------------------------------------------------------
/tests/perf/measure_runtimes.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Runs a performance test for tfdeploy, tensorflow@CPU and tensorflow@GPU with different data and
5 | network dimensions. The results are written to a json file "times.json" in the current directory.
6 | """
7 |
8 | import os
9 | import time
10 | import json
11 | import uuid
12 | import itertools
13 | import collections
14 |
15 | import numpy as np
16 | import tensorflow as tf
17 | import tfdeploy as td
18 |
19 |
20 | #
21 | # test specs and constants
22 | #
23 |
24 | output_file = "times.json"
25 |
26 | specs = collections.OrderedDict()
27 | specs["features"] = [10, 20, 50, 100, 200, 500, 1000]
28 | specs["examples"] = [100000]
29 | specs["batchsize"] = [1000, 100, 10]
30 | specs["layers"] = [1, 2, 5, 10]
31 | specs["units"] = [10, 20, 50, 100, 200, 500]
32 |
33 |
34 | #
35 | # Data class that contains cached random numbers and handles batch iteration
36 | #
37 |
38 | # static instance
39 | data = None
40 |
41 | # class definition
42 | class Data(object):
43 |
44 | def __init__(self, max_features, max_examples):
45 | super(Data, self).__init__()
46 |
47 | self._max_features = max_features
48 | self._max_examples = max_examples
49 |
50 | self._data = np.random.rand(max_examples, max_features).astype(np.float32)
51 |
52 | self._features = None
53 | self._examples = None
54 | self._batchsize = None
55 |
56 | def prepare(self, features, examples, batchsize):
57 | if features > self._max_features:
58 | raise ValueError("features must be lower than max_features (%d)" % self._max_features)
59 |
60 | if examples > self._max_examples:
61 | raise ValueError("examples must be lower than max_examples (%d)" % self._max_examples)
62 |
63 | if examples % batchsize != 0:
64 | raise ValueError("batchsize must be a divider of examples")
65 |
66 | self._features = features
67 | self._examples = examples
68 | self._batchsize = batchsize
69 |
70 | def __iter__(self):
71 | for i in range(self._examples / self._batchsize):
72 | yield self._data[(i*self._batchsize):((i+1)*self._batchsize), :self._features]
73 |
74 |
75 | #
76 | # model generation helpers
77 | #
78 |
79 | def create_tf_model(features, layers, units, device, input_name, output_name):
80 | with tf.device(device):
81 | x = tf.placeholder(tf.float32, shape=[None, features], name=input_name)
82 | y = x
83 | for i in range(layers):
84 | W = tf.Variable(tf.random_normal([features if y == x else units, units]))
85 | b = tf.Variable(tf.zeros([units]))
86 | y = tf.tanh(tf.matmul(y, W) + b, name=output_name if i == layers - 1 else None)
87 | return x, y
88 |
89 |
90 | def create_models(features, layers, units):
91 | postfix = str(uuid.uuid4())[:8]
92 | input_name = "input_" + postfix
93 | output_name = "output_" + postfix
94 |
95 | tf_cpu_x, tf_cpu_y = create_tf_model(features, layers, units, "/cpu:0", input_name, output_name)
96 | tf_gpu_x, tf_gpu_y = create_tf_model(features, layers, units, "/gpu:0", input_name, output_name)
97 |
98 | tf_sess = tf.Session()
99 | tf_sess.run(tf.initialize_all_variables())
100 |
101 | td_model = td.Model()
102 | td_model.add(tf_cpu_y, tf_sess)
103 | td_x, td_y = td_model.get(input_name, output_name)
104 |
105 | tf_cpu_fn = lambda batch: tf_sess.run(tf_cpu_y, feed_dict={tf_cpu_x: batch})
106 | tf_gpu_fn = lambda batch: tf_sess.run(tf_gpu_y, feed_dict={tf_gpu_x: batch})
107 | td_fn = lambda batch: td_y.eval({td_x: batch})
108 |
109 | return collections.OrderedDict([
110 | ("tf_cpu", tf_cpu_fn),
111 | ("tf_gpu", tf_gpu_fn),
112 | ("td", td_fn)
113 | ])
114 |
115 |
116 | #
117 | # actual test function
118 | #
119 |
120 | def test(features, examples, batchsize, layers, units):
121 | # prepare the data for the given input dimensions
122 | data.prepare(features, examples, batchsize)
123 |
124 | # create models / evaluation functions
125 | models = create_models(features, layers, units)
126 |
127 | # storage for measured runtimes
128 | times = collections.OrderedDict((name, []) for name in models.keys())
129 |
130 | # loop through batches and evaluation functions
131 | for batch in data:
132 | for name, fn in models.items():
133 | t1 = time.time()
134 | fn(batch)
135 | times[name].append(time.time() - t1)
136 |
137 | for name, l in times.items():
138 | a = np.array(l)
139 | times[name] = {
140 | "total" : np.sum(a),
141 | "mean" : np.mean(a),
142 | "variance": np.var(a)
143 | }
144 |
145 | return times
146 |
147 |
148 | #
149 | # main and entry hook
150 | #
151 |
152 | def main():
153 | combis = list(itertools.product(*specs.values()))
154 |
155 | # create data with maximum shape
156 | global data
157 | print("create data")
158 | data = Data(max(specs["features"]), max(specs["examples"]))
159 | print("done")
160 |
161 | # run actual tests
162 | results = []
163 | for i, combi in enumerate(combis):
164 | print("running test %d/%d" % (i + 1, len(combis)))
165 | d = collections.OrderedDict(zip(specs.keys(), combi))
166 | d["times"] = test(**d)
167 | results.append(d)
168 | if i == 3: break
169 |
170 | with open(output_file, "w") as f:
171 | json.dump(results, f, indent=4)
172 |
173 | if __name__ == "__main__":
174 | main()
175 |
--------------------------------------------------------------------------------
/tests/perf/simple.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import tensorflow as tf
4 | import tfdeploy as td
5 | import numpy as np
6 |
7 |
8 | # setup tf graph
9 | sess = tf.Session()
10 | x = tf.placeholder("float", shape=[None, 784], name="input")
11 | W = tf.Variable(tf.truncated_normal([784, 100], stddev=0.05))
12 | b = tf.Variable(tf.zeros([100]))
13 | y = tf.nn.softmax(tf.matmul(x, W) + b, name="output")
14 |
15 | if td._tf_version[:3] < (0, 12, 0) and 0:
16 | sess.run(tf.initialize_all_variables())
17 | else:
18 | sess.run(tf.global_variables_initializer())
19 |
20 | # setup td model
21 | model = td.Model()
22 | model.add(y, sess)
23 | inp, outp = model.get("input", "output")
24 |
25 | # testing code
26 | batch = np.random.rand(10000, 784)
27 |
28 | def test_tf():
29 | return y.eval(session=sess, feed_dict={x: batch})
30 |
31 | def test_td():
32 | return outp.eval({inp: batch})
33 |
--------------------------------------------------------------------------------
/tfdeploy.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Deploy tensorflow graphs for fast evaluation and export to tensorflow-less environments running
5 | numpy.
6 | """
7 |
8 |
9 | __author__ = "Marcel Rieger"
10 | __copyright__ = "Copyright 2016-2025, Marcel Rieger"
11 | __credits__ = ["Marcel Rieger"]
12 | __contact__ = "https://github.com/riga/tfdeploy"
13 | __license__ = "BSD-3-Clause"
14 | __status__ = "Development"
15 | __version__ = "0.4.2"
16 |
17 | __all__ = ["Model", "Tensor", "Operation", "Ensemble",
18 | "UnknownOperationException", "OperationMismatchException",
19 | "InvalidImplementationException", "UnknownImplementationException",
20 | "EnsembleMismatchException", "ScipyOperationException",
21 | "reset", "optimize", "print_tensor", "print_op", "print_tf_tensor", "print_tf_op",
22 | "IMPL_NUMPY", "IMPL_SCIPY", "IMPLS",
23 | "METHOD_MEAN", "METHOD_MAX", "METHOD_MIN", "METHOD_CUSTOM", "METHODS",
24 | "HAS_SCIPY"]
25 |
26 |
27 | # imports for core code
28 | import os
29 | import re
30 | from uuid import uuid4
31 | from functools import reduce
32 |
33 | try:
34 | # python 2
35 | import cPickle as pickle
36 | except ImportError:
37 | # python 3
38 | import pickle
39 |
40 | # third-party imports
41 | import numpy as np
42 |
43 |
44 | # metaclass decorator from six package, credits to Benjamin Peterson
45 | def add_metaclass(metaclass):
46 | def wrapper(cls):
47 | orig_vars = cls.__dict__.copy()
48 | slots = orig_vars.get("__slots__")
49 | if slots is not None:
50 | if isinstance(slots, str):
51 | slots = [slots]
52 | for slots_var in slots:
53 | orig_vars.pop(slots_var)
54 | orig_vars.pop("__dict__", None)
55 | orig_vars.pop("__weakref__", None)
56 | return metaclass(cls.__name__, cls.__bases__, orig_vars)
57 | return wrapper
58 |
59 |
60 | class Model(object):
61 | """
62 | A trained model that contains one or more converted tensorflow graphs. When *path* is set, a
63 | previously saved model is loaded from that path. Usage:
64 |
65 | .. code-block:: python
66 |
67 | import tensorflow as tf
68 | import tfdeploy as td
69 |
70 | # build your graph, use names for input and output tensors
71 | sess = tf.Session()
72 | x = tf.placeholder("float", shape=[None, 784], name="input")
73 | W = tf.Variable(tf.truncated_normal([784, 100], stddev=0.05))
74 | b = tf.Variable(tf.zeros([100]))
75 | y = tf.nn.softmax(tf.matmul(x, W) + b, name="output")
76 | sess.run(tf.initialize_all_variables())
77 |
78 | # ... training ...
79 |
80 | # create a model and save it to disk
81 | model = td.Model()
82 | model.add(y, sess)
83 | model.save("/path/to/model.pkl")
84 |
85 | And then in an other file:
86 |
87 | .. code-block:: python
88 |
89 | import tfdeploy as td
90 | import numpy as np
91 |
92 | model = td.Model("/path/to/model.pkl")
93 | inp, outp = model.get("input", "output")
94 |
95 | batch = np.random.rand(10000, 784)
96 | result = outp.eval({inp: batch})
97 |
98 | .. py:attribute:: roots
99 |
100 | Contained root tensors in a dict mapped to a key.
101 | """
102 |
103 | value_index_cre = re.compile("\:\d+$")
104 | default_value_index = 0
105 |
106 | def __init__(self, path=None):
107 | super(Model, self).__init__()
108 |
109 | self.roots = {}
110 |
111 | # load when desired
112 | if path is not None:
113 | self.load(path)
114 |
115 | def get(self, *names, **kwargs):
116 | """ get(*names, key=None)
117 | Returns one or more :py:class:`Tensor` instances given by *names* using a deep lookup within
118 | the model. If *key* is not *None*, only the root tensor with that *key* is traversed. *None*
119 | is returned when no tensor was found. In case a tensor is passed, it's name is used for the
120 | lookup.
121 | """
122 | tensors = tuple(self._get(name, **kwargs) for name in names)
123 | return tensors[0] if len(names) == 1 else tensors
124 |
125 | def _get(self, name, key=None):
126 | if isinstance(name, Tensor):
127 | name = name.name
128 |
129 | # append the default value_index if there's none
130 | if not self.value_index_cre.search(name):
131 | name += ":%d" % self.default_value_index
132 |
133 | # return the first occurance of a tensor with that name
134 | if key is not None:
135 | return self.roots[key].get(name)
136 | else:
137 | return reduce(lambda t1, t2: t1 or t2.get(name), self.roots.values(), None)
138 |
139 | def __getitem__(self, name):
140 | return self.get(name)
141 |
142 | def __contains__(self, name):
143 | return self.get(name) is not None
144 |
145 | def add(self, tensor, tf_sess=None, key=None, **kwargs):
146 | """
147 | Adds a new root *tensor* for a *key* which, if *None*, defaults to a consecutive number.
148 | When *tensor* is not an instance of :py:class:`Tensor` but an instance of
149 | ``tensorflow.Tensor``, it is converted first. In that case, *tf_sess* should be a valid
150 | tensorflow session and *kwargs* are forwarded to the :py:class:`Tensor` constructor.
151 | """
152 | if not isinstance(tensor, Tensor):
153 | tensor = Tensor(tensor, tf_sess, **kwargs)
154 |
155 | if key is None:
156 | if len(self.roots) == 0:
157 | key = 0
158 | else:
159 | key = max(self.roots.keys()) + 1
160 |
161 | self.roots[key] = tensor
162 |
163 | def load(self, path):
164 | """
165 | Loads all tensors from a file defined by *path* and adds them to the root set.
166 | """
167 | path = os.path.expandvars(os.path.expanduser(path))
168 | with open(path, "rb") as f:
169 | roots = pickle.load(f)
170 |
171 | for key, tensor in roots.items():
172 | self.add(tensor, key=key)
173 |
174 | def save(self, path):
175 | """
176 | Saves all tensors of the root set to a file defined by *path*.
177 | """
178 | path = os.path.expandvars(os.path.expanduser(path))
179 | with open(path, "wb") as f:
180 | pickle.dump(self.roots, f)
181 |
182 |
183 | class TensorRegister(type):
184 | """
185 | Meta class of :py:class:`Tensor` that performs instance caching indexed by tensorflow tensor
186 | instances.
187 | """
188 |
189 | instances = {}
190 |
191 | def __call__(cls, tf_tensor, *args, **kwargs):
192 | # simple caching
193 | if tf_tensor not in cls.instances:
194 | inst = super(TensorRegister, cls).__call__(tf_tensor, *args, **kwargs)
195 | cls.instances[tf_tensor] = inst
196 | return cls.instances[tf_tensor]
197 |
198 |
199 | @add_metaclass(TensorRegister)
200 | class Tensor(object):
201 | """
202 | Building block of a model. In *graph* terms, tensors represent connections between nodes (ops)
203 | of a graph. It contains information on the op it results from. The conversion uses the
204 | (tensorflow) instances *tf_tensor* and *tf_sess*, *tf_feed_dict* can be set to evaluate the
205 | tensor's current value.
206 |
207 | .. py:attribute:: name
208 |
209 | The name of the tensor.
210 |
211 | .. py:attribute:: value_index
212 |
213 | The integer value index of this tensor, i.e., the position in the op's output list.
214 |
215 | .. py:attribute:: op
216 |
217 | The op instance that defines the value of this tensor. When created from a
218 | ``tensorflow.Placeholder`` or a ``tensorflow.Variable/V2``, op will be *None*.
219 |
220 | .. py:attribute:: value
221 |
222 | The value of this tensor. When created from a ``tensorflow.Variable/V2``, this will be the
223 | value of that variable, or *None* otherwise until it is evaluated the first time.
224 | """
225 |
226 | def __init__(self, tf_tensor, tf_sess, tf_feed_dict=None):
227 | super(Tensor, self).__init__()
228 |
229 | if not tf_sess:
230 | raise ValueError("bad tensorflow session: %s" % tf_sess)
231 |
232 | self.name = tf_tensor.name
233 | self.value_index = tf_tensor.value_index
234 | self.op = None
235 | self.value = None
236 | self.last_uuid = None
237 |
238 | # guess the value
239 | # explicitly evaluate variables and constants, use feed_dict for placeholders
240 | if tf_tensor.op.type in ("Variable", "VariableV2", "Const"):
241 | self.value = tf_tensor.eval(session=tf_sess, feed_dict=tf_feed_dict)
242 | elif tf_tensor.op.type == "Placeholder":
243 | if tf_feed_dict is not None and tf_tensor in tf_feed_dict:
244 | self.value = tf_feed_dict[tf_tensor]
245 |
246 | # create the op
247 | # no op for variables, placeholders and constants
248 | if tf_tensor.op.type not in ("Variable", "VariableV2", "Const", "Placeholder"):
249 | self.op = Operation.new(tf_tensor.op, tf_sess, tf_feed_dict=tf_feed_dict)
250 |
251 | def get(self, *names):
252 | """
253 | Returns one or more tensors given by *names* using a deep lookup within the inputs of the
254 | op. Note that *this* tensor is returned when the name matches. *None* is returned when no
255 | tensor was found.
256 | """
257 | tensors = tuple(self._get(name) for name in names)
258 | return tensors[0] if len(names) == 1 else tensors
259 |
260 | def _get(self, name):
261 | if self.name == name:
262 | return self
263 | elif self.op is None:
264 | return None
265 | else:
266 | return self.op.get(name)
267 |
268 | def eval(self, feed_dict=None, _uuid=None):
269 | """ eval(feed_dict=None)
270 | Returns the value of this tensor based on the evaluation of all dependent ops and tensors.
271 | You can overwrite values of dependent tensors using *feed_dict*, a mapping of tensors to
272 | numpy arrays, which is passed down the evaluation chain.
273 | """
274 | # set a cache uuid for this eval call
275 | if _uuid is None:
276 | _uuid = uuid4()
277 |
278 | # already cached? this is important for tensors that are used multiple time within the graph
279 | if _uuid == self.last_uuid:
280 | return self.value
281 | else:
282 | self.last_uuid = _uuid
283 |
284 | if feed_dict is None:
285 | feed_dict = {}
286 |
287 | # when _this_ tensor is in the feed_dict, return the fed value
288 | # otherwise, eval the op
289 | if self in feed_dict:
290 | self.value = feed_dict[self]
291 | elif self.op is not None:
292 | self.value = self.op.eval(feed_dict=feed_dict, _uuid=_uuid)[self.value_index]
293 |
294 | return self.value
295 |
296 | def __call__(self, *args, **kwargs):
297 | return self.eval(*args, **kwargs)
298 |
299 |
300 | class OperationRegister(type):
301 | """
302 | Meta class of :py:class:`Operation` that performs instance caching indexed by tensorflow op
303 | instances. Additionaly, all derived classes are registered in a mapping using their type's for
304 | faster op class lookup.
305 | """
306 |
307 | classes = {}
308 | instances = {}
309 |
310 | def __new__(metacls, classname, bases, classdict):
311 | # when not set explicitly in that class, set type to the class name
312 | classdict.setdefault("types", (classname,))
313 | cls = super(OperationRegister, metacls).__new__(metacls, classname, bases, classdict)
314 | # register the class for each of its types
315 | for type in cls.types:
316 | metacls.classes[type] = cls
317 | return cls
318 |
319 | def __call__(cls, tf_op, *args, **kwargs):
320 | # simple caching
321 | if tf_op not in cls.instances:
322 | inst = super(OperationRegister, cls).__call__(tf_op, *args, **kwargs)
323 | cls.instances[tf_op] = inst
324 | return cls.instances[tf_op]
325 |
326 |
327 | # implementation types
328 | IMPLS = IMPL_NUMPY, IMPL_SCIPY = range(2)
329 | IMPL_NAMES = ["numpy", "scipy"]
330 |
331 |
332 | @add_metaclass(OperationRegister)
333 | class Operation(object):
334 | """
335 | Building block of a model. In *graph* terms, operations (ops) represent nodes that are connected
336 | via tensors. It contains information on its input tensors. The conversion uses the
337 | (tensorflow) instance *tf_op*, all *args* and *kwargs* are forwarded to the :py:class:`Tensor`
338 | constructor for this op's input tensors. Op instances can have multiple implementations, i.e.,
339 | different methods that lead to equivalent results but might use additional third-party software
340 | such as *scipy*. To select a specific implementation, invoke :py:func:`use_impl`:
341 |
342 | .. code-block:: python
343 |
344 | # tell SomeOp to use the scipy implementation of its op logic
345 | SomeOp.use_impl(IMPL_SCIPY)
346 |
347 | See :py:func:`add_impl` for more info about adding new implementations.
348 |
349 | .. py:attribute:: types
350 | classmember
351 |
352 | A tuple containing the types of tensorflow ops that this op can represent.
353 |
354 | .. py:attribute:: unpack
355 | classmember
356 |
357 | If *True* (default), the values of evaluated input tensors are forwarded to *func* as single
358 | arguments, or, otherwise, as a list.
359 |
360 | .. py:attribute:: attrs
361 | classmember
362 |
363 | Names of the configuration attributes of the original tensorflow op in a tuple.
364 |
365 | .. py:attribute:: name
366 |
367 | The name of the op.
368 |
369 | .. py:attribute:: inputs
370 |
371 | Tuple of tensors that are input to this op. Their order is important as they are forwarded to
372 | *func* for evaluation.
373 |
374 | .. py:attribute:: kwargs
375 |
376 | Keyword arguments containing configuration values that will be passed to *func*.
377 | """
378 |
379 | impl = None
380 | impls = []
381 |
382 | types = ()
383 | unpack = True
384 | attrs = ()
385 | output_dtypes = False
386 |
387 | def __init__(self, tf_op, *args, **kwargs):
388 | super(Operation, self).__init__()
389 |
390 | # compare types as a cross check
391 | if tf_op.type not in self.types:
392 | raise OperationMismatchException("operation types do not match: %s, %s" \
393 | % (self.types, tf_op.type))
394 |
395 | self.name = tf_op.name
396 | self.inputs = tuple(Tensor(tf_tensor, *args, **kwargs) for tf_tensor in tf_op.inputs)
397 |
398 | self.value = None
399 | self.last_uuid = None
400 |
401 | # store attributes as kwargs for calls to eval
402 | self.kwargs = []
403 | for attr in self.attrs:
404 | try:
405 | value = tf_op.get_attr(attr)
406 | except ValueError:
407 | value = None
408 | self.kwargs.append(value)
409 |
410 | # store output dtypes for calls to eval when x is True
411 | self.output_dtypes = [dtype_map[dtype] for dtype in tf_op._output_types]
412 |
413 | @classmethod
414 | def new(cls, tf_op, *args, **kwargs):
415 | """
416 | Factory function that takes a tensorflow op *tf_op* and returns an instance of the
417 | appropriate op class. *args* and *kwargs* are forwarded to the op constructor. Raises an
418 | exception of type :py:exc:`UnknownOperationException` in case the requested op type is not
419 | known.
420 | """
421 | if tf_op.type not in cls.classes:
422 | raise UnknownOperationException("unknown operation: %s" % tf_op.type)
423 |
424 | return cls.classes[tf_op.type](tf_op, *args, **kwargs)
425 |
426 | def set_attr(self, attr, value):
427 | """
428 | Overwrites the value of an attribute *attr* with a new *value*.
429 | """
430 | if attr not in self.attrs:
431 | raise AttributeError("no attribute '%s' in op '%s'" % (attr, self.name))
432 |
433 | self.kwargs[self.attrs.index(attr)] = value
434 |
435 | def get(self, *names):
436 | """
437 | Returns one or more tensors given by *names* using a deep lookup within this op. *None* is
438 | returned when no tensor was found.
439 | """
440 | tensors = tuple(self._get(name) for name in names)
441 | return tensors[0] if len(names) == 1 else tensors
442 |
443 | def _get(self, name):
444 | return reduce(lambda t1,t2: t1 or t2.get(name), self.inputs, None)
445 |
446 | def eval(self, feed_dict=None, _uuid=None):
447 | """ eval(feed_dict=None)
448 | Returns the value of all output tensors in a tuple. See :py:meth:`Tensor.eval` for more
449 | info.
450 | """
451 | # set a cache uuid for this eval call
452 | if _uuid is None:
453 | _uuid = uuid4()
454 |
455 | # already cached?
456 | if _uuid == self.last_uuid:
457 | return self.value
458 | else:
459 | self.last_uuid = _uuid
460 |
461 | args = [t.eval(feed_dict=feed_dict, _uuid=_uuid) for t in self.inputs]
462 | if self.unpack:
463 | args.extend(self.kwargs)
464 | else:
465 | args = [args] + self.kwargs
466 | if self.__class__.output_dtypes:
467 | args.append(self.output_dtypes)
468 |
469 | self.value = self.func(*args)
470 |
471 | return self.value
472 |
473 | @classmethod
474 | def func(cls, *args):
475 | """
476 | The actual op logic. By default, the method call is forwareded to the
477 | implementation-specific version which is determined using *impl*. Overwrite this method in
478 | inheriting classes to disable this feature. Must return a tuple.
479 | """
480 | if cls.impl == IMPL_NUMPY:
481 | return cls.func_numpy(*args)
482 | elif cls.impl == IMPL_SCIPY:
483 | return cls.func_scipy(*args)
484 | else:
485 | raise InvalidImplementationException(cls.impl)
486 |
487 | @staticmethod
488 | def func_numpy(*args):
489 | """
490 | Numpy implementation of the op logic. Returns a tuple.
491 | """
492 | raise NotImplementedError
493 |
494 | @staticmethod
495 | def func_scipy(*args):
496 | """
497 | Scipy implementation of the op logic. Returns a tuple.
498 | """
499 | raise NotImplementedError
500 |
501 | @classmethod
502 | def factory(cls, func=None, impl=IMPL_NUMPY, **kwargs):
503 | """ factory(func=None, impl=IMPL_NUMPY, **kwargs)
504 | Returns a new op class whose static function will be set to *func*. The name of *func* will
505 | also be the op class name. *impl* is the default implementation type of the op. *kwargs* are
506 | used to update the class dict of the newly created op class.
507 | """
508 | if impl not in IMPLS:
509 | raise InvalidImplementationException(impl)
510 |
511 | def wrapper(func):
512 | classdict = {"impls": [], "func_" + IMPL_NAMES[impl]: staticmethod(func)}
513 | classdict.update(kwargs)
514 |
515 | cls = Operation.__class__(func.__name__, (Operation,), classdict)
516 | cls.__doc__ = func.__doc__
517 | cls.impls.append(impl)
518 | cls.use_impl(impl)
519 |
520 | return cls
521 |
522 | return wrapper if func is None else wrapper(func)
523 |
524 | @classmethod
525 | def use_impl(cls, impl):
526 | """
527 | Switches the implementation type to *impl*. Returns the previous type.
528 | """
529 | if impl not in cls.impls:
530 | raise UnknownImplementationException(impl)
531 |
532 | prev = cls.impl
533 | cls.impl = impl
534 | return prev
535 |
536 | @classmethod
537 | def add_impl(cls, impl):
538 | """
539 | Decorator to add an additional implementation to this op. Example:
540 |
541 | .. code-block:: python
542 |
543 | # initial implementation using factory, defaults to numpy
544 | @Operation.factory
545 | def MyOp(a, b):
546 | # use numpy only
547 | return ...
548 |
549 | # also add a scipy implementation
550 | @MyOp.add_impl(IMPL_SCIPY)
551 | def MyOp(a, b):
552 | # also use scipy
553 | return ...
554 | """
555 | if impl not in IMPLS:
556 | raise InvalidImplementationException(impl)
557 |
558 | def wrapper(func):
559 | setattr(cls, "func_" + IMPL_NAMES[impl], staticmethod(func))
560 | if impl not in cls.impls:
561 | cls.impls.append(impl)
562 | return cls
563 |
564 | return wrapper
565 |
566 |
567 | # ensemble method types
568 | METHODS = METHOD_MEAN, METHOD_MAX, METHOD_MIN, METHOD_CUSTOM = range(4)
569 | METHOD_NAMES = ["mean", "max", "min", "custom"]
570 |
571 |
572 | class Ensemble(object):
573 | """
574 | An ensemble is a wrapper around multiple models to compute ensemble values. It can initialized
575 | with a list of model paths and an ensembling method that decides how to compute the merged
576 | value.
577 |
578 | .. code-block:: python
579 |
580 | # create the ensemble
581 | ensemble = Ensemble(["model1.pkl", "model2.pkl", ...], METHOD_MEAN)
582 |
583 | # get input and output tensors (which actually are TensorEnsemble instances)
584 | input, output = ensemble.get("input", "output")
585 |
586 | # evaluate the ensemble just like a normal model
587 | batch = ...
588 | value = output.eval({input: batch})
589 |
590 | If you want to use another method than ``METHOD_MEAN``, ``METHOD_MAX`` or ``METHOD_MAX``, use
591 | ``METHOD_CUSTOM`` and overwrite the ``func_custom`` method of the :py:class:`TensorEnsemble`
592 | instance.
593 |
594 | .. py:attribute:: models
595 |
596 | A list that contains all read models.
597 |
598 | .. py:attribute:: method
599 |
600 | The ensembling method.
601 | """
602 |
603 | def __init__(self, paths=None, method=METHOD_MEAN):
604 | """ __init__(paths=None, method=METHOD_MEAN)
605 | """
606 | super(Ensemble, self).__init__()
607 |
608 | # check method
609 | if method not in METHODS:
610 | raise UnknownEnsembleMethodException(method)
611 | self.method = method
612 |
613 | # loaded models
614 | self.models = []
615 |
616 | # load when desired
617 | if paths is not None:
618 | self.load(paths)
619 |
620 | def get(self, *names, **kwargs):
621 | """ get(*names, key=None)
622 | Returns one or more :py:class:`TensorEnsemble` instances given by *names* using a deep
623 | lookup within all read models. Each returned tensor ensemble will have ``len(models)``
624 | tensors. If a model does not contain a specific tensor defined by a specific *name*, the
625 | associated ensemble tensor will contain a *None* for that model in its tensors. If *key* is
626 | not *None*, only the root tensors with that *key* are traversed.
627 | """
628 | # create empty tensor ensembles with our method
629 | tensor_ensembles = [TensorEnsemble([], self.method) for name in names]
630 |
631 | # loop over models, collect and add tensors
632 | for model in self.models:
633 | tensors = model.get(*names, **kwargs)
634 | if not isinstance(tensors, tuple):
635 | tensors = (tensors,)
636 | for i, t in enumerate(tensors if isinstance(tensors, tuple) else (tensors,)):
637 | tensor_ensembles[i].tensors.append(t)
638 |
639 | return tensor_ensembles[0] if len(names) == 1 else tuple(tensor_ensembles)
640 |
641 | def load(self, paths):
642 | """
643 | Loads models from a list of *paths*.
644 | """
645 | for path in paths:
646 | self.models.append(Model(path))
647 |
648 |
649 | class TensorEnsemble(object):
650 | """
651 | A tensor ensemble basically contains a list of tensors that correspond to models of an
652 | :py:class:`Ensemble` instance.
653 |
654 | .. py:attribute: tensors
655 |
656 | The list of contained tensors. Tensor *i* corresponds to model *i*.
657 |
658 | .. py:attribute: method
659 |
660 | The ensembling method.
661 | """
662 |
663 | def __init__(self, tensors, method=METHOD_MEAN):
664 | super(TensorEnsemble, self).__init__()
665 |
666 | # check method
667 | if method not in METHODS:
668 | raise UnknownEnsembleMethodException(method)
669 | self.method = method
670 |
671 | self.tensors = list(tensors)
672 |
673 | def eval(self, feed_dict=None):
674 | """
675 | Evaluates all contained tensors using a *feed_dict* and returns the ensemble value. The keys
676 | of *feed_dict* must be tensor ensembles. Its values can be batches, i.e., numpy arrays, or
677 | lists or tuples of batches. In the latter case, these lists or tuples must have the same
678 | length as the list of stored tensors as they will be mapped.
679 | """
680 | # first, check that the length of all feed_dict keys match our own length
681 | for tensor_ensemble in feed_dict:
682 | if len(tensor_ensemble.tensors) != len(self.tensors):
683 | raise EnsembleMismatchException("incompatible lengths of tensors: %d, %d" \
684 | % (len(self.tensors), len(tensor_ensemble.tensors)))
685 |
686 | # create a joined uuid
687 | _uuid = uuid4()
688 |
689 | # prepare feed_dicts
690 | feed_dicts = [{} for _ in range(len(self.tensors))]
691 | for tensor_ensemble, value in feed_dict.items():
692 | for i, tensor in enumerate(tensor_ensemble.tensors):
693 | if tensor is not None:
694 | feed_dicts[i][tensor] = value[i] if isinstance(value, (list, tuple)) else value
695 |
696 | # eval all tensors
697 | values = [t.eval(feed_dict=d, _uuid=_uuid) for t, d in zip(self.tensors, feed_dicts)]
698 |
699 | # return the computed ensemble value
700 | return self.func(values)
701 |
702 | def __call__(self, *args, **kwargs):
703 | return self.eval(*args, **kwargs)
704 |
705 | def func(self, values):
706 | """
707 | The actual ensembling logic that combines multiple *values*. The method call is forwareded
708 | tothe ensemble method-specific variant which is determined using *method*.
709 | """
710 | if self.method == METHOD_MEAN:
711 | return self.func_mean(values)
712 | elif self.method == METHOD_MAX:
713 | return self.func_max(values)
714 | elif self.method == METHOD_MIN:
715 | return self.func_min(values)
716 | elif self.method == METHOD_CUSTOM:
717 | return self.func_custom(values)
718 | else:
719 | raise UnknownEnsembleMethodException(self.method)
720 |
721 | @staticmethod
722 | def func_mean(values):
723 | return np.mean(np.stack(values), axis=0)
724 |
725 | @staticmethod
726 | def func_max(values):
727 | return np.amax(np.stack(values), axis=0)
728 |
729 | @staticmethod
730 | def func_min(values):
731 | return np.amin(np.stack(values), axis=0)
732 |
733 | @staticmethod
734 | def func_custom(values):
735 | raise NotImplementedError
736 |
737 |
738 | class UnknownOperationException(Exception):
739 | """
740 | An exception which is raised when trying to convert an unknown tensorflow.
741 | """
742 |
743 |
744 | class OperationMismatchException(Exception):
745 | """
746 | An exception which is raised during instantiation of an op whose type does not match the
747 | underlying tensorflow op.
748 | """
749 |
750 |
751 | class InvalidImplementationException(Exception):
752 | """
753 | An exception which is raised when an implementation of an unknown type is registered for an
754 | :py:class:`Operation` class.
755 | """
756 |
757 |
758 | class UnknownImplementationException(Exception):
759 | """
760 | An exception which is raised when an :py:class:`Operation` instance is requested to use an
761 | implementation type that was not yet added.
762 | """
763 |
764 |
765 | class UnknownEnsembleMethodException(Exception):
766 | """
767 | An exception which is raised when an :py:class:`Ensemble` instance is initialised with an
768 | unknown ensemle method.
769 | """
770 |
771 |
772 | class EnsembleMismatchException(Exception):
773 | """
774 | An exception which is raised when a :py:class:`TensorEnsemble` instance is evaluated with a
775 | *feed_dict* whose keys, i.e. also :py:class:`TensorEnsemble` instances, do not match the tensor
776 | to evaluate. An example would be that a tensor ensemble with *n* tensors is evaluated with a
777 | tensor ensemble it its *feed_dict* that contains *m* tensors.
778 | """
779 |
780 |
781 | class ScipyOperationException(Exception):
782 | """
783 | An exception which is raised when trying to evaluate an op that uses scipy internally and scipy
784 | is not available.
785 | """
786 | def __init__(self, attr):
787 | msg = "trying to access 'scipy.%s', but scipy is not installed on your system, " \
788 | "install scipy to use this operation or use an other implementation" % attr
789 | super(ScipyOperationException, self).__init__(msg)
790 |
791 |
792 | # parses the tf version and returns a tuple, e.g. "0.12.0-rc1" => (0, 12, 0, "rc1")
793 | def _parse_tf_version(v):
794 | parts = v.split(".", 2)
795 | if "-" in parts[2]:
796 | parts.extend(parts.pop().split("-", 1))
797 | return tuple([int(p) for p in parts[:3]] + parts[3:])
798 |
799 |
800 | # default (last) tf version
801 | _tf_version_string = "0.12.0-rc1"
802 | _tf_version = _parse_tf_version(_tf_version_string)
803 |
804 |
805 | def setup(tf, order=None):
806 | """
807 | Sets up global variables (currently only the tensorflow version) to adapt to peculiarities of
808 | different tensorflow versions. This function should only be called before :py:class:`Model`
809 | creation, not for evaluation. Therefore, the tensorflow module *tf* must be passed:
810 |
811 | .. code-block:: python
812 |
813 | import tensorflow as tf
814 | import tfdeploy as td
815 |
816 | td.setup(tf)
817 |
818 | # ...
819 |
820 | Also, when *order* is not *None*, it is forwarded to :py:func:`optimize` for convenience.
821 | """
822 | global _tf_version_string, _tf_version
823 | _tf_version_string = tf.__version__
824 | _tf_version = _parse_tf_version(_tf_version_string)
825 |
826 | if order is not None:
827 | optimize(order)
828 |
829 |
830 | def reset():
831 | """
832 | Resets the instance caches of :py:class:`TensorRegister` and :py:class:`OperationRegister`.
833 | """
834 | TensorRegister.instances.clear()
835 | OperationRegister.instances.clear()
836 |
837 |
838 | def optimize(order):
839 | """ optimize(impl)
840 | Tries to set the implementation type of all registered :py:class:`Operation` classes to *impl*.
841 | This has no effect when an op does not implement that type.
842 |
843 | The behavior is equivalent to:
844 |
845 | .. code-block:: python
846 |
847 | for op in Operation.__subclasses__():
848 | if impl in op.impls:
849 | op.use_impl(impl)
850 |
851 | *impl* can also be a list or tuple of valid implementation types representing a preferred order.
852 | """
853 | if not isinstance(order, (list, tuple)):
854 | order = [order]
855 |
856 | for op in Operation.__subclasses__():
857 | for impl in order:
858 | if impl in op.impls:
859 | op.use_impl(impl)
860 | break
861 |
862 |
863 | def print_tensor(td_tensor, indent="| ", max_depth=-1, depth=0):
864 | """ print_tensor(td_tensor, indent=" ", max_depth=-1)
865 | Prints the dependency graph of a :py:class:`Tensor` *td_tensor*, where each new level is
866 | indented by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where
867 | each tensor and each op count as a level.
868 | """
869 | offset = depth * indent
870 | line = "td tensor: %s" % td_tensor.name
871 | if td_tensor.value is not None:
872 | line += " (%s)" % (",".join(str(i) for i in td_tensor.value.shape),)
873 |
874 | print(offset + line)
875 |
876 | if td_tensor.op and (max_depth < 0 or max_depth > depth):
877 | print_op(td_tensor.op, indent=indent, max_depth=max_depth, depth=depth+1)
878 |
879 |
880 | def print_op(td_op, indent="| ", max_depth=-1, depth=0):
881 | """ print_op(td_op, indent=" ", max_depth=-1)
882 | Prints the dependency graph of a :py:class:`Operation` *td_op*, where each new level is indented
883 | by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where each
884 | tensor and each op count as a level.
885 | """
886 | offset = depth * indent
887 | line = "td op: %s (%s)" % (td_op.name, ",".join(td_op.types))
888 |
889 | print(offset + line)
890 |
891 | if max_depth < 0 or max_depth > depth:
892 | for td_tensor in td_op.inputs:
893 | print_tensor(td_tensor, indent=indent, max_depth=max_depth, depth=depth+1)
894 |
895 |
896 | def print_tf_tensor(tf_tensor, indent="| ", max_depth=-1, depth=0):
897 | """ print_tf_tensor(tf_tensor, indent=" ", max_depth=-1)
898 | Prints the dependency graph of a tensorflow tensor *tf_tensor*, where each new level is indented
899 | by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where each
900 | tensor and each op count as a level.
901 | """
902 | offset = depth * indent
903 | shape = tuple(int(i) for i in tf_tensor.get_shape())
904 | line = "tf tensor: %s (%s)" % (tf_tensor.name, ",".join(str(i) for i in shape))
905 |
906 | print(offset + line)
907 |
908 | if tf_tensor.op and (max_depth < 0 or max_depth > depth):
909 | print_tf_op(tf_tensor.op, indent=indent, max_depth=max_depth, depth=depth+1)
910 |
911 |
912 | def print_tf_op(tf_op, indent="| ", max_depth=-1, depth=0):
913 | """ print_tf_op(tf_tensor, indent=" ", max_depth=-1)
914 | Prints the dependency graph of a tensorflow operation *tf_op*, where each new level is indented
915 | by *indent*. When *max_depth* is positive, the graph is truncated at that depth, where each
916 | tensor and each op count as a level.
917 | """
918 | offset = depth * indent
919 | line = "tf op: %s (%s)" % (tf_op.name, tf_op.type)
920 |
921 | print(offset + line)
922 |
923 | if max_depth < 0 or max_depth > depth:
924 | for tf_tensor in tf_op.inputs:
925 | print_tf_tensor(tf_tensor, indent=indent, max_depth=max_depth, depth=depth+1)
926 |
927 |
928 | # imports exclusively for ops
929 | from operator import mul
930 | from itertools import product
931 | from collections import defaultdict
932 |
933 | # optional import of scipy
934 | try:
935 | if os.environ.get("TD_REFUSE_SCIPY", "").lower() in ("1", "true", "yes"):
936 | raise ImportError
937 |
938 | import scipy as sp
939 | import scipy.special
940 | HAS_SCIPY = True
941 | except ImportError:
942 | class ScipyDummy(object):
943 | def __getattr__(self, attr):
944 | raise ScipyOperationException(attr)
945 | sp = ScipyDummy()
946 | HAS_SCIPY = False
947 |
948 |
949 | # mapping of tf dtypes to np dtypes
950 | dtype_map = {
951 | 1: np.float32,
952 | 2: np.float64,
953 | 3: np.int32,
954 | 4: np.uint8,
955 | 5: np.int16,
956 | 6: np.int8,
957 | 7: np.object,
958 | 8: np.complex64,
959 | 9: np.int64,
960 | 10: np.bool,
961 | 14: np.uint16,
962 | 17: np.uint16,
963 | 18: np.complex128,
964 | 19: np.float16,
965 | 101: np.float32,
966 | 102: np.float64,
967 | 103: np.int32,
968 | 104: np.uint8,
969 | 105: np.int16,
970 | 106: np.int8,
971 | 107: np.object,
972 | 108: np.complex64,
973 | 109: np.int64,
974 | 110: np.bool,
975 | 114: np.uint16,
976 | 117: np.uint16,
977 | 118: np.complex128,
978 | 119: np.float16
979 | }
980 |
981 |
982 | lgamma_vec = np.vectorize(np.math.lgamma)
983 | erf_vec = np.vectorize(np.math.erf)
984 | erfc_vec = np.vectorize(np.math.erfc)
985 |
986 | def _transpose(a, dim=2):
987 | if dim <= 0:
988 | axes = None
989 | else:
990 | axes = list(range(a.ndim))
991 | axes.append(axes.pop(-1 * dim))
992 | return np.transpose(a, axes=axes)
993 |
994 | def _adjoint(a, dim=2):
995 | return np.conj(_transpose(a, dim=dim))
996 |
997 |
998 | #
999 | # sequences
1000 | #
1001 |
1002 | @Operation.factory
1003 | def LinSpace(start, stop, num):
1004 | """
1005 | Linspace op.
1006 | """
1007 | return np.linspace(start, stop, num=num, dtype=np.float32),
1008 |
1009 |
1010 | @Operation.factory
1011 | def Range(start, limit, delta):
1012 | """
1013 | Range op.
1014 | """
1015 | return np.arange(start, limit, delta, dtype=np.int32),
1016 |
1017 |
1018 | #
1019 | # random tensors
1020 | #
1021 |
1022 | @Operation.factory(attrs=("dtype", "seed"))
1023 | def RandomStandardNormal(shape, dtype, seed):
1024 | """
1025 | Standard (mu=0, sigma=1) gaussian op.
1026 | """
1027 | if seed:
1028 | np.random.seed(seed)
1029 | return np.random.normal(size=reduce(mul, shape)).reshape(shape).astype(dtype_map[dtype]),
1030 |
1031 |
1032 | @Operation.factory(attrs=("dtype", "seed"))
1033 | def TruncatedNormal(shape, dtype, seed):
1034 | """
1035 | Standard (mu=0, sigma=1) gaussian op with truncation above 2 sigma.
1036 | """
1037 | if seed:
1038 | np.random.seed(seed)
1039 | n = reduce(mul, shape)
1040 | r = np.empty(n, dtype=dtype_map[dtype])
1041 | idxs = np.ones(n, dtype=np.bool)
1042 | while n:
1043 | r[idxs] = np.random.normal(size=n)
1044 | idxs = np.abs(r) > 2
1045 | n = np.sum(idxs)
1046 | return r.reshape(shape),
1047 |
1048 |
1049 | @Operation.factory(attrs=("dtype", "seed"))
1050 | def RandomUniform(shape, dtype, seed):
1051 | """
1052 | Random uniform op.
1053 | """
1054 | if seed:
1055 | np.random.seed(seed)
1056 | return np.random.uniform(size=shape).astype(dtype_map[dtype]),
1057 |
1058 |
1059 | @Operation.factory(attrs=("seed",))
1060 | def RandomUniformInt(shape, minval, maxval, seed):
1061 | """
1062 | Random uniform int op.
1063 | """
1064 | if seed:
1065 | np.random.seed(seed)
1066 | return np.random.randint(minval, maxval, size=shape),
1067 |
1068 |
1069 | @Operation.factory(attrs=("seed",))
1070 | def RandomShuffle(a, seed):
1071 | """
1072 | Random uniform op.
1073 | """
1074 | if seed:
1075 | np.random.seed(seed)
1076 | r = a.copy()
1077 | np.random.shuffle(r)
1078 | return r,
1079 |
1080 |
1081 | #
1082 | # casting
1083 | #
1084 |
1085 | @Operation.factory(types=("Cast", "StringToNumber"), output_dtypes=True)
1086 | def Cast(a, output_dtypes):
1087 | """
1088 | Cast op.
1089 | """
1090 | return np.copy(a).astype(output_dtypes[0]),
1091 |
1092 |
1093 | #
1094 | # shapes and shaping
1095 | #
1096 |
1097 | @Operation.factory
1098 | def Shape(a):
1099 | """
1100 | Shape op.
1101 | """
1102 | return np.array(a.shape, dtype=np.int32),
1103 |
1104 |
1105 | @Operation.factory
1106 | def Size(a):
1107 | """
1108 | Size op.
1109 | """
1110 | return np.array([a.size], dtype=np.int32),
1111 |
1112 |
1113 | @Operation.factory
1114 | def Rank(a):
1115 | """
1116 | Rank op.
1117 | """
1118 | return np.array([len(a.shape)], dtype=np.int32),
1119 |
1120 |
1121 | @Operation.factory
1122 | def Reshape(a, shape):
1123 | """
1124 | Reshape op.
1125 | """
1126 | return np.copy(a).reshape(shape),
1127 |
1128 |
1129 | @Operation.factory(attrs=("squeeze_dims",))
1130 | def Squeeze(a, squeeze_dims):
1131 | """
1132 | Squeeze op, i.e. removes singular axes.
1133 | """
1134 | if not squeeze_dims:
1135 | squeeze_dims = list(range(len(a.shape)))
1136 | slices = [(0 if (dim == 1 and i in squeeze_dims) else slice(None)) \
1137 | for i, dim in enumerate(a.shape)]
1138 | return np.copy(a)[slices],
1139 |
1140 |
1141 | @Operation.factory
1142 | def ExpandDims(a, dim):
1143 | """
1144 | Expand dim op, i.e. add singular axis at dim.
1145 | """
1146 | shape = list(a.shape)
1147 | if dim >= 0:
1148 | shape.insert(dim, 1)
1149 | else:
1150 | shape.insert(len(shape) + dim + 1, 1)
1151 | return np.copy(a).reshape(*shape),
1152 |
1153 |
1154 | #
1155 | # slicing and joining
1156 | #
1157 |
1158 | @Operation.factory
1159 | def Slice(a, begin, size):
1160 | """
1161 | Slicing op.
1162 | """
1163 | return np.copy(a)[[slice(*tpl) for tpl in zip(begin, begin+size)]],
1164 |
1165 |
1166 | @Operation.factory(attrs=("num_split",))
1167 | def Split(axis, a, n):
1168 | """
1169 | Split op with n splits.
1170 | """
1171 | return tuple(np.split(np.copy(a), n, axis=axis))
1172 |
1173 |
1174 | @Operation.factory
1175 | def SplitV(a, splits, axis):
1176 | """
1177 | Split op with multiple split sizes.
1178 | """
1179 | return tuple(np.split(np.copy(a), np.cumsum(splits), axis=axis))
1180 |
1181 |
1182 | @Operation.factory
1183 | def Tile(a, n):
1184 | """
1185 | Tile op.
1186 | """
1187 | return np.tile(a, n),
1188 |
1189 |
1190 | @Operation.factory
1191 | def Pad(a, paddings):
1192 | """
1193 | Zero padping op.
1194 | """
1195 | return np.pad(a, paddings, mode="constant", constant_values=0),
1196 |
1197 |
1198 | @Operation.factory(unpack=False)
1199 | def ConcatV2(inputs):
1200 | """
1201 | Concat op.
1202 | """
1203 | axis = inputs.pop()
1204 | return np.concatenate(inputs, axis=axis),
1205 |
1206 |
1207 | @Operation.factory(attrs=("axis",), unpack=False)
1208 | def Pack(inputs, axis):
1209 | """
1210 | Pack op.
1211 | """
1212 | return np.stack(inputs, axis=axis),
1213 |
1214 |
1215 | @Operation.factory(attrs=("num", "axis"))
1216 | def Unpack(a, num, axis):
1217 | """
1218 | Unpack op.
1219 | """
1220 | return tuple(np.squeeze(b, axis=axis) for b in np.split(a, num, axis=axis))
1221 |
1222 |
1223 | @Operation.factory(attrs=("seq_dim", "batch_dim"))
1224 | def ReverseSequence(a, seq_lengths, seq_dim, batch_dim):
1225 | """
1226 | Sequential reverse op.
1227 | """
1228 | r = np.copy(a)
1229 | invidxs = (len(r.shape) - 1) * [slice(None)]
1230 | if seq_dim < batch_dim:
1231 | invidxs[seq_dim] = slice(None, None, -1)
1232 | else:
1233 | invidxs[seq_dim - 1] = slice(None, None, -1)
1234 | _invidxs = tuple(invidxs)
1235 | selidxs = len(r.shape) * [slice(None)]
1236 | for i, l in enumerate(seq_lengths):
1237 | if not l:
1238 | continue
1239 | selidxs[batch_dim] = i
1240 | selidxs[seq_dim] = slice(0, l)
1241 | _selidxs = tuple(selidxs)
1242 | r[_selidxs] = a[_selidxs][_invidxs]
1243 | return r,
1244 |
1245 |
1246 | @Operation.factory
1247 | def ReverseV2(a, axes):
1248 | """
1249 | Reverse op.
1250 | """
1251 | idxs = tuple(slice(None, None, 2 * int(i not in axes) - 1) for i in range(len(a.shape)))
1252 | return np.copy(a[idxs]),
1253 |
1254 |
1255 | @Operation.factory
1256 | def Transpose(a, perm=None):
1257 | """
1258 | Transpose op.
1259 | """
1260 | return np.transpose(a, axes=perm),
1261 |
1262 |
1263 | #
1264 | # arithmetic math ops
1265 | #
1266 |
1267 | @Operation.factory(types=("Add", "BiasAdd"))
1268 | def Add(a, b):
1269 | """
1270 | Addition op.
1271 | """
1272 | return np.add(a, b),
1273 |
1274 |
1275 | @Operation.factory(types=("Subtract", "Sub"))
1276 | def Subtract(a, b):
1277 | """
1278 | Subtraction op.
1279 | """
1280 | return np.subtract(a, b),
1281 |
1282 |
1283 | @Operation.factory(types=("Multiply", "Mul"))
1284 | def Multiply(a, b):
1285 | """
1286 | Multiplication op.
1287 | """
1288 | return np.multiply(a, b),
1289 |
1290 |
1291 | @Operation.factory(types=("Div", "RealDiv"))
1292 | def Div(a, b):
1293 | """
1294 | Division op.
1295 | """
1296 | return np.divide(a, b),
1297 |
1298 |
1299 | @Operation.factory
1300 | def FloorDiv(a, b):
1301 | """
1302 | Floor division op, i.e., a // b.
1303 | """
1304 | return np.floor_divide(a, b),
1305 |
1306 |
1307 | @Operation.factory(types=("Mod", "FloorMod"))
1308 | def Mod(a, b):
1309 | """
1310 | Modulo op.
1311 | """
1312 | return np.mod(a, b),
1313 |
1314 |
1315 | @Operation.factory
1316 | def Cross(a, b):
1317 | """
1318 | Cross product op.
1319 | """
1320 | return np.cross(a, b),
1321 |
1322 |
1323 | #
1324 | # basic math ops
1325 | #
1326 |
1327 | @Operation.factory(unpack=False)
1328 | def AddN(inputs):
1329 | """
1330 | Multi add op.
1331 | """
1332 | return reduce(np.add, inputs),
1333 |
1334 |
1335 | @Operation.factory
1336 | def Abs(a):
1337 | """
1338 | Abs op.
1339 | """
1340 | return np.abs(a),
1341 |
1342 |
1343 | @Operation.factory(types=("Negative", "Neg"))
1344 | def Negative(a):
1345 | """
1346 | Negative op.
1347 | """
1348 | return np.negative(a),
1349 |
1350 |
1351 | @Operation.factory
1352 | def Sign(a):
1353 | """
1354 | Sign op.
1355 | """
1356 | return np.sign(a),
1357 |
1358 |
1359 | @Operation.factory
1360 | def Inv(a):
1361 | """
1362 | Reciprocal op.
1363 | """
1364 | return np.reciprocal(a),
1365 |
1366 |
1367 | @Operation.factory
1368 | def Square(a):
1369 | """
1370 | Square op.
1371 | """
1372 | return np.square(a),
1373 |
1374 |
1375 | @Operation.factory
1376 | def Round(a):
1377 | """
1378 | Round op.
1379 | """
1380 | return np.round(a),
1381 |
1382 |
1383 | @Operation.factory
1384 | def Sqrt(a):
1385 | """
1386 | Square root op.
1387 | """
1388 | return np.sqrt(a),
1389 |
1390 |
1391 | @Operation.factory
1392 | def Rsqrt(a):
1393 | """
1394 | Reciprocal square root op.
1395 | """
1396 | return np.reciprocal(np.sqrt(a)),
1397 |
1398 |
1399 | @Operation.factory
1400 | def Pow(a, b):
1401 | """
1402 | Power op.
1403 | """
1404 | return np.power(a, b),
1405 |
1406 |
1407 | @Operation.factory
1408 | def Exp(a):
1409 | """
1410 | Exponential op.
1411 | """
1412 | return np.exp(a),
1413 |
1414 |
1415 | @Operation.factory
1416 | def Log(a):
1417 | """
1418 | Logarithm op.
1419 | """
1420 | return np.log(a),
1421 |
1422 |
1423 | @Operation.factory
1424 | def Ceil(a):
1425 | """
1426 | Ceil round op.
1427 | """
1428 | return np.ceil(a),
1429 |
1430 |
1431 | @Operation.factory
1432 | def Floor(a):
1433 | """
1434 | Floor round op.
1435 | """
1436 | return np.floor(a),
1437 |
1438 |
1439 | @Operation.factory
1440 | def Maximum(a, b):
1441 | """
1442 | Maximum op.
1443 | """
1444 | return np.maximum(a, b),
1445 |
1446 |
1447 | @Operation.factory
1448 | def Minimum(a, b):
1449 | """
1450 | Minimum op.
1451 | """
1452 | return np.minimum(a, b),
1453 |
1454 |
1455 | @Operation.factory
1456 | def Cos(a):
1457 | """
1458 | Cos op.
1459 | """
1460 | return np.cos(a),
1461 |
1462 |
1463 | @Operation.factory
1464 | def Sin(a):
1465 | """
1466 | Sin op.
1467 | """
1468 | return np.sin(a),
1469 |
1470 |
1471 | @Operation.factory
1472 | def Tan(a):
1473 | """
1474 | Tan op.
1475 | """
1476 | return np.tan(a),
1477 |
1478 |
1479 | @Operation.factory
1480 | def Acos(a):
1481 | """
1482 | Acos op.
1483 | """
1484 | return np.arccos(a),
1485 |
1486 |
1487 | @Operation.factory
1488 | def Asin(a):
1489 | """
1490 | Asin op.
1491 | """
1492 | return np.arcsin(a),
1493 |
1494 |
1495 | @Operation.factory
1496 | def Atan(a):
1497 | """
1498 | Atan op.
1499 | """
1500 | return np.arctan(a),
1501 |
1502 |
1503 | @Operation.factory
1504 | def Lgamma(a):
1505 | """
1506 | lgamma op.
1507 | """
1508 | return lgamma_vec(a),
1509 |
1510 | @Lgamma.add_impl(IMPL_SCIPY)
1511 | def Lgamma(a):
1512 | return sp.special.gammaln(a),
1513 |
1514 |
1515 | @Operation.factory(impl=IMPL_SCIPY)
1516 | def Digamma(a):
1517 | """
1518 | Digamma op.
1519 | """
1520 | return sp.special.digamma(a),
1521 |
1522 |
1523 | @Operation.factory
1524 | def Erf(a):
1525 | """
1526 | Gaussian error function op.
1527 | """
1528 | return erf_vec(a),
1529 |
1530 | @Erf.add_impl(IMPL_SCIPY)
1531 | def Erf(a):
1532 | return sp.special.erf(a),
1533 |
1534 |
1535 | @Operation.factory
1536 | def Erfc(a):
1537 | """
1538 | Complementary gaussian error function op.
1539 | """
1540 | return erfc_vec(a),
1541 |
1542 | @Erfc.add_impl(IMPL_SCIPY)
1543 | def Erfc(a):
1544 | return sp.special.erfc(a),
1545 |
1546 |
1547 | @Operation.factory
1548 | def SquaredDifference(a, b):
1549 | """
1550 | Squared diff op, i.e. (a-b)**2
1551 | """
1552 | return (a - b)**2,
1553 |
1554 |
1555 | @Operation.factory(impl=IMPL_SCIPY)
1556 | def Igamma(a, b):
1557 | """
1558 | Incomplete gamma op.
1559 | """
1560 | return sp.special.gammainc(a, b),
1561 |
1562 |
1563 | @Operation.factory(impl=IMPL_SCIPY)
1564 | def Igammac(a, b):
1565 | """
1566 | Complemented, incomplete gamma op.
1567 | """
1568 | return sp.special.gammaincc(a, b),
1569 |
1570 |
1571 | @Operation.factory(impl=IMPL_SCIPY)
1572 | def Zeta(a, b):
1573 | """
1574 | Zeta op.
1575 | """
1576 | return sp.special.zeta(a, b),
1577 |
1578 |
1579 | @Operation.factory(impl=IMPL_SCIPY)
1580 | def Polygamma(a, b):
1581 | """
1582 | Polygamma op.
1583 | """
1584 | return sp.special.polygamma(a, b),
1585 |
1586 |
1587 | @Operation.factory(impl=IMPL_SCIPY)
1588 | def Betainc(a, b, x):
1589 | """
1590 | Complemented, incomplete gamma op.
1591 | """
1592 | return sp.special.betainc(a, b, x),
1593 |
1594 |
1595 | #
1596 | # matrix math ops
1597 | #
1598 |
1599 | @Operation.factory
1600 | def Diag(a):
1601 | """
1602 | Diag op.
1603 | """
1604 | r = np.zeros(2 * a.shape, dtype=a.dtype)
1605 | for idx, v in np.ndenumerate(a):
1606 | r[2 * idx] = v
1607 | return r,
1608 |
1609 |
1610 | @Operation.factory
1611 | def DiagPart(a):
1612 | """
1613 | Diag op that returns only the diagonal elements.
1614 | """
1615 | return np.diagonal(a),
1616 |
1617 |
1618 | @Operation.factory
1619 | def MatrixDiagPart(a):
1620 | """
1621 | Batched diag op that returns only the diagonal elements.
1622 | """
1623 | r = np.zeros(a.shape[:-2] + (min(a.shape[-2:]),))
1624 | for coord in np.ndindex(a.shape[:-2]):
1625 | pos = coord + (Ellipsis,)
1626 | r[pos] = np.diagonal(a[pos])
1627 | return r,
1628 |
1629 |
1630 | @Operation.factory(attrs=("transpose_a", "transpose_b"))
1631 | def MatMul(a, b, transpose_a, transpose_b):
1632 | """
1633 | Matrix multiplication op.
1634 | """
1635 | return np.dot(a if not transpose_a else np.transpose(a),
1636 | b if not transpose_b else np.transpose(b)),
1637 |
1638 |
1639 | @Operation.factory
1640 | def MatrixDeterminant(a):
1641 | """
1642 | Matrix det op.
1643 | """
1644 | return np.linalg.det(a),
1645 |
1646 |
1647 | @Operation.factory(attrs=("adjoint",))
1648 | def MatrixInverse(a, adj):
1649 | """
1650 | Matrix inversion op.
1651 | """
1652 | return np.linalg.inv(a if not adj else _adjoint(a)),
1653 |
1654 |
1655 | @Operation.factory
1656 | def Cholesky(a):
1657 | """
1658 | Cholesky decomposition op.
1659 | """
1660 | return np.linalg.cholesky(a),
1661 |
1662 |
1663 | @Operation.factory(attrs=("adjoint",))
1664 | def MatrixSolve(a, rhs, adj):
1665 | """
1666 | Matrix solve op.
1667 | """
1668 | return np.linalg.solve(a if not adj else _adjoint(a), rhs),
1669 |
1670 |
1671 | @Operation.factory(attrs=("lower", "adjoint"), impl=IMPL_SCIPY)
1672 | def MatrixTriangularSolve(a, rhs, lower, adj):
1673 | """
1674 | Matrix triangular solve op.
1675 | """
1676 | trans = 0 if not adj else 2
1677 |
1678 | r = np.empty(rhs.shape).astype(a.dtype)
1679 | for coord in np.ndindex(a.shape[:-2]):
1680 | pos = coord + (Ellipsis,)
1681 | r[pos] = sp.linalg.solve_triangular(a[pos] if not adj else np.conj(a[pos]), rhs[pos],
1682 | trans=trans, lower=lower)
1683 |
1684 | return r,
1685 |
1686 |
1687 | @Operation.factory
1688 | def MatrixSolveLs(a, rhs, l2_reg):
1689 | """
1690 | Matrix least-squares solve op.
1691 | """
1692 | r = np.empty(rhs.shape).astype(a.dtype)
1693 | for coord in np.ndindex(a.shape[:-2]):
1694 | pos = coord + (Ellipsis,)
1695 | r[pos] = np.linalg.lstsq(a[pos], rhs[pos])[0]
1696 |
1697 | return r,
1698 |
1699 |
1700 | @Operation.factory
1701 | def SelfAdjointEig(a):
1702 | """
1703 | Eigen decomp op.
1704 | """
1705 | shape = list(a.shape)
1706 | shape[-2] += 1
1707 | return np.append(*np.linalg.eig(a)).reshape(*shape),
1708 |
1709 |
1710 | @Operation.factory
1711 | def SelfAdjointEigV2(a):
1712 | """
1713 | Eigen decomp op.
1714 | """
1715 | return np.linalg.eig(a)
1716 |
1717 |
1718 | @Operation.factory(attrs=("compute_uv", "full_matrices"))
1719 | def Svd(a, uv, full):
1720 | """
1721 | Single value decomp op.
1722 | """
1723 | u, s, v = np.linalg.svd(a, full_matrices=full, compute_uv=uv)
1724 | return s, u, v
1725 |
1726 |
1727 | #
1728 | # complex number ops
1729 | #
1730 |
1731 | @Operation.factory
1732 | def Complex(a, b):
1733 | """
1734 | Complex number op.
1735 | """
1736 | return np.add(a, np.multiply(b, 1j)),
1737 |
1738 |
1739 | @Operation.factory
1740 | def Conj(a):
1741 | """
1742 | Complex conjugate op.
1743 | """
1744 | return np.conj(a),
1745 |
1746 |
1747 | @Operation.factory
1748 | def Imag(a):
1749 | """
1750 | Complex imag op.
1751 | """
1752 | return np.imag(a),
1753 |
1754 |
1755 | @Operation.factory
1756 | def Real(a):
1757 | """
1758 | Complex real op.
1759 | """
1760 | return np.real(a),
1761 |
1762 |
1763 | #
1764 | # Fourier transform ops
1765 | #
1766 |
1767 | @Operation.factory
1768 | def FFT2D(a):
1769 | """
1770 | Discrete 2D FT op.
1771 | """
1772 | return np.fft.fft2(a),
1773 |
1774 |
1775 | @Operation.factory
1776 | def IFFT2D(a):
1777 | """
1778 | Discrete inverse 2D FT op.
1779 | """
1780 | return np.fft.ifft2(a),
1781 |
1782 |
1783 | @Operation.factory
1784 | def FFT3D(a):
1785 | """
1786 | Discrete 3D FT op.
1787 | """
1788 | return np.fft.fftn(a),
1789 |
1790 |
1791 | @Operation.factory
1792 | def IFFT3D(a):
1793 | """
1794 | Discrete inverse 3D FT op.
1795 | """
1796 | return np.fft.ifftn(a),
1797 |
1798 |
1799 | #
1800 | # reduction
1801 | #
1802 |
1803 | @Operation.factory(attrs=("keep_dims",))
1804 | def Sum(a, axis, keep_dims):
1805 | """
1806 | Sum reduction op.
1807 | """
1808 | return np.sum(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis),
1809 | keepdims=keep_dims),
1810 |
1811 |
1812 | @Operation.factory(attrs=("keep_dims",))
1813 | def Prod(a, axis, keep_dims):
1814 | """
1815 | Prod reduction op.
1816 | """
1817 | return np.prod(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis),
1818 | keepdims=keep_dims),
1819 |
1820 |
1821 | @Operation.factory(attrs=("keep_dims",))
1822 | def Min(a, axis, keep_dims):
1823 | """
1824 | Min reduction op.
1825 | """
1826 | return np.amin(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis),
1827 | keepdims=keep_dims),
1828 |
1829 |
1830 | @Operation.factory(attrs=("keep_dims",))
1831 | def Max(a, axis, keep_dims):
1832 | """
1833 | Max reduction op.
1834 | """
1835 | return np.amax(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis),
1836 | keepdims=keep_dims),
1837 |
1838 |
1839 | @Operation.factory(attrs=("keep_dims",))
1840 | def Mean(a, axis, keep_dims):
1841 | """
1842 | Mean reduction op.
1843 | """
1844 | return np.mean(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis),
1845 | keepdims=keep_dims),
1846 |
1847 |
1848 | @Operation.factory(attrs=("keep_dims",))
1849 | def All(a, axis, keep_dims):
1850 | """
1851 | All reduction op.
1852 | """
1853 | return np.all(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis),
1854 | keepdims=keep_dims),
1855 |
1856 |
1857 | @Operation.factory(attrs=("keep_dims",))
1858 | def Any(a, axis, keep_dims):
1859 | """
1860 | Any reduction op.
1861 | """
1862 | return np.any(a, axis=axis if not isinstance(axis, np.ndarray) else tuple(axis),
1863 | keepdims=keep_dims),
1864 |
1865 |
1866 | #
1867 | # segmentation
1868 | #
1869 |
1870 | def seg_map(func, a, ids):
1871 | m = defaultdict(list)
1872 | for i, e in enumerate(ids):
1873 | m[e].append(i)
1874 | r = np.empty((len(m),) + a.shape[1:], dtype=a.dtype)
1875 | for i, idxs in m.items():
1876 | r[i] = func(idxs)
1877 | return r
1878 |
1879 |
1880 | @Operation.factory(types=("SegmentSum", "UnsortedSegmentSum"))
1881 | def SegmentSum(a, ids, *args):
1882 | """
1883 | Segmented sum op.
1884 | """
1885 | func = lambda idxs: reduce(np.add, a[idxs])
1886 | return seg_map(func, a, ids),
1887 |
1888 |
1889 | @Operation.factory
1890 | def SegmentProd(a, ids):
1891 | """
1892 | Segmented prod op.
1893 | """
1894 | func = lambda idxs: reduce(np.multiply, a[idxs])
1895 | return seg_map(func, a, ids),
1896 |
1897 |
1898 | @Operation.factory
1899 | def SegmentMin(a, ids):
1900 | """
1901 | Segmented min op.
1902 | """
1903 | func = lambda idxs: np.amin(a[idxs], axis=0)
1904 | return seg_map(func, a, ids),
1905 |
1906 |
1907 | @Operation.factory
1908 | def SegmentMax(a, ids):
1909 | """
1910 | Segmented max op.
1911 | """
1912 | func = lambda idxs: np.amax(a[idxs], axis=0)
1913 | return seg_map(func, a, ids),
1914 |
1915 |
1916 | @Operation.factory
1917 | def SegmentMean(a, ids):
1918 | """
1919 | Segmented mean op.
1920 | """
1921 | func = lambda idxs: np.mean(a[idxs], axis=0)
1922 | return seg_map(func, a, ids),
1923 |
1924 |
1925 | @Operation.factory
1926 | def SparseSegmentSum(a, idxs, ids):
1927 | """
1928 | Sparse segmented sum op.
1929 | """
1930 | return SegmentSum.func(a[idxs], ids)
1931 |
1932 |
1933 | @Operation.factory
1934 | def SparseSegmentMean(a, idxs, ids):
1935 | """
1936 | Sparse segmented mean op.
1937 | """
1938 | return SegmentMean.func(a[idxs], ids)
1939 |
1940 |
1941 | @Operation.factory
1942 | def SparseSegmentSqrtN(a, idxs, ids):
1943 | """
1944 | Sparse segmented sum / sqrt(n=len(idxs)) op.
1945 | """
1946 | func = lambda _idxs: np.divide(reduce(np.add, a[idxs][_idxs]), np.math.sqrt(len(_idxs)))
1947 | return seg_map(func, a, ids),
1948 |
1949 |
1950 | #
1951 | # sequence comparison and indexing
1952 | #
1953 |
1954 | @Operation.factory
1955 | def ArgMin(a, dim):
1956 | """
1957 | Argmin op.
1958 | """
1959 | return np.argmin(a, axis=dim),
1960 |
1961 |
1962 | @Operation.factory
1963 | def ArgMax(a, dim):
1964 | """
1965 | Argmax op.
1966 | """
1967 | return np.argmax(a, axis=dim),
1968 |
1969 |
1970 | @Operation.factory
1971 | def ListDiff(a, b):
1972 | """
1973 | List diff op.
1974 | """
1975 | d = np.setdiff1d(a, b)
1976 | return d, np.searchsorted(a, d).astype(np.int32)
1977 |
1978 |
1979 | @Operation.factory
1980 | def Where(a):
1981 | """
1982 | Boolean where op.
1983 | """
1984 | return np.argwhere(a),
1985 |
1986 |
1987 | @Operation.factory(attrs=("out_idx",))
1988 | def Unique(a, t):
1989 | """
1990 | Unique op.
1991 | """
1992 | _, idxs, inv = np.unique(a, return_index=True, return_inverse=True)
1993 | return np.copy(a)[np.sort(idxs)], idxs[inv].astype(dtype_map[t])
1994 |
1995 |
1996 | @Operation.factory
1997 | def InvertPermutation(a):
1998 | """
1999 | Invert perm op.
2000 | """
2001 | return np.argsort(a).astype(np.int32),
2002 |
2003 |
2004 | #
2005 | # control flow ops
2006 | #
2007 |
2008 | @Operation.factory
2009 | def Identity(a):
2010 | """
2011 | Identity op.
2012 | """
2013 | return np.copy(a),
2014 |
2015 |
2016 | #
2017 | # NN activation ops
2018 | #
2019 |
2020 | @Operation.factory
2021 | def Relu(a):
2022 | """
2023 | Relu op.
2024 | """
2025 | return np.maximum(a, 0),
2026 |
2027 |
2028 | @Operation.factory
2029 | def Relu6(a):
2030 | """
2031 | Relu6 op.
2032 | """
2033 | return np.clip(a, 0, 6),
2034 |
2035 |
2036 | @Operation.factory
2037 | def Elu(a):
2038 | """
2039 | Elu op.
2040 | """
2041 | return np.where(a < 0, np.subtract(np.exp(a), 1), a),
2042 |
2043 |
2044 | @Operation.factory
2045 | def Softplus(a):
2046 | """
2047 | Softplus op.
2048 | """
2049 | return np.log(np.add(np.exp(a), 1)),
2050 |
2051 |
2052 | @Operation.factory
2053 | def Softsign(a):
2054 | """
2055 | Softsign op.
2056 | """
2057 | return np.divide(a, np.add(np.abs(a), 1)),
2058 |
2059 |
2060 | @Operation.factory
2061 | def Sigmoid(a):
2062 | """
2063 | Sogmoid (logistic) op.
2064 | """
2065 | return np.reciprocal(np.add(1, np.exp(-a))),
2066 |
2067 |
2068 | @Operation.factory
2069 | def Tanh(a):
2070 | """
2071 | Tanh op.
2072 | """
2073 | return np.tanh(a),
2074 |
2075 |
2076 | @Operation.factory
2077 | def Softmax(a):
2078 | """
2079 | Softmax op.
2080 | """
2081 | e = np.exp(a)
2082 | return np.divide(e, np.sum(e, axis=-1, keepdims=True)),
2083 |
2084 |
2085 | #
2086 | # NN convolution ops
2087 | #
2088 |
2089 | def _prepare_patches(a, f, strides, padding, padmode):
2090 | v = np.array((0,) + (a.ndim - 2) * (1,) + (0,))
2091 | w = np.array((0,) + f.shape[:-2] + (0,))
2092 |
2093 | src = a
2094 | if padding == "SAME":
2095 | out_shape = np.ceil(np.array(a.shape).astype(np.float) / strides).astype(np.int)
2096 | pad = ((out_shape - v) * strides + w - a.shape).clip(min=0)
2097 | pad_start = pad // 2
2098 | if np.any(pad):
2099 | src = np.pad(a, list(zip(pad_start, pad - pad_start)), padmode)
2100 | else: # VALID
2101 | out_shape = np.ceil((np.array(a.shape).astype(np.float) - w + v) \
2102 | / strides).astype(np.int)
2103 | pad = np.zeros(len(a.shape))
2104 |
2105 | return out_shape, src
2106 |
2107 |
2108 | def _conv_patches(a, f, strides, padding):
2109 | out_shape, src = _prepare_patches(a, f, strides, padding, "constant")
2110 |
2111 | patches = np.empty(tuple(out_shape)[:-1] + f.shape).astype(a.dtype)
2112 |
2113 | s = (slice(None),)
2114 | e = (Ellipsis,)
2115 | en = (Ellipsis, np.newaxis)
2116 | for coord in np.ndindex(*out_shape[1:-1]):
2117 | pos = np.array(strides[1:-1]) * coord
2118 | patches[s + coord + e] = \
2119 | src[s + tuple(slice(*tpl) for tpl in zip(pos, pos + f.shape[:-2]))][en] * f
2120 |
2121 | return patches
2122 |
2123 |
2124 | @Operation.factory(attrs=("strides", "padding", "data_format"))
2125 | def Conv1D(a, f, strides, padding, data_format):
2126 | """
2127 | 1D conv op.
2128 | """
2129 | if data_format.decode("ascii") == "NCHW":
2130 | a = np.rollaxis(a, 1, -1),
2131 |
2132 | patches = _conv_patches(a, f, 3 * [strides], padding.decode("ascii"))
2133 | conv = np.sum(patches, axis=tuple(range(-f.ndim, -1)))
2134 |
2135 | if data_format.decode("ascii") == "NCHW":
2136 | conv = np.rollaxis(conv, -1, 1)
2137 |
2138 | return conv,
2139 |
2140 |
2141 | @Operation.factory(attrs=("strides", "padding", "data_format"))
2142 | def Conv2D(a, f, strides, padding, data_format):
2143 | """
2144 | 2D conv op.
2145 | """
2146 | if data_format.decode("ascii") == "NCHW":
2147 | a = np.rollaxis(a, 1, -1),
2148 |
2149 | patches = _conv_patches(a, f, strides, padding.decode("ascii"))
2150 | conv = np.sum(patches, axis=tuple(range(-f.ndim, -1)))
2151 |
2152 | if data_format.decode("ascii") == "NCHW":
2153 | conv = np.rollaxis(conv, -1, 1)
2154 |
2155 | return conv,
2156 |
2157 |
2158 | @Operation.factory(attrs=("strides", "padding"))
2159 | def Conv3D(a, f, strides, padding):
2160 | """
2161 | 3D conv op.
2162 | """
2163 | patches = _conv_patches(a, f, strides, padding.decode("ascii"))
2164 | return np.sum(patches, axis=tuple(range(-f.ndim, -1))),
2165 |
2166 |
2167 | #
2168 | # NN pooling ops
2169 | #
2170 |
2171 | def _pool_patches(a, k, strides, padding):
2172 | f = np.ones(k[1:] + [a.shape[-1]])
2173 |
2174 | out_shape, src = _prepare_patches(a, f, strides, padding, "edge")
2175 |
2176 | patches = np.empty(tuple(out_shape) + f.shape).astype(a.dtype)
2177 |
2178 | s = (slice(None),)
2179 | e = (Ellipsis,)
2180 | en = (Ellipsis, np.newaxis)
2181 | for coord in np.ndindex(*out_shape[1:]):
2182 | pos = np.array(strides[1:]) * coord
2183 | patches[s + coord + e] = \
2184 | src[s + tuple(slice(*tpl) for tpl in zip(pos, pos + f.shape[:-1]))][en] * f
2185 |
2186 | return patches
2187 |
2188 |
2189 | @Operation.factory(attrs=("ksize", "strides", "padding", "data_format"))
2190 | def AvgPool(a, k, strides, padding, data_format):
2191 | """
2192 | Average pooling op.
2193 | """
2194 | if data_format.decode("ascii") == "NCHW":
2195 | a = np.rollaxis(a, 1, -1),
2196 |
2197 | patches = _pool_patches(a, k, strides, padding.decode("ascii"))
2198 | pool = np.average(patches, axis=tuple(range(-len(k), 0)))
2199 |
2200 | if data_format.decode("ascii") == "NCHW":
2201 | pool = np.rollaxis(pool, -1, 1)
2202 |
2203 | return pool,
2204 |
2205 |
2206 | @Operation.factory(attrs=("ksize", "strides", "padding", "data_format"))
2207 | def MaxPool(a, k, strides, padding, data_format):
2208 | """
2209 | Maximum pooling op.
2210 | """
2211 | if data_format.decode("ascii") == "NCHW":
2212 | a = np.rollaxis(a, 1, -1),
2213 |
2214 | patches = _pool_patches(a, k, strides, padding.decode("ascii"))
2215 | pool = np.amax(patches, axis=tuple(range(-len(k), 0)))
2216 |
2217 | if data_format.decode("ascii") == "NCHW":
2218 | pool = np.rollaxis(pool, -1, 1)
2219 |
2220 | return pool,
2221 |
2222 |
2223 | @Operation.factory(attrs=("ksize", "strides", "padding"))
2224 | def AvgPool3D(a, k, strides, padding):
2225 | """
2226 | Average 3D pooling op.
2227 | """
2228 | patches = _pool_patches(a, k, strides, padding.decode("ascii"))
2229 | return np.average(patches, axis=tuple(range(-len(k), 0))),
2230 |
2231 |
2232 | @Operation.factory(attrs=("ksize", "strides", "padding"))
2233 | def MaxPool3D(a, k, strides, padding):
2234 | """
2235 | Maximum 3D pooling op.
2236 | """
2237 | patches = _pool_patches(a, k, strides, padding.decode("ascii"))
2238 | return np.amax(patches, axis=tuple(range(-len(k), 0))),
2239 |
--------------------------------------------------------------------------------