├── .dockerignore ├── .gitignore ├── .travis.yml ├── Dockerfile ├── LICENSE.txt ├── README.md ├── docs ├── README.md ├── SUMMARY.md └── start │ ├── README.md │ └── train.md ├── requirements.txt ├── setup.cfg ├── setup.py ├── snapcraft.yaml ├── tensorcraft ├── __init__.py ├── arglib.py ├── asynclib.py ├── backend │ ├── __init__.py │ ├── experiment.py │ ├── httpapi │ │ ├── __init__.py │ │ ├── experiment.py │ │ ├── httplib.py │ │ ├── model.py │ │ ├── routing.py │ │ └── server.py │ ├── model.py │ └── saving.py ├── callbacks.py ├── client.py ├── errors.py ├── experiment.py ├── logging.py ├── server.py ├── shell │ ├── __init__.py │ ├── commands.py │ ├── main.py │ └── termlib.py ├── signal.py └── tlslib.py └── tests ├── __init__.py ├── asynctest.py ├── clienttest.py ├── cryptotest.py ├── kerastest.py ├── test_cache.py ├── test_callbacks.py ├── test_client.py ├── test_command.py ├── test_saving.py ├── test_server.py ├── test_server_extra.py └── test_server_ssl.py /.dockerignore: -------------------------------------------------------------------------------- 1 | .venv 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | script: 5 | - pytest 6 | notifications: 7 | email: false 8 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6-slim-stretch as builder 2 | 3 | RUN mkdir /src 4 | COPY . /src 5 | WORKDIR /src 6 | 7 | RUN python setup.py bdist_wheel 8 | RUN pip install dist/* 9 | 10 | 11 | FROM python:3.6-slim-stretch 12 | 13 | COPY --from=builder /usr/local/lib/python3.6/site-packages /usr/local/lib/python3.6/site-packages 14 | COPY --from=builder /usr/local/bin/tensorcraft /usr/local/bin/tensorcraft 15 | EXPOSE 5678/tcp 16 | 17 | CMD ["tensorcraft", "server", "--host", "0.0.0.0"] 18 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 Yasha Bubnov 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorCraft 2 | 3 | [![Build Status][BuildStatus]](https://travis-ci.org/netrack/tensorcraft) 4 | [![tensorcraft][SnapCraft]](https://snapcraft.io/tensorcraft) 5 | 6 | The TensorCraft is a HTTP server that serves [Keras](https://github.com/keras-team/keras) 7 | models using TensorFlow runtime. 8 | 9 | _Currently TensorCraft is in beta, client and server API may change in the 10 | future versions_. 11 | 12 | This server solves such problems as: 13 | 14 | * Versioning of models. 15 | * Warehousing of models. 16 | * Enabling CI/CD for machine-learning models. 17 | 18 | ## Installation 19 | 20 | ### Installation Using Snap 21 | 22 | This is the recommended way to install `tensorcraft`. Simply run the following 23 | command: 24 | ```bash 25 | snap install tensorcraft --devmode --edge 26 | snap start tensorcraft 27 | ``` 28 | 29 | ### Installation Using Docker 30 | 31 | TensorCraft can be used as a Docker container. The major note on this approach is 32 | that `tensorflow` library that is installed into the Docker image is not compiled 33 | with support of AVX instructions or GPU. 34 | ```bash 35 | docker pull netrack/tensorcraft:latest 36 | ``` 37 | 38 | In order to start the container, run the following command: 39 | ```bash 40 | docker run -it -p 5678:5678/tcp netrack/tensorcraft 41 | ``` 42 | 43 | You can optinally specify volume to persist models between restarts of conatiner: 44 | ```bash 45 | docker run -it -p 5678:5678/tcp -v tensorcraft:/var/run/tensorcraft netrack/tensorcraft 46 | ``` 47 | 48 | ### Installation Using PyPi 49 | 50 | Install latest version from pypi repository. 51 | ```bash 52 | pip install tensorcraft 53 | ``` 54 | 55 | ## Using TensorCraft 56 | 57 | ### Keras Requirements 58 | 59 | One of the possible ways of using `tensorcraft` is publising model snapshots to 60 | the server on each epoch end. 61 | ```py 62 | from keras.models import Sequential 63 | from keras.layers import Dense, Activation 64 | from tensorcraft.callbacks import ModelCheckpoint 65 | 66 | model = keras.Sequential() 67 | model.add(Dense(32, input_dim=784)) 68 | model.add(Activation('relu')) 69 | 70 | model.compile(optimizer='sgd', loss='binary_crossentropy') 71 | model.fit(x_train, y_train, callbacks=[ModelCheckpoint(verbose=1)], epochs=100) 72 | ``` 73 | 74 | Currently, `tensorcraft` supports only models in the TensorFlow Saved Model, therefore 75 | in order to publish Keras model, it must be saved as Saved Model at first. 76 | 77 | Considering the following Keras model: 78 | ```py 79 | from tensorflow import keras 80 | from tensorflow.keras import layers 81 | 82 | inputs = keras.Input(shape=(8,), name='digits') 83 | x = layers.Dense(4, activation='relu', name='dense_1')(inputs) 84 | x = layers.Dense(4, activation='relu', name='dense_2')(x) 85 | outputs = layers.Dense(2, activation='softmax', name='predictions')(x) 86 | 87 | model = keras.Model(inputs=inputs, outputs=outputs, name='3_layer_mlp') 88 | ``` 89 | 90 | Save it using the `export_saved_model` function from the 2.0 TensorFlow API: 91 | ```py 92 | keras.experimental.export_saved_model(model, "3_layer_mlp") 93 | ``` 94 | 95 | ### Starting Server 96 | 97 | To start server run `server` command: 98 | ```sh 99 | sudo tensorcraft server 100 | ``` 101 | 102 | By default it starts listening _unsecured_ port on localhost at `http://localhost:5678`. 103 | 104 | Default configuration saves models to `/var/lib/tensorcraft` directory. Apart of 105 | that server requires access to `/var/run` directory in order to save pid file 106 | there. 107 | 108 | ### Pushing New Model 109 | 110 | Note, both client and server of `tensorcraft` application share the same code 111 | base. This implies the need to install a lot of server dependencies for a 112 | client. This will be improved in uncoming versions. 113 | 114 | Once model saved in directory, pack it using `tar` utility. For instance, this 115 | is how it will look like for `3_layer_mlp` model from the previous example: 116 | ```sh 117 | tar -cf 3_layer_mlp.tar 3_layer_mlp 118 | ``` 119 | 120 | Now the model packed into the archive can be pushed to the server under the 121 | arbitrary tag: 122 | ```sh 123 | tensorcraft push --name 3_layer_mlp --tag 0.0.1 3_layer_mlp.tar 124 | ``` 125 | 126 | ### Listing Available Models 127 | 128 | You can list all available models on the server using the following command: 129 | ```sh 130 | tensorcraft list 131 | ``` 132 | 133 | After the execution of `list` command you'll see to available models: 134 | ```sh 135 | 3_layer_mlp:0.0.1 136 | 3_layer_mlp:latest 137 | ``` 138 | 139 | This is the features of `tensorcraft` server, each published model name results in 140 | creation of _model group_. Each model group has it's `latest` tag, that references 141 | the _latest pushed model_. 142 | 143 | ### Removing Model 144 | 145 | Remove of the unused model can be performed in using `remove` command: 146 | ```sh 147 | tensorcraft remove --name 3_layer_mlp --tag 0.0.1 148 | ``` 149 | 150 | Execution of `remove` commands results in the remove of the model itself, and 151 | the model group, when is is the last model in the group. 152 | 153 | ### Using Model 154 | 155 | In order to use the pushed model, `tensorcraft` exposes REST API. An example query 156 | to the server looks like this: 157 | ```sh 158 | curl -X POST https://localhost:5678/models/3_layer_mlp/0.0.1/predict -d \ 159 | '{"x": [[1.0, 2.1, 1.43, 4.43, 12.1, 3.2, 1.44, 2.3]]}' 160 | ``` 161 | 162 | # License 163 | 164 | The code and docs are released under the [Apache 2.0 license](LICENSE). 165 | 166 | [BuildStatus]: https://travis-ci.org/netrack/tensorcraft.svg?branch=master 167 | [SnapCraft]: https://snapcraft.io/tensorcraft/badge.svg 168 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | The TensorCraft is a HTTP server that serves [Keras](https://github.com/keras-team/keras) 4 | models using TensorFlow runtime. 5 | 6 | This server solves such problems as: 7 | 8 | * Versioning of models. 9 | * Warehousing of models. 10 | * Enabling CI/CD for machine-learning models. -------------------------------------------------------------------------------- /docs/SUMMARY.md: -------------------------------------------------------------------------------- 1 | # Summary 2 | 3 | * [TensorCraft Documentation](README.md) 4 | * [Getting Started](start/README.md) 5 | * [Saving Keras Model During Training](start/train.md) 6 | -------------------------------------------------------------------------------- /docs/start/README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | This walkthrough will take you through all of the basic of using TensorCraft. 4 | Within this section, you will learn how to setup TensorCraft server and deploy 5 | your models during traing and inference steps. 6 | 7 | # What's next? 8 | 9 | TBD -------------------------------------------------------------------------------- /docs/start/train.md: -------------------------------------------------------------------------------- 1 | # Saving Keras Model During Training 2 | 3 | One of the possible ways of using `tensorcraft` is publising model snapshots to 4 | the server on each epoch end. 5 | ```py 6 | from keras.models import Sequential 7 | from keras.layers import Dense, Activation 8 | from tensorcraft.callbacks import ModelCheckpoint 9 | model = keras.Sequential() 10 | model.add(Dense(32, input_dim=784)) 11 | model.add(Activation('relu')) 12 | model.compile(optimizer='sgd', loss='binary_crossentropy') 13 | model.fit(x_train, y_train, callbacks=[ModelCheckpoint(verbose=1)], epochs=100) 14 | ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | aiofiles==0.4.0 3 | aiohttp==3.7.4 4 | aiojobs==0.2.2 5 | aiorwlock==0.6.0 6 | astor==0.8.0 7 | async-timeout==3.0.1 8 | attrs==19.1.0 9 | flagparse==0.0.3 10 | chardet==3.0.4 11 | cryptography==3.3.2 12 | gast==0.2.2 13 | google-pasta==0.1.7 14 | grpcio==1.21.1 15 | h5py==2.9.0 16 | humanize==0.5.1 17 | idna==2.8 18 | idna-ssl==1.1.0 19 | Keras-Applications==1.0.8 20 | Keras-Preprocessing==1.1.0 21 | Markdown==3.1.1 22 | multidict==4.5.2 23 | numpy==1.16.4 24 | pid==2.2.3 25 | protobuf==3.8.0 26 | pyyaml==5.4 27 | semver==2.8.1 28 | six==1.12.0 29 | tb-nightly==1.14.0a20190603 30 | tensorflow==2.5.1 31 | termcolor==1.1.0 32 | tf-estimator-nightly==1.14.0.dev2019060501 33 | tinydb==3.13.0 34 | typing-extensions==3.7.2 35 | Werkzeug==0.15.4 36 | wrapt==1.11.1 37 | yarl==1.3.0 38 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal = 1 3 | 4 | [metadata] 5 | description-file = README.md 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import tensorcraft 2 | import os 3 | import setuptools 4 | 5 | 6 | here = os.path.abspath(os.path.dirname(__file__)) 7 | 8 | # Get the long description from the README file. 9 | with open(os.path.join(here, "README.md"), encoding="utf-8") as md: 10 | long_description = md.read() 11 | 12 | 13 | setuptools.setup( 14 | name="tensorcraft", 15 | version=tensorcraft.__version__, 16 | 17 | long_description=long_description, 18 | long_description_content_type="text/markdown", 19 | description="TensorCraft is a server for Keras models", 20 | 21 | url="https://github.com/netrack/tensorcraft", 22 | author="Yasha Bubnov", 23 | author_email="girokompass@gmail.com", 24 | 25 | classifiers=[ 26 | "Intended Audience :: Developers", 27 | "License :: OSI Approved :: MIT License", 28 | ], 29 | 30 | packages=setuptools.find_packages(exclude=["tests"]), 31 | tests_require=[ 32 | "pytest-aiohttp>=0.3.0", 33 | "cryptography>=2.7", 34 | ], 35 | install_requires=[ 36 | "aiojobs>=0.2.2", 37 | "aiofiles>=0.4.0", 38 | "aiorwlock>=0.6.0", 39 | "aiohttp>=3.5.4", 40 | "flagparse>=0.0.2", 41 | "humanize>=0.5.1", 42 | "numpy>=1.16.3", 43 | "pid>=2.2.3", 44 | "pyyaml>=5.1.1", 45 | "semver>=2.8.1", 46 | "tensorflow>=2.0.0a0", 47 | "tinydb>=3.13.0", 48 | ], 49 | 50 | entry_points={ 51 | "console_scripts": ["tensorcraft = tensorcraft.shell.main:main"], 52 | }, 53 | ) 54 | -------------------------------------------------------------------------------- /snapcraft.yaml: -------------------------------------------------------------------------------- 1 | name: tensorcraft 2 | version: git 3 | version-script: git describe --tags 4 | 5 | architectures: 6 | - build-on: [amd64] 7 | run-on: [amd64] 8 | 9 | summary: Server for Keras models 10 | description: | 11 | The TensorCraft is a HTTP server that serves Keras models using TensorFlow 12 | runtime. 13 | 14 | base: core18 15 | grade: devel 16 | confinement: devmode 17 | 18 | parts: 19 | tensorcraft: 20 | plugin: python 21 | python-version: python3 22 | source: . 23 | 24 | apps: 25 | tensorcraft: 26 | daemon: simple 27 | stop-timeout: 10s 28 | command: bin/tensorcraft server 29 | -------------------------------------------------------------------------------- /tensorcraft/__init__.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | __version__ = "0.0.1b5" 4 | __apiversion__ = "1.0.0" 5 | 6 | 7 | homepath = pathlib.Path.home().joinpath(".tensorcraft") 8 | -------------------------------------------------------------------------------- /tensorcraft/arglib.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | 4 | def filter_callable_arguments(callable, **kwargs): 5 | argnames = inspect.getfullargspec(callable) 6 | return {k: v for k, v in kwargs.items() if k in argnames.args} 7 | -------------------------------------------------------------------------------- /tensorcraft/asynclib.py: -------------------------------------------------------------------------------- 1 | import aiofiles 2 | import asyncio 3 | import io 4 | import pathlib 5 | import tarfile 6 | import shutil 7 | 8 | from typing import IO 9 | 10 | 11 | def run(main): 12 | loop = asyncio.new_event_loop() 13 | try: 14 | return loop.run_until_complete(main) 15 | finally: 16 | loop.close() 17 | 18 | 19 | async def reader(path: pathlib.Path, chunk_size=64*1024) -> bytes: 20 | async with aiofiles.open(str(path), "rb") as f: 21 | chunk = await f.read(chunk_size) 22 | while len(chunk): 23 | yield chunk 24 | chunk = await f.read(chunk_size) 25 | 26 | 27 | class AsyncIO: 28 | 29 | def __init__(self, io: IO): 30 | self.io = io 31 | 32 | async def read(self, size=-1): 33 | return self.io.read(size) 34 | 35 | async def write(self, b): 36 | return self.io.write(b) 37 | 38 | 39 | async def extract_tar(fileobj: io.IOBase, dest: str) -> None: 40 | """Extract content of the TAR archive into the given directory.""" 41 | with tarfile.open(fileobj=fileobj, mode="r") as tf: 42 | tf.extractall(dest) 43 | 44 | 45 | async def create_tar(fileobj: io.IOBase, path: str) -> None: 46 | """Create TAR archive with the data specified by path.""" 47 | with tarfile.open(fileobj=fileobj, mode="w") as tf: 48 | tf.add(path, arcname="") 49 | 50 | 51 | async def remove_dir(path: pathlib.Path, ignore_errors: bool = False): 52 | shutil.rmtree(path, ignore_errors=ignore_errors) 53 | 54 | 55 | class _AsyncContextManager: 56 | 57 | def __init__(self, async_generator): 58 | self.agen = async_generator.__aiter__() 59 | 60 | async def __aenter__(self): 61 | return await self.agen.__anext__() 62 | 63 | async def __aexit__(self, typ, value, traceback): 64 | try: 65 | await self.agen.__anext__() 66 | except StopAsyncIteration: 67 | return False 68 | 69 | 70 | def asynccontextmanager(func): 71 | """Simple implementation of async context manager decorator.""" 72 | def _f(*args, **kwargs): 73 | async_generator = func(*args, **kwargs) 74 | return _AsyncContextManager(async_generator) 75 | return _f 76 | 77 | 78 | # Prefer the run function from the standard library over the custom 79 | # implementation. 80 | run = asyncio.run if hasattr(asyncio, "run") else run 81 | -------------------------------------------------------------------------------- /tensorcraft/backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/netrack/tensorcraft/15e0c54b795f4ce527cc5e2c46bbb7da434ac036/tensorcraft/backend/__init__.py -------------------------------------------------------------------------------- /tensorcraft/backend/experiment.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from abc import ABCMeta, abstractmethod 4 | from typing import Dict, NamedTuple, Sequence, Union 5 | 6 | 7 | class Metric(NamedTuple): 8 | """Metric of the model's training.""" 9 | 10 | name: str 11 | value: float 12 | 13 | def __repr__(self) -> str: 14 | return f"" 15 | 16 | @classmethod 17 | def from_dict(cls, **kwargs): 18 | return cls(**kwargs) 19 | 20 | def asdict(self) -> Dict: 21 | return dict(name=self.name, value=self.value) 22 | 23 | 24 | class Epoch(NamedTuple): 25 | """Epoch is an iteration of model fitting (training). 26 | 27 | Attributes: 28 | metrics -- list of model's metrics 29 | """ 30 | 31 | metrics: Sequence[Metric] 32 | 33 | @classmethod 34 | def new(cls, metrics=Sequence[Dict]): 35 | return cls([Metric.from_dict(**m) for m in metrics]) 36 | 37 | @classmethod 38 | def from_dict(cls, **kwargs): 39 | return cls([Metric.from_dict(**m) for m in kwargs.pop("metrics", [])]) 40 | 41 | def asdict(self) -> Dict: 42 | return dict(metrics=[m.asdict() for m in self.metrics]) 43 | 44 | 45 | class Experiment: 46 | """Machine-learning experiment. 47 | 48 | Attributes: 49 | id -- unique experiment identifier 50 | name -- name of the experiment 51 | epochs -- a list of experiment epochs 52 | """ 53 | 54 | @classmethod 55 | def new(cls, name: str, epochs=Sequence[Dict]) -> 'Experiment': 56 | experiment_id = uuid.uuid4() 57 | epochs = [Epoch.from_dict(**e) for e in epochs] 58 | return cls(uid=experiment_id, name=name, epochs=epochs) 59 | 60 | @classmethod 61 | def from_dict(cls, **kwargs) -> 'Experiment': 62 | epochs = [Epoch.from_dict(**e) for e in kwargs.pop("epochs", [])] 63 | return cls(epochs=epochs, **kwargs) 64 | 65 | def __init__(self, 66 | uid: Union[uuid.UUID, str], 67 | name: str, 68 | epochs: Sequence[Epoch]): 69 | self.id = uuid.UUID(str(uid)) 70 | self.name = name 71 | self.epochs = epochs 72 | 73 | def __repr__(self) -> str: 74 | return (f"") 76 | 77 | def asdict(self) -> Dict: 78 | return dict(id=self.id.hex, 79 | name=self.name, 80 | epochs=[e.asdict() for e in self.epochs]) 81 | 82 | 83 | class AbstractStorage(metaclass=ABCMeta): 84 | """Storage used to persist experiments.""" 85 | 86 | @abstractmethod 87 | async def save(self, e: Experiment) -> None: 88 | """Save the experiment. 89 | 90 | The persistence guarantee is defined by the implementation. 91 | """ 92 | 93 | @abstractmethod 94 | async def save_epoch(self, name: str, epoch: Epoch) -> None: 95 | """Save the epoch with metrics. 96 | 97 | Add a new epoch to the experiment, after execution count of epochs 98 | for the experiment referenced by eid should be increased by one. 99 | 100 | Args: 101 | name -- experiment name 102 | epoch -- experiment epoch 103 | """ 104 | 105 | @abstractmethod 106 | async def load(self, name: str) -> Experiment: 107 | """Load the experiment. 108 | 109 | Args: 110 | name -- experiment name 111 | """ 112 | 113 | @abstractmethod 114 | async def all(self) -> Sequence[Experiment]: 115 | """Load all experiments.""" 116 | -------------------------------------------------------------------------------- /tensorcraft/backend/httpapi/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import ModelView 2 | from .server import ServerView 3 | from .experiment import ExperimentView 4 | 5 | 6 | __all__ = ["ModelView", "ServerView", "ExperimentView"] 7 | -------------------------------------------------------------------------------- /tensorcraft/backend/httpapi/experiment.py: -------------------------------------------------------------------------------- 1 | from aiohttp import web 2 | 3 | from tensorcraft.backend import experiment 4 | from tensorcraft.backend.httpapi import routing 5 | 6 | 7 | class ExperimentView: 8 | """View to handle actions related to experiments. 9 | 10 | Attributes: 11 | experiments -- container of experiments 12 | """ 13 | 14 | def __init__(self, experiments: experiment.AbstractStorage) -> None: 15 | self.experiments = experiments 16 | 17 | @routing.urlto("/experiments") 18 | async def create(self, req: web.Request) -> web.Response: 19 | """HTTP handler to update (or create when missing) the experiment. 20 | 21 | Args: 22 | req -- request with an experiment 23 | """ 24 | if not req.can_read_body: 25 | raise web.HTTPBadRequest(text="request has no body") 26 | 27 | body = await req.json() 28 | e = experiment.Experiment.new(**body) 29 | 30 | await self.experiments.save(e) 31 | return web.json_response(status=web.HTTPCreated.status_code) 32 | 33 | @routing.urlto("/experiments") 34 | async def list(self, req: web.Request) -> web.Response: 35 | experiments = [e.asdict() async for e in self.experiments.all()] 36 | return web.json_response(list(experiments)) 37 | 38 | @routing.urlto("/experiments/{name}") 39 | async def get(self, req: web.Request) -> web.Response: 40 | name = req.match_info.get("name") 41 | e = await self.experiments.load(name) 42 | 43 | return web.json_response(e.asdict()) 44 | 45 | @routing.urlto("/experiments/{name}/epochs") 46 | async def create_epoch(self, req: web.Request) -> web.Response: 47 | name = req.match_info.get("name") 48 | 49 | if not req.can_read_body: 50 | raise web.HTTPBadRequest(text="request has no body") 51 | 52 | body = await req.json() 53 | epoch = experiment.Epoch.new(**body) 54 | 55 | await self.exeriments.save_epoch(name, epoch) 56 | return web.json_response(status=web.HTTPOk.status_code) 57 | -------------------------------------------------------------------------------- /tensorcraft/backend/httpapi/httplib.py: -------------------------------------------------------------------------------- 1 | from aiohttp import web 2 | -------------------------------------------------------------------------------- /tensorcraft/backend/httpapi/model.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | 4 | from aiohttp import web 5 | from typing import Union 6 | 7 | from tensorcraft import errors 8 | from tensorcraft.backend import model 9 | from tensorcraft.backend.httpapi import routing 10 | 11 | 12 | _ConflictReason = Union[errors.DuplicateError, errors.LatestTagError] 13 | 14 | 15 | def make_error_response(exc_class, model_exc=errors.ModelError, text=""): 16 | """Return an exception with a specific error code. 17 | 18 | Use this function to construct an exception with custom headers 19 | that specify concrete error generated by the server. 20 | """ 21 | return exc_class(text=str(text), 22 | headers={"Error-Code": f"{model_exc.error_code}"}) 23 | 24 | 25 | def make_bad_request_response(text: str) -> web.HTTPException: 26 | """Return HTTP "bad request" exception.""" 27 | return make_error_response(web.HTTPBadRequest, errors.ModelError, text) 28 | 29 | 30 | def make_conflict_response(reason: _ConflictReason) -> web.HTTPException: 31 | """Return HTTP "conflict" exception.""" 32 | return make_error_response(web.HTTPConflict, reason, str(reason)) 33 | 34 | 35 | def make_not_found_response(reason: errors.NotFoundError) -> web.HTTPException: 36 | """Return HTTP "not found" exception.""" 37 | return make_error_response(web.HTTPNotFound, reason, str(reason)) 38 | 39 | 40 | class ModelView: 41 | """View to handle actions related to models. 42 | 43 | Attributes: 44 | models -- container of models 45 | """ 46 | 47 | def __init__(self, models: model.AbstractStorage) -> None: 48 | self.models = models 49 | 50 | @routing.urlto("/models/{name}/{tag}") 51 | async def save(self, req: web.Request) -> web.Response: 52 | """HTTP handler to save the model. 53 | 54 | Args: 55 | req -- request with a model tar archive 56 | """ 57 | name = req.match_info.get("name") 58 | tag = req.match_info.get("tag") 59 | 60 | if not req.can_read_body: 61 | raise make_bad_request_response(text="request has no body") 62 | 63 | try: 64 | model_stream = io.BytesIO(await req.read()) 65 | await self.models.save(name, tag, model_stream) 66 | except errors.ModelError as e: 67 | raise make_conflict_response(reason=e) 68 | 69 | return web.Response(status=web.HTTPCreated.status_code) 70 | 71 | @routing.urlto("/models/{name}/{tag}/predict") 72 | async def predict(self, req: web.Request) -> web.Response: 73 | """HTTP handler to calculate model predictions. 74 | 75 | Feed model with feature vectors and calculate predictions. 76 | 77 | Args: 78 | req -- request with a list of feature-vectors 79 | """ 80 | name = req.match_info.get("name") 81 | tag = req.match_info.get("tag") 82 | 83 | if not req.can_read_body: 84 | raise make_bad_request_response(text="request has no body") 85 | 86 | try: 87 | body = await req.json() 88 | model = await self.models.load(name, tag) 89 | 90 | predictions = model.predict(x=body["x"]) 91 | except (errors.InputShapeError, json.decoder.JSONDecodeError) as e: 92 | raise make_bad_request_response(text=str(e)) 93 | except errors.NotFoundError as e: 94 | raise make_not_found_response(reason=e) 95 | 96 | return web.json_response(dict(y=predictions)) 97 | 98 | @routing.urlto("/models") 99 | async def list(self, req: web.Request) -> web.Response: 100 | """HTTP handler to list available models. 101 | 102 | List available models in the storage. 103 | 104 | Args: 105 | req -- empty request 106 | """ 107 | models = [m.to_dict() async for m in self.models.all()] 108 | return web.json_response(list(models)) 109 | 110 | @routing.urlto("/models/{name}/{tag}") 111 | async def delete(self, req: web.Request) -> web.Response: 112 | """Handler that removes a model.""" 113 | name = req.match_info.get("name") 114 | tag = req.match_info.get("tag") 115 | 116 | try: 117 | await self.models.delete(name, tag) 118 | except errors.NotFoundError as e: 119 | raise make_not_found_response(reason=e) 120 | return web.Response(status=web.HTTPOk.status_code) 121 | 122 | @routing.urlto("/models/{name}/{tag}") 123 | async def export(self, req: web.Request) -> web.Response: 124 | name = req.match_info.get("name") 125 | tag = req.match_info.get("tag") 126 | 127 | try: 128 | writer = io.BytesIO() 129 | await self.models.export(name, tag, writer) 130 | 131 | return web.Response(body=writer.getvalue()) 132 | except errors.NotFoundError as e: 133 | raise make_not_found_response(reason=e) 134 | -------------------------------------------------------------------------------- /tensorcraft/backend/httpapi/routing.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | 4 | def urlto(path: str) -> Callable: 5 | def _to(func): 6 | func.url = path 7 | return func 8 | return _to 9 | -------------------------------------------------------------------------------- /tensorcraft/backend/httpapi/server.py: -------------------------------------------------------------------------------- 1 | import tensorcraft 2 | 3 | from aiohttp import web 4 | 5 | from tensorcraft.backend import model 6 | from tensorcraft.backend.httpapi import routing 7 | 8 | 9 | class ServerView: 10 | """Server view to handle actions related to server.""" 11 | 12 | def __init__(self, models: model.AbstractStorage) -> None: 13 | self.models = models 14 | 15 | @routing.urlto("/status") 16 | async def status(self, req: web.Request) -> web.Response: 17 | """Handler that returns server status.""" 18 | return web.json_response(dict( 19 | models=len([m async for m in self.models.all()]), 20 | server_version=tensorcraft.__version__, 21 | api_version=tensorcraft.__apiversion__, 22 | root_path=str(self.models.root_path), 23 | )) 24 | -------------------------------------------------------------------------------- /tensorcraft/backend/model.py: -------------------------------------------------------------------------------- 1 | import aiorwlock 2 | import enum 3 | import contextlib 4 | import copy 5 | import io 6 | import logging 7 | import numpy 8 | import pathlib 9 | import tensorflow as tf 10 | import uuid 11 | 12 | from abc import ABCMeta, abstractmethod 13 | from datetime import datetime 14 | from typing import Sequence, Union 15 | 16 | from tensorcraft import errors 17 | from tensorcraft import signal 18 | from tensorcraft.logging import internal_logger 19 | 20 | 21 | class Strategy(enum.Enum): 22 | """Strategy is an execution strategy of the model.""" 23 | 24 | No = "no" 25 | Mirrored = "mirrored" 26 | MultiWorkerMirrored = "multi_worker_mirrored" 27 | 28 | 29 | class Tag(enum.Enum): 30 | """Magic tags of the models.""" 31 | 32 | Latest = "latest" 33 | 34 | 35 | class NoStrategy: 36 | """A strategy that does nothing additional to the loaded model. 37 | 38 | This strategy used when the computation strategy is not specified. 39 | """ 40 | 41 | def scope(self): 42 | return contextlib.contextmanager(lambda: (yield None))() 43 | 44 | 45 | class Loader: 46 | """Load the model with the specific computation strategy.""" 47 | 48 | strategies = { 49 | Strategy.No: NoStrategy, 50 | Strategy.Mirrored: tf.distribute.MirroredStrategy, 51 | Strategy.MultiWorkerMirrored: ( 52 | tf.distribute.experimental.MultiWorkerMirroredStrategy), 53 | } 54 | 55 | def __init__(self, strategy: str, 56 | logger: logging.Logger = internal_logger): 57 | if Strategy(strategy) not in self.strategies: 58 | raise ValueError("unknown strategy {0}".format(strategy)) 59 | 60 | strategy_class = self.strategies[Strategy(strategy)] 61 | logger.info("Using '%s' execution strategy", strategy) 62 | 63 | self.logger = logger 64 | self.strategy = strategy_class() 65 | 66 | def load(self, path: Union[str, pathlib.Path]): 67 | """Load the model by the given path.""" 68 | with self.strategy.scope(): 69 | m = tf.keras.experimental.load_from_saved_model(str(path)) 70 | self.logger.debug("Model loaded from path %s", path) 71 | return m 72 | 73 | 74 | class Model: 75 | """Machine-leaning model. 76 | 77 | Attributes: 78 | id -- unique model identifier 79 | name -- the name of the model 80 | tag -- the tag of the model 81 | path -- the location of the model on file system 82 | loader -- the model loader 83 | """ 84 | 85 | @classmethod 86 | def from_dict(cls, **kwargs): 87 | return cls(**kwargs) 88 | 89 | @classmethod 90 | def new(cls, name: str, tag: str, root: pathlib.Path, 91 | loader: Loader = None): 92 | model_id = uuid.uuid4() 93 | model_path = root.joinpath(model_id.hex) 94 | model_created_at = datetime.utcnow().timestamp() 95 | 96 | return cls(uid=model_id, name=name, tag=tag, 97 | created_at=model_created_at, 98 | path=model_path, loader=loader) 99 | 100 | def to_dict(self): 101 | return dict(id=self.id.hex, 102 | name=self.name, 103 | tag=self.tag, 104 | created_at=self.created_at) 105 | 106 | def __init__(self, uid: Union[uuid.UUID, str], 107 | name: str, tag: str, created_at: float, 108 | path: str = None, loader: Loader = None): 109 | self.id = uuid.UUID(str(uid)) 110 | self.name = name 111 | self.tag = tag 112 | self.created_at = created_at 113 | 114 | self.loader = loader 115 | self.path = path 116 | self.model = None 117 | 118 | def copy(self): 119 | return copy.copy(self) 120 | 121 | @property 122 | def key(self): 123 | return (self.name, self.tag) 124 | 125 | @property 126 | def loaded(self): 127 | """True when the model is loaded and False otherwise.""" 128 | return self.model is not None 129 | 130 | def load(self): 131 | """Load the execution model.""" 132 | self.model = self.loader.load(self.path) 133 | return self 134 | 135 | def predict(self, x): 136 | if not self.model: 137 | raise errors.NotLoadedError(self.name, self.tag) 138 | 139 | x = numpy.array(x) 140 | 141 | # This check make sense only for models with defined input shapes 142 | # (for example, when the layer is Dense). 143 | if hasattr(self.model, "input_shape"): 144 | # Calculate the shape of the input data and validate it with the 145 | # model parameters. This exception is handled by the server in 146 | # order to return an appropriate error to the client. 147 | _, *expected_dims = self.model.input_shape 148 | _, *actual_dims = x.shape 149 | 150 | if expected_dims != actual_dims: 151 | raise errors.InputShapeError(expected_dims, actual_dims) 152 | 153 | return self.model.predict(x).tolist() 154 | 155 | def __str__(self): 156 | return "{0}:{1}".format(self.name, self.tag) 157 | 158 | 159 | class AbstractStorage(metaclass=ABCMeta): 160 | """Storage used to persist model (a TAR archive).""" 161 | 162 | @property 163 | @abstractmethod 164 | def on_save(self) -> signal.Signal: 165 | """A list of saving callbacks. 166 | 167 | Each callback in the list will be executed asynchronously on model 168 | saving. 169 | 170 | Returns: 171 | A list of callbacks as :class:`tensorcraft.signal.Signal`. 172 | """ 173 | 174 | @property 175 | @abstractmethod 176 | def on_delete(self) -> signal.Signal: 177 | """A list of deletion callbacks. 178 | 179 | Each callback in the list will be executed asynchronously on model 180 | deletion. 181 | 182 | Returns: 183 | A list of callbacks as :class:`tensorcraft.signal.Signal`. 184 | """ 185 | 186 | @property 187 | @abstractmethod 188 | async def root_path(self) -> pathlib.Path: 189 | """Root path of the storage. 190 | 191 | The returned path specifies the path where all models and related 192 | metadata is persisted. 193 | 194 | Returns: 195 | Data root path as :class:`pathlib.Path`. 196 | """ 197 | 198 | @abstractmethod 199 | async def all(self) -> Sequence[Model]: 200 | """List all existing models. 201 | 202 | The returned models are not necessary loaded for the sake of 203 | performance. 204 | 205 | Returns: 206 | Sequence of :class:`Model`. 207 | """ 208 | 209 | @abstractmethod 210 | async def save(self, name: str, tag: str, stream: io.IOBase) -> Model: 211 | """Save the model archive. 212 | 213 | The persistence guarantee is provided by the implementation. 214 | 215 | Args: 216 | name (str): Model name. 217 | tag (str): Model tag. 218 | 219 | Returns: 220 | Saved instance of :class:`Model`. 221 | """ 222 | 223 | @abstractmethod 224 | async def delete(self, name: str, tag: str) -> None: 225 | """Delete the model. 226 | 227 | After the deletion model should not be addressable anymore. 228 | 229 | Args: 230 | name (str): Model name. 231 | tag (str): Model tag. 232 | """ 233 | 234 | @abstractmethod 235 | async def load(self, name: str, tag: str) -> Model: 236 | """Load the model. 237 | 238 | Load model into the memory from the storage. Implementation must 239 | consider concurrent requests to load the same model. 240 | 241 | Args: 242 | name (str): Model name. 243 | tag (str): Model tag. 244 | 245 | Returns: 246 | Loaded :class:`Model`. 247 | """ 248 | 249 | @abstractmethod 250 | async def export(self, name: str, tag: str, writer: io.IOBase) -> None: 251 | """Export model to the writer 252 | 253 | Write model's archive into the stream. Implementation must consider 254 | concurrent requests to export the same model. 255 | 256 | Args: 257 | name (str): Model name 258 | tag (str): Model tag 259 | writer (io.IOBase): Destination writer instance. 260 | """ 261 | 262 | 263 | class Cache: 264 | """Cache of models used to speeds up models loading time. 265 | 266 | Cache saves models into the in-memory cache and delegates calls 267 | to the parent storage when the model is not found locally. 268 | """ 269 | 270 | @classmethod 271 | async def new(cls, 272 | storage: AbstractStorage, 273 | preload: bool = False, 274 | logger: logging.Logger = internal_logger): 275 | self = cls() 276 | self.logger = logger 277 | self.storage = storage 278 | self.lock = aiorwlock.RWLock() 279 | self.models = {} 280 | 281 | self.storage.on_save.append(self.save_to_cache) 282 | self.storage.on_delete.append(self.delete_from_cache) 283 | 284 | if not preload: 285 | return self 286 | 287 | async for m in self.all(): 288 | logger.info("Loading {0} model".format(m)) 289 | await self.unsafe_load(m.name, m.tag) 290 | 291 | return self 292 | 293 | @property 294 | def root_path(self) -> pathlib.Path: 295 | return self.storage.root_path 296 | 297 | async def all(self) -> Sequence[Model]: 298 | """List all available models. 299 | 300 | The call puts all retrieved models into the cache. All that models are 301 | not loaded. So before using them, they must be loaded. 302 | """ 303 | async with self.lock.reader_lock: 304 | async for m in self.storage.all(): 305 | if m.key not in self.models: 306 | self.models[m.key] = m 307 | yield m 308 | 309 | async def save(self, name: str, tag: str, model: io.IOBase) -> Model: 310 | """Save the model and load it into the memory. 311 | 312 | Most likely the saved model will be used in the short period of time, 313 | therefore it is beneficial to load it right after the save. 314 | """ 315 | m = await self.storage.save(name, tag, model) 316 | await self.save_to_cache(m) 317 | return m 318 | 319 | async def save_to_cache(self, m: Model) -> None: 320 | async with self.lock.writer_lock: 321 | self.models[(m.name, m.tag)] = m 322 | 323 | async def delete(self, name: str, tag: str) -> None: 324 | # This is totally fine to loose the data from the cache but 325 | # leave it in the storage (due to unexpected error). 326 | await self.delete_from_cache(name, tag) 327 | await self.storage.delete(name, tag) 328 | 329 | async def delete_from_cache(self, name: str, tag: str) -> None: 330 | async with self.lock.writer_lock: 331 | key = (name, tag) 332 | if key in self.models: 333 | del self.models[key] 334 | 335 | async def unsafe_load(self, name: str, tag: str) -> Model: 336 | """Load the model into the internal cache without acquiring a lock.""" 337 | key = (name, tag) 338 | if ((key not in self.models) or not self.models[key].loaded): 339 | self.models[key] = await self.storage.load(name, tag) 340 | return self.models[key] 341 | 342 | async def load(self, name: str, tag: str) -> Model: 343 | # Load the model from the parent storage when 344 | # it is missing in the cache. 345 | async with self.lock.writer_lock: 346 | return await self.unsafe_load(name, tag) 347 | 348 | async def export(self, name: str, tag: str, writer: io.IOBase) -> None: 349 | return await self.storage.export(name, tag, writer) 350 | -------------------------------------------------------------------------------- /tensorcraft/backend/saving.py: -------------------------------------------------------------------------------- 1 | import aiorwlock 2 | import asyncio 3 | import concurrent.futures 4 | import io 5 | import logging 6 | import operator 7 | import pathlib 8 | import tinydb 9 | import uuid 10 | 11 | import tensorcraft.logging 12 | 13 | from typing import Dict, Coroutine, Sequence, Union 14 | 15 | from tensorcraft import arglib 16 | from tensorcraft import asynclib 17 | from tensorcraft import errors 18 | from tensorcraft import signal 19 | from tensorcraft.backend import model 20 | from tensorcraft.backend import experiment 21 | 22 | 23 | def query_by_name_and_tag(name: str, tag: str): 24 | """Query documents by name and tag.""" 25 | q = tinydb.Query() 26 | return (q.name == name) & (q.tag == tag) 27 | 28 | 29 | def query_by_name(name: str): 30 | """Query documents by name.""" 31 | return tinydb.Query().name == name 32 | 33 | 34 | def query_by_id(uid: uuid.UUID): 35 | """Query the document by unique identifier.""" 36 | return tinydb.Query().id == uid 37 | 38 | 39 | class FsModelsMetadata: 40 | """A file-based database with JSON encoding for models metadata.""" 41 | 42 | @classmethod 43 | def new(cls, path: pathlib.Path): 44 | self = cls() 45 | self._rw_lock = aiorwlock.RWLock() 46 | self._db = tinydb.TinyDB(path=path.joinpath("metadata.json"), 47 | default_table="metadata") 48 | return self 49 | 50 | async def close(self) -> None: 51 | async with self._rw_lock.writer_lock: 52 | self._db.close() 53 | 54 | async def get(self, cond) -> Dict: 55 | async with self._rw_lock.reader_lock: 56 | return self._db.get(cond) 57 | 58 | async def search(self, cond) -> Dict: 59 | async with self._rw_lock.reader_lock: 60 | return self._db.search(cond) 61 | 62 | async def all(self) -> Sequence[Dict]: 63 | async with self._rw_lock.reader_lock: 64 | return self._db.all() 65 | 66 | async def insert(self, document: Dict) -> None: 67 | async with self._rw_lock.writer_lock: 68 | self._db.insert(document) 69 | 70 | async def upsert(self, document: Dict, cond) -> None: 71 | async with self._rw_lock.writer_lock: 72 | self._db.upsert(document, cond) 73 | 74 | async def remove(self, cond) -> None: 75 | async with self._rw_lock.writer_lock: 76 | self._db.remove(cond) 77 | 78 | async def latest(self, cond, key) -> Union[Dict, None]: 79 | async with self._rw_lock.reader_lock: 80 | documents = self._db.search(cond) 81 | 82 | sorted(documents, key=key) 83 | return documents.pop() if documents else None 84 | 85 | @asynclib.asynccontextmanager 86 | async def write_locked(self): 87 | async with self._rw_lock.writer_lock: 88 | db = FsModelsMetadata() 89 | db._db = self._db 90 | db._rw_lock = aiorwlock.RWLock(fast=True) 91 | yield db 92 | 93 | 94 | class FsModelsStorage(model.AbstractStorage): 95 | """Storage of models based on ordinary file system. 96 | 97 | Implementation saves the models as unpacked TensorFlow SaveModel 98 | under the data root path. 99 | """ 100 | 101 | @classmethod 102 | def new(cls, 103 | path: pathlib.Path, 104 | loader: model.Loader, 105 | logger: logging.Logger = tensorcraft.logging.internal_logger): 106 | 107 | self = cls() 108 | logger.info("Using file storage backing engine") 109 | 110 | self.logger = logger 111 | self.loader = loader 112 | self.meta = FsModelsMetadata.new(path) 113 | self.models_path = path.joinpath("models") 114 | 115 | self._on_delete = signal.Signal() 116 | self._on_save = signal.Signal() 117 | 118 | self.models_path.mkdir(parents=True, exist_ok=True) 119 | self.executor = concurrent.futures.ThreadPoolExecutor() 120 | 121 | return self 122 | 123 | async def close(self) -> None: 124 | """Clean-up resources allocated by storage.""" 125 | await self.meta.close() 126 | 127 | @property 128 | def on_delete(self) -> signal.Signal: 129 | return self._on_delete 130 | 131 | @property 132 | def on_save(self) -> signal.Signal: 133 | return self._on_save 134 | 135 | @property 136 | def root_path(self) -> pathlib.Path: 137 | return self.models_path 138 | 139 | def build_model_from_document(self, doc: Dict) -> model.Model: 140 | path = self.models_path.joinpath(doc["id"]) 141 | d = arglib.filter_callable_arguments(model.Model, uid=doc["id"], **doc) 142 | return model.Model(path=path, loader=self.loader, **d) 143 | 144 | def await_in_thread(self, coro: Coroutine): 145 | """Run the given function within an instance executor.""" 146 | loop = asyncio.get_event_loop() 147 | return loop.run_in_executor(self.executor, asynclib.run, coro) 148 | 149 | async def all(self) -> Sequence[model.Model]: 150 | """List available models and their tags. 151 | 152 | The method returns a list of not loaded models, therefore before using 153 | them (e.g. for prediction), models must be loaded. 154 | """ 155 | for document in await self.meta.all(): 156 | yield self.build_model_from_document(document) 157 | 158 | async def save_to_meta(self, m: model.Model) -> None: 159 | async with self.meta.write_locked() as meta: 160 | if await meta.get(query_by_name_and_tag(m.name, m.tag)): 161 | self.logger.debug("Model %s already exists", m) 162 | 163 | raise errors.DuplicateError(m.name, m.tag) 164 | 165 | # Insert the model metadata, and update the latest model link. 166 | await meta.insert(m.to_dict()) 167 | await self.on_save.send(m) 168 | 169 | # Since the saving is happening right now, the latest model 170 | # will obviously be the current one. 171 | latest = m.copy() 172 | 173 | latest.tag = model.Tag.Latest.value 174 | latest.id = m.id 175 | 176 | latest_query = query_by_name_and_tag(latest.name, latest.tag) 177 | await meta.upsert(latest.to_dict(), latest_query) 178 | await self.on_save.send(latest) 179 | 180 | async def save(self, name: str, tag: str, 181 | stream: io.IOBase) -> model.Model: 182 | """Save the model into the local storage. 183 | 184 | Extracts the TAR archive into the data directory. 185 | """ 186 | # Raise error on attempt to save model with the latest tag. 187 | if tag == model.Tag.Latest.value: 188 | raise errors.LatestTagError(name, tag) 189 | 190 | m = model.Model.new(name, tag, self.models_path, self.loader) 191 | 192 | try: 193 | coro = asynclib.extract_tar(fileobj=stream, dest=m.path) 194 | await self.await_in_thread(coro) 195 | 196 | # Now load the model into the memory, to pass all validations. 197 | self.logger.debug("Ensuring model has correct format") 198 | 199 | coro = asyncio.coroutine(m.load)() 200 | m = await self.await_in_thread(coro) 201 | 202 | await self.save_to_meta(m) 203 | 204 | # Model successfully loaded, so now it can be moved to the original 205 | # data root directory. 206 | self.logger.info("Pushing model %s to %s", m, m.path) 207 | return m 208 | 209 | except Exception as e: 210 | # In case of an exception, remove the model from the directory 211 | # and ensure the metadata database does not store any information. 212 | # 213 | # The caller have to ensure atomicity of this operation. 214 | await self.meta.remove(query_by_id(m.id)) 215 | 216 | coro = asynclib.remove_dir(m.path, ignore_errors=True) 217 | await self.await_in_thread(coro) 218 | raise e 219 | 220 | async def delete_from_meta(self, name: str, tag: str) -> model.Model: 221 | # Model found, remove metadata from the database. 222 | async with self.meta.write_locked() as meta: 223 | m = await self.load_from_meta(name, tag) 224 | 225 | await meta.remove(query_by_id(m.id)) 226 | await self.on_delete.send(m.name, m.tag) 227 | 228 | # Remove the "latest" model link. 229 | query = query_by_name_and_tag(m.name, model.Tag.Latest.value) 230 | await meta.remove(query) 231 | 232 | # Retrieve a new "latest" model. 233 | key = operator.itemgetter("created_at") 234 | document = await meta.latest(query_by_name(m.name), key) 235 | 236 | latest = self.build_model_from_document(document) 237 | latest.tag = model.Tag.Latest.value 238 | 239 | await meta.insert(latest.to_dict()) 240 | await self.on_delete.send(m.name, model.Tag.Latest.value) 241 | await self.on_save.send(latest) 242 | return m 243 | 244 | async def delete(self, name: str, tag: str) -> None: 245 | """Remove model with the given name and tag.""" 246 | if tag == model.Tag.Latest.value: 247 | raise model.NotFoundError(name, tag) 248 | 249 | try: 250 | # Remove the model from the metadata database. 251 | m = await self.delete_from_meta(name, tag) 252 | 253 | # Remove the model data from the file system. 254 | await self.await_in_thread(asynclib.remove_dir(m.path)) 255 | 256 | self.logger.info("Removed model %s:%s", name, tag) 257 | except FileNotFoundError: 258 | raise errors.NotFoundError(name, tag) 259 | 260 | async def load_from_meta(self, name: str, tag: str): 261 | document = await self.meta.get(query_by_name_and_tag(name, tag)) 262 | if not document: 263 | raise errors.NotFoundError(name, tag) 264 | return self.build_model_from_document(document) 265 | 266 | async def load(self, name: str, tag: str) -> model.Model: 267 | """Load model with the given name and tag.""" 268 | m = await self.load_from_meta(name, tag) 269 | return await self.await_in_thread(asyncio.coroutine(m.load)()) 270 | 271 | async def export(self, name: str, tag: str, writer: io.IOBase) -> None: 272 | """Export serialized model. 273 | 274 | Method writes a serialized TAR to the stream. 275 | """ 276 | m = await self.load_from_meta(name, tag) 277 | 278 | coro = asynclib.create_tar(fileobj=writer, path=m.path) 279 | await self.await_in_thread(coro) 280 | 281 | 282 | class FsExperimentsStorage(experiment.AbstractStorage): 283 | 284 | @classmethod 285 | def new(cls, 286 | path: pathlib.Path, 287 | logger: logging.Logger = tensorcraft.logging.internal_logger): 288 | self = cls() 289 | logger.info("Using file storage experiment engine") 290 | 291 | self.logger = logger 292 | self.rwlock = aiorwlock.RWLock() 293 | self.db = tinydb.TinyDB(path=path.joinpath("experiments.json"), 294 | default_table="experiments") 295 | return self 296 | 297 | async def close(self) -> None: 298 | async with self.rwlock.writer_lock: 299 | self.db.close() 300 | 301 | async def save(self, e: experiment.Experiment) -> None: 302 | """Save the given experiment.""" 303 | async with self.rwlock.writer_lock: 304 | self.db.upsert(e.asdict(), query_by_name(e.name)) 305 | 306 | def build_experiment_from_document(self, 307 | doc: Dict) -> experiment.Experiment: 308 | return experiment.Experiment.from_dict(uid=doc.pop("id"), **doc) 309 | 310 | def query_experiment(self, name: str) -> experiment.Experiment: 311 | doc = self.db.get(query_by_name(name)) 312 | if doc is None: 313 | raise Exception(f"experiment '{name}' not found") 314 | return self.build_experiment_from_document(doc) 315 | 316 | async def save_epoch(self, name: str, epoch: experiment.Epoch) -> None: 317 | """Add epoch with generated metrics to the experiment.""" 318 | async with self.rwlock.writer_lock: 319 | e = self.query_experiment(name) 320 | e.epochs.append(epoch) 321 | self.db.upsert(e.asdict(), query_by_name(e.name)) 322 | 323 | async def load(self, name: str) -> experiment.Experiment: 324 | """Load experiment with the given name.""" 325 | async with self.rwlock.reader_lock: 326 | return self.query_experiment(name) 327 | 328 | async def all(self) -> Sequence[experiment.Experiment]: 329 | async with self.rwlock.reader_lock: 330 | for doc in self.db.all(): 331 | yield self.build_experiment_from_document(doc) 332 | -------------------------------------------------------------------------------- /tensorcraft/callbacks.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import pathlib 3 | import semver 4 | import tarfile 5 | import tempfile 6 | 7 | from tensorflow import keras 8 | from tensorflow.keras import callbacks 9 | 10 | from tensorcraft import asynclib 11 | from tensorcraft import client 12 | 13 | 14 | class _RemoteCallback(callbacks.Callback): 15 | """Callback with default session for communication with remote server. 16 | 17 | Args: 18 | service_url -- endpoint to the server 19 | tls -- true to use TLS 20 | tlsverify -- use TLS and verify remote 21 | tlscacert -- trust certs signed only by this CA 22 | tlscert -- path to TLS certificate file 23 | tlskey -- path to TLS key file 24 | """ 25 | 26 | def __init__(self, 27 | service_url: str = "localhost:5678", 28 | tls: bool = False, 29 | tlsverify: bool = False, 30 | tlscacert: pathlib.Path = "cacert.pem", 31 | tlscert: pathlib.Path = "cert.pem", 32 | tlskey: pathlib.Path = "key.pem"): 33 | super().__init__() 34 | 35 | self.service_url = service_url 36 | self.tls = tls 37 | self.tlsverify = tlsverify 38 | self.tlscacert = tlscacert 39 | self.tlscert = tlscert 40 | self.tlskey = tlskey 41 | 42 | def new_session(self): 43 | return client.Session.new( 44 | service_url=self.service_url, tls=self.tls, 45 | tlsverify=self.tlsverify, tlscacert=self.tlscacert, 46 | tlscert=self.tlscert, tlskey=self.tlskey) 47 | 48 | def on_train_begin(self, logs=None) -> None: 49 | self.loop = asyncio.get_event_loop() 50 | self.session = self.loop.run_until_complete(self.new_session()) 51 | 52 | def on_train_end(self, logs=None) -> None: 53 | self.loop.run_until_complete(self.session.close()) 54 | 55 | 56 | class ModelCheckpoint(_RemoteCallback): 57 | """Publish model to server after every epoch. 58 | 59 | Args: 60 | name -- name of the model, when name is not given, name attribute of 61 | the model will be used 62 | tag -- tag of the model, by default is "0.0.0", every iteration will 63 | bump build version, so on the next epoch version will be 64 | "0.0.0+build1"; tag must be valid semantic version 65 | """ 66 | 67 | def __init__(self, name: str = None, tag: str = "0.0.0", 68 | verbose: int = 0, **kwargs) -> None: 69 | super().__init__(**kwargs) 70 | 71 | self.name = name 72 | self.tag = tag 73 | self.verbose = verbose 74 | 75 | def on_train_begin(self, logs=None) -> None: 76 | super().on_train_begin(logs) 77 | self.models = client.Model(self.session) 78 | 79 | def on_epoch_end(self, epoch, logs=None) -> None: 80 | with tempfile.TemporaryDirectory() as td: 81 | modelpath = pathlib.Path(td, "model") 82 | 83 | # Pure keras models (in contrast with tf.keras), must be saved 84 | # into h5 format first and loaded using tf model loader. 85 | h5path = modelpath.with_suffix(".h5") 86 | self.model.save(h5path) 87 | 88 | # Now this model can be translated into tf entities and saved 89 | # into the serving format. 90 | model = keras.models.load_model(h5path) 91 | keras.experimental.export_saved_model(model, str(modelpath)) 92 | 93 | tarpath = modelpath.with_suffix(".tar") 94 | with tarfile.open(str(tarpath), mode="w") as tar: 95 | tar.add(str(modelpath), arcname="") 96 | 97 | asyncreader = asynclib.reader(tarpath) 98 | 99 | # Use explicit name when set, use generated model name instead. 100 | name = self.name or self.model.name 101 | tag = semver.bump_build(self.tag) 102 | 103 | if self.verbose > 0: 104 | print("\nEpoch {0:5d}: pushing model {1}:{2}". 105 | format(epoch + 1, name, tag)) 106 | 107 | task = self.models.push(name, tag, asyncreader) 108 | self.loop.run_until_complete(task) 109 | 110 | # Update tag after successful model publish. 111 | self.tag = tag 112 | 113 | 114 | class ExperimentCallback(_RemoteCallback): 115 | """Publish metrics of model on each epoch end. 116 | 117 | Args: 118 | experiment_name -- name of the experiment used to trace metrics. 119 | """ 120 | 121 | def __init__(self, experiment_name: str, **kwargs) -> None: 122 | super().__init__(**kwargs) 123 | 124 | self.experiment_name = experiment_name 125 | 126 | def on_train_begin(self, logs=None) -> None: 127 | super().on_train_begin(logs) 128 | self.experiemnts = client.Experiment(self.session) 129 | 130 | def on_epoch_end(self, epoch, logs=None) -> None: 131 | # Add support of non-eager execution. 132 | metrics = [dict(name=m.name, value=m.result().numpy()) 133 | for m in self.model.metrics] 134 | 135 | task = self.experiments.trace(experiment_name, metrics) 136 | self.loop.run_until_complete(task) 137 | -------------------------------------------------------------------------------- /tensorcraft/client.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | import aiohttp.web 3 | import numpy 4 | import ssl 5 | 6 | import tensorcraft 7 | import tensorcraft.asynclib 8 | 9 | from tensorcraft import arglib 10 | from tensorcraft import errors 11 | from tensorcraft import tlslib 12 | 13 | from types import TracebackType 14 | from typing import Dict, IO, NamedTuple, Optional, Sequence, Union, Type 15 | from urllib.parse import urlparse, urlunparse 16 | 17 | 18 | class Session: 19 | 20 | default_headers = {"Accept-Version": 21 | ">={0}".format(tensorcraft.__apiversion__)} 22 | 23 | def __init__(self, service_url: str, 24 | ssl_context: Union[ssl.SSLContext, None] = None): 25 | 26 | # Change the protocol to "HTTPS" if SSL context is given. 27 | if ssl_context: 28 | url = urlparse(service_url) 29 | _, *parts = url 30 | service_url = urlunparse(["https"]+parts) 31 | 32 | self.service_url = service_url 33 | self.session = aiohttp.ClientSession( 34 | connector=aiohttp.TCPConnector(ssl_context=ssl_context), 35 | headers=self.default_headers, 36 | ) 37 | 38 | @property 39 | def default_headers(self) -> Dict: 40 | return {"Accept-Version": f">={tensorcraft.__apiversion__}"} 41 | 42 | async def __aenter__(self) -> aiohttp.ClientSession: 43 | return await self.session.__aenter__() 44 | 45 | async def __aexit__(self, 46 | exc_type: Optional[Type[BaseException]], 47 | exc_val: Optional[BaseException], 48 | exc_tb: Optional[TracebackType]) -> None: 49 | await self.session.__aexit__(exc_type, exc_val, exc_tb) 50 | 51 | def url(self, path: str) -> str: 52 | return f"{self.service_url}/{path}" 53 | 54 | async def close(self) -> None: 55 | """Close the session and interrupt communication with remote server.""" 56 | await self.session.close() 57 | 58 | @classmethod 59 | async def new(cls, **kwargs): 60 | ssl_args = arglib.filter_callable_arguments( 61 | tlslib.create_client_ssl_context, **kwargs) 62 | 63 | ssl_context = tlslib.create_client_ssl_context(**ssl_args) 64 | self = cls(kwargs.get("service_url"), ssl_context) 65 | return self 66 | 67 | 68 | class Model: 69 | """A client to do basic model operations remotely 70 | 71 | An asynchronous client used to publish, remove and list 72 | available models. 73 | 74 | Attributes: 75 | session -- connection to remote server 76 | """ 77 | 78 | 79 | def __init__(self, session: Session) -> None: 80 | self.session = session 81 | 82 | async def __aenter__(self) -> "Model": 83 | return self 84 | 85 | async def __aexit__(self, 86 | exc_type: Optional[Type[BaseException]], 87 | exc_val: Optional[BaseException], 88 | exc_tb: Optional[TracebackType]) -> None: 89 | await self.session.close() 90 | 91 | @classmethod 92 | async def new(cls, **kwargs): 93 | return cls(await Session.new(**kwargs)) 94 | 95 | def make_error_from_response(self, 96 | resp: aiohttp.web.Response, 97 | success_status=200) -> Optional[Exception]: 98 | if resp.status != success_status: 99 | error_code = resp.headers.get("Error-Code", 0) 100 | return errors.ModelError.from_error_code(error_code) 101 | return None 102 | 103 | async def push(self, name: str, tag: str, reader: IO) -> None: 104 | """Push the model to the server. 105 | 106 | The model is expected to be a tarball with in a SaveModel 107 | format. 108 | """ 109 | async with self.session as session: 110 | url = self.session.url(f"models/{name}/{tag}") 111 | resp = await session.put(url, data=reader) 112 | 113 | error_class = self.make_error_from_response(resp, 114 | success_status=201) 115 | if error_class: 116 | raise error_class(name, tag) 117 | 118 | async def remove(self, name: str, tag: str) -> None: 119 | """Remove the model from the server. 120 | 121 | Method raises error when the model is missing. 122 | """ 123 | async with self.session as session: 124 | url = self.session.url(f"models/{name}/{tag}") 125 | resp = await session.delete(url) 126 | 127 | error_class = self.make_error_from_response(resp) 128 | if error_class: 129 | raise error_class(name, tag) 130 | 131 | async def list(self): 132 | """List available models on the server.""" 133 | async with self.session as session: 134 | async with session.get(self.session.url("models")) as resp: 135 | return await resp.json() 136 | 137 | async def export(self, name: str, tag: str, writer: IO) -> None: 138 | """Export the model from the server.""" 139 | async with self.session as session: 140 | resp = await session.get(self.session.url(f"models/{name}/{tag}")) 141 | 142 | error_class = self.make_error_from_response(resp) 143 | if error_class: 144 | raise error_class(name, tag) 145 | 146 | await writer.write(await resp.read()) 147 | 148 | async def predict(self, name: str, tag: str, 149 | x_pred: Union[numpy.array, list]) -> numpy.array: 150 | """Feed X array to the given model and retrieve prediction.""" 151 | async with self.session as session: 152 | url = self.session.url(f"models/{name}/{tag}/predict") 153 | async with session.post(url, json=dict(x=x_pred)) as resp: 154 | error_class = self.make_error_from_response(resp) 155 | if error_class: 156 | raise error_class(name, tag) 157 | 158 | resp_data = await resp.json() 159 | return numpy.array(resp_data.get("y")) 160 | 161 | async def status(self) -> Dict[str, str]: 162 | async with self.session as session: 163 | resp = await session.get(self.session.url("status")) 164 | return await resp.json() 165 | 166 | 167 | class _Metric(NamedTuple): 168 | name: str 169 | value: float 170 | 171 | 172 | class Experiment: 173 | """A client to do basic experiments operations remotely 174 | 175 | An asynchronous client used to create, remove and update experiments 176 | 177 | Attributes: 178 | session -- connection to remove server 179 | """ 180 | 181 | def __init__(self, session: Session) -> None: 182 | self.session = session 183 | 184 | async def create(self, name: str) -> None: 185 | async with self.session as session: 186 | await session.post(self.session.url("experiments"), 187 | json=dict(name=name)) 188 | 189 | async def trace(self, 190 | experiment_name: str, 191 | metrics: Sequence[_Metric]) -> None: 192 | async with self.session as session: 193 | url = self.session.url(f"experiments/{experiment_name}/epochs") 194 | await session.post(url, json=dict(metrics=metrics)) 195 | -------------------------------------------------------------------------------- /tensorcraft/errors.py: -------------------------------------------------------------------------------- 1 | class InputShapeError(Exception): 2 | """Exception raised for invalid model input shape 3 | 4 | Attributes: 5 | expected_dims -- model's dimensions 6 | actual_dims -- input dimensions 7 | """ 8 | 9 | def __init__(self, expected_dims, actual_dims): 10 | self.expected_dims = tuple(expected_dims) 11 | self.actual_dims = tuple(actual_dims) 12 | 13 | def __str__(self): 14 | return "Input shape is {0}, while {1} is given.".format( 15 | self.expected_dims, self.actual_dims) 16 | 17 | 18 | class _ModelErrorMeta(type): 19 | 20 | error_mapping = {} 21 | 22 | def __new__(cls, *args, **kwargs): 23 | klass = super().__new__(cls, *args, **kwargs) 24 | if hasattr(klass, "error_code"): 25 | cls.error_mapping[klass.error_code] = klass 26 | return klass 27 | 28 | def from_error_code(self, error_code: str): 29 | return self.error_mapping.get(error_code, ModelError) 30 | 31 | 32 | class ModelError(Exception, metaclass=_ModelErrorMeta): 33 | 34 | error_code = "Model Error" 35 | 36 | def __init__(self, name: str, tag: str): 37 | self.name = name 38 | self.tag = tag 39 | 40 | 41 | class NotFoundError(ModelError): 42 | """Exception raised on missing model.""" 43 | 44 | error_code = "Model Not Found" 45 | 46 | def __str__(self): 47 | return f"Model {self.name}:{self.tag} not found" 48 | 49 | 50 | class NotLoadedError(ModelError): 51 | """Exception raised on attempt to use not loaded model.""" 52 | 53 | error_code = "Model Not Loaded" 54 | 55 | def __str__(self): 56 | return f"Model {self.name}:{self.tag} is not loaded" 57 | 58 | 59 | class DuplicateError(ModelError): 60 | """Exception raised on attempt to save model with the same name and tag.""" 61 | 62 | error_code = "Model Duplicate" 63 | 64 | def __str__(self): 65 | return f"Model {self.name}:{self.tag} is a duplicate" 66 | 67 | 68 | class LatestTagError(ModelError): 69 | """Exception raised on attempt to save model with latest tag.""" 70 | 71 | error_code = "Model Latest Tag" 72 | 73 | def __str__(self): 74 | return f"Model {self.name}:{self.tag} cannot be latest" 75 | -------------------------------------------------------------------------------- /tensorcraft/experiment.py: -------------------------------------------------------------------------------- 1 | class Experiment: 2 | pass 3 | -------------------------------------------------------------------------------- /tensorcraft/logging.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import os 4 | 5 | 6 | # Disable logging from TensorFlow CPP files. 7 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 8 | 9 | # Disable asyncio and TensorFlow logging. 10 | logging.getLogger("asyncio").disabled = True 11 | logging.getLogger("tensorflow").disabled = True 12 | 13 | 14 | try: 15 | # This import is done in a hope that TensorFlow will drop this 16 | # dependency in the future versions. 17 | absl_logging = importlib.import_module("absl.logging") 18 | 19 | # ABSL logging removes all root handlers and puts itself to the list of 20 | # root handlers, this dependency comes with TensorFlow, and we want to 21 | # modify this behaviour to make logs pretty and consistent. 22 | for h in logging.root.handlers: 23 | if isinstance(h, absl_logging.ABSLHandler): 24 | logging.root.removeHandler(h) 25 | except ModuleNotFoundError: 26 | pass 27 | # Nothing to do. 28 | 29 | 30 | internal_format = "{asctime} {levelname} - {message}" 31 | 32 | internal_handler = logging.StreamHandler() 33 | internal_handler.setFormatter(logging.Formatter(internal_format, style="{")) 34 | 35 | internal_logger = logging.getLogger("tensorcraft") 36 | internal_logger.addHandler(internal_handler) 37 | internal_logger.setLevel(logging.DEBUG) 38 | -------------------------------------------------------------------------------- /tensorcraft/server.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | import aiohttp.web 3 | import asyncio 4 | import logging 5 | import pathlib 6 | import pid 7 | import semver 8 | 9 | import tensorcraft 10 | 11 | from aiojobs.aiohttp import atomic, setup 12 | from functools import partial 13 | from typing import Awaitable 14 | 15 | from tensorcraft import arglib 16 | from tensorcraft import tlslib 17 | from tensorcraft.backend import httpapi 18 | from tensorcraft.backend import model 19 | from tensorcraft.backend import saving 20 | from tensorcraft.logging import internal_logger 21 | 22 | 23 | class Server: 24 | """Serve the models.""" 25 | 26 | @classmethod 27 | async def new(cls, data_root: str, pidfile: str, 28 | host: str = None, port: str = None, 29 | preload: bool = False, 30 | close_timeout: int = 10, 31 | strategy: str = model.Strategy.No.value, 32 | logger: logging.Logger = internal_logger): 33 | """Create new instance of the server.""" 34 | 35 | self = cls() 36 | 37 | pidfile = pathlib.Path(pidfile) 38 | self.pid = pid.PidFile(piddir=pidfile.parent, pidname=pidfile.name) 39 | 40 | # Create a data root directory where all server data is persisted. 41 | data_root = pathlib.Path(data_root) 42 | data_root.mkdir(parents=True, exist_ok=True) 43 | 44 | # TODO: use different execution strategies for models and 45 | # fallback to the server-default execution strategy. 46 | loader = model.Loader(strategy=strategy, logger=logger) 47 | 48 | storage = saving.FsModelsStorage.new(path=data_root, loader=loader) 49 | models = await model.Cache.new(storage=storage, preload=preload) 50 | 51 | # Experiments storage based on regular file system. 52 | experiments = saving.FsExperimentsStorage.new(path=data_root) 53 | 54 | self.app = aiohttp.web.Application(client_max_size=1024**10) 55 | 56 | self.app.on_startup.append(cls.app_callback(self.pid.create)) 57 | self.app.on_response_prepare.append(self._prepare_response) 58 | self.app.on_shutdown.append(cls.app_callback(storage.close)) 59 | self.app.on_shutdown.append(cls.app_callback(experiments.close)) 60 | self.app.on_shutdown.append(cls.app_callback(self.pid.close)) 61 | 62 | route = partial(route_to, api_version=tensorcraft.__apiversion__) 63 | 64 | models_view = httpapi.ModelView(models) 65 | server_view = httpapi.ServerView(models) 66 | experiments_view = httpapi.ExperimentView(experiments) 67 | 68 | self.app.add_routes([ 69 | # Model-related endpoints. 70 | aiohttp.web.get(models_view.list.url, route(models_view.list)), 71 | aiohttp.web.put(models_view.save.url, route(models_view.save)), 72 | aiohttp.web.get(models_view.export.url, route(models_view.export)), 73 | aiohttp.web.delete(models_view.delete.url, 74 | route(models_view.delete)), 75 | aiohttp.web.post(models_view.predict.url, 76 | route(models_view.predict)), 77 | 78 | # Experiment-related endpoints. 79 | aiohttp.web.post(experiments_view.create.url, 80 | route(experiments_view.create)), 81 | aiohttp.web.post(experiments_view.create_epoch.url, 82 | route(experiments_view.create_epoch)), 83 | aiohttp.web.get(experiments_view.get.url, 84 | route(experiments_view.get)), 85 | aiohttp.web.get(experiments_view.list.url, 86 | route(experiments_view.list)), 87 | 88 | # Server-related endpoints. 89 | aiohttp.web.get(server_view.status.url, route(server_view.status)), 90 | # aiohttp.web.static("/ui", "static"), 91 | ]) 92 | 93 | setup(self.app) 94 | logger.info("Server initialization completed") 95 | 96 | return self 97 | 98 | async def _prepare_response(self, request, response): 99 | server = "TensorCraft/{0}".format(tensorcraft.__version__) 100 | response.headers["Server"] = server 101 | response.headers["Access-Control-Allow-Origin"] = "*" 102 | 103 | @classmethod 104 | def start(cls, **kwargs): 105 | """Start serving the models. 106 | 107 | Run event loop to handle the requests. 108 | """ 109 | application_args = arglib.filter_callable_arguments(cls.new, **kwargs) 110 | 111 | async def application_factory(): 112 | s = await cls.new(**application_args) 113 | return s.app 114 | 115 | ssl_args = arglib.filter_callable_arguments( 116 | tlslib.create_server_ssl_context, **kwargs) 117 | ssl_context = tlslib.create_server_ssl_context(**ssl_args) 118 | 119 | aiohttp.web.run_app(application_factory(), 120 | print=None, 121 | ssl_context=ssl_context, 122 | host=kwargs.get("host"), port=kwargs.get("port")) 123 | 124 | @classmethod 125 | def app_callback(cls, awaitable): 126 | async def on_signal(app): 127 | coroutine = awaitable() 128 | if asyncio.iscoroutine(coroutine): 129 | await coroutine 130 | return on_signal 131 | 132 | 133 | def handle_accept_version(req: aiohttp.web.Request, api_version: str): 134 | default_version = "=={0}".format(api_version) 135 | req_version = req.headers.get("Accept-Version", default_version) 136 | 137 | try: 138 | match = semver.match(api_version, req_version) 139 | except ValueError as e: 140 | raise aiohttp.web.HTTPNotAcceptable(text=str(e)) 141 | else: 142 | if not match: 143 | text = ("accept version {0} does not match API version {1}" 144 | ).format(req_version, api_version) 145 | raise aiohttp.web.HTTPNotAcceptable(text=text) 146 | 147 | 148 | def accept_version(handler: Awaitable, api_version: str) -> Awaitable: 149 | async def _f(req: aiohttp.web.Request) -> aiohttp.web.Response: 150 | handle_accept_version(req, api_version) 151 | return await handler(req) 152 | return _f 153 | 154 | 155 | def route_to(handler: Awaitable, api_version: str) -> Awaitable: 156 | """Create a route with the API version validation. 157 | 158 | Returns handler decorated with API version check. 159 | """ 160 | return atomic(accept_version(handler, api_version)) 161 | -------------------------------------------------------------------------------- /tensorcraft/shell/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/netrack/tensorcraft/15e0c54b795f4ce527cc5e2c46bbb7da434ac036/tensorcraft/shell/__init__.py -------------------------------------------------------------------------------- /tensorcraft/shell/commands.py: -------------------------------------------------------------------------------- 1 | import aiofiles 2 | import argparse 3 | import flagparse 4 | import importlib 5 | import pathlib 6 | import tarfile 7 | import yaml 8 | 9 | import tensorcraft.errors 10 | 11 | from tensorcraft import asynclib 12 | from tensorcraft import client 13 | from tensorcraft.shell import termlib 14 | 15 | 16 | class AsyncSubCommand(flagparse.SubCommand): 17 | 18 | async def async_handle(self, args: flagparse.Namespace) -> None: 19 | pass 20 | 21 | def handle(self, args: flagparse.Namespace) -> None: 22 | asynclib.run(self.async_handle(args)) 23 | 24 | 25 | class Server(flagparse.SubCommand): 26 | """Server shell command used to run a server.""" 27 | 28 | name = "server" 29 | aliases = ["s"] 30 | help = "run server" 31 | 32 | description = "Start serving models." 33 | 34 | arguments = [ 35 | (["-H", "--host"], 36 | dict(metavar="HOST", 37 | help="address to listen to", 38 | default="localhost")), 39 | (["-p", "--port"], 40 | dict(metavar="PORT", 41 | help="port to listen to", 42 | default="5678")), 43 | (["--data-root"], 44 | dict(metavar="PATH", 45 | help="root directory of persistent state", 46 | default="/var/lib/tensorcraft")), 47 | (["--pidfile"], 48 | dict(metavar="PIDFILE", 49 | help="path to use for daemon pid file", 50 | default="/var/run/tensorcraft.pid")), 51 | (["--strategy"], 52 | dict(metavar="STRATEGY", 53 | choices=["mirrored", "multi_worker_mirrored", "no"], 54 | default="mirrored", 55 | help="model execution strategy")), 56 | (["--preload"], 57 | dict(action="store_true", 58 | default=False, 59 | help="preload all models into the memory before start"))] 60 | 61 | def handle(self, args: flagparse.Namespace) -> None: 62 | try: 63 | server = importlib.import_module("tensorcraft.server") 64 | server.Server.start(**args.__dict__) 65 | except Exception as e: 66 | raise flagparse.ExitError(1, f"Failed to start server. {e}.") 67 | 68 | 69 | class Push(AsyncSubCommand): 70 | """Shell command to push model to the server.""" 71 | 72 | name = "push" 73 | aliases = ["p"] 74 | help = "push model" 75 | 76 | description = "Push a model image to the repository." 77 | 78 | arguments = [ 79 | (["-n", "--name"], 80 | dict(metavar="NAME", 81 | type=str, 82 | required=True, 83 | default=argparse.SUPPRESS, 84 | help="model name")), 85 | (["-t", "--tag"], 86 | dict(metavar="TAG", 87 | type=str, 88 | required=True, 89 | default=argparse.SUPPRESS, 90 | help="model tag")), 91 | (["path"], 92 | dict(metavar="PATH", 93 | type=pathlib.Path, 94 | default=argparse.SUPPRESS, 95 | help="model location"))] 96 | 97 | async def async_handle(self, args: flagparse.Namespace) -> None: 98 | print(f"loading model {args.name}:{args.tag}") 99 | 100 | try: 101 | if not args.path.exists(): 102 | raise ValueError(f"{args.path} does not exist") 103 | if not tarfile.is_tarfile(str(args.path)): 104 | raise ValueError(f"{args.path} is not a tar file") 105 | 106 | asyncreader = asynclib.reader(args.path) 107 | reader = termlib.async_progress(args.path, asyncreader) 108 | 109 | models_client = await client.Model.new(**args.__dict__) 110 | async with models_client as models: 111 | await models.push(args.name, args.tag, reader) 112 | except Exception as e: 113 | raise flagparse.ExitError(1, f"Failed to push model. {e}") 114 | 115 | 116 | class Remove(AsyncSubCommand): 117 | """Shell command to remove the model from server.""" 118 | 119 | name = "remove" 120 | aliases = ["rm"] 121 | help = "remove model" 122 | 123 | description = "Remove a model from the repository." 124 | 125 | arguments = [ 126 | (["-n", "--name"], 127 | dict(metavar="NAME", 128 | type=str, 129 | help="model name")), 130 | (["-q", "--quiet"], 131 | dict(action="store_true", 132 | help="do not return error on missing model")), 133 | (["-t", "--tag"], 134 | dict(metavar="TAG", 135 | type=str, 136 | help="model tag"))] 137 | 138 | async def async_handle(self, args: flagparse.Namespace) -> None: 139 | try: 140 | models_client = await client.Model.new(**args.__dict__) 141 | async with models_client as models: 142 | await models.remove(args.name, args.tag) 143 | except tensorcraft.errors.NotFoundError as e: 144 | if not args.quiet: 145 | raise flagparse.ExitError(1, f"{e}") 146 | except Exception as e: 147 | raise flagparse.ExitError(1, f"Failed to remove model. {e}.") 148 | 149 | 150 | class List(AsyncSubCommand): 151 | """Shell command to list models from the server.""" 152 | 153 | name = "list" 154 | aliases = ["ls"] 155 | help = "list models" 156 | 157 | description = "List available models." 158 | 159 | arguments = [] 160 | 161 | async def async_handle(self, args: flagparse.Namespace) -> None: 162 | try: 163 | models_client = await client.Model.new(**args.__dict__) 164 | async with models_client as models: 165 | for model in await models.list(): 166 | print("{name}:{tag}".format(**model)) 167 | except Exception as e: 168 | raise flagparse.ExitError(1, f"Failed to list models. {e}.") 169 | 170 | 171 | class Export(AsyncSubCommand): 172 | """Shell command to export model from the server.""" 173 | 174 | name = "export" 175 | aliases = [] 176 | help = "export model tar" 177 | 178 | description = "Export model as TAR." 179 | 180 | arguments = [ 181 | (["-n", "--name"], 182 | dict(metavar="NAME", 183 | type=str, 184 | required=True, 185 | default=argparse.SUPPRESS, 186 | help="model name")), 187 | (["-t", "--tag"], 188 | dict(metavar="TAG", 189 | type=str, 190 | required=True, 191 | default=argparse.SUPPRESS, 192 | help="model tag")), 193 | (["path"], 194 | dict(metavar="PATH", 195 | type=pathlib.Path, 196 | default=argparse.SUPPRESS, 197 | help="file location"))] 198 | 199 | async def async_handle(self, args: flagparse.Namespace) -> None: 200 | try: 201 | async with aiofiles.open(args.path, "wb+") as writer: 202 | models_client = await client.Model.new(**args.__dict__) 203 | async with models_client as models: 204 | await models.export(args.name, args.tag, writer) 205 | except Exception as e: 206 | raise flagparse.ExitError(1, f"Failed to export model. {e}") 207 | 208 | 209 | class Status(AsyncSubCommand): 210 | """Shell command to retrieve server status information.""" 211 | 212 | name = "status" 213 | aliases = [] 214 | help = "server status" 215 | 216 | description = "Retrieve server status.""" 217 | 218 | arguments = [] 219 | 220 | async def async_handle(self, args: flagparse.Namespace) -> None: 221 | try: 222 | models_client = await client.Model.new(**args.__dict__) 223 | async with models_client as models: 224 | status = await models.status() 225 | print(yaml.dump(status), end="") 226 | except Exception as e: 227 | raise flagparse.ExitError(1, f"Failed to export model. {e}") 228 | -------------------------------------------------------------------------------- /tensorcraft/shell/main.py: -------------------------------------------------------------------------------- 1 | import flagparse 2 | 3 | import tensorcraft 4 | 5 | from tensorcraft.shell import commands 6 | 7 | 8 | class Command(flagparse.Command): 9 | 10 | name = "tensorcraft" 11 | 12 | arguments = [ 13 | (["-s", "--service-url"], 14 | dict(help="service endpoint", 15 | default="http://localhost:5678")), 16 | (["--tls"], 17 | dict(action="store_true", 18 | default=False, 19 | help="use TLS")), 20 | (["--tlsverify"], 21 | dict(action="store_true", 22 | default=False, 23 | help="use TLS and verify remote")), 24 | (["--tlscacert"], 25 | dict(metavar="TLS_CACERT", 26 | default=tensorcraft.homepath.joinpath("cacert.pem"), 27 | help="trust certs signed only by this CA")), 28 | (["--tlscert"], 29 | dict(metavar="TLS_CERT", 30 | default=tensorcraft.homepath.joinpath("cert.pem"), 31 | help="path to TLS certificate file")), 32 | (["--tlskey"], 33 | dict(metavar="TLS_KEY", 34 | default=tensorcraft.homepath.joinpath("key.pem"), 35 | help="path to TLS key file")), 36 | (["-v", "--version"], 37 | dict(help="print version and exit", 38 | action="version", 39 | version="%(prog)s {0}".format(tensorcraft.__version__)))] 40 | 41 | def handle(self, args: flagparse.Namespace) -> None: 42 | self.parser.print_help() 43 | 44 | 45 | def main(): 46 | Command([commands.Server, 47 | commands.Push, 48 | commands.Remove, 49 | commands.List, 50 | commands.Export, 51 | commands.Status]).parse(trace=True) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /tensorcraft/shell/termlib.py: -------------------------------------------------------------------------------- 1 | import humanize 2 | import pathlib 3 | 4 | from typing import Coroutine 5 | 6 | 7 | async def async_progress(path: pathlib.Path, reader: Coroutine) -> bytes: 8 | def progress(loaded, total, bar_len=30): 9 | filled_len = int(round(bar_len * loaded / total)) 10 | empty_len = bar_len - filled_len 11 | 12 | loaded = humanize.naturalsize(loaded).replace(" ", "") 13 | total = humanize.naturalsize(total).replace(" ", "") 14 | 15 | bar = "=" * filled_len + " " * empty_len 16 | print(f"[{bar}] {loaded}/{total}\r", end="", flush=True) 17 | 18 | total = path.stat().st_size 19 | loaded = 0 20 | 21 | progress(loaded, total) 22 | async for chunk in reader: 23 | yield chunk 24 | loaded += len(chunk) 25 | progress(loaded, total) 26 | 27 | progress(loaded, total) 28 | print("", flush=True) 29 | -------------------------------------------------------------------------------- /tensorcraft/signal.py: -------------------------------------------------------------------------------- 1 | class Signal: 2 | """Registers the signal subscribers and delivers signal when necessary.""" 3 | 4 | def __init__(self): 5 | self.receivers = [] 6 | 7 | def append(self, receiver): 8 | self.receivers.append(receiver) 9 | 10 | async def send(self, *args, **kwargs): 11 | for receiver in frozenset(self.receivers): 12 | await receiver(*args, **kwargs) 13 | -------------------------------------------------------------------------------- /tensorcraft/tlslib.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | import ssl 4 | 5 | from tensorcraft.logging import internal_logger 6 | 7 | 8 | def create_server_ssl_context(tls: bool = False, 9 | tlsverify: bool = False, 10 | tlscert: str = None, 11 | tlskey: str = None, 12 | tlscacert: str = None, 13 | logger: logging.Logger = internal_logger): 14 | """Create server SSL context with the given TLS parameters.""" 15 | if not tls and not tlsverify: 16 | return None 17 | 18 | tlscert = pathlib.Path(tlscert).resolve(strict=True) 19 | tlskey = pathlib.Path(tlskey).resolve(strict=True) 20 | 21 | ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) 22 | ssl_context.load_cert_chain(tlscert, tlskey) 23 | ssl_context.verify_mode = ssl.CERT_NONE 24 | 25 | logger.info("Using transport layer security") 26 | 27 | if not tlsverify: 28 | return ssl_context 29 | 30 | tlscacert = pathlib.Path(tlscacert).resolve(strict=True) 31 | 32 | ssl_context.verify_mode = ssl.CERT_REQUIRED 33 | ssl_context.load_verify_locations(cafile=tlscacert) 34 | logger.info("Using peer certificates validation") 35 | 36 | return ssl_context 37 | 38 | 39 | def create_client_ssl_context(tls: bool = False, 40 | tlsverify: bool = False, 41 | tlscert: str = None, 42 | tlskey: str = None, 43 | tlscacert: str = None): 44 | """Create client SSL context with the given TLS parameters.""" 45 | if not tls and not tlsverify: 46 | return None 47 | 48 | tlscert = pathlib.Path(tlscert).resolve(strict=True) 49 | tlskey = pathlib.Path(tlskey).resolve(strict=True) 50 | 51 | ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) 52 | ssl_context.load_cert_chain(tlscert, tlskey) 53 | 54 | if tlsverify: 55 | tlscacert = pathlib.Path(tlscacert).resolve(strict=True) 56 | 57 | ssl_context.load_verify_locations(cafile=tlscacert) 58 | ssl_context.load_default_certs(ssl.Purpose.SERVER_AUTH) 59 | else: 60 | ssl_context.check_hostname = False 61 | ssl_context.verify_mode = ssl.CERT_NONE 62 | 63 | return ssl_context 64 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/netrack/tensorcraft/15e0c54b795f4ce527cc5e2c46bbb7da434ac036/tests/__init__.py -------------------------------------------------------------------------------- /tests/asynctest.py: -------------------------------------------------------------------------------- 1 | import aiohttp.web 2 | import asyncio 3 | import typing 4 | import unittest 5 | import unittest.mock 6 | 7 | 8 | class AsyncMagicMock(unittest.mock.MagicMock): 9 | 10 | async def __call__(self, *args, **kwargs): 11 | return super().__call__(*args, **kwargs) 12 | 13 | 14 | class AsyncGeneratorMock(unittest.mock.MagicMock): 15 | """Mock async generator type 16 | 17 | This type allows to pass a regular sequence of items in order 18 | to mimic asynchronous generator. 19 | """ 20 | 21 | def __init__(self, *args, return_value: typing.Sequence = [], **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self.iter = return_value.__iter__() 24 | self.return_value = self 25 | 26 | def __aiter__(self) -> typing.AsyncGenerator: 27 | return self 28 | 29 | async def __anext__(self): 30 | try: 31 | return self.iter.__next__() 32 | except StopIteration: 33 | raise StopAsyncIteration 34 | 35 | 36 | class AsyncTestCase(unittest.TestCase): 37 | 38 | def setUp(self): 39 | self.__loop = asyncio.get_event_loop() 40 | self.__loop.run_until_complete(self.setUpAsync()) 41 | 42 | def tearDown(self): 43 | self.__loop.run_until_complete(self.tearDownAsync()) 44 | 45 | async def setUpAsync(self) -> None: 46 | pass 47 | 48 | async def tearDownAsync(self) -> None: 49 | pass 50 | 51 | 52 | def unittest_run_loop(coroutine): 53 | def test(*args, **kwargs): 54 | loop = asyncio.get_event_loop() 55 | return loop.run_until_complete(coroutine(*args, **kwargs)) 56 | return test 57 | 58 | 59 | def unittest_handler(awaitable): 60 | async def _handler(req: aiohttp.web.Request) -> aiohttp.web.Response: 61 | return await awaitable() 62 | return _handler 63 | -------------------------------------------------------------------------------- /tests/clienttest.py: -------------------------------------------------------------------------------- 1 | import unittest.mock 2 | 3 | from tensorcraft.client import Model 4 | from tests import asynctest 5 | 6 | 7 | def unittest_mock_model_client(method: str): 8 | return unittest.mock.patch.object(Model, method, 9 | new_callable=asynctest.AsyncMagicMock) 10 | -------------------------------------------------------------------------------- /tests/cryptotest.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import math 3 | import numpy 4 | import pathlib 5 | import random 6 | import string 7 | 8 | from cryptography import x509 9 | from cryptography.hazmat.backends import default_backend 10 | from cryptography.hazmat.primitives import hashes 11 | from cryptography.hazmat.primitives import serialization 12 | from cryptography.hazmat.primitives.asymmetric import rsa 13 | from cryptography.x509.oid import NameOID 14 | from typing import Tuple 15 | 16 | 17 | def random_string(length=5): 18 | multiplier = math.ceil(length/len(string.ascii_letters)) 19 | return "".join(random.sample(string.ascii_letters*multiplier, length)) 20 | 21 | 22 | def random_bytes(length=1024): 23 | return bytes(random_string(length), "utf-8") 24 | 25 | 26 | def random_dict(items=10): 27 | return {random_string(): random_string() for _ in range(items)} 28 | 29 | 30 | def random_array(length=10) -> list: 31 | return numpy.random.uniform(0, 1, size=length).tolist() 32 | 33 | 34 | def create_self_signed_cert(path: pathlib.Path) -> Tuple[pathlib.Path, 35 | pathlib.Path]: 36 | key = rsa.generate_private_key( 37 | public_exponent=65537, 38 | key_size=2048, 39 | backend=default_backend(), 40 | ) 41 | 42 | subject = issuer = x509.Name([ 43 | x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), 44 | x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"), 45 | x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"), 46 | x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test Company"), 47 | x509.NameAttribute(NameOID.COMMON_NAME, "test.org"), 48 | ]) 49 | 50 | cert = x509.CertificateBuilder().subject_name( 51 | subject 52 | ).issuer_name( 53 | issuer 54 | ).public_key( 55 | key.public_key() 56 | ).serial_number( 57 | x509.random_serial_number() 58 | ).not_valid_before( 59 | datetime.datetime.utcnow() 60 | ).not_valid_after( 61 | # Certificate will be valid for 10 days 62 | datetime.datetime.utcnow() + datetime.timedelta(days=10) 63 | ).add_extension( 64 | x509.SubjectAlternativeName([x509.DNSName(u"localhost")]), 65 | critical=False, # Sign certificate with our private key. 66 | ).sign(key, hashes.SHA256(), default_backend()) 67 | 68 | keypath = path.joinpath("key.pem") 69 | with open(path.joinpath(keypath), "wb") as key_pem: 70 | key_pem.write(key.private_bytes( 71 | encoding=serialization.Encoding.PEM, 72 | format=serialization.PrivateFormat.TraditionalOpenSSL, 73 | encryption_algorithm=serialization.NoEncryption(), 74 | )) 75 | 76 | certpath = path.joinpath("cert.pem") 77 | with open(path.joinpath(certpath), "wb") as cert_pem: 78 | cert_pem.write(cert.public_bytes(serialization.Encoding.PEM)) 79 | 80 | return keypath, certpath 81 | -------------------------------------------------------------------------------- /tests/kerastest.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import pathlib 3 | import tarfile 4 | import tensorflow as tf 5 | import tempfile 6 | import uuid 7 | 8 | from collections import namedtuple 9 | 10 | from tensorcraft import asynclib 11 | from tensorcraft.backend import model 12 | from tests import cryptotest 13 | 14 | 15 | Model = namedtuple("Model", ["name", "tag", "tarpath", "url"]) 16 | 17 | 18 | @asynclib.asynccontextmanager 19 | async def crossentropy_model_tar(name: str, tag: str): 20 | with tempfile.TemporaryDirectory() as workdir: 21 | model = tf.keras.models.Sequential() 22 | model.add(tf.keras.layers.Activation("tanh")) 23 | model.compile(optimizer="sgd", loss="binary_crossentropy") 24 | 25 | n = 1000 26 | x = numpy.random.uniform(0, numpy.pi/2, (n, 1)) 27 | y = numpy.random.randint(2, size=(n, 1)) 28 | 29 | model.fit(x, y) 30 | 31 | workpath = pathlib.Path(workdir) 32 | dest = workpath.joinpath(uuid.uuid4().hex) 33 | tf.keras.experimental.export_saved_model(model, str(dest)) 34 | 35 | # Ensure that model has been created. 36 | assert dest.exists() 37 | 38 | tarpath = dest.with_suffix(".tar") 39 | with tarfile.open(str(tarpath), mode="w") as tar: 40 | tar.add(str(dest), arcname="") 41 | 42 | yield tarpath 43 | 44 | 45 | def new_model(name: str = None, 46 | tag: str = None): 47 | return model.Model.new(name=name or cryptotest.random_string(), 48 | tag=tag or cryptotest.random_string(), 49 | root=pathlib.Path("/")) 50 | -------------------------------------------------------------------------------- /tests/test_cache.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import unittest.mock 3 | 4 | from tensorcraft.backend.model import Cache, AbstractStorage 5 | from tests import asynctest 6 | from tests import kerastest 7 | 8 | 9 | class TestCache(asynctest.AsyncTestCase): 10 | 11 | async def setUpAsync(self) -> None: 12 | self.storage = unittest.mock.create_autospec(AbstractStorage) 13 | 14 | @asynctest.unittest_run_loop 15 | async def test_all(self): 16 | # Create loaded model. 17 | m = kerastest.new_model() 18 | m.model = unittest.mock.MagicMock() 19 | 20 | self.storage.all = asynctest.AsyncGeneratorMock(return_value=[m]) 21 | cache = await Cache.new(storage=self.storage) 22 | 23 | # Put unloaded model into the cache and ensure that it won't be 24 | # replaced the call to "all". 25 | cache.models[m.key] = m 26 | models = [m async for m in cache.all()] 27 | 28 | self.storage.all.assert_called() 29 | self.assertEqual(models, [m]) 30 | # Ensure all returned models are loaded. 31 | self.assertTrue(all(map(lambda m: m.loaded, models))) 32 | 33 | @asynctest.unittest_run_loop 34 | async def test_save(self): 35 | m1 = kerastest.new_model() 36 | 37 | self.storage.save = asynctest.AsyncMagicMock(return_value=m1) 38 | 39 | cache = await Cache.new(storage=self.storage) 40 | m2 = await cache.save(m1.name, m1.tag, None) 41 | 42 | self.storage.save.assert_called() 43 | self.assertEqual(m1, m2) 44 | self.assertIn(m1.key, cache.models) 45 | 46 | @asynctest.unittest_run_loop 47 | async def test_delete(self): 48 | m = kerastest.new_model() 49 | 50 | self.storage.delete = asynctest.AsyncMagicMock() 51 | 52 | cache = await Cache.new(storage=self.storage) 53 | cache.models[m.key] = m 54 | 55 | await cache.delete(m.name, m.tag) 56 | 57 | self.storage.delete.assert_called() 58 | self.assertNotIn(m.key, cache.models) 59 | 60 | @asynctest.unittest_run_loop 61 | async def test_delete_not_found(self): 62 | m = kerastest.new_model() 63 | 64 | self.storage.delete = asynctest.AsyncMagicMock() 65 | 66 | cache = await Cache.new(storage=self.storage) 67 | await cache.delete(m.name, m.tag) 68 | 69 | self.storage.delete.assert_called() 70 | self.assertNotIn(m.key, cache.models) 71 | 72 | @asynctest.unittest_run_loop 73 | async def test_load(self): 74 | m1 = kerastest.new_model() 75 | m1.model = unittest.mock.MagicMock() 76 | 77 | self.storage.load = asynctest.AsyncMagicMock() 78 | 79 | cache = await Cache.new(storage=self.storage) 80 | cache.models[m1.key] = m1 81 | 82 | m2 = await cache.load(m1.name, m1.tag) 83 | 84 | self.storage.load.assert_not_called() 85 | self.assertIn(m1.key, cache.models) 86 | self.assertEqual(m1, m2) 87 | 88 | @asynctest.unittest_run_loop 89 | async def test_load_not_found(self): 90 | m1 = kerastest.new_model() 91 | 92 | self.storage.load = asynctest.AsyncMagicMock(return_value=m1) 93 | 94 | cache = await Cache.new(storage=self.storage) 95 | m2 = await cache.load(m1.name, m1.tag) 96 | 97 | self.storage.load.assert_called() 98 | self.assertIn(m1.key, cache.models) 99 | 100 | self.assertEqual(m1, m2) 101 | 102 | 103 | if __name__ == "__main__": 104 | unittest.main() 105 | -------------------------------------------------------------------------------- /tests/test_callbacks.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import unittest 4 | 5 | from tensorcraft import callbacks 6 | from tests import asynctest 7 | from tests import clienttest 8 | from tests import kerastest 9 | 10 | 11 | class TestCallbacks(asynctest.AsyncTestCase): 12 | 13 | @clienttest.unittest_mock_model_client("push") 14 | def test_on_epoch_end(self, push_mock): 15 | cb = callbacks.ModelCheckpoint(verbose=1) 16 | 17 | model = tf.keras.models.Sequential() 18 | model.add(tf.keras.layers.Dense(1, input_shape=(1,))) 19 | model.compile(optimizer="sgd", loss="binary_crossentropy") 20 | 21 | x, y = np.array([[1.0]]), np.array([[1.0]]) 22 | model.fit(x, y, callbacks=[cb], epochs=3) 23 | 24 | push_mock.assert_called() 25 | 26 | 27 | if __name__ == "__main__": 28 | unittest.main() 29 | -------------------------------------------------------------------------------- /tests/test_client.py: -------------------------------------------------------------------------------- 1 | import aiohttp.test_utils as aiohttptest 2 | import aiohttp.web 3 | import io 4 | import numpy 5 | import pathlib 6 | import unittest 7 | 8 | from tensorcraft import asynclib 9 | from tensorcraft import errors 10 | from tensorcraft import client 11 | from tests import asynctest 12 | from tests import cryptotest 13 | from tests import kerastest 14 | 15 | 16 | class TestModelClient(asynctest.AsyncTestCase): 17 | 18 | @asynclib.asynccontextmanager 19 | async def handle_request(self, 20 | method: str, 21 | path: str, 22 | resp: aiohttp.web.Response = None, 23 | ) -> client.Model: 24 | resp = aiohttp.web.Response() if resp is None else resp 25 | handler_mock = asynctest.AsyncMagicMock(return_value=resp) 26 | 27 | app = aiohttp.web.Application() 28 | route = aiohttp.web.RouteDef( 29 | method, path, asynctest.unittest_handler(handler_mock), {}) 30 | 31 | app.add_routes([route]) 32 | 33 | async with aiohttptest.TestServer(app) as server: 34 | service_url = str(server.make_url("")) 35 | yield client.Model(client.Session(service_url)) 36 | 37 | handler_mock.assert_called() 38 | 39 | @asynctest.unittest_run_loop 40 | async def test_list(self): 41 | want_value = cryptotest.random_dict() 42 | resp = aiohttp.web.json_response(want_value) 43 | 44 | async with self.handle_request("GET", "/models", resp) as client: 45 | recv_value = await client.list() 46 | self.assertEqual(want_value, recv_value) 47 | 48 | @asynctest.unittest_run_loop 49 | async def test_remove(self): 50 | m = kerastest.new_model() 51 | path = f"/models/{m.name}/{m.tag}" 52 | 53 | async with self.handle_request("DELETE", path) as client: 54 | self.assertIsNone(await client.remove(m.name, m.tag)) 55 | 56 | @asynctest.unittest_run_loop 57 | async def test_remove_not_found(self): 58 | m = kerastest.new_model() 59 | path = f"/models/{m.name}/{m.tag}" 60 | resp = aiohttp.web.Response(status=404, 61 | headers={"Error-Code": "Model Not Found"}) 62 | 63 | async with self.handle_request("DELETE", path, resp) as client: 64 | with self.assertRaises(errors.NotFoundError): 65 | await client.remove(m.name, m.tag) 66 | 67 | @asynctest.unittest_run_loop 68 | async def test_remove_latest(self): 69 | m = kerastest.new_model() 70 | path = f"/models/{m.name}/latest" 71 | 72 | resp = aiohttp.web.Response(status=409, 73 | headers={"Error-Code": "Model Latest Tag"}) 74 | 75 | async with self.handle_request("DELETE", path, resp) as client: 76 | with self.assertRaises(errors.LatestTagError): 77 | await client.remove(m.name, "latest") 78 | 79 | @asynctest.unittest_run_loop 80 | async def test_status(self): 81 | want_value = cryptotest.random_dict() 82 | resp = aiohttp.web.json_response(want_value) 83 | 84 | async with self.handle_request("GET", "/status", resp) as client: 85 | recv_value = await client.status() 86 | self.assertEqual(want_value, recv_value) 87 | 88 | @asynctest.unittest_run_loop 89 | async def test_export_not_found(self): 90 | m = kerastest.new_model() 91 | path = f"/models/{m.name}/{m.tag}" 92 | resp = aiohttp.web.Response(status=404, 93 | headers={"Error-Code": "Model Not Found"}) 94 | 95 | async with self.handle_request("GET", path, resp) as client: 96 | with self.assertRaises(errors.NotFoundError): 97 | await client.export(m.name, m.tag, pathlib.Path("/")) 98 | 99 | @asynctest.unittest_run_loop 100 | async def test_export(self): 101 | m = kerastest.new_model() 102 | path = f"/models/{m.name}/{m.tag}" 103 | 104 | want_value = cryptotest.random_bytes() 105 | resp = aiohttp.web.Response(body=want_value) 106 | 107 | async with self.handle_request("GET", path, resp) as client: 108 | writer = io.BytesIO() 109 | await client.export(m.name, m.tag, asynclib.AsyncIO(writer)) 110 | 111 | self.assertEqual(want_value, writer.getvalue()) 112 | 113 | @asynctest.unittest_run_loop 114 | async def test_predict(self): 115 | m = kerastest.new_model() 116 | path = f"/models/{m.name}/{m.tag}/predict" 117 | 118 | y_true = [cryptotest.random_array()] 119 | resp = aiohttp.web.json_response(data=dict(y=y_true)) 120 | 121 | async with self.handle_request("POST", path, resp) as client: 122 | x_pred = cryptotest.random_array() 123 | y_pred = await client.predict(m.name, m.tag, [x_pred]) 124 | 125 | self.assertTrue(numpy.array_equal(y_true, y_pred)) 126 | 127 | @asynctest.unittest_run_loop 128 | async def test_push(self): 129 | m = kerastest.new_model() 130 | path = f"/models/{m.name}/{m.tag}" 131 | resp = aiohttp.web.Response(status=201) 132 | 133 | async with self.handle_request("PUT", path, resp) as client: 134 | b = cryptotest.random_bytes() 135 | 136 | resp = await client.push(m.name, m.tag, io.BytesIO(b)) 137 | self.assertIsNone(resp) 138 | 139 | @asynctest.unittest_run_loop 140 | async def test_push_latest(self): 141 | m = kerastest.new_model() 142 | path = f"/models/{m.name}/latest" 143 | resp = aiohttp.web.Response(status=409, 144 | headers={"Error-Code": "Model Latest Tag"}) 145 | 146 | async with self.handle_request("PUT", path, resp) as client: 147 | with self.assertRaises(errors.LatestTagError): 148 | b = cryptotest.random_bytes() 149 | resp = await client.push(m.name, "latest", io.BytesIO(b)) 150 | 151 | @asynctest.unittest_run_loop 152 | async def test_push_duplicate(self): 153 | m = kerastest.new_model() 154 | path = f"/models/{m.name}/{m.tag}" 155 | resp = aiohttp.web.Response(status=409, 156 | headers={"Error-Code": "Model Duplicate"}) 157 | 158 | async with self.handle_request("PUT", path, resp) as client: 159 | with self.assertRaises(errors.DuplicateError): 160 | b = cryptotest.random_bytes() 161 | await client.push(m.name, m.tag, io.BytesIO(b)) 162 | 163 | 164 | if __name__ == "__main__": 165 | unittest.main() 166 | -------------------------------------------------------------------------------- /tests/test_command.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import flagparse 3 | import pathlib 4 | import tarfile 5 | import tempfile 6 | import unittest 7 | import unittest.mock 8 | 9 | from tensorcraft import errors 10 | from tensorcraft.shell import commands 11 | from tests import clienttest 12 | from tests import cryptotest 13 | from tests import kerastest 14 | from tests import asynctest 15 | 16 | 17 | 18 | class TestCommand(unittest.TestCase): 19 | 20 | @clienttest.unittest_mock_model_client("list") 21 | @unittest.mock.patch("builtins.print") 22 | def test_list(self, print_mock, list_mock): 23 | m = kerastest.new_model() 24 | list_mock.return_value = [m.to_dict()] 25 | 26 | command = commands.List(unittest.mock.Mock()) 27 | command.handle(flagparse.Namespace()) 28 | print_mock.assert_called_with(str(m)) 29 | 30 | @clienttest.unittest_mock_model_client("remove") 31 | def test_remove(self, remove_mock): 32 | m = kerastest.new_model() 33 | 34 | command = commands.Remove(unittest.mock.Mock()) 35 | command.handle(flagparse.Namespace(name=m.name, tag=m.tag)) 36 | remove_mock.assert_called_with(m.name, m.tag) 37 | 38 | @clienttest.unittest_mock_model_client("remove") 39 | def test_remove_not_found(self, remove_mock): 40 | m = kerastest.new_model() 41 | remove_mock.side_effect = errors.NotFoundError(m.name, m.tag) 42 | 43 | with self.assertRaises(flagparse.ExitError): 44 | args = flagparse.Namespace(name=m.name, tag=m.tag, quiet=False) 45 | command = commands.Remove(unittest.mock.Mock()) 46 | command.handle(args) 47 | 48 | remove_mock.assert_called_with(m.name, m.tag) 49 | 50 | @clienttest.unittest_mock_model_client("remove") 51 | def test_remove_not_found_quiet(self, remove_mock): 52 | m = kerastest.new_model() 53 | remove_mock.side_effect = errors.NotFoundError(m.name, m.tag) 54 | 55 | command = commands.Remove(unittest.mock.Mock()) 56 | command.handle(flagparse.Namespace(name=m.name, tag=m.tag, quiet=True)) 57 | 58 | remove_mock.assert_called_with(m.name, m.tag) 59 | 60 | @clienttest.unittest_mock_model_client("push") 61 | def test_push(self, push_mock): 62 | with tempfile.NamedTemporaryFile() as tf: 63 | with tarfile.open(tf.name, mode="w") as tar: 64 | tar.add("tests", arcname="") 65 | 66 | m = kerastest.new_model() 67 | path = pathlib.Path(tf.name) 68 | 69 | args = flagparse.Namespace(name=m.name, tag=m.tag, path=path) 70 | command = commands.Push(unittest.mock.Mock()) 71 | command.handle(args) 72 | 73 | @clienttest.unittest_mock_model_client("push") 74 | def test_push_file_not_exists(self, push_mock): 75 | m = kerastest.new_model() 76 | path = pathlib.Path(cryptotest.random_string()) 77 | 78 | with self.assertRaises(flagparse.ExitError): 79 | args = flagparse.Namespace(name=m.name, tag=m.tag, path=path) 80 | command = commands.Push(unittest.mock.Mock()) 81 | command.handle(args) 82 | 83 | 84 | @clienttest.unittest_mock_model_client("export") 85 | def test_export(self, export_mock): 86 | with tempfile.NamedTemporaryFile() as tf: 87 | m = kerastest.new_model() 88 | path = pathlib.Path(tf.name) 89 | 90 | args = flagparse.Namespace(name=m.name, tag=m.tag, path=path) 91 | command = commands.Export(unittest.mock.Mock()) 92 | command.handle(args) 93 | 94 | 95 | if __name__ == "__main__": 96 | unittest.main() 97 | -------------------------------------------------------------------------------- /tests/test_saving.py: -------------------------------------------------------------------------------- 1 | import aiofiles 2 | import io 3 | import pathlib 4 | import tempfile 5 | import unittest 6 | import unittest.mock 7 | 8 | from tensorcraft.backend import model 9 | from tensorcraft.backend import saving 10 | from tests import asynctest 11 | from tests import kerastest 12 | 13 | 14 | class TestModelRuntime(asynctest.AsyncTestCase): 15 | 16 | async def setUpAsync(self) -> None: 17 | self.workdir = tempfile.TemporaryDirectory() 18 | self.workpath = pathlib.Path(self.workdir.name) 19 | 20 | async def tearDownAsync(self) -> None: 21 | self.workdir.cleanup() 22 | 23 | @asynctest.unittest_run_loop 24 | async def test_save(self): 25 | loader = model.Loader("no") 26 | fs = saving.FsModelsStorage.new(path=self.workpath, loader=loader) 27 | 28 | async with kerastest.crossentropy_model_tar("n", "t") as tarpath: 29 | async with aiofiles.open(tarpath, "rb") as model_tar: 30 | stream = io.BytesIO(await model_tar.read()) 31 | m = await fs.save("n", "t", stream) 32 | 33 | d1 = await fs.meta.get(saving.query_by_name_and_tag("n", "t")) 34 | d2 = await fs.meta.get(saving.query_by_name_and_tag("n", "latest")) 35 | 36 | self.assertEqual(d1["id"], d2["id"]) 37 | self.assertTrue(m.loaded) 38 | 39 | 40 | if __name__ == "__main__": 41 | unittest.main() 42 | -------------------------------------------------------------------------------- /tests/test_server.py: -------------------------------------------------------------------------------- 1 | import aiofiles 2 | import aiohttp.test_utils as aiohttptest 3 | import aiohttp.web 4 | import pathlib 5 | import tempfile 6 | import unittest 7 | 8 | from tensorcraft import asynclib 9 | from tensorcraft import server 10 | from tests import kerastest 11 | 12 | 13 | class TestServer(aiohttptest.AioHTTPTestCase): 14 | """Functional test of the server.""" 15 | 16 | def setUp(self) -> None: 17 | # Preserve the link for the temporary directory in order 18 | # to prevent the self-destruction of this directory. 19 | self.workdir = tempfile.TemporaryDirectory() 20 | self.workpath = pathlib.Path(self.workdir.name) 21 | super().setUp() 22 | 23 | async def tearDownAsync(self) -> None: 24 | self.workdir.cleanup() 25 | 26 | async def get_application(self) -> aiohttp.web.Application: 27 | """Create the server application.""" 28 | s = await server.Server.new( 29 | strategy="mirrored", 30 | pidfile=str(self.workpath.joinpath("k.pid")), 31 | data_root=str(self.workpath)) 32 | return s.app 33 | 34 | @asynclib.asynccontextmanager 35 | async def pushed_model(self, name: str = None, tag: str = None): 36 | m = kerastest.new_model(name, tag) 37 | 38 | async with kerastest.crossentropy_model_tar(m.name, m.tag) as tarpath: 39 | async with self.uploaded_model_tar(tarpath, m.name, m.tag) as mod: 40 | yield mod 41 | 42 | @asynclib.asynccontextmanager 43 | async def uploaded_model_tar(self, tarpath: pathlib.Path, 44 | name: str, 45 | tag: str) -> kerastest.Model: 46 | try: 47 | # Upload the serialized model to the server. 48 | data = asynclib.reader(tarpath) 49 | url = "/models/{0}/{1}".format(name, tag) 50 | 51 | # Ensure the model has been uploaded. 52 | resp = await self.client.put(url, data=data) 53 | self.assertEqual(resp.status, 201) 54 | 55 | yield kerastest.Model(name, tag, tarpath, url) 56 | finally: 57 | await self.client.delete(url) 58 | 59 | @aiohttptest.unittest_run_loop 60 | async def test_create_twice(self): 61 | async with self.pushed_model() as m: 62 | data = asynclib.reader(m.tarpath) 63 | 64 | resp = await self.client.put(m.url, data=data) 65 | self.assertEqual(resp.status, 409) 66 | 67 | @aiohttptest.unittest_run_loop 68 | async def test_predict(self): 69 | async with self.pushed_model() as m: 70 | data = dict(x=[[1.0]]) 71 | resp = await self.client.post(m.url+"/predict", json=data) 72 | self.assertEqual(resp.status, 200) 73 | 74 | @aiohttptest.unittest_run_loop 75 | async def test_predict_not_found(self): 76 | data = dict(x=[[1.0]]) 77 | resp = await self.client.post("/models/x/y/predict", json=data) 78 | self.assertEqual(resp.status, 404) 79 | 80 | @aiohttptest.unittest_run_loop 81 | async def test_predict_latest(self): 82 | async with self.pushed_model() as m1: 83 | async with self.pushed_model(m1.name) as m2: 84 | url = "/models/{0}/latest/predict".format(m2.name) 85 | data = dict(x=[[1.0]]) 86 | 87 | resp = await self.client.post(url, json=data) 88 | self.assertEqual(resp.status, 200) 89 | 90 | @aiohttptest.unittest_run_loop 91 | async def test_list(self): 92 | async with self.pushed_model() as m: 93 | resp = await self.client.get("/models") 94 | self.assertEqual(resp.status, 200) 95 | 96 | data = await resp.json() 97 | print(data) 98 | self.assertEqual(2, len(data)) 99 | 100 | data = data[0] 101 | data = dict(name=data.get("name"), tag=data.get("tag")) 102 | 103 | self.assertEqual(data, dict(name=m.name, tag=m.tag)) 104 | 105 | @aiohttptest.unittest_run_loop 106 | async def test_export(self): 107 | async with self.pushed_model() as m: 108 | resp = await self.client.get(m.url) 109 | self.assertEqual(resp.status, 200) 110 | 111 | # Export pushed model back to the file system. 112 | tarpath = self.workpath.joinpath("export.tar") 113 | async with aiofiles.open(tarpath, "wb+") as tar: 114 | await tar.write(await resp.read()) 115 | 116 | # Regular file comaprison won't work as server creates a new TAR 117 | # archive for the pushed model. 118 | # 119 | # Therefore, ensure that returned TAR is correct model by uploading 120 | # in back to the server. 121 | m = kerastest.new_model() 122 | async with self.uploaded_model_tar(tarpath, m.name, m.tag) as m: 123 | pass 124 | 125 | 126 | if __name__ == "__main__": 127 | unittest.main() 128 | -------------------------------------------------------------------------------- /tests/test_server_extra.py: -------------------------------------------------------------------------------- 1 | import aiohttp.test_utils as aiohttptest 2 | import aiohttp.web 3 | import pathlib 4 | import tempfile 5 | import unittest 6 | 7 | from tensorcraft.server import Server 8 | 9 | 10 | class TestServerExtra(aiohttptest.AioHTTPTestCase): 11 | 12 | def setUp(self) -> None: 13 | self.workdir = tempfile.TemporaryDirectory() 14 | super().setUp() 15 | 16 | async def tearDownAsync(self) -> None: 17 | self.workdir.cleanup() 18 | 19 | async def get_application(self) -> aiohttp.web.Application: 20 | data_root = pathlib.Path(self.workdir.name).joinpath("non/existing") 21 | 22 | server = await Server.new( 23 | data_root=data_root, 24 | pidfile=data_root.joinpath("tensorcraft.pid")) 25 | return server.app 26 | 27 | @aiohttptest.unittest_run_loop 28 | async def test_must_create_directory(self): 29 | resp = await self.client.get("/status") 30 | self.assertEqual(resp.status, 200) 31 | 32 | @aiohttptest.unittest_run_loop 33 | async def test_accept_version(self): 34 | headers = {"Accept-Version": ">=0.0.0"} 35 | resp = await self.client.get("/status", headers=headers) 36 | self.assertEqual(resp.status, 200) 37 | 38 | @aiohttptest.unittest_run_loop 39 | async def test_accept_version_not_accepted(self): 40 | headers = {"Accept-Version": "==0.0.0"} 41 | resp = await self.client.get("/status", headers=headers) 42 | 43 | self.assertEqual(resp.status, 406) 44 | 45 | 46 | if __name__ == "__main__": 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /tests/test_server_ssl.py: -------------------------------------------------------------------------------- 1 | import aiohttp.test_utils as aiohttptest 2 | import pathlib 3 | import tempfile 4 | import unittest 5 | 6 | from aiohttp.web import Application 7 | 8 | from tensorcraft.server import Server 9 | from tensorcraft import tlslib 10 | from tests import cryptotest 11 | 12 | 13 | class TestServerSSL(aiohttptest.AioHTTPTestCase): 14 | 15 | def setUp(self) -> None: 16 | self.workdir = tempfile.TemporaryDirectory() 17 | 18 | workpath = pathlib.Path(self.workdir.name) 19 | keypath, certpath = cryptotest.create_self_signed_cert(workpath) 20 | 21 | self.server_ssl_context = tlslib.create_server_ssl_context( 22 | tls=True, tlscert=certpath, tlskey=keypath, 23 | ) 24 | self.client_ssl_context = tlslib.create_client_ssl_context( 25 | tls=True, tlscert=certpath, tlskey=keypath, 26 | ) 27 | 28 | super().setUp() 29 | 30 | async def tearDownAsync(self) -> None: 31 | self.workdir.cleanup() 32 | 33 | async def get_application(self) -> Application: 34 | data_root = pathlib.Path(self.workdir.name).joinpath("non/existing") 35 | 36 | server = await Server.new( 37 | data_root=data_root, 38 | pidfile=data_root.joinpath("tensorcraft.pid"), 39 | ) 40 | return server.app 41 | 42 | async def get_server(self, app: Application) -> aiohttptest.TestServer: 43 | return aiohttptest.TestServer( 44 | app, loop=self.loop, ssl=self.server_ssl_context, 45 | ) 46 | 47 | @aiohttptest.unittest_run_loop 48 | async def test_must_accept_tls(self): 49 | resp = await self.client.get("/status", ssl=self.client_ssl_context) 50 | self.assertEqual(resp.status, 200) 51 | 52 | 53 | if __name__ == "__main__": 54 | unittest.main() 55 | --------------------------------------------------------------------------------