├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── SECURITY.md ├── images ├── conversion_mechanism.png ├── generated_files.png └── training_overview.png ├── poetry.lock ├── poetry.toml ├── pyproject.toml ├── tests ├── data │ └── .gitkeep ├── generated_code_expected │ └── .gitkeep ├── sample_v1 │ └── tf_model │ │ └── assets │ │ └── .gitkeep └── test_integration.py ├── tf2rust ├── __init__.py ├── __main__.py ├── constants.py ├── nodes │ ├── __init__.py │ ├── activationNode.py │ ├── batchNormalizationNode.py │ ├── concatenateNode.py │ ├── conv1dNode.py │ ├── denseNode.py │ ├── dropoutNode.py │ ├── embeddingNode.py │ ├── flattenNode.py │ ├── globalAveragePooling1dNode.py │ ├── inputLayerNode.py │ ├── maxPooling1dNode.py │ ├── multiplyNode.py │ ├── node.py │ ├── reshapeNode.py │ ├── tensorFlowAdd2Node.py │ ├── tensorFlowMeanNode.py │ └── thresholdedrelu.py └── utils │ ├── __init__.py │ ├── model_saver.py │ ├── rust_converter.py │ ├── scoring_metrics.py │ └── surgeon │ ├── __init__.py │ ├── _utils │ ├── __init__.py │ ├── layer.py │ ├── node.py │ └── tensor_dict.py │ ├── identify.py │ ├── operations.py │ ├── surgeon.py │ └── utils.py └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv* 110 | env/ 111 | venv*/ 112 | ENV/ 113 | env.bak/ 114 | venv*.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | /target 141 | /target-linux 142 | Cargo.lock 143 | docker-build 144 | 145 | **.DS_Store 146 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [0.4.0] 2023-09-12 9 | 10 | ### Changed 11 | 12 | - Update `tensorflow` to `2.13` 13 | - Update generated template Rust code dependencies 14 | - Update Rust edition to 2021 15 | - Use `once_cell` instead of `lazy_static` 16 | 17 | ## [0.3.0] 2023-07-12 18 | 19 | ### Changed 20 | 21 | - Modify package to make it a python wheel that is buildable with poetry 22 | - Moved to tox and pytest 23 | 24 | ## [0.2.0] 2022-11-08 25 | 26 | ### Fixed 27 | 28 | - Added support for Tensorflow > 2.5 29 | 30 | ## [0.1.0] 2022-10-10 31 | 32 | - Initial release. 33 | 34 | ### Added 35 | 36 | - Added changelog tracking. 37 | - Added CI/CD tooling to run builds for the repo. 38 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # `tf2rust` Community Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | dsci-oss@crowdstrike.com or https://crowdstrike.ethicspoint.com/. 64 | 65 | All complaints will be reviewed and investigated promptly and fairly. 66 | 67 | All community leaders are obligated to respect the privacy and security of the 68 | reporter of any incident. 69 | 70 | ## Enforcement Guidelines 71 | 72 | Community leaders will follow these Community Impact Guidelines in determining 73 | the consequences for any action they deem in violation of this Code of Conduct: 74 | 75 | ### 1. Correction 76 | 77 | **Community Impact**: Use of inappropriate language or other behavior deemed 78 | unprofessional or unwelcome in the community. 79 | 80 | **Consequence**: A private, written warning from community leaders, providing 81 | clarity around the nature of the violation and an explanation of why the 82 | behavior was inappropriate. A public apology may be requested. 83 | 84 | ### 2. Warning 85 | 86 | **Community Impact**: A violation through a single incident or series 87 | of actions. 88 | 89 | **Consequence**: A warning with consequences for continued behavior. No 90 | interaction with the people involved, including unsolicited interaction with 91 | those enforcing the Code of Conduct, for a specified period of time. This 92 | includes avoiding interactions in community spaces as well as external channels 93 | like social media. Violating these terms may lead to a temporary or 94 | permanent ban. 95 | 96 | ### 3. Temporary Ban 97 | 98 | **Community Impact**: A serious violation of community standards, including 99 | sustained inappropriate behavior. 100 | 101 | **Consequence**: A temporary ban from any sort of interaction or public 102 | communication with the community for a specified period of time. No public or 103 | private interaction with the people involved, including unsolicited interaction 104 | with those enforcing the Code of Conduct, is allowed during this period. 105 | Violating these terms may lead to a permanent ban. 106 | 107 | ### 4. Permanent Ban 108 | 109 | **Community Impact**: Demonstrating a pattern of violation of community 110 | standards, including sustained inappropriate behavior, harassment of an 111 | individual, or aggression toward or disparagement of classes of individuals. 112 | 113 | **Consequence**: A permanent ban from any sort of public interaction within 114 | the community. 115 | 116 | ## Attribution 117 | 118 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 119 | version 2.0, available at 120 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 121 | 122 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 123 | enforcement ladder](https://github.com/mozilla/diversity). 124 | 125 | [homepage]: https://www.contributor-covenant.org 126 | 127 | For answers to common questions about this code of conduct, see the FAQ at 128 | https://www.contributor-covenant.org/faq. Translations are available at 129 | https://www.contributor-covenant.org/translations. 130 | 131 | [![Twitter URL](https://img.shields.io/twitter/url?label=Follow%20%40CrowdStrike&style=social&url=https%3A%2F%2Ftwitter.com%2FCrowdStrike)](https://twitter.com/CrowdStrike) 132 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to this repository 2 | 3 | ## Getting started 4 | 5 | _Welcome!_ We're excited you want to take part in the `tf2rust` community! 6 | 7 | Please review this document for details regarding getting started with your first contribution, packages you'll need to install as a developer, and our Pull Request process. If you have any questions, please let us know by 8 | posting your question as an [issue](https://github.com/CrowdStrike/tf2rust/issues/new). 9 | 10 | ### Before you begin 11 | 12 | * Have you read the [Code of Conduct](CODE_OF_CONDUCT.md)? The Code of Conduct helps us establish community norms and how they'll be enforced. 13 | 14 | ### Table of Contents 15 | 16 | * [How you can contribute](#how-you-can-contribute) 17 | * [Bug reporting](#bug-reporting-and-questions-are-handled-using-githubs-issues) 18 | * [Pull Requests](#pull-requests) 19 | * [Contributor dependencies](#additional-contributor-package-requirements) 20 | * [Unit testing](#unit-testing--code-coverage) 21 | * [Linting](#linting) 22 | * [Breaking changes](#breaking-changes) 23 | * [Branch targeting](#branch-targeting) 24 | * [Suggestions](#suggestions) 25 | 26 | ## How you can contribute 27 | 28 | * See something? Say something! Submit a [bug report](https://github.com/CrowdStrike/tf2rust/issues) to let the community know what you've experienced or found. Bonus points if you suggest possible fixes or what you feel may resolve the issue. For example: "_Attempted to use the XZY API class but it errored out. Could a more descriptive error code be returned?_" 29 | * Submit a [Pull Request](#pull-requests) 30 | 31 | ### Bug reporting and questions are handled using GitHub's issues 32 | 33 | We use GitHub issues to track bugs. Report a bug by opening a [new issue](https://github.com/CrowdStrike/tf2rust/issues). 34 | 35 | ## Pull Requests 36 | 37 | ### All contributions will be submitted under the MIT license 38 | 39 | When you submit code changes, your submissions are understood to be under the same MIT [license](LICENSE) that covers the project. 40 | If this is a concern, contact the maintainers before contributing. 41 | 42 | ### Breaking changes 43 | 44 | In an effort to maintain backwards compatibility, we thoroughly unit test every Pull Request for any issues. These unit tests are intended to catch general programmatic errors, possible vulnerabilities (via bandit) and _potential breaking changes_. 45 | 46 | > If you have to adjust a unit test locally in order to produce passing results, there is a possibility you are working with a potential breaking change. 47 | 48 | Please fully document changes to unit tests within your Pull Request. If you did not specify "Breaking Change" on the punch list in the description, and the change is identified as possibly breaking, this may delay or prevent approval of your PR. 49 | 50 | ### Versioning 51 | 52 | We use [SemVer](https://semver.org/) as our versioning scheme. (Example: _2.1.4_) 53 | 54 | ### Pull Request template 55 | 56 | Please use the pull request template provided, making sure the following details are included in your request: 57 | 58 | * Is this a breaking change? 59 | * Are all new or changed code paths covered by unit testing? 60 | * A complete listing of issues addressed or closed with this change. 61 | * A complete listing of any enhancements provided by this change. 62 | * Any usage details developers may need to make use of this new functionality. 63 | * Does additional documentation need to be developed beyond what is listed in your Pull Request? 64 | * Any other salient points of interest. 65 | 66 | ### Approval / Merging 67 | 68 | All Pull Requests must be approved by at least one maintainer. Once approved, a maintainer will perform the merge and execute any backend 69 | processes related to package deployment. At this time, contributors _do not_ have the ability to merge to the `main` branch. 70 | 71 | ## Suggestions 72 | 73 | If you have suggestions on how this process could be improved, please let us know by [posting an issue](https://github.com/CrowdStrike/tf2rust/issues). 74 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2022 Crowdstrike 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow to Rust package 2 | 3 | ## Special thanks to `Marian Radu` for his support during the development of the package 4 | 5 | ## Training Workflow 6 | 7 | ![training overview](images/training_overview.png) 8 | 9 | ## General Information 10 | 11 | ![conversion mechanism](images/conversion_mechanism.png) 12 | 13 | A Python package that converts a TensorFlow model (.pb or .h5 format) into pure Rust code. 14 | This package is dependent on [`tf-layers`](https://github.com/CrowdStrike/tf-layers) (Rust): 15 | 16 | Currently, this package supports models that contain the following layers (the layers number is expected to grow in the future with the addition of further architectures): 17 | 18 | * InputLayer - input layer. For further information check: 19 | * Multiply - multiply layer. For further information check: 20 | * Reshape - reshape layer. For further information check: 21 | * Conv1D - 1D convolutional layer. For further information check: 22 | * Embedding - embedding layer. For further information check: 23 | * Dense - dense layer. For further information check: 24 | * Flatten - flatten layer. For further information check: 25 | * Concatenate - concatenate layer. For further information check: 26 | * GlobalAveragePooling - global average pooling layer. For further information check: 27 | * MaxPooling - maxpooling layer. For further information check: 28 | * AveragePooling - averagepooling layer. For further information check: 29 | * BatchNormalization - batchnormalization layer. For further information check: 30 | * Add - addition layer. 31 | * Mean - mean layer over a specified axis. 32 | * Activation - different types of activation supported (can be used as an independent layer or inside different NN layers such as Dense, Conv1D, etc). Support available for: 33 | * Linear(Identity) 34 | * Relu 35 | * ThresholdedRelu 36 | * Selu 37 | * Sigmoid 38 | * Softmax 39 | * SoftPlus 40 | * SoftSign 41 | * Tanh 42 | * `Note1`: Some layers might not have all the functionalities from TensorFlow implemented. 43 | * `Note2`: It is mandatory to use an `InputLayer` for each input that the model expects. It is also mandatory that `InputLayer's` `dtype` be exactly specified (default is `float`). 44 | For instance, if an `InputLayer` is followed by an `EmbeddingLayer`, then the type of that particular `InputLayer` must be set to int - e.g. "int64". 45 | Another requirement is to have the `output_shape` of each layer specified (the only unspecified size should be about the batch size). 46 | This is usually done by setting the `input_shape` parameter when initializing the `InputLayer`. 47 | 48 | ## Requirements 49 | 50 | This project targets the Python 3.8 interpreter. You will need to install 51 | `graphviz` using your system dependency manager of choice. On macOS, this can be 52 | done with the command: 53 | 54 | ```bash 55 | brew install graphviz 56 | ``` 57 | 58 | To set up a virtualenv with poetry, execute the following commands in the 59 | project root: 60 | 61 | ```bash 62 | poetry install 63 | poetry shell 64 | ``` 65 | 66 | ## Configuration arguments 67 | 68 | ### --path_to_tf_model 69 | 70 | The path (relative or absolute) to the TensorFlow model to be converted into pure Rust code. It is mandatory. 71 | 72 | ### --path_to_save 73 | 74 | The path (relative or absolute) where to save the generated Rust code. It is mandatory. 75 | 76 | ### --model_name 77 | 78 | The model name. A struct named Model will be created in Rust. E.g model_name = Mnist => Mnist. It is mandatory. 79 | 80 | ### --binary_classification 81 | 82 | Set this flag to true/false whether the model is a binary classifier or not (false for regression or multiclass classifiers). Default is true. 83 | 84 | ### --enable_inplace 85 | 86 | Set this flag to true/false whether you want the model written in Rust to use in-place operations whenever possible (in `predict_from_array` function). Default is true. 87 | 88 | ### --enable_memdrop 89 | 90 | Set this flag to true/false whether you want the model written in Rust to free the memory of intermediate layers results as soon as possible (instead of the actual ending of `predict_from_array` function). Default is true. 91 | 92 | ### --path_to_fv 93 | 94 | Set the path to a npz array containing the FV for a bunch of samples. The keys for the arrays should match the keys from perform_fx from NeuralBrain (which must be the same as the InputLayers' names when building the model). Also, the expected predictions should be saved as an array in `features.npz` by the key `predictions`. This flag is optional. 95 | 96 | ## Output Files 97 | 98 | ![generated files](images/generated_files.png) 99 | 100 | * saved_model_from_tensorflow: 101 | * computation_graph.json: The computational dependencies. 102 | * model_architecture.json: Different parameters for the actual NN layers (stride, pool_size, kernel_size, activation type, etc). 103 | * model_overview.png: A graph image describing the model. 104 | * model_weights.npz: model's weights. 105 | * rust_generated_code: 106 | * build.rs: A Rust build file used in serializing the model by reading from model_weights.npz 107 | * Cargo.toml: the place where all the imports are specified (and many more). 108 | * rust_generated_code/model: 109 | * model_weights.npz: model weights saved in a format that can be used by Rust. 110 | * thresholds.json: the thresholds for `low`, `bottom`, `medium`, `high` confidence levels. 111 | * rust_generated_code/src: 112 | * model.rs: A Rust structure encapsulating all the logic behind prediction. 113 | * lib.rs: the file containing the tests. 114 | * rust_generated_code/testdata: 115 | * features.npz: the features to be passed to the model (1D numpy ndarray). 116 | * rust_generated_code/benches: 117 | * benchmarks.rs: the file in charge of benchmarks. 118 | 119 | ### In order to asses the performance of the model, run `cargo bench` 120 | 121 | ### In order to test the predictions and see the translation went as expected, run `cargo test` 122 | 123 | ### Note: all this commands need be executed on `rust_generated_code/` directory 124 | 125 | ## Usage 126 | 127 | To convert a TensorFlow model use a command-line like the followings: 128 | 129 | ```bash 130 | python3 -m tf2rust \ 131 | --path_to_tf_model tests/data/mnist/tf_model/ \ 132 | --path_to_save tests/data/generated_classifiers/mnist \ 133 | --model_name MNist \ 134 | --binary_classification True \ 135 | --enable_inplace True \ 136 | --enable_memdrop True \ 137 | --path_to_fv tests/data/mnist/features.npz # for testing purposes, optional 138 | 139 | ``` 140 | 141 | ## Converting `.h5` models to `.pb` 142 | 143 | ```python 144 | from tensorflow.keras.models import load_model, save_model 145 | # Note that models will have different metrics also saved with the models and expect the implementations for these 146 | # metrics. 147 | # We have these implemented in utils/scoring_metrics.py but these are not used, and we can also provide None. 148 | model = load_model('new_model.h5', custom_objects={'tpr': None, 'tnr': None, 'auc': None}) 149 | save_model(model=model, filepath='tf_model/', include_optimizer=False) 150 | ``` 151 | 152 | ## Running the tests 153 | 154 | At the time we'll migrate towards Dockerising this, we'll also switch to tox (this should not pose any difficulties). 155 | 156 | We have currently set up integration tests, which do the following: 157 | 158 | * Given model artifacts, generate Rust code 159 | * Check that the Rust code is the same code to what we expect to be generated 160 | * Compile the Rust code and see that all tests pass 161 | * Tests take in FVs and the DVs generated by the Tensorflow model 162 | * We check that inference with the Rust model yields the same results as the initial Tensorflow model 163 | 164 | ```bash 165 | pytest 166 | ``` 167 | 168 | ## Next steps 169 | 170 | After everything runs smoothly with your model, please add artifacts and a new test for it. 171 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | [![Twitter URL](https://img.shields.io/twitter/url?label=Follow%20%40CrowdStrike&style=social&url=https%3A%2F%2Ftwitter.com%2FCrowdStrike)](https://twitter.com/CrowdStrike) 2 | 3 | # Security Policy 4 | 5 | This document outlines security policy and procedures for the CrowdStrike `tf2rust` project. 6 | 7 | * [Supported versions](#supported-versions) 8 | * [Reporting a potential security vulnerability](#reporting-a-potential-security-vulnerability) 9 | * [Disclosure and Mitigation Process](#disclosure-and-mitigation-process) 10 | 11 | ## Supported versions 12 | 13 | When discovered, we release security vulnerability patches for the most recent release at an accelerated cadence. 14 | 15 | ## Reporting a potential security vulnerability 16 | 17 | We have multiple avenues to receive security-related vulnerability reports. 18 | 19 | Please report suspected security vulnerabilities by: 20 | 21 | * Submitting a [bug](https://github.com/CrowdStrike/tf2rust/issues/new?assignees=&labels=bug+%3Abug%3A&template=bug_report.md&title=%5B+BUG+%5D+...). 22 | * Starting a new [discussion](https://github.com/CrowdStrike/tf2rust/discussions). 23 | * Submitting a [pull request](https://github.com/CrowdStrike/tf2rust/pulls) to potentially resolve the issue. (New contributors: please review the content located [here](https://github.com/CrowdStrike/tf2rust/blob/main/CONTRIBUTING.md).) 24 | * Sending an email to __dsci-oss@crowdstrike.com__. 25 | 26 | ## Disclosure and mitigation process 27 | 28 | Upon receiving a security bug report, the issue will be assigned to one of the project maintainers. This person will coordinate the related fix and release 29 | process, involving the following steps: 30 | 31 | * Communicate with you to confirm we have received the report and provide you with a status update. 32 | 33 | * You should receive this message within 48 - 72 business hours. 34 | 35 | * Confirmation of the issue and a determination of affected versions. 36 | * An audit of the codebase to find any potentially similar problems. 37 | * Preparation of patches for all releases still under maintenance. 38 | 39 | * These patches will be submitted as a separate pull request and contain a version update. 40 | * This pull request will be flagged as a security fix. 41 | * Once merged, and after post-merge unit testing has been completed, the patch will be immediately published to both PyPI repositories. 42 | 43 | ## Comments 44 | 45 | If you have suggestions on how this process could be improved, please let us know by [starting a new discussion](https://github.com/CrowdStrike/tf2rust/discussions). 46 | -------------------------------------------------------------------------------- /images/conversion_mechanism.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrowdStrike/tf2rust/a11f63e5d461ca0a8f7092d338c1838f9484ac4d/images/conversion_mechanism.png -------------------------------------------------------------------------------- /images/generated_files.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrowdStrike/tf2rust/a11f63e5d461ca0a8f7092d338c1838f9484ac4d/images/generated_files.png -------------------------------------------------------------------------------- /images/training_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrowdStrike/tf2rust/a11f63e5d461ca0a8f7092d338c1838f9484ac4d/images/training_overview.png -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | create = true 3 | in-project = true 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "tf2rust" 3 | dependencies = [ 4 | "argparse==1.4", 5 | "numpy==1.24.3", 6 | "pydot==1.4.2", 7 | "scikit-learn==1.1.3", 8 | "tensorflow==2.13.0", 9 | ] 10 | 11 | [tool.poetry] 12 | name = "tf2rust" 13 | version = "0.4.0" 14 | description = "" 15 | authors = ["DSCI OSS "] 16 | 17 | # This is a bit unfortunate but with this setup the dependencies need to be 18 | # duplicated in project.dependencies above. The dependencies above are included 19 | # in the resulting wheel where as the ones listed here are used for development 20 | # and testing. 21 | [tool.poetry.dependencies] 22 | python = ">=3.9,<3.12" 23 | argparse = "==1.4" 24 | numpy = "==1.24.3" 25 | pydot = "==1.4.2" 26 | scikit-learn = "==1.3.0" 27 | tensorflow = "==2.13.0" 28 | 29 | [tool.poetry.dev-dependencies] 30 | black = "^22.3.0" 31 | isort = "^5.10.1" 32 | pytest = "^6" 33 | tox = "^3.28.0" 34 | twine = "^4.0.2" 35 | 36 | [tool.isort] 37 | profile = "black" 38 | -------------------------------------------------------------------------------- /tests/data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrowdStrike/tf2rust/a11f63e5d461ca0a8f7092d338c1838f9484ac4d/tests/data/.gitkeep -------------------------------------------------------------------------------- /tests/generated_code_expected/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrowdStrike/tf2rust/a11f63e5d461ca0a8f7092d338c1838f9484ac4d/tests/generated_code_expected/.gitkeep -------------------------------------------------------------------------------- /tests/sample_v1/tf_model/assets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrowdStrike/tf2rust/a11f63e5d461ca0a8f7092d338c1838f9484ac4d/tests/sample_v1/tf_model/assets/.gitkeep -------------------------------------------------------------------------------- /tests/test_integration.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import filecmp 3 | import os 4 | import pathlib 5 | import shutil 6 | import subprocess 7 | 8 | import pytest 9 | 10 | # Models to test the TensorFlow-to-Rust conversion for. 11 | MODELS = ["sample_v1"] 12 | 13 | TESTS_PATH = pathlib.Path(__file__).parent.absolute() 14 | 15 | # Path where we need to have the TensorFlow model (ft_model), but also a numpy zip containing 16 | # features and TensorFlow predictions for those features, in order to check that Rust predictions 17 | # match TensorFlow predictions. 18 | MODELS_METADATA = TESTS_PATH / "data" 19 | 20 | # Path where we'll store the Rust generated code that results from the TensorFlow-to-Rust 21 | # conversion. 22 | RUST_GENERATED_CODE = TESTS_PATH / "generated_code" 23 | 24 | # Path where we'll store the Rust expected code that should be generated. 25 | RUST_GENERATED_CODE_EXPECTED = TESTS_PATH / "generated_code_expected" 26 | 27 | 28 | @pytest.mark.parametrize("model", MODELS) 29 | def test_model_conversion(model): 30 | """ 31 | Tests model conversion for all the provided models. 32 | """ 33 | convert_model_and_check_predictions(MODELS_METADATA, model) 34 | 35 | 36 | def convert_model_and_check_predictions(data_path, model_name_versioned): 37 | # Compose the paths needed forward. 38 | model_dir_path = os.path.join(data_path, model_name_versioned) 39 | model_generated_code = os.path.join(RUST_GENERATED_CODE, model_name_versioned) 40 | model_generated_rust_code = os.path.join( 41 | model_generated_code, "rust_generated_code" 42 | ) 43 | model_generated_code_expected = os.path.join( 44 | RUST_GENERATED_CODE_EXPECTED, model_name_versioned 45 | ) 46 | 47 | # 1. Before running the test, make sure the generated_code path is removed 48 | if os.path.exists(RUST_GENERATED_CODE): 49 | shutil.rmtree(RUST_GENERATED_CODE) 50 | 51 | # 2. Run TensorFlow --> Rust conversion 52 | run_tensorflow_to_rust_conversion( 53 | model_path=os.path.join(model_dir_path, "tf_model"), 54 | save_path=model_generated_code, 55 | fv_path=os.path.join(model_dir_path, "features.npz"), 56 | ) 57 | print(model_generated_code) 58 | assert os.path.exists(model_generated_code) 59 | 60 | assert _are_dir_equal(model_generated_rust_code, model_generated_code_expected) 61 | 62 | # 3. Check that the Rust code compiles. 63 | check_rust_code_compiles(code_path=model_generated_rust_code) 64 | 65 | 66 | def run_tensorflow_to_rust_conversion(model_path=None, save_path=None, fv_path=None): 67 | """ 68 | Run the command-line that converts TensorFlow model to Rust. 69 | 70 | python3 -m tf2rust \ 71 | --path_to_tf_model tf_model/ \ 72 | --path_to_save generated_code/ \ 73 | --model_name HybridCNN \ 74 | --path_to_fv features.npz 75 | """ 76 | if not all((model_path, save_path, fv_path)): 77 | raise ValueError( 78 | "Expected to get model_path & save_path & fv_path as arguments!" 79 | ) 80 | command = [ 81 | "python3", 82 | "-m", 83 | "tf2rust", 84 | "--path_to_tf_model", 85 | model_path, 86 | "--path_to_save", 87 | save_path, 88 | "--model_name", 89 | "Test", 90 | "--path_to_fv", 91 | fv_path, 92 | ] 93 | print(command) 94 | 95 | subprocess.run(command) 96 | 97 | 98 | def check_rust_code_compiles(code_path=None): 99 | """ 100 | Check that the Rust code compiles and check TensorFlow against Rust predictions. 101 | 102 | cargo fmt 103 | cargo test --release 104 | """ 105 | # Make sure tests don't fail because of extra newlines/tabs/spaces. 106 | subprocess.run(["cargo", "fmt"], cwd=code_path) 107 | 108 | # Build crate and run tests. 109 | p = subprocess.run( 110 | ["cargo", "test", "--release"], 111 | cwd=code_path, 112 | stdout=subprocess.PIPE, 113 | stderr=subprocess.STDOUT, 114 | ) 115 | output_lines = p.stdout.decode("utf-8") 116 | assert output_lines 117 | 118 | # Go through the output and check that the number of "test result:" lines is equal to the 119 | # number of "test result: ok" lines. 120 | test_result_counts = 0 121 | test_result_ok_counts = 0 122 | for new_line in output_lines.split("\n"): 123 | if "test result:" in new_line: 124 | test_result_counts += 1 125 | if "test result: ok" in new_line: 126 | test_result_ok_counts += 1 127 | print(new_line) 128 | assert test_result_counts == test_result_ok_counts 129 | 130 | 131 | def _are_dir_equal(dir1, dir2): 132 | """ 133 | Compare two directories recursively. Files in each directory are assumed to be equal if 134 | their names and contents are equal. 135 | 136 | Returns True if the directory trees are the same and there were no errors while accessing 137 | the directories or files, False otherwise. 138 | """ 139 | dirs_cmp = filecmp.dircmp(dir1, dir2) 140 | if ( 141 | len(dirs_cmp.left_only) > 0 142 | or len(dirs_cmp.right_only) > 0 143 | or len(dirs_cmp.funny_files) > 0 144 | ): 145 | return False 146 | 147 | (_, mismatch, errors) = filecmp.cmpfiles( 148 | dir1, dir2, dirs_cmp.common_files, shallow=False 149 | ) 150 | 151 | if len(mismatch) > 0 or len(errors) > 0: 152 | return False 153 | 154 | for common_dir in dirs_cmp.common_dirs: 155 | new_dir1 = os.path.join(dir1, common_dir) 156 | new_dir2 = os.path.join(dir2, common_dir) 157 | if not _are_dir_equal(new_dir1, new_dir2): 158 | return False 159 | 160 | return True 161 | -------------------------------------------------------------------------------- /tf2rust/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrowdStrike/tf2rust/a11f63e5d461ca0a8f7092d338c1838f9484ac4d/tf2rust/__init__.py -------------------------------------------------------------------------------- /tf2rust/__main__.py: -------------------------------------------------------------------------------- 1 | from .utils import model_saver, rust_converter 2 | 3 | if __name__ == "__main__": 4 | model_saver.save_tf_model() 5 | rust_converter.convert_to_rust() 6 | -------------------------------------------------------------------------------- /tf2rust/constants.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | from tensorflow.keras.layers import ( 6 | AlphaDropout, 7 | BatchNormalization, 8 | Dropout, 9 | GaussianNoise, 10 | SpatialDropout1D, 11 | ) 12 | 13 | from .nodes import ( 14 | activationNode, 15 | batchNormalizationNode, 16 | concatenateNode, 17 | conv1dNode, 18 | denseNode, 19 | embeddingNode, 20 | flattenNode, 21 | globalAveragePooling1dNode, 22 | inputLayerNode, 23 | maxPooling1dNode, 24 | multiplyNode, 25 | reshapeNode, 26 | tensorFlowAdd2Node, 27 | tensorFlowMeanNode, 28 | thresholdedrelu, 29 | dropoutNode, 30 | ) 31 | 32 | parser = argparse.ArgumentParser(description="Arguments for the python scripts") 33 | parser.add_argument( 34 | "--path_to_tf_model", 35 | type=str, 36 | default=None, 37 | help="The path to the TF model intended to converts (only .pb or .h5 format supported)", 38 | ) 39 | parser.add_argument( 40 | "--path_to_save", 41 | type=str, 42 | default=None, 43 | help="The path where to store the conversion.", 44 | ) 45 | parser.add_argument( 46 | "--model_name", 47 | type=str, 48 | default=None, 49 | help="The name of the model to be converted.", 50 | ) 51 | parser.add_argument( 52 | "--binary_classification", 53 | type=str, 54 | default="True", 55 | help="A flag specifying whether the model is a binary classifier or not.", 56 | ) 57 | parser.add_argument( 58 | "--enable_inplace", 59 | type=str, 60 | default="True", 61 | help="Enable inplace operations in model.rs", 62 | ) 63 | parser.add_argument( 64 | "--enable_memdrop", type=str, default="True", help="Enable memory drop in model.rs" 65 | ) 66 | parser.add_argument( 67 | "--path_to_fv", 68 | type=str, 69 | default=None, 70 | help="The path to the fv (npz format) having as keys the names of the InputLayers (e.g. character_level, word_level, extra_level)", 71 | ) 72 | args = parser.parse_args() 73 | 74 | 75 | PATH_TO_TF_MODEL = args.path_to_tf_model 76 | assert PATH_TO_TF_MODEL is not None 77 | PATH_TO_TF_MODEL = Path(PATH_TO_TF_MODEL) 78 | 79 | PATH_TO_SAVE = args.path_to_save 80 | assert PATH_TO_SAVE is not None 81 | PATH_TO_SAVE = Path(PATH_TO_SAVE) 82 | 83 | MODEL_NAME = args.model_name 84 | assert MODEL_NAME is not None 85 | 86 | FILE_PATH_FV = args.path_to_fv 87 | 88 | 89 | def str2bool(v): 90 | return v.lower() in ("yes", "true", "t", "1") 91 | 92 | 93 | ENABLE_INPLACE = str2bool(args.enable_inplace) 94 | ENABLE_MEMDROP = str2bool(args.enable_memdrop) 95 | BINARY_CLASSIFICATION = str2bool(args.binary_classification) 96 | 97 | os.makedirs(PATH_TO_SAVE, exist_ok=True) 98 | os.makedirs(PATH_TO_SAVE.joinpath("saved_model_from_tensorflow/"), exist_ok=True) 99 | os.makedirs(PATH_TO_SAVE.joinpath("rust_generated_code/"), exist_ok=True) 100 | 101 | PROJECT_PATH = PATH_TO_SAVE.joinpath("rust_generated_code/") 102 | 103 | FILE_PATH_WEIGHTS = PATH_TO_SAVE.joinpath("saved_model_from_tensorflow/").joinpath( 104 | "model_weights.npz" 105 | ) 106 | FILE_PATH_MODEL_ARCHITECTURE = PATH_TO_SAVE.joinpath( 107 | "saved_model_from_tensorflow/" 108 | ).joinpath("model_architecture.json") 109 | FILE_PATH_COMPUTATIONAL_GRAPH = PATH_TO_SAVE.joinpath( 110 | "saved_model_from_tensorflow/" 111 | ).joinpath("computation_graph.json") 112 | FILE_PATH_OVERVIEW = PATH_TO_SAVE.joinpath("saved_model_from_tensorflow/").joinpath( 113 | "model_overview.png" 114 | ) 115 | 116 | DELETE_LAYERS = (Dropout, SpatialDropout1D, GaussianNoise, AlphaDropout) 117 | 118 | LAYERS_DICTIONARY = { 119 | "InputLayer": inputLayerNode.InputLayerNode, 120 | "Embedding": embeddingNode.EmbeddingNode, 121 | "Dense": denseNode.DenseNode, 122 | "Conv1D": conv1dNode.Conv1DNode, 123 | "MaxPooling1D": maxPooling1dNode.MaxPool1dNode, 124 | "Concatenate": concatenateNode.ConcatenateNode, 125 | "Flatten": flattenNode.FlattenNode, 126 | "Reshape": reshapeNode.ReshapeNode, 127 | "Multiply": multiplyNode.MultiplyNode, 128 | "GlobalAveragePooling1D": globalAveragePooling1dNode.GlobalAveragePooling1DNode, 129 | "Activation": activationNode.ActivationNode, 130 | "ThresholdedReLU": thresholdedrelu.ThresholdedReLU, 131 | "BatchNormalization": batchNormalizationNode.BatchNormalizationNode, 132 | "TensorFlowOpLayer": { 133 | "mean": tensorFlowMeanNode.TensorFlowMeanNode, 134 | "add": tensorFlowAdd2Node.TensorFlowADD2Node, 135 | }, 136 | "Dropout": dropoutNode.DropoutNode, 137 | } 138 | 139 | 140 | def get_class(class_name, layer_name): 141 | layer_name = layer_name.lower() 142 | 143 | target_class = None 144 | if class_name in LAYERS_DICTIONARY: 145 | if class_name == "TensorFlowOpLayer": 146 | if "mean" in layer_name: 147 | target_class = LAYERS_DICTIONARY[class_name]["mean"] 148 | elif "add" in layer_name: 149 | target_class = LAYERS_DICTIONARY[class_name]["add"] 150 | else: 151 | target_class = LAYERS_DICTIONARY[class_name] 152 | 153 | if not target_class: 154 | print("Class_name: {}, layer_name: {}".format(class_name, layer_name)) 155 | raise Exception("Unknown type of layer") 156 | 157 | return target_class 158 | -------------------------------------------------------------------------------- /tf2rust/nodes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrowdStrike/tf2rust/a11f63e5d461ca0a8f7092d338c1838f9484ac4d/tf2rust/nodes/__init__.py -------------------------------------------------------------------------------- /tf2rust/nodes/activationNode.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class ActivationNode(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | self.activation_type = layer_info["config"]["activation"].capitalize() 8 | 9 | def can_be_done_inplace(self): 10 | return True 11 | 12 | def apply_layer(self): 13 | assert len(self.connections["inbounds"]) == 1 14 | assert len(self.input_shape) == len(self.output_shape) 15 | assert ( 16 | len(self.parents_name) == 1 17 | ), "Node {} has parents {}. It should have exactly one parent".format( 18 | self.name, self.parents_name 19 | ) 20 | 21 | if self.inplace_op: 22 | return [ 23 | "tensorflow_layers::Activation::{node.activation_type}.activation_mut(&mut out_{input});".format( 24 | node=self, input=self.parents_name[0] 25 | ) 26 | ] 27 | else: 28 | return [ 29 | "let {declare_mut}out_{node.name} = tensorflow_layers::Activation::{node.activation_type}.activation(&out_{input});".format( 30 | declare_mut=self._format_mut(), 31 | node=self, 32 | input=self.parents_name[0], 33 | ) 34 | ] 35 | -------------------------------------------------------------------------------- /tf2rust/nodes/batchNormalizationNode.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .node import Node 4 | 5 | 6 | class BatchNormalizationNode(Node): 7 | def __init__(self, layer_info, layer_weights): 8 | super().__init__(layer_info, layer_weights) 9 | self.epsilon = str(layer_info["config"]["epsilon"]) 10 | 11 | # gamma, beta, moving_mean, moving_variance 12 | assert ( 13 | len(self.weights_list) == 4 14 | ), "self_weights_list has length {}, but this should be 4".format( 15 | len(self.weights_list) 16 | ) 17 | 18 | self.type = "tensorflow_layers::BatchNormalization" 19 | 20 | @staticmethod 21 | def get_weights(layer): 22 | array_length = len(layer.get_weights()[0]) 23 | 24 | res = [] 25 | order_weights = ["gamma", "beta", "moving_mean", "moving_variance"] 26 | for weight_name in order_weights: 27 | weight = None 28 | for w in layer.weights: 29 | if weight_name in w.name.lower(): 30 | weight = w.numpy() 31 | break 32 | if weight is None: 33 | if weight_name == "gamma": 34 | weight = np.ones((array_length,)) 35 | elif weight_name == "beta": 36 | weight = np.zeros((array_length,)) 37 | else: 38 | raise Exception( 39 | "Expected gamma or beta, found: {}".format(weight_name) 40 | ) 41 | 42 | res.append(weight) 43 | 44 | return res 45 | 46 | def can_be_done_inplace(self): 47 | return True 48 | 49 | def initialize_layer(self): 50 | operation_list = [] 51 | 52 | for i, weights in enumerate(self.weights_list): 53 | operation_list.append( 54 | 'let {node.name}_weight_{idx}: Array{dimensions}<{node.dtype}> = weights_dict.by_name("{node.name}_weight_{idx}.npy")?;'.format( 55 | node=self, idx=i, dimensions=len(weights.shape) 56 | ) 57 | ) 58 | 59 | args = ", ".join( 60 | ["{}_weight_{}".format(self.name, i) for i in range(len(self.weights_list))] 61 | ) 62 | args = ", ".join([args, self.epsilon]) 63 | 64 | layer_declaration = "let {node.name} = {node.type}::new({input});".format( 65 | node=self, input=args 66 | ) 67 | operation_list.append(layer_declaration) 68 | 69 | return operation_list 70 | 71 | def declare_build(self): 72 | return (self.name, self.type) 73 | 74 | def apply_layer(self): 75 | assert len(self.connections["inbounds"]) == 1 76 | assert ( 77 | len(self.parents_name) == 1 78 | ), "Node {} has parents {}. It should have exactly one parent".format( 79 | self.name, self.parents_name 80 | ) 81 | 82 | if self.inplace_op: 83 | return [ 84 | "self.{node.name}.apply_mut(&mut out_{input});".format( 85 | node=self, input=self.parents_name[0] 86 | ) 87 | ] 88 | else: 89 | return [ 90 | "let {declare_mut}out_{node.name} = self.{node.name}.apply(&out_{input});".format( 91 | declare_mut=self._format_mut(), 92 | node=self, 93 | input=self.parents_name[0], 94 | ) 95 | ] 96 | -------------------------------------------------------------------------------- /tf2rust/nodes/concatenateNode.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class ConcatenateNode(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | self.axis = ( 8 | self.input_dimensions + layer_info["config"]["axis"] 9 | ) % self.input_dimensions 10 | 11 | def apply_layer(self): 12 | args = ", ".join( 13 | ["out_{}".format(layer_name) for layer_name in self.parents_name] 14 | ) 15 | 16 | return [ 17 | "let {declare_mut}out_{node.name}: Array{node.output_dimensions}<{node.dtype}> = concatenate![Axis({node.axis}), {input}];".format( 18 | declare_mut=self._format_mut(), node=self, input=args 19 | ) 20 | ] 21 | -------------------------------------------------------------------------------- /tf2rust/nodes/conv1dNode.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class Conv1DNode(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | config = layer_info["config"] 8 | self.kernel_size = config["kernel_size"][0] 9 | self.strides = config["strides"][0] 10 | self.padding = config["padding"] 11 | self.dilation_rate = config["dilation_rate"][0] 12 | self.groups = config["groups"] 13 | self.activation = config["activation"].capitalize() 14 | self.type = "tensorflow_layers::Conv1DLayer" 15 | 16 | def initialize_layer(self): 17 | assert len(self.weights_list) >= 1 18 | assert len(self.weights_list[0].shape) == 3 19 | 20 | padding = self.get_padding( 21 | input_shape=self.input_shape, 22 | output_shape=self.output_shape, 23 | kernel_size=self.kernel_size, 24 | strides=self.strides, 25 | padding_type=self.padding, 26 | ) 27 | operation_list = [] 28 | 29 | for i, weights in enumerate(self.weights_list): 30 | operation_list.append( 31 | 'let {node.name}_weight_{idx}: Array{dimensions}<{node.dtype}> = weights_dict.by_name("{node.name}_weight_{idx}.npy")?;'.format( 32 | node=self, idx=i, dimensions=len(weights.shape) 33 | ) 34 | ) 35 | 36 | args = ", ".join( 37 | ["{}_weight_{}".format(self.name, i) for i in range(len(self.weights_list))] 38 | ) 39 | 40 | # if bias doesnt exist 41 | if len(self.weights_list) == 1: 42 | args = ", ".join( 43 | [args, "Array1::zeros({})".format(self.weights_list[0].shape[0])] 44 | ) 45 | 46 | layer_declaration = "let {node.name} = {node.type}::new({input}, {node.strides}, vec!{padding}, {node.dilation_rate}, {node.groups}, tensorflow_layers::Activation::{node.activation});".format( 47 | node=self, input=args, padding=padding 48 | ) 49 | 50 | operation_list.append(layer_declaration) 51 | 52 | return operation_list 53 | 54 | def declare_build(self): 55 | return (self.name, self.type) 56 | 57 | def apply_layer(self): 58 | assert len(self.connections["inbounds"]) == 1 59 | assert ( 60 | len(self.parents_name) == 1 61 | ), "Node {} has parents {}. It should have exactly one parent".format( 62 | self.name, self.parents_name 63 | ) 64 | 65 | return [ 66 | "let {declare_mut}out_{node.name}: Array{node.output_dimensions}<{node.dtype}> = self.{node.name}.apply(&out_{input});".format( 67 | declare_mut=self._format_mut(), node=self, input=self.parents_name[0] 68 | ) 69 | ] 70 | -------------------------------------------------------------------------------- /tf2rust/nodes/denseNode.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class DenseNode(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | self.activation = layer_info["config"]["activation"].capitalize() 8 | self.type = "tensorflow_layers::DenseLayer" 9 | 10 | def initialize_layer(self): 11 | assert len(self.weights_list) >= 1 12 | assert len(self.weights_list[0].shape) == 2 13 | 14 | operation_list = [] 15 | 16 | for i, weights in enumerate(self.weights_list): 17 | operation_list.append( 18 | 'let {node.name}_weight_{idx}: Array{dimensions}<{node.dtype}> = weights_dict.by_name("{node.name}_weight_{idx}.npy")?;'.format( 19 | node=self, 20 | idx=i, 21 | dimensions=len(weights.shape), 22 | ) 23 | ) 24 | 25 | args = ", ".join( 26 | ["{}_weight_{}".format(self.name, i) for i in range(len(self.weights_list))] 27 | ) 28 | 29 | # if bias doesnt exist 30 | if len(self.weights_list) == 1: 31 | args = ", ".join( 32 | [args, "Array1::zeros({})".format(self.weights_list[0].shape[1])] 33 | ) 34 | 35 | operation_list.append( 36 | "let {node.name} = {node.type}::new({input}, tensorflow_layers::Activation::{node.activation});".format( 37 | node=self, input=args 38 | ) 39 | ) 40 | 41 | return operation_list 42 | 43 | def declare_build(self): 44 | return (self.name, self.type) 45 | 46 | def apply_layer(self): 47 | assert len(self.connections["inbounds"]) == 1 48 | assert ( 49 | len(self.parents_name) == 1 50 | ), "Node {} has parents {}. It should have exactly one parent".format( 51 | self.name, self.parents_name 52 | ) 53 | 54 | return [ 55 | "let {declare_mut}out_{node.name}: Array{node.output_dimensions}<{node.dtype}> = self.{node.name}.apply{node.input_dimensions}d(&out_{input});".format( 56 | declare_mut=self._format_mut(), node=self, input=self.parents_name[0] 57 | ) 58 | ] 59 | -------------------------------------------------------------------------------- /tf2rust/nodes/dropoutNode.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | class DropoutNode(Node): 4 | def __init__(self, layer_info, layer_weights): 5 | super().__init__(layer_info, layer_weights) 6 | self.rate = layer_info["config"]["rate"].capitalize() 7 | self.type = "tensorflow_layers::Dropout" 8 | 9 | def initialize_layer(self): 10 | assert len(self.weights_list) >= 1 11 | assert len(self.weights_list[0].shape) == 2 12 | 13 | operation_list = [] 14 | 15 | for i, weights in enumerate(self.weights_list): 16 | operation_list.append( 17 | 'let {node.name}_weight_{idx}: Array{dimensions}<{node.dtype}> = weights_dict.by_name("{node.name}_weight_{idx}.npy")?;'.format( 18 | node=self, 19 | idx=i, 20 | dimensions=len(weights.shape), 21 | ) 22 | ) 23 | 24 | args = ", ".join( 25 | ["{}_weight_{}".format(self.name, i) for i in range(len(self.weights_list))] 26 | ) 27 | 28 | # if bias doesn't exist 29 | if len(self.weights_list) == 1: 30 | args = ", ".join( 31 | [args, "Array1::zeros({})".format(self.weights_list[0].shape[1])] 32 | ) 33 | 34 | operation_list.append( 35 | "let {node.name} = {node.type}::new({input}, tensorflow_layers::Dropout::{node.rate});".format( 36 | node=self, input=args 37 | ) 38 | ) 39 | 40 | return operation_list 41 | 42 | def declare_build(self): 43 | return (self.name, self.type) 44 | 45 | def apply_layer(self): 46 | assert len(self.connections["inbounds"]) == 1 47 | assert ( 48 | len(self.parents_name) == 1 49 | ), "Node {} has parents {}. It should have exactly one parent".format( 50 | self.name, self.parents_name 51 | ) 52 | 53 | return [ 54 | "let {declare_mut}out_{node.name}: Array{node.output_dimensions}<{node.dtype}> = self.{node.name}.apply{node.input_dimensions}d(&out_{input});".format( 55 | declare_mut=self._format_mut(), node=self, input=self.parents_name[0] 56 | ) 57 | ] 58 | -------------------------------------------------------------------------------- /tf2rust/nodes/embeddingNode.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class EmbeddingNode(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | self.type = "tensorflow_layers::EmbeddingLayer" 8 | 9 | def initialize_layer(self): 10 | assert len(self.weights_list) == 1 11 | operation_list = [] 12 | 13 | operation_list.append( 14 | 'let {node.name}_weight_0: Array{dimensions}<{node.dtype}> = weights_dict.by_name("{node.name}_weight_0.npy")?;'.format( 15 | node=self, dimensions=len(self.weights_list[0].shape) 16 | ) 17 | ) 18 | 19 | operation_list.append( 20 | "let {node.name} = {node.type}::new({node.name}_weight_0);".format( 21 | node=self 22 | ) 23 | ) 24 | 25 | return operation_list 26 | 27 | def declare_build(self): 28 | return (self.name, self.type) 29 | 30 | def apply_layer(self): 31 | assert len(self.connections["inbounds"]) == 1 32 | assert ( 33 | len(self.parents_name) == 1 34 | ), "Node {} has parents {}. It should have exactly one parent".format( 35 | self.name, self.parents_name 36 | ) 37 | 38 | return [ 39 | "let {declare_mut}out_{node.name}: Array{node.output_dimensions}<{node.dtype}> = self.{node.name}.apply(&out_{input});".format( 40 | declare_mut=self._format_mut(), node=self, input=self.parents_name[0] 41 | ) 42 | ] 43 | -------------------------------------------------------------------------------- /tf2rust/nodes/flattenNode.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class FlattenNode(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | 8 | def apply_layer(self): 9 | assert len(self.connections["inbounds"]) == 1 10 | assert ( 11 | len(self.parents_name) == 1 12 | ), "Node {} has parents {}. It should have exactly one parent".format( 13 | self.name, self.parents_name 14 | ) 15 | 16 | output_shape = ["batch_size"] + list(self.output_shape)[1:] 17 | output_shape = "[" + ", ".join([str(dim) for dim in output_shape]) + "]" 18 | 19 | return [ 20 | "let {declare_mut}out_{node.name}: Array{node.output_dimensions}<{node.dtype}> = out_{input}.clone().into_shape({shape}).unwrap();".format( 21 | declare_mut=self._format_mut(), 22 | node=self, 23 | input=self.parents_name[0], 24 | shape=output_shape, 25 | ) 26 | ] 27 | -------------------------------------------------------------------------------- /tf2rust/nodes/globalAveragePooling1dNode.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class GlobalAveragePooling1DNode(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | 8 | def apply_layer(self): 9 | assert len(self.connections["inbounds"]) == 1 10 | assert ( 11 | len(self.parents_name) == 1 12 | ), "Node {} has parents {}. It should have exactly one parent".format( 13 | self.name, self.parents_name 14 | ) 15 | 16 | return [ 17 | "let {declare_mut}out_{node.name}: Array{node.output_dimensions}<{node.dtype}> = out_{input}.mean_axis(Axis(1)).unwrap();".format( 18 | declare_mut=self._format_mut(), node=self, input=self.parents_name[0] 19 | ) 20 | ] 21 | -------------------------------------------------------------------------------- /tf2rust/nodes/inputLayerNode.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class InputLayerNode(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | 8 | def apply_layer(self, input_name): 9 | assert ( 10 | len(self.connections["inbounds"]) == 0 11 | ), "Node {} has parents {}. It should have no parents".format( 12 | self.name, self.parents_name 13 | ) 14 | 15 | return [ 16 | "let out_{node.name}: Array{node.output_dimensions}<{node.dtype}> = {input};".format( 17 | node=self, input=input_name 18 | ) 19 | ] 20 | -------------------------------------------------------------------------------- /tf2rust/nodes/maxPooling1dNode.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class MaxPool1dNode(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | config = layer_info["config"] 8 | self.pool_size = config["pool_size"][0] 9 | self.strides = config["strides"][0] 10 | self.padding = config["padding"] 11 | self.type = "tensorflow_layers::MaxPooling1DLayer" 12 | 13 | def initialize_layer(self): 14 | padding = self.get_padding( 15 | input_shape=self.input_shape, 16 | output_shape=self.output_shape, 17 | kernel_size=self.pool_size, 18 | strides=self.strides, 19 | padding_type=self.padding, 20 | ) 21 | 22 | operation_list = [ 23 | "let {node.name} = {node.type}::new({node.pool_size}, {node.strides}, vec!{padding});".format( 24 | node=self, padding=padding 25 | ) 26 | ] 27 | 28 | return operation_list 29 | 30 | def declare_build(self): 31 | return (self.name, self.type) 32 | 33 | def apply_layer(self): 34 | assert len(self.connections["inbounds"]) == 1 35 | assert ( 36 | len(self.parents_name) == 1 37 | ), "Node {} has parents {}. It should have exactly one parent".format( 38 | self.name, self.parents_name 39 | ) 40 | 41 | return [ 42 | "let {declare_mut}out_{node.name}: Array{node.output_dimensions}<{node.dtype}> = self.{node.name}.apply(&out_{input});".format( 43 | declare_mut=self._format_mut(), node=self, input=self.parents_name[0] 44 | ) 45 | ] 46 | -------------------------------------------------------------------------------- /tf2rust/nodes/multiplyNode.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class MultiplyNode(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | 8 | def apply_layer(self): 9 | args = " * ".join( 10 | ["&out_{}".format(layer_name) for layer_name in self.parents_name] 11 | ) 12 | 13 | return [ 14 | "let {declare_mut}out_{node.name}: Array{node.output_dimensions}<{node.dtype}> = {input};".format( 15 | declare_mut=self._format_mut(), node=self, input=args 16 | ) 17 | ] 18 | -------------------------------------------------------------------------------- /tf2rust/nodes/node.py: -------------------------------------------------------------------------------- 1 | class Node: 2 | def __init__(self, layer_info, layer_weights): 3 | self.python_to_rust_primitive_mappings = { 4 | "int16": "usize", 5 | "int32": "usize", 6 | "int64": "usize", 7 | "uint16": "u16", 8 | "float32": "f32", 9 | "float64": "f32", 10 | } 11 | self.class_name = layer_info["class_name"] 12 | self.name = layer_info["name"].lower() 13 | self.input_shape = self.get_shape(layer_info["input_shape"]) 14 | self.output_shape = self.get_shape(layer_info["output_shape"]) 15 | self.input_dimensions = len(self.input_shape) 16 | self.output_dimensions = len(self.output_shape) 17 | self.dtype = self.python_to_rust_primitive_mappings[ 18 | layer_info["config"]["dtype"] 19 | ] 20 | self.connections = layer_info["connections"] 21 | self.weights_list = layer_weights 22 | self.parents_name = self.connections["inbounds"] 23 | self.output_as_mut = False 24 | self.inplace_op = False 25 | 26 | # Override this in the corresponding class if this form is not compatible 27 | @staticmethod 28 | def get_weights(layer): 29 | return layer.get_weights() 30 | 31 | def initialize_layer(self): 32 | pass 33 | 34 | def can_be_done_inplace(self): 35 | return False 36 | 37 | def memory_drop(self): 38 | return ["mem::drop(out_{});".format(self.name)] 39 | 40 | def _format_mut(self): 41 | if self.output_as_mut: 42 | return "mut " 43 | return "" 44 | 45 | def declare_build(self): 46 | return None 47 | 48 | def apply_layer(self): 49 | pass 50 | 51 | @staticmethod 52 | def get_padding(input_shape, output_shape, kernel_size, strides, padding_type): 53 | # remove batch_size 54 | assert len(input_shape) == len(output_shape) 55 | no_dims_input = len(input_shape) 56 | 57 | if isinstance(strides, int): 58 | strides = [strides] * no_dims_input 59 | elif isinstance(strides, (list, tuple)): 60 | assert len(strides) == 1 61 | strides = [strides[0]] * no_dims_input 62 | # for the batch_size 63 | strides[0] = 0 64 | 65 | padding = [(0, 0)] * no_dims_input 66 | if padding_type.lower() == "valid": 67 | return padding 68 | 69 | elif padding_type.lower() == "same": 70 | total_padding = max( 71 | (output_shape[1] - 1) * strides[1] + kernel_size - input_shape[1], 0 72 | ) 73 | pad_top = total_padding // 2 74 | pad_bottom = total_padding - pad_top 75 | padding[1] = (pad_top, pad_bottom) 76 | return padding 77 | 78 | elif padding_type.lower() == "causal": 79 | total_padding = max( 80 | (output_shape[1] - 1) * strides[1] + kernel_size - input_shape[1], 0 81 | ) 82 | padding[1] = (total_padding, 0) 83 | return padding 84 | 85 | raise Exception("Unknown type of padding") 86 | 87 | @staticmethod 88 | def get_shape(shape): 89 | if isinstance(shape, tuple): 90 | return shape 91 | if isinstance(shape, list): 92 | if isinstance(shape[0], list): 93 | return tuple(shape[0]) 94 | return tuple(shape) 95 | raise Exception("Unknown shape") 96 | -------------------------------------------------------------------------------- /tf2rust/nodes/reshapeNode.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class ReshapeNode(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | 8 | def apply_layer(self): 9 | assert len(self.connections["inbounds"]) == 1 10 | assert ( 11 | len(self.parents_name) == 1 12 | ), "Node {} has parents {}. It should have exactly one parent".format( 13 | self.name, self.parents_name 14 | ) 15 | 16 | output_shape = ["batch_size"] + list(self.output_shape)[1:] 17 | output_shape = "[" + ", ".join(str(dim) for dim in output_shape) + "]" 18 | 19 | return [ 20 | "let {declare_mut}out_{node.name}: Array{node.output_dimensions}<{node.dtype}> = out_{input}.clone().into_shape({shape}).unwrap();".format( 21 | declare_mut=self._format_mut(), 22 | node=self, 23 | input=self.parents_name[0], 24 | shape=output_shape, 25 | ) 26 | ] 27 | -------------------------------------------------------------------------------- /tf2rust/nodes/tensorFlowAdd2Node.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class TensorFlowADD2Node(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | 8 | def apply_layer(self): 9 | args = " + ".join( 10 | ["&out_{}".format(layer_name) for layer_name in self.parents_name] 11 | ) 12 | 13 | return [ 14 | "let {declare_mut}out_{node.name}: Array{node.output_dimensions}<{node.dtype}> = {input};".format( 15 | declare_mut=self._format_mut(), node=self, input=args 16 | ) 17 | ] 18 | -------------------------------------------------------------------------------- /tf2rust/nodes/tensorFlowMeanNode.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class TensorFlowMeanNode(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | self.axis = ( 8 | self.input_dimensions + layer_info["config"]["constants"]["1"] 9 | ) % self.input_dimensions 10 | 11 | def apply_layer(self): 12 | assert len(self.connections["inbounds"]) == 1 13 | assert ( 14 | len(self.parents_name) == 1 15 | ), "Node {} has parents {}. It should have exactly one parent".format( 16 | self.name, self.parents_name 17 | ) 18 | 19 | return [ 20 | "let {declare_mut}out_{node.name}: Array{node.output_dimensions}<{node.dtype}> = out_{input}.mean_axis(Axis({node.axis})).unwrap();".format( 21 | declare_mut=self._format_mut(), node=self, input=self.parents_name[0] 22 | ) 23 | ] 24 | -------------------------------------------------------------------------------- /tf2rust/nodes/thresholdedrelu.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | 3 | 4 | class ThresholdedReLU(Node): 5 | def __init__(self, layer_info, layer_weights): 6 | super().__init__(layer_info, layer_weights) 7 | self.theta = layer_info["config"]["theta"] 8 | 9 | def can_be_done_inplace(self): 10 | return True 11 | 12 | def apply_layer(self): 13 | assert len(self.connections["inbounds"]) == 1 14 | assert ( 15 | len(self.parents_name) == 1 16 | ), "Node {} has parents {}. It should have exactly one parent".format( 17 | self.name, self.parents_name 18 | ) 19 | assert len(self.input_shape) == len(self.output_shape) 20 | 21 | if self.inplace_op: 22 | return [ 23 | "tensorflow_layers::Activation::ThresholdedRelu({node.theta}).activation_mut(&mut out_{input});".format( 24 | node=self, input=self.parents_name[0] 25 | ) 26 | ] 27 | else: 28 | return [ 29 | "let {declare_mut}out_{node.name} = tensorflow_layers::Activation::ThresholdedRelu({node.theta}).activation(&out_{input});".format( 30 | declare_mut=self._format_mut(), 31 | node=self, 32 | input=self.parents_name[0], 33 | ) 34 | ] 35 | -------------------------------------------------------------------------------- /tf2rust/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrowdStrike/tf2rust/a11f63e5d461ca0a8f7092d338c1838f9484ac4d/tf2rust/utils/__init__.py -------------------------------------------------------------------------------- /tf2rust/utils/model_saver.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | 6 | # For some obscure reason and maybe conflict with version of `sk-learn` and `tf`, 7 | # `sk-learn` has to be imported before `tf` is ever used... https://github.com/scikit-learn/scikit-learn/issues/14485 8 | import sklearn 9 | import tensorflow as tf 10 | from tensorflow.keras.models import load_model 11 | 12 | from tf2rust.constants import ( 13 | DELETE_LAYERS, 14 | FILE_PATH_COMPUTATIONAL_GRAPH, 15 | FILE_PATH_MODEL_ARCHITECTURE, 16 | FILE_PATH_OVERVIEW, 17 | FILE_PATH_WEIGHTS, 18 | PATH_TO_TF_MODEL, 19 | get_class, 20 | ) 21 | 22 | from .scoring_metrics import tnr, tpr 23 | from .surgeon.operations import delete_layer 24 | 25 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 26 | 27 | 28 | def sanitize(model): 29 | initial_number_of_layers = len(model.layers) 30 | layers_to_delete = [ 31 | layer for layer in model.layers if isinstance(layer, DELETE_LAYERS) 32 | ] 33 | 34 | for layer_to_delete in layers_to_delete: 35 | model = delete_layer(model, layer_to_delete) 36 | 37 | tf.keras.utils.plot_model( 38 | model, 39 | to_file=FILE_PATH_OVERVIEW, 40 | show_shapes=True, 41 | show_layer_names=True, 42 | rankdir="TB", 43 | expand_nested=False, 44 | dpi=96, 45 | ) 46 | 47 | print( 48 | "#### The number of layers before/after sanitizing the model: {}/{} ####".format( 49 | initial_number_of_layers, len(model.layers) 50 | ) 51 | ) 52 | return model 53 | 54 | 55 | def save_rust_model( 56 | path_to_tf_model, 57 | file_path_weights, 58 | file_path_model_architecture, 59 | file_path_computational_graph, 60 | ): 61 | model = load_model( 62 | filepath=path_to_tf_model, 63 | compile=False, 64 | custom_objects={"tpr": tpr, "tnr": tnr}, 65 | ) 66 | model = sanitize(model) 67 | 68 | # save computational graph 69 | computational_graph = {} 70 | 71 | json_model = json.loads(model.to_json())["config"]["layers"] 72 | for layer_dict in json_model: 73 | # must be lower case 74 | layer_name = layer_dict["name"].lower() 75 | 76 | if layer_name not in computational_graph: 77 | computational_graph[layer_name] = {"inbounds": [], "outbounds": []} 78 | 79 | if len(layer_dict["inbound_nodes"]) > 0: 80 | for in_node in layer_dict["inbound_nodes"][0]: 81 | inbound_node = in_node[0].lower() 82 | if inbound_node not in computational_graph: 83 | computational_graph[inbound_node] = { 84 | "inbounds": [], 85 | "outbounds": [], 86 | } 87 | 88 | computational_graph[inbound_node]["outbounds"].append(layer_name) 89 | computational_graph[layer_name]["inbounds"].append(inbound_node) 90 | 91 | with open(file_path_computational_graph, "w+") as json_file: 92 | json.dump(computational_graph, json_file, indent=4) 93 | 94 | # save architecture + save weights 95 | dictionary_architecture = { 96 | layer_dict["name"].lower(): layer_dict for layer_dict in json_model 97 | } 98 | 99 | for layer in model.layers: 100 | layer_name = layer.name.lower() 101 | dictionary_architecture[layer_name]["input_shape"] = tuple(layer.input_shape) 102 | dictionary_architecture[layer_name]["output_shape"] = tuple(layer.output_shape) 103 | dictionary_architecture[layer_name]["connections"] = computational_graph[ 104 | layer_name 105 | ] 106 | del dictionary_architecture[layer_name]["inbound_nodes"] 107 | 108 | dictionary_weights = {} 109 | for layer in model.layers: 110 | layer_name = layer.name.lower() 111 | class_name = dictionary_architecture[layer_name]["class_name"] 112 | weights_list = get_class( 113 | class_name=class_name, layer_name=layer_name 114 | ).get_weights(layer) 115 | 116 | channels_last_case = ( 117 | "data_format" in dictionary_architecture[layer_name]["config"] 118 | and dictionary_architecture[layer_name]["config"]["data_format"] 119 | == "channels_last" 120 | ) 121 | if channels_last_case: 122 | dictionary_architecture[layer_name]["config"][ 123 | "data_format" 124 | ] = "channels_first" 125 | 126 | for i, weight in enumerate(weights_list): 127 | name_weight = ("{}_weight_{}".format(layer_name, i)).lower() 128 | if channels_last_case: 129 | dimensions = len(weight.shape) 130 | transpose_dimensions = tuple( 131 | [dimensions - 1] + [i for i in range(dimensions - 1)] 132 | ) 133 | dictionary_weights[name_weight] = weight.astype(np.float32).transpose( 134 | transpose_dimensions 135 | ) 136 | else: 137 | dictionary_weights[name_weight] = weight.astype(np.float32) 138 | 139 | np.savez(file_path_weights, **dictionary_weights) 140 | 141 | with open(file_path_model_architecture, "w+") as json_file: 142 | json.dump(dictionary_architecture, json_file, indent=4) 143 | 144 | 145 | def save_tf_model(): 146 | save_rust_model( 147 | path_to_tf_model=PATH_TO_TF_MODEL, 148 | file_path_weights=FILE_PATH_WEIGHTS, 149 | file_path_model_architecture=FILE_PATH_MODEL_ARCHITECTURE, 150 | file_path_computational_graph=FILE_PATH_COMPUTATIONAL_GRAPH, 151 | ) 152 | print( 153 | "#### The model was successfully saved in a suitable format for the Rust converter! ####" 154 | ) 155 | 156 | 157 | if __name__ == "__main__": 158 | save_tf_model() 159 | -------------------------------------------------------------------------------- /tf2rust/utils/rust_converter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import shutil 5 | from collections import deque 6 | 7 | import numpy as np 8 | 9 | from tf2rust.constants import ( 10 | BINARY_CLASSIFICATION, 11 | ENABLE_INPLACE, 12 | ENABLE_MEMDROP, 13 | FILE_PATH_COMPUTATIONAL_GRAPH, 14 | FILE_PATH_FV, 15 | FILE_PATH_MODEL_ARCHITECTURE, 16 | FILE_PATH_WEIGHTS, 17 | MODEL_NAME, 18 | PROJECT_PATH, 19 | get_class, 20 | ) 21 | from tf2rust.nodes.inputLayerNode import InputLayerNode 22 | 23 | MODEL_SERIALIZED_FILE_NAME = "tf_model" 24 | 25 | 26 | # This implementation assumes that the graph is acyclic as there are will be no valid topological sort otherwise. 27 | def topological_sort(graph): 28 | def topological_sort_utils(node, visited, stack): 29 | visited[node] = True 30 | 31 | for neigh in graph[node]["outbounds"]: 32 | if visited[neigh] is False: 33 | topological_sort_utils(neigh, visited, stack) 34 | 35 | stack.append(node) 36 | 37 | visited, stack = {node: False for node in graph}, [] 38 | for node in graph: 39 | if not visited[node]: 40 | topological_sort_utils(node, visited, stack) 41 | 42 | return stack[::-1] 43 | 44 | 45 | def get_option_preferences(nodes_dict, computational_graph): 46 | def _initialize_queue_and_visited(): 47 | in_degree = { 48 | layer_name: len(computational_graph[layer_name]["inbounds"]) 49 | for layer_name in computational_graph 50 | } 51 | queue = deque( 52 | [layer_name for layer_name in in_degree if in_degree[layer_name] == 0] 53 | ) 54 | visited = { 55 | layer_name: True for layer_name in in_degree if in_degree[layer_name] == 0 56 | } 57 | 58 | return in_degree, visited, queue 59 | 60 | layers_option = { 61 | layer_name: { 62 | "can_have_output_mut": True, 63 | "can_be_done_inplace": nodes_dict[layer_name].can_be_done_inplace(), 64 | } 65 | for layer_name in computational_graph 66 | } 67 | for layer_name in layers_option: 68 | if isinstance(nodes_dict[layer_name], InputLayerNode): 69 | layers_option[layer_name] = { 70 | "can_have_output_mut": False, 71 | "can_be_done_inplace": False, 72 | } 73 | 74 | graph_parents = { 75 | layer_name: computational_graph[layer_name]["inbounds"] 76 | for layer_name in computational_graph 77 | } 78 | graph_children = { 79 | layer_name: computational_graph[layer_name]["outbounds"] 80 | for layer_name in computational_graph 81 | } 82 | 83 | in_degree_remaining, visited, queue = _initialize_queue_and_visited() 84 | while len(queue) > 0: 85 | curr_layer_name = queue.popleft() 86 | number_of_special_children = sum( 87 | [ 88 | 1 89 | for child_name in graph_children[curr_layer_name] 90 | if layers_option[child_name]["can_be_done_inplace"] 91 | ] 92 | ) 93 | if ( 94 | len(graph_children[curr_layer_name]) >= 2 95 | and number_of_special_children >= 1 96 | ): 97 | for child_name in graph_children[curr_layer_name]: 98 | layers_option[child_name]["can_be_done_inplace"] = False 99 | 100 | for child_name in graph_children[curr_layer_name]: 101 | in_degree_remaining[child_name] -= 1 102 | if in_degree_remaining[child_name] == 0 and child_name not in visited: 103 | visited[child_name] = True 104 | queue.append(child_name) 105 | 106 | # Remove layers that can be done in place, but doesn't have all the inputs set to mut 107 | # Additionally, compute the output name for each layer (taking into consideration inplace operations). 108 | in_nodes = { 109 | layer_name: { 110 | parent_name: parent_name for parent_name in graph_parents[layer_name] 111 | } 112 | for layer_name in graph_parents 113 | } 114 | out_nodes = {layer_name: layer_name for layer_name in computational_graph} 115 | 116 | in_degree_remaining, visited, queue = _initialize_queue_and_visited() 117 | necessary_outputs_mut = {} 118 | while len(queue) > 0: 119 | curr_layer_name = queue.popleft() 120 | 121 | if layers_option[curr_layer_name]["can_be_done_inplace"]: 122 | number_of_parents_with_mut_output = sum( 123 | [ 124 | 1 125 | for direct_parent_name in in_nodes[curr_layer_name] 126 | if layers_option[direct_parent_name]["can_have_output_mut"] 127 | ] 128 | ) 129 | 130 | if number_of_parents_with_mut_output != len(in_nodes[curr_layer_name]): 131 | layers_option[curr_layer_name]["can_be_done_inplace"] = False 132 | out_nodes[curr_layer_name] = curr_layer_name 133 | else: 134 | assert len(in_nodes[curr_layer_name]) == 1 135 | [direct_parent_name] = in_nodes[curr_layer_name].keys() 136 | necessary_outputs_mut[direct_parent_name] = True 137 | out_nodes[curr_layer_name] = out_nodes[direct_parent_name] 138 | 139 | else: 140 | out_nodes[curr_layer_name] = curr_layer_name 141 | 142 | for child_name in graph_children[curr_layer_name]: 143 | in_degree_remaining[child_name] -= 1 144 | in_nodes[child_name][curr_layer_name] = out_nodes[curr_layer_name] 145 | if in_degree_remaining[child_name] == 0 and child_name not in visited: 146 | visited[child_name] = True 147 | queue.append(child_name) 148 | 149 | in_degree_remaining, visited, queue = _initialize_queue_and_visited() 150 | 151 | # Update nodes_dict, while removing unnecessary can_have_output_mut from layers 152 | for layer_name in nodes_dict: 153 | nodes_dict[layer_name].output_as_mut = layer_name in necessary_outputs_mut 154 | nodes_dict[layer_name].inplace_op = layers_option[layer_name][ 155 | "can_be_done_inplace" 156 | ] 157 | nodes_dict[layer_name].parents_name = list(in_nodes[layer_name].values()) 158 | 159 | # Transfer all info into layers_option 160 | for layer_name in layers_option: 161 | layers_option[layer_name]["in_nodes"] = in_nodes[layer_name] 162 | layers_option[layer_name]["out_node"] = out_nodes[layer_name] 163 | 164 | return layers_option 165 | 166 | 167 | def construct_node(layer_info, layer_weights): 168 | class_name = layer_info["class_name"] 169 | layer_name = layer_info["name"].lower() 170 | 171 | target_class = get_class(class_name=class_name, layer_name=layer_name) 172 | return target_class(layer_info=layer_info, layer_weights=layer_weights) 173 | 174 | 175 | def get_weights_by_name(layer_name, model_weights): 176 | layers_list, i = [], 0 177 | 178 | while True: 179 | aux_name = "{}_weight_{}".format(layer_name, i) 180 | if aux_name not in model_weights: 181 | break 182 | 183 | layers_list.append(model_weights[aux_name]) 184 | i += 1 185 | 186 | return layers_list 187 | 188 | 189 | def define_class_name(name): 190 | res = re.split("[^0-9a-zA-Z]", name) 191 | res = "".join([word[0:1].upper() + word[1:] for word in res if word != ""]) 192 | if len(res) < 5 or res[len(res) - 5 :] != " Model": 193 | res += "Model" 194 | return res 195 | 196 | 197 | # Recursively convert None to 1 198 | # eg: [[None, 350], [None, 50, 15], [None, 50]] -> [[1, 350], [1, 50, 15], [1, 50]] 199 | def convert_none_to_one(lst): 200 | if isinstance(lst, list): 201 | return [convert_none_to_one(l) for l in lst] 202 | else: 203 | return 1 if lst is None else lst 204 | 205 | 206 | def declare_build(traversal_order, nodes_dict, file_path_build, rsrc_path): 207 | shutil.copy(FILE_PATH_WEIGHTS, rsrc_path) 208 | 209 | with open(file_path_build, "w") as f: 210 | f.write( 211 | """\ 212 | // DO NOT EDIT! THIS FILE IS AUTOMATICALLY GENERATED! 213 | #[path = "src/model.rs"] 214 | mod model; 215 | 216 | use model::{class_name}; 217 | use ndarray::{{Array1, Array2, Array3}}; 218 | use std::{{error::Error, io::Write}}; 219 | 220 | #[allow(clippy::similar_names, clippy::too_many_lines)] // suppress clippy's complaints about our autogenerated code 221 | #[rustfmt::skip] 222 | fn serialize_model() -> Result, Box> {{ 223 | let mut weights_dict = ndarray_npy::NpzReader::new(std::fs::File::open("{file_path}")?)?; 224 | 225 | {body} 226 | 227 | let model = {class_name} {{ 228 | {struct_initializer} 229 | }}; 230 | 231 | Ok(model.serialize()) 232 | }} 233 | 234 | fn main() -> Result<(), Box> {{ 235 | let path = std::path::Path::new(&std::env::var("OUT_DIR")?).join("{model_serialized_name}"); 236 | let mut file = std::fs::File::create(&path)?; 237 | file.write_all(&serialize_model()?)?; 238 | linkin::link_model(path) 239 | }} 240 | 241 | /// Use toolchain specific tools to embed model in the build. 242 | /// The default way to do this was to use `include_bytes!` macro which works just fine for any smaller size models, 243 | /// however, since the models have grown (over 100MB) the include bytes take very long time as well as consumes huge amount of memory on build machines. 244 | /// Due to the above a toolchain specific way of embedding the blob in the file is required. 245 | /// 246 | /// Linux targets use `ld` to create obj files. 247 | /// Windows MSVC targets use `resource embedding`. 248 | /// 249 | /// Both techniques require specific retrieval methods as per `raw_model` function in the lib source. 250 | mod linkin {{ 251 | 252 | use std::{{env, error::Error, fs, path::Path}}; 253 | 254 | /// On Linux GNU toolchain we can create object file and include in the final binary with linker directives. 255 | #[cfg(target_os = "linux")] 256 | pub fn link_model(path: impl AsRef) -> Result<(), Box> {{ 257 | use std::{{os::unix::fs as unix_fs, process::Command}}; 258 | 259 | const LIB_NAME: &str = "{model_serialized_name}_raw"; 260 | 261 | let out_dir = env::var("OUT_DIR")?; 262 | let out_dir = Path::new(&out_dir); 263 | 264 | // symlink the original lib so we get our choice of symbol names 265 | let orig_model_path = env::current_dir()?.join(path); 266 | let model_path = out_dir.join(LIB_NAME); 267 | match fs::remove_file(&model_path) {{ 268 | Ok(()) => (), 269 | Err(e) if e.kind() == std::io::ErrorKind::NotFound => (), 270 | Err(e) => panic!("failed to unlink {{}}: {{}}", model_path.display(), e), 271 | }} 272 | unix_fs::symlink(&orig_model_path, &model_path).unwrap(); 273 | 274 | // use ld turn the model into a .o 275 | let object_file = out_dir.join(format!("{{}}.o", LIB_NAME)); 276 | let status = Command::new("ld") 277 | .current_dir(out_dir) 278 | .arg("-r") 279 | .arg("-b") 280 | .arg("binary") 281 | .arg(LIB_NAME) 282 | .arg("-o") 283 | .arg(&object_file) 284 | .status()?; 285 | assert!( 286 | status.success(), 287 | "ld failed to convert model to object file" 288 | ); 289 | 290 | // create lib from .o 291 | let status = Command::new("ar") 292 | .current_dir(out_dir) 293 | .arg("cr") 294 | .arg(format!("lib{{}}.a", LIB_NAME)) 295 | .arg(object_file) 296 | .status()?; 297 | assert!(status.success(), "ar failed to create lib from object file"); 298 | 299 | // technically we should also 300 | // println!("cargo:rerun-if-changed={{}}", orig_model_path) 301 | // so we rerun if the model file changes, but: 302 | // a. model file changes include source code changes (each model update 303 | // has a new filename), and 304 | // b. checking a 150+ MB model for changes is slooooow. 305 | println!("cargo:rustc-link-lib=static={{}}", LIB_NAME); 306 | println!("cargo:rustc-link-search=native={{}}", out_dir.display()); 307 | 308 | Ok(()) 309 | }} 310 | 311 | /// On Windows we use resource embedding to be able to later retrieve the data via the Windows API `FindResource` and `LoadResource` 312 | #[cfg(target_os = "windows")] 313 | pub fn link_model(_path: impl AsRef) {{ 314 | // For Resource embedding we need to have defined 2 types. 315 | // As per MSVC docs `NAME_ID_MODEL_FILE` can be any `u16`. 316 | const NAME_ID_MODEL_FILE: &str = "101"; 317 | // As per MSVC docs `TYPE_ID_BINARY_FILE` has to be an int over 255 value. 318 | const TYPE_ID_BINARY_FILE: &str = "333"; 319 | 320 | let path_emb = Path::new(&env::var("OUT_DIR")?).join("embed.rc"); 321 | let data = format!( 322 | r#" 323 | #define BINARY_FILE {{}} 324 | #define MODEL_FILE {{}} 325 | MODEL_FILE BINARY_FILE "{model_serialized_name}" 326 | "#, 327 | TYPE_ID_BINARY_FILE, NAME_ID_MODEL_FILE 328 | ); 329 | fs::write(&path_emb, data)?; 330 | 331 | // This crate does some Windows MSVC specific black magic... 332 | // It finds required tools and converts the blob to an embeddable obj and then includes as a resource as per `rc` file. 333 | embed_resource::compile(path_emb); 334 | 335 | Ok(()) 336 | }} 337 | }} 338 | """.format( 339 | model_serialized_name=MODEL_SERIALIZED_FILE_NAME, 340 | class_name=define_class_name(MODEL_NAME), 341 | file_path=rsrc_path.joinpath(FILE_PATH_WEIGHTS.name).relative_to( 342 | PROJECT_PATH 343 | ), 344 | body="\n\t".join( 345 | i 346 | for lst in ( 347 | nodes_dict[layer_name].initialize_layer() 348 | for layer_name in traversal_order 349 | ) 350 | for i in (lst if isinstance(lst, list) else []) 351 | ).expandtabs(4), 352 | struct_initializer=",\n\t\t".join( 353 | i[0] 354 | for i in ( 355 | nodes_dict[layer_name].declare_build() 356 | for layer_name in traversal_order 357 | ) 358 | if i != None 359 | ).expandtabs( 360 | 4 361 | ), # we want this to fail if len(i) < 1 362 | ) 363 | ) 364 | 365 | 366 | def declare_model( 367 | traversal_order, 368 | nodes_dict, 369 | computational_graph, 370 | model_architecture, 371 | file_path_model, 372 | enable_inplace=True, 373 | enable_memory_drop=True, 374 | ): 375 | rust_class_name = define_class_name(MODEL_NAME) 376 | input_shapes = convert_none_to_one( 377 | [ 378 | model_architecture[layer]["input_shape"][0] 379 | for layer in model_architecture 380 | if model_architecture[layer]["class_name"] == "InputLayer" 381 | ] 382 | ) 383 | 384 | start_layers_name = [ 385 | name 386 | for name in computational_graph 387 | if len(computational_graph[name]["inbounds"]) == 0 388 | ] 389 | assert len(input_shapes) == len(start_layers_name) 390 | 391 | final_layers_name = [ 392 | name 393 | for name in computational_graph 394 | if len(computational_graph[name]["outbounds"]) == 0 395 | ] 396 | 397 | input_dict = { 398 | name: "input_{}".format(i) for i, name in enumerate(start_layers_name) 399 | } 400 | output_dict = { 401 | name: "output_{}".format(i) for i, name in enumerate(final_layers_name) 402 | } 403 | 404 | out_degree_remaining = { 405 | layer_name: len(computational_graph[layer_name]["outbounds"]) 406 | for layer_name in computational_graph 407 | } 408 | 409 | logic_operations = ["let batch_size = (input_0).shape()[0];"] 410 | 411 | # populate parents_name + output_as_mut + inplace_op of each node 412 | if enable_inplace: 413 | get_option_preferences( 414 | nodes_dict=nodes_dict, computational_graph=computational_graph 415 | ) 416 | 417 | for layer_name in traversal_order: 418 | layer = nodes_dict[layer_name] 419 | if isinstance(layer, InputLayerNode): 420 | logic_operations += nodes_dict[layer_name].apply_layer( 421 | input_name=input_dict[layer_name] 422 | ) 423 | else: 424 | logic_operations += nodes_dict[layer_name].apply_layer() 425 | 426 | # implement memory drop as soon as possible 427 | if enable_memory_drop: 428 | if not nodes_dict[layer_name].inplace_op: 429 | for parent in nodes_dict[layer_name].parents_name: 430 | out_degree_remaining[parent] -= 1 431 | if out_degree_remaining[parent] == 0: 432 | logic_operations += nodes_dict[parent].memory_drop() 433 | 434 | args_list = [] 435 | arg_dtype = None 436 | for layer_name in input_dict: 437 | name_var = input_dict[layer_name] 438 | if arg_dtype is None: 439 | arg_dtype = nodes_dict[layer_name].dtype 440 | else: 441 | # make sure the same data type is used across all input arrays 442 | # if this fails, the generated code needs to change to allow for different input data types 443 | assert arg_dtype == nodes_dict[layer_name].dtype 444 | input_dimensions = nodes_dict[layer_name].input_dimensions 445 | 446 | args_list.append(("{}: ".format(name_var), "Array{}".format(input_dimensions))) 447 | 448 | result_types_list, result_var_list = [], [] 449 | for layer_name in output_dict: 450 | name_var = "out_{}".format(layer_name) 451 | result_var_list.append(name_var) 452 | 453 | dtype = nodes_dict[layer_name].dtype 454 | input_dimensions = nodes_dict[layer_name].input_dimensions 455 | result_types_list.append("Array{}<{}>".format(input_dimensions, dtype)) 456 | 457 | # helper to format multiple params as a tuple 458 | tuplefy = lambda line, items: "({})".format(line) if items > 1 else line 459 | 460 | with open(file_path_model, "w") as f: 461 | f.write( 462 | """\ 463 | // DO NOT EDIT! THIS FILE IS AUTOMATICALLY GENERATED! 464 | use ndarray::{{concatenate, Array1, Array2, Array3, Array4, Axis}}; 465 | use serde::{{Deserialize, Serialize}}; 466 | use std::mem; 467 | 468 | /// Translated TensorFlow {model} Model 469 | #[derive(Serialize, Deserialize, Debug, Clone)] 470 | #[allow(clippy::module_name_repetitions)] 471 | pub(crate) struct {class_name} {{ 472 | {members} 473 | }} 474 | 475 | impl {class_name} {{ 476 | fn fv_to_arrays(mut fv: Vec<{dtype}>) -> {input_type} {{ 477 | {tuple_builder} 478 | }} 479 | 480 | /// Predict from a single feature vector. 481 | /// 482 | /// # Errors 483 | /// 484 | /// Fails if there is an internal error. 485 | #[allow(dead_code)] // we're imported by build.rs, but the build script doesn't call `predict` 486 | #[must_use] 487 | pub(crate) fn predict(&self, fv: Vec<{dtype}>) -> {return_type} {{ 488 | let {tuple_items} = Self::fv_to_arrays(fv); 489 | self.predict_from_arrays({tuple_ref_items}) 490 | }} 491 | 492 | #[allow(clippy::similar_names, clippy::too_many_lines)] 493 | #[rustfmt::skip] 494 | fn predict_from_arrays(&self, {input}) -> {return_type} {{ 495 | {body} 496 | 497 | {result} 498 | }} 499 | 500 | /// Serialize this model to a byte vector 501 | #[allow(dead_code)] 502 | pub(crate) fn serialize(&self) -> Vec {{ 503 | bincode::serialize(&self).unwrap() 504 | }} 505 | }} 506 | """.format( 507 | model=MODEL_NAME, 508 | class_name=rust_class_name, 509 | members="\n\t".join( 510 | "pub(crate) {item[0]}: {item[1]},".format(item=item) 511 | for item in ( 512 | nodes_dict[layer_name].declare_build() 513 | for layer_name in traversal_order 514 | ) 515 | if item != None 516 | ).expandtabs(4), 517 | dtype=arg_dtype, 518 | input=", ".join( 519 | "{}{}<{}>".format(n, t, arg_dtype) for (n, t) in args_list 520 | ), 521 | input_type=tuplefy( 522 | ", ".join("{}<{}>".format(t, arg_dtype) for (_, t) in args_list), 523 | len(args_list), 524 | ), 525 | tuple_builder=tuplefy( 526 | ",\n\t\t\t".join( 527 | "{type}::from_shape_vec(({shape}), fv.drain(..{idx}).collect()).unwrap()".format( 528 | type=t, 529 | shape=", ".join(str(_) for _ in s), 530 | idx=np.array(s).prod(), 531 | ) 532 | for ((_, t), s) in zip(args_list, input_shapes) 533 | ).expandtabs(4), 534 | len(args_list), 535 | ), 536 | tuple_items=tuplefy( 537 | ", ".join("input_{}".format(i) for i in range(len(args_list))), 538 | len(args_list), 539 | ), 540 | tuple_ref_items=", ".join( 541 | "input_{}".format(i) for i in range(len(args_list)) 542 | ), 543 | return_type=", ".join(result_types_list), 544 | body="\n\t\t".join(logic_operations).expandtabs(4), 545 | result=", ".join(result_var_list), 546 | ) 547 | ) 548 | 549 | 550 | def declare_cargo_toml(file_path_cargo_toml): 551 | with open(file_path_cargo_toml, "w") as f: 552 | f.write( 553 | """\ 554 | [package] 555 | name = "predictor-example" 556 | version = "0.0.1" 557 | authors = ["Crowdstrike DSCI "] 558 | edition = "2021" 559 | description = "Example Predictor" 560 | license = "MIT" 561 | include = ["build.rs", "Cargo.toml", "benches/*", "model/*", "src/*"] 562 | 563 | build = "build.rs" 564 | 565 | [build-dependencies] 566 | bincode = "1.3.1" 567 | ndarray = { version = "0.15.5" } 568 | ndarray-npy = "0.8.1" 569 | serde = { version = "1.0.188", features = ["derive"] } 570 | tensorflow_layers = { package = "tf-layers", version = "0.4.0" } 571 | 572 | [target.'cfg(all(windows, target_env = "msvc"))'.build-dependencies] 573 | embed-resource = "2.3" 574 | 575 | [dependencies] 576 | bincode = "1.3.1" 577 | ndarray = { version = "0.15.5", features = ["serde-1"] } 578 | once_cell = "1.18" 579 | serde = { version = "1.0.188", features = ["derive"] } 580 | serde_json = "1.0.107" 581 | tensorflow_layers = { package = "tf-layers", version = "0.4.0" } 582 | 583 | [target.'cfg(windows)'.dependencies] 584 | windows = { version = "0.51.1", features = [ 585 | "Win32_System_LibraryLoader", 586 | "Win32_Foundation", 587 | ] } 588 | 589 | [dev-dependencies] 590 | base64 = "0.21.4" 591 | criterion = "0.5" 592 | itertools = "0.11.0" 593 | ndarray-npy = "0.8.1" 594 | once_cell = "1.18" 595 | 596 | [[bench]] 597 | name = "benchmarks" 598 | harness = false 599 | """ 600 | ) 601 | 602 | 603 | def declare_lib(file_path_lib): 604 | 605 | if BINARY_CLASSIFICATION: 606 | predictor_base = "pub use predictor_base::{BinaryClassThresholds, BinaryModelResult, ScanPrediction};" 607 | return_type = "BinaryModelResult" 608 | content_predict_from = """\ 609 | let decision_value = MODEL.predict(fv)[[0, 1]]; 610 | 611 | BinaryModelResult { 612 | dirty: ScanPrediction { 613 | fv_index: 0, 614 | confidence: THRESHOLDS.dirty_confidence(decision_value), 615 | decision_value, 616 | }, 617 | }""" 618 | 619 | test_content = """\ 620 | #[cfg(test)] 621 | mod tests { 622 | use super::*; 623 | use itertools::izip; 624 | use ndarray::{Array1, Array2, Axis}; 625 | use ndarray_npy::NpzReader; 626 | use std::fs::File; 627 | 628 | #[test] 629 | fn test_predictions() { 630 | let (features, expected_predictions): (Array2, Array1) = { 631 | let mut npz = NpzReader::new(File::open("testdata/features.npz").unwrap()).unwrap(); 632 | 633 | let fv: Array2 = npz.by_name("inputs.npy").unwrap(); 634 | let predictions: Array1 = npz.by_name("predictions.npy").unwrap(); 635 | (fv.mapv(|elem| elem as usize), predictions) 636 | }; 637 | 638 | for (fv, expected_dirty_score) in 639 | izip!(features.axis_iter(Axis(0)), expected_predictions.iter()) 640 | { 641 | let fv_to_vec = fv.into_owned().into_raw_vec(); 642 | let result = predict_from(fv_to_vec); 643 | let dirty_score = result.dirty.decision_value; 644 | let tolerance: f32 = 1.0e-5; 645 | assert!( 646 | (dirty_score - expected_dirty_score).abs() < tolerance, 647 | "predicted: {} whereas: {} was expected", 648 | dirty_score, 649 | expected_dirty_score 650 | ); 651 | } 652 | } 653 | }""" 654 | 655 | else: 656 | predictor_base = "pub use predictor_base::{BinaryClassThresholds};" 657 | return_type = "Array1" 658 | content_predict_from = """\ 659 | MODEL.predict(fv).index_axis(Axis(0), 0).to_owned()""" 660 | test_content = "" 661 | 662 | rust_class_name = define_class_name(MODEL_NAME) 663 | with open(file_path_lib, "w") as f: 664 | f.write( 665 | """\ 666 | //! Crate for running predictions against the `{class_name}` 667 | 668 | use ndarray::{{Array1, Axis}}; 669 | use once_cell::sync::Lazy; 670 | {predictor_base} 671 | 672 | #[rustfmt::skip] 673 | mod model; 674 | use model::{class_name}; 675 | 676 | /// Model version. 677 | pub const MODEL_VERSION: u32 = 1; 678 | 679 | static MODEL: Lazy<{class_name}> = Lazy::new(|| bincode::deserialize(raw_model::raw_model()).unwrap()); 680 | static THRESHOLDS: Lazy = 681 | Lazy::new(|| serde_json::from_str(include_str!("../model/thresholds.json")).unwrap()); 682 | 683 | /// Predict from a set of feature vectors. 684 | #[must_use] 685 | pub fn predict_from(fv: Vec) -> {return_type} {{ 686 | {content_predict_from} 687 | }} 688 | 689 | /// This module contains a function to load the model from current binary. 690 | mod raw_model {{ 691 | 692 | /// On Linux GNU read symbols from the lib and load them. 693 | #[cfg(target_os = "linux")] 694 | #[must_use] 695 | pub(crate) fn raw_model() -> &'static [u8] {{ 696 | use std::ptr::addr_of; 697 | 698 | // On linux, the model is created from an object file and linked in; we 699 | // have to do some pointer shenanigans to find it. Our build script names 700 | // the input `predictor_pe_model_raw`, then the `ld` invocation prepends 701 | // `_binary` and creates a start and end symbol giving us the range. 702 | extern "C" {{ 703 | static _binary_tf_model_raw_start: u8; 704 | static _binary_tf_model_raw_end: u8; 705 | }} 706 | 707 | unsafe {{ 708 | let start = addr_of!(_binary_tf_model_raw_start); 709 | let end = addr_of!(_binary_tf_model_raw_end); 710 | let length = (end as usize) - (start as usize); 711 | std::slice::from_raw_parts(start, length) 712 | }} 713 | }} 714 | 715 | /// On Windows MSVC we use `FindResource` and `LoadResource` to get the model from the binary. 716 | #[cfg(target_os = "windows")] 717 | #[must_use] 718 | pub(crate) fn raw_model() -> &'static [u8] {{ 719 | use windows::core::PCWSTR; 720 | 721 | // For Resource embedding we need to have defined 2 types. 722 | // As per MSVC docs `NAME_ID_MODEL_FILE` can be any `u16`. 723 | const NAME_ID_MODEL_FILE: usize = 101; 724 | 725 | // As per MSVC docs `TYPE_ID_BINARY_FILE` has to be an int over 255 value. 726 | const TYPE_ID_BINARY_FILE: usize = 333; 727 | 728 | // We have hardcoded the `MODEL_FILE` name id in the `rc` file, retrieving it via the ID. 729 | let name_id = PCWSTR(unsafe {{ 730 | // SAFETY: The resource name id is a u16 which is supposed to be constructed with `MAKEINTRESOURCE` 731 | // to generate a "string pointer" (the winapi knows its not a pointer) 732 | core::mem::transmute(NAME_ID_MODEL_FILE) 733 | }}); 734 | 735 | // We have hardcoded the `BINARY_FILE` type id in the `rc` file, retrieving it via the ID. 736 | let type_id = PCWSTR(unsafe {{ 737 | // SAFETY: The resource type id is a u16 which is supposed to be constructed with `MAKEINTRESOURCE` 738 | // to generate a "string pointer" (the winapi knows its not a pointer) 739 | core::mem::transmute(TYPE_ID_BINARY_FILE) 740 | }}); 741 | 742 | let p: *const u16 = std::ptr::null(); 743 | unsafe {{ 744 | // We need current module handle for the next call 745 | let handle = 746 | windows::Win32::System::LibraryLoader::GetModuleHandleW(PCWSTR(p)).unwrap(); 747 | assert!( 748 | !handle.is_invalid(), 749 | "Failed to GetModuleHandleW : {{:?}}", 750 | handle 751 | ); 752 | 753 | // Use `FindResource` to retrieve the pointer to the embedded resource and then load to memory. 754 | let resource = 755 | windows::Win32::System::LibraryLoader::FindResourceW(handle, name_id, type_id); 756 | assert!( 757 | !resource.is_invalid(), 758 | "Failed to find resource in PE : {{:?}}", 759 | resource 760 | ); 761 | 762 | let resource_data = 763 | windows::Win32::System::LibraryLoader::LoadResource(handle, resource); 764 | assert!(resource_data != 0, "Failed to load resource"); 765 | 766 | let size = windows::Win32::System::LibraryLoader::SizeofResource(handle, resource); 767 | 768 | // Finally get the data from the module. 769 | let data = windows::Win32::System::LibraryLoader::LockResource(resource_data); 770 | assert!(!data.is_null(), "Failed to lock resource"); 771 | 772 | std::slice::from_raw_parts(data.cast::(), size as usize) 773 | }} 774 | }} 775 | }} 776 | 777 | {test_content} 778 | """.format( 779 | predictor_base=predictor_base, 780 | class_name=rust_class_name, 781 | return_type=return_type, 782 | content_predict_from=content_predict_from, 783 | test_content=test_content, 784 | ) 785 | ) 786 | 787 | 788 | def declare_thresholds(file_path_thresholds): 789 | with open(file_path_thresholds, "w") as f: 790 | f.write( 791 | """\ 792 | { 793 | "dirty": { 794 | "bottom": 0.0, 795 | "low": 0.0, 796 | "medium": 0.0, 797 | "high": 0.0 798 | } 799 | } 800 | """ 801 | ) 802 | 803 | 804 | def declare_bench(file_path_bench): 805 | with open(file_path_bench, "w") as f: 806 | f.write( 807 | """\ 808 | use criterion::{black_box, criterion_group, criterion_main, Criterion}; 809 | use ndarray::{Array2, Axis, Slice}; 810 | use ndarray_npy::NpzReader; 811 | use once_cell::sync::Lazy; 812 | use predictor_example::predict_from; 813 | use std::fs::File; 814 | use std::time::{Duration, Instant}; 815 | 816 | static FEATURES: Lazy> = Lazy::new(|| { 817 | let mut npz = NpzReader::new(File::open("testdata/features.npz").unwrap()).unwrap(); 818 | let fv: Array2 = npz.by_name("inputs.npy").unwrap(); 819 | fv.mapv(|elem| elem as usize) 820 | }); 821 | 822 | fn bench_predict_one_sample(c: &mut Criterion) { 823 | let fv = FEATURES.index_axis(Axis(0), 0).to_vec(); 824 | c.bench_function("Inference for one sample", move |b| { 825 | b.iter_custom(|iters| { 826 | let mut duration: Vec = Vec::with_capacity(iters as usize); 827 | for _i in 0..iters { 828 | let f = fv.clone(); 829 | 830 | let start = Instant::now(); 831 | let _ = predict_from(black_box(f)); 832 | let end = start.elapsed(); 833 | 834 | duration.push(end) 835 | } 836 | duration.iter().sum() 837 | }) 838 | }); 839 | } 840 | 841 | fn bench_predict_multiple_samples(c: &mut Criterion) { 842 | let head_features = FEATURES.slice_axis(Axis(0), Slice::from(0..10)); 843 | c.bench_function("Run inference for a batch of samples", |b| { 844 | b.iter_custom(|iters| { 845 | let mut duration: Vec = Vec::with_capacity(iters as usize); 846 | for _i in 0..iters { 847 | duration.push( 848 | head_features 849 | .axis_iter(Axis(0)) 850 | .map(|fv| { 851 | let start = Instant::now(); 852 | let fv_to_vec = fv.into_owned().into_raw_vec(); 853 | let _ = predict_from(fv_to_vec); 854 | start.elapsed() 855 | }) 856 | .sum(), 857 | ); 858 | } 859 | duration.iter().sum() 860 | }); 861 | }); 862 | } 863 | 864 | criterion_group!( 865 | benches, 866 | bench_predict_one_sample, 867 | bench_predict_multiple_samples 868 | ); 869 | 870 | criterion_main!(benches); 871 | """ 872 | ) 873 | 874 | 875 | def prepare_fv(model_architecture, path_feature_vectors, file_path_testdata): 876 | order_keys = [ 877 | layer_name 878 | for layer_name in model_architecture 879 | if model_architecture[layer_name]["class_name"].lower() == "InputLayer".lower() 880 | ] 881 | dictionary_arrays = np.load(path_feature_vectors) 882 | arrays_list = [dictionary_arrays[key] for key in order_keys] 883 | 884 | array_1d = np.concatenate( 885 | [arr.reshape((-1, np.prod(arr.shape[1:]))) for arr in arrays_list], axis=1 886 | ) 887 | array_1d = array_1d.astype(np.int32) 888 | 889 | predictions = dictionary_arrays["predictions"] 890 | predictions = predictions.astype(np.float32) 891 | 892 | save_args = {"inputs": array_1d, "predictions": predictions} 893 | np.savez(file_path_testdata, **save_args) 894 | 895 | 896 | def convert_to_rust(): 897 | PROJECT_SRC = PROJECT_PATH.joinpath("src") 898 | os.makedirs(PROJECT_SRC, exist_ok=True) 899 | PROJECT_RSRC_PATH = PROJECT_PATH.joinpath("model") 900 | os.makedirs(PROJECT_RSRC_PATH, exist_ok=True) 901 | PROJECT_CONFIG_PATH = PROJECT_PATH.joinpath(".cargo") 902 | os.makedirs(PROJECT_CONFIG_PATH, exist_ok=True) 903 | PROJECT_TESTDATA_PATH = PROJECT_PATH.joinpath("testdata") 904 | os.makedirs(PROJECT_TESTDATA_PATH, exist_ok=True) 905 | PROJECT_BENCH_PATH = PROJECT_PATH.joinpath("benches") 906 | os.makedirs(PROJECT_BENCH_PATH, exist_ok=True) 907 | 908 | FILE_PATH_MODEL = PROJECT_SRC.joinpath("model.rs") 909 | FILE_PATH_LIB = PROJECT_SRC.joinpath("lib.rs") 910 | FILE_PATH_BUILD = PROJECT_PATH.joinpath("build.rs") 911 | FILE_PATH_CARGO_TOML = PROJECT_PATH.joinpath("Cargo.toml") 912 | FILE_PATH_THRESHOLDS = PROJECT_RSRC_PATH.joinpath("thresholds.json") 913 | FILE_PATH_TESTDATA = PROJECT_TESTDATA_PATH.joinpath("features.npz") 914 | FILE_PATH_BENCH = PROJECT_BENCH_PATH.joinpath("benchmarks.rs") 915 | 916 | model_architecture = json.load(open(FILE_PATH_MODEL_ARCHITECTURE, "r")) 917 | computational_graph = json.load(open(FILE_PATH_COMPUTATIONAL_GRAPH, "r")) 918 | model_weights = np.load(FILE_PATH_WEIGHTS, allow_pickle=True) 919 | 920 | traversal_order = topological_sort(computational_graph) 921 | 922 | nodes_dict = { 923 | layer_name: construct_node( 924 | layer_info=model_architecture[layer_name], 925 | layer_weights=get_weights_by_name(layer_name, model_weights), 926 | ) 927 | for layer_name in computational_graph 928 | } 929 | 930 | declare_build( 931 | traversal_order=traversal_order, 932 | nodes_dict=nodes_dict, 933 | file_path_build=FILE_PATH_BUILD, 934 | rsrc_path=PROJECT_RSRC_PATH, 935 | ) 936 | declare_model( 937 | traversal_order=traversal_order, 938 | nodes_dict=nodes_dict, 939 | computational_graph=computational_graph, 940 | model_architecture=model_architecture, 941 | file_path_model=FILE_PATH_MODEL, 942 | enable_inplace=ENABLE_INPLACE, 943 | enable_memory_drop=ENABLE_MEMDROP, 944 | ) 945 | declare_lib(file_path_lib=FILE_PATH_LIB) 946 | declare_cargo_toml(file_path_cargo_toml=FILE_PATH_CARGO_TOML) 947 | declare_thresholds(file_path_thresholds=FILE_PATH_THRESHOLDS) 948 | declare_bench(file_path_bench=FILE_PATH_BENCH) 949 | 950 | if FILE_PATH_FV is not None: 951 | prepare_fv( 952 | model_architecture=model_architecture, 953 | path_feature_vectors=FILE_PATH_FV, 954 | file_path_testdata=FILE_PATH_TESTDATA, 955 | ) 956 | 957 | print("#### The model was successfully converted into pure Rust code! ####") 958 | 959 | 960 | if __name__ == "__main__": 961 | convert_to_rust() 962 | -------------------------------------------------------------------------------- /tf2rust/utils/scoring_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from sklearn.metrics import confusion_matrix, roc_auc_score 4 | 5 | 6 | def compute_auc(y_true, y_pred): 7 | """ 8 | Computes the area under curve (auc) score. 9 | 10 | Args: 11 | y_true: the true labels. 12 | y_pred: the predicted probabilities. 13 | 14 | Returns: auc score. 15 | 16 | """ 17 | return roc_auc_score( 18 | np.asarray(y_true, dtype=np.float32), np.asarray(y_pred, dtype=np.float32) 19 | ) 20 | 21 | 22 | def auc(y_true, y_pred): 23 | """ 24 | TensorFlow wrapper function over compute_auc. 25 | 26 | Args: 27 | y_true: the true labels. 28 | y_pred: the predicted labels. 29 | 30 | Returns: a TensorFlow decorated auc function. 31 | 32 | """ 33 | 34 | return tf.py_function(compute_auc, (y_true, y_pred), tf.double) 35 | 36 | 37 | def true_pos_rate(y_expected, y_predicted): 38 | """ 39 | Computes the true positive rate (TPR). 40 | 41 | Args: 42 | y_expected: true labels. 43 | y_predicted: predicted probabilities. 44 | 45 | Returns: return tpr score. 46 | 47 | """ 48 | 49 | y_expected = np.argmax(y_expected, axis=1) 50 | y_predicted = np.argmax(y_predicted, axis=1) 51 | tn, fp, fn, tp = confusion_matrix(y_expected, y_predicted, labels=[0, 1]).ravel() 52 | 53 | if tp + fn == 0: 54 | return 1 / 2 55 | 56 | return float(tp / (tp + fn)) 57 | 58 | 59 | def tpr(y_true, y_pred): 60 | """ 61 | TensorFlow wrapper function over tpr. 62 | 63 | Args: 64 | y_expected: true labels. 65 | y_predicted: predicted probabilities. 66 | 67 | Returns: a TensorFlow decorated tpr function. 68 | 69 | """ 70 | 71 | return tf.py_function(true_pos_rate, (y_true, y_pred), tf.double) 72 | 73 | 74 | def true_neg_rate(y_expected, y_predicted): 75 | """ 76 | Computes the true negative rate (TNR). 77 | 78 | Args: 79 | y_expected: true labels. 80 | y_predicted: predicted probabilities. 81 | 82 | Returns: return tnr score. 83 | 84 | """ 85 | 86 | y_expected = np.argmax(y_expected, axis=1) 87 | y_predicted = np.argmax(y_predicted, axis=1) 88 | 89 | tn, fp, fn, tp = confusion_matrix(y_expected, y_predicted, labels=[0, 1]).ravel() 90 | 91 | if tn + fp == 0: 92 | return 1 / 2 93 | 94 | return float(tn / (tn + fp)) 95 | 96 | 97 | def tnr(y_true, y_pred): 98 | """ 99 | TensorFlow wrapper function over tnr. 100 | 101 | Args: 102 | y_expected: true labels. 103 | y_predicted: predicted probabilities. 104 | 105 | Returns: a TensorFlow decorated tnr function. 106 | 107 | """ 108 | 109 | return tf.py_function(true_neg_rate, (y_true, y_pred), tf.double) 110 | -------------------------------------------------------------------------------- /tf2rust/utils/surgeon/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrowdStrike/tf2rust/a11f63e5d461ca0a8f7092d338c1838f9484ac4d/tf2rust/utils/surgeon/__init__.py -------------------------------------------------------------------------------- /tf2rust/utils/surgeon/_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrowdStrike/tf2rust/a11f63e5d461ca0a8f7092d338c1838f9484ac4d/tf2rust/utils/surgeon/_utils/__init__.py -------------------------------------------------------------------------------- /tf2rust/utils/surgeon/_utils/layer.py: -------------------------------------------------------------------------------- 1 | def inbound_nodes(layer): 2 | return layer.inbound_nodes 3 | -------------------------------------------------------------------------------- /tf2rust/utils/surgeon/_utils/node.py: -------------------------------------------------------------------------------- 1 | from . import layer as layer_utils 2 | import collections.abc 3 | 4 | 5 | def make_list_if_not(x): 6 | if isinstance(x, collections.abc.Sequence) and not isinstance(x, str): 7 | return x 8 | else: 9 | return [x] 10 | 11 | 12 | def node_indices(node): 13 | return make_list_if_not(node.node_indices) 14 | 15 | 16 | def inbound_layers(node): 17 | return make_list_if_not(node.inbound_layers) 18 | 19 | 20 | def parent_nodes(node): 21 | try: 22 | return node.parent_nodes 23 | except AttributeError: 24 | return [layer_utils.inbound_nodes(inbound_layers(node)[i])[node_index] 25 | for i, node_index in enumerate(node_indices(node))] 26 | -------------------------------------------------------------------------------- /tf2rust/utils/surgeon/_utils/tensor_dict.py: -------------------------------------------------------------------------------- 1 | class TensorKeys(list): 2 | def __init__(self, refs): 3 | super().__init__(refs) 4 | 5 | def __contains__(self, item): 6 | try: 7 | return super().__contains__(item.ref()) 8 | except AttributeError: 9 | return super().__contains__(item.experimental_ref()) 10 | 11 | 12 | class TensorDict(dict): 13 | def __init__(self): 14 | super().__init__() 15 | # self.d = {} 16 | 17 | def __setitem__(self, key, value): 18 | try: 19 | super().__setitem__(key.ref(), value) 20 | except AttributeError: 21 | super().__setitem__(key.experimental_ref(), value) 22 | 23 | def __getitem__(self, item): 24 | try: 25 | return super().__getitem__(item.ref()) 26 | except AttributeError: 27 | return super().__getitem__(item.experimental_ref()) 28 | 29 | def keys(self): 30 | return TensorKeys(super().keys()) 31 | -------------------------------------------------------------------------------- /tf2rust/utils/surgeon/identify.py: -------------------------------------------------------------------------------- 1 | """Identify which channels to delete.""" 2 | import numpy as np 3 | from tensorflow.keras.models import Model 4 | 5 | from . import utils 6 | 7 | 8 | def get_apoz(model, layer, x_val, node_indices=None): 9 | """Identify neurons with high Average Percentage of Zeros (APoZ). 10 | 11 | The APoZ a.k.a. (A)verage (P)ercentage (o)f activations equal to (Z)ero, 12 | is a metric for the usefulness of a channel defined in this paper: 13 | "Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient 14 | Deep Architectures" - [Hu et al. (2016)][] 15 | `high_apoz()` enables the pruning methodology described in this paper to be 16 | replicated. 17 | 18 | If node_indices are not specified and the layer is shared within the model 19 | the APoZ will be calculated over all instances of the shared layer. 20 | 21 | Args: 22 | model: A Keras model. 23 | layer: The layer whose channels will be evaluated for pruning. 24 | x_val: The input of the validation set. This will be used to calculate 25 | the activations of the layer of interest. 26 | node_indices(list[int]): (optional) A list of node indices. 27 | 28 | Returns: 29 | List of the APoZ values for each channel in the layer. 30 | """ 31 | 32 | if isinstance(layer, str): 33 | layer = model.get_layer(name=layer) 34 | 35 | # Check that layer is in the model 36 | if layer not in model.layers: 37 | raise ValueError('layer is not a valid Layer in model.') 38 | 39 | layer_node_indices = utils.find_nodes_in_model(model, layer) 40 | # If no nodes are specified, all of the layer's inbound nodes which are 41 | # in model are selected. 42 | if not node_indices: 43 | node_indices = layer_node_indices 44 | # Check for duplicate node indices 45 | elif len(node_indices) != len(set(node_indices)): 46 | raise ValueError('`node_indices` contains duplicate values.') 47 | # Check that all of the selected nodes are in the layer 48 | elif not set(node_indices).issubset(layer_node_indices): 49 | raise ValueError('One or more nodes specified by `layer` and ' 50 | '`node_indices` are not in `model`.') 51 | 52 | data_format = getattr(layer, 'data_format', 'channels_last') 53 | # Perform the forward pass and get the activations of the layer. 54 | mean_calculator = utils.MeanCalculator(sum_axis=0) 55 | for node_index in node_indices: 56 | act_layer, act_index = utils.find_activation_layer(layer, node_index) 57 | # Get activations 58 | temp_model = Model(model.inputs, act_layer.get_output_at(act_index)) 59 | a = temp_model.predict(x_val) 60 | 61 | if data_format == 'channels_first': 62 | a = np.swapaxes(a, 1, -1) 63 | # Flatten all except channels axis 64 | activations = np.reshape(a, [-1, a.shape[-1]]) 65 | zeros = (activations == 0).astype(int) 66 | mean_calculator.add(zeros) 67 | 68 | return mean_calculator.calculate() 69 | 70 | 71 | def high_apoz(apoz, method="std", cutoff_std=1, cutoff_absolute=0.99): 72 | """ 73 | Args: 74 | apoz: List of the APoZ values for each channel in the layer. 75 | method: Cutoff method for high APoZ. "std", "absolute" or "both". 76 | cutoff_std: Channels with a higher APoZ than the layer mean plus 77 | `cutoff_std` standard deviations will be identified for pruning. 78 | cutoff_absolute: Channels with a higher APoZ than `cutoff_absolute` 79 | will be identified for pruning. 80 | 81 | Returns: 82 | high_apoz_channels: List of indices of channels with high APoZ. 83 | 84 | """ 85 | if method not in {'std', 'absolute', 'both'}: 86 | raise ValueError('Invalid `mode` argument. ' 87 | 'Expected one of {"std", "absolute", "both"} ' 88 | 'but got', method) 89 | if method == "std": 90 | cutoff = apoz.mean() + apoz.std()*cutoff_std 91 | elif method == 'absolute': 92 | cutoff = cutoff_absolute 93 | else: 94 | cutoff = min([cutoff_absolute, apoz.mean() + apoz.std()*cutoff_std]) 95 | 96 | cutoff = min(cutoff, 1) 97 | 98 | return np.where(apoz >= cutoff)[0] 99 | -------------------------------------------------------------------------------- /tf2rust/utils/surgeon/operations.py: -------------------------------------------------------------------------------- 1 | from .surgeon import Surgeon 2 | 3 | 4 | def delete_layer(model, layer, *, node_indices=None, copy=True): 5 | """Delete instances of a layer from a Keras model. 6 | 7 | Args: 8 | model: A Model. 9 | layer: A Layer contained in model. 10 | node_indices: The indices of the inbound_node to the layer instances to 11 | be deleted. 12 | copy: If True, the model will be copied before and after 13 | manipulation. This keeps both the old and new models' layers 14 | clean of each-others data-streams. 15 | 16 | Returns: 17 | Keras Model object with the layer at node_index deleted. 18 | """ 19 | surgeon = Surgeon(model, copy) 20 | surgeon.add_job('delete_layer', layer, node_indices=node_indices) 21 | return surgeon.operate() 22 | 23 | 24 | def insert_layer(model, layer, new_layer, *, node_indices=None, copy=True): 25 | """Insert new_layer before instances of layer. 26 | 27 | If node_indices is not specified. The layer will be inserted before all 28 | instances of the layer in the model. 29 | 30 | Args: 31 | model: A Model. 32 | layer: A Layer contained in model. 33 | new_layer: A layer to be inserted into model before layer. 34 | node_indices: The indices of the inbound_node to layer where the 35 | new layer is to be inserted. 36 | copy: If True, the model will be copied before and after 37 | manipulation. This keeps both the old and new models' layers 38 | clean of each-others data-streams. 39 | 40 | Returns: 41 | A new Model object with layer inserted. 42 | """ 43 | surgeon = Surgeon(model, copy) 44 | surgeon.add_job('insert_layer', layer, 45 | node_indices=node_indices, new_layer=new_layer) 46 | return surgeon.operate() 47 | 48 | 49 | def replace_layer(model, layer, new_layer, *, node_indices=None, copy=True): 50 | """Replace instances of layer with new_layer. 51 | 52 | If node_indices is not specified, all instances of layer will be 53 | replaced by instances of new_layer 54 | 55 | Args: 56 | model: A Model. 57 | layer: A Layer contained in model. 58 | new_layer: A layer to be inserted into model before layer. 59 | node_indices: The indices of the inbound_node to layer where the 60 | new layer is to be inserted. 61 | copy: If True, the model will be copied before and after 62 | manipulation. This keeps both the old and new models' layers 63 | clean of each-others data-streams. 64 | 65 | Returns: 66 | A new Model object with layer inserted. 67 | """ 68 | surgeon = Surgeon(model, copy) 69 | surgeon.add_job('replace_layer', layer, 70 | node_indices=node_indices, new_layer=new_layer) 71 | return surgeon.operate() 72 | 73 | 74 | def delete_channels(model, layer, channels, *, node_indices=None, copy=None): 75 | """Delete channels from instances of the specified layer. 76 | 77 | This method is designed to facilitate research into pruning networks to 78 | improve their prediction performance and/or reduce computational load by 79 | deleting channels. 80 | All weights associated with the deleted channels in the specified layer 81 | and any affected downstream layers are deleted. 82 | If the layer is shared and node_indices is set, channels will be deleted 83 | from the corresponding layer instances only. This will break the weight 84 | sharing between affected and unaffected instances in subsequent training. 85 | In this case affected instances will be renamed. 86 | 87 | 88 | Args: 89 | model: Model object. 90 | layer: Layer whose channels are to be deleted. 91 | channels: Indices of the channels to be deleted 92 | node_indices: Indices of the nodes where channels are to be deleted. 93 | copy: If True, the model will be copied before and after 94 | manipulation. This keeps both the old and new models' layers 95 | clean of each-others data-streams. 96 | 97 | Returns: 98 | A new Model with the specified channels and associated weights deleted. 99 | 100 | Notes: 101 | Channels are filters in conv layers and units in other layers. 102 | """ 103 | surgeon = Surgeon(model, copy) 104 | surgeon.add_job('delete_channels', layer, node_indices=node_indices, channels=channels) 105 | return surgeon.operate() 106 | -------------------------------------------------------------------------------- /tf2rust/utils/surgeon/surgeon.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.keras.layers import BatchNormalization 6 | from tensorflow.keras.models import Model 7 | 8 | from . import utils 9 | from ._utils.tensor_dict import TensorDict 10 | from ._utils import node as node_utils 11 | 12 | # Set up logging 13 | logging.basicConfig(level=logging.INFO) 14 | 15 | 16 | class Surgeon: 17 | """Performs network surgery on a model. 18 | 19 | Surgeons can perform multiple network surgeries (jobs) at once. 20 | This is much faster than performing them sequentially. 21 | See `add_jobs` for a list of valid jobs and their required keyword arguments. 22 | 23 | Examples: 24 | Delete some channels from layer_1 and layer_2: 25 | surgeon = Surgeon(model) 26 | surgeon.add_job('delete_channels', layer_1, channels_1) 27 | surgeon.add_job('delete_channels', layer_2, channels_2) 28 | new_model = surgeon.operate() 29 | 30 | Arguments: 31 | model: The model to be modified 32 | copy: If True, the model will be copied before and after any operations 33 | This keeps the layers in the original model and the new model separate. 34 | """ 35 | def __init__(self, model, copy=None): 36 | if copy: 37 | self.model = utils.clean_copy(model) 38 | else: 39 | self.model = model 40 | self.nodes = [] 41 | self._copy = copy 42 | self._finished_nodes = {} 43 | self._replace_tensors = TensorDict() 44 | self._channels_map = {} 45 | self._new_layers_map = {} 46 | self._insert_layers_map = {} 47 | self._replace_layers_map = {} 48 | self._mod_func_map = {} 49 | self._kwargs_map = {} 50 | self.valid_jobs = ('delete_layer', 51 | 'insert_layer', 52 | 'replace_layer', 53 | 'delete_channels') 54 | 55 | def add_job(self, job, layer, *, 56 | channels=None, new_layer=None, node_indices=None): 57 | """Adds a job for the Surgeon to perform on the model. 58 | 59 | Job options are: 60 | 'delete_layer': delete `layer` from the model 61 | required keyword arguments: None 62 | 'insert_layer': insert `new_layer` before `layer` 63 | required keyword arguments: `new_layer` 64 | 'replace_layer': replace `layer` with `new_layer` 65 | required keyword arguments: `new_layer` 66 | 'delete_channels': delete `channels` from `layer` 67 | required keyword arguments: `channels` 68 | 69 | Jobs can be added in any order. They will be performed in order of 70 | decreasing network depth. 71 | A maximum of one job can be performed per node. 72 | 73 | Args: 74 | job(string): job identifier. One of `Surgeon.valid_jobs`. 75 | layer(Layer): A layer from `model` to be modified. 76 | channels(list[int]): A list of channels used for the job. 77 | Used in `delete_channels`. 78 | new_layer(Layer): A new layer used for the job. Used in 79 | `insert_layer` and `replace_layer`. 80 | node_indices(list[int]): (optional) A list of node indices used to 81 | selectively apply the job to a subset of 82 | the layer's nodes. Nodes are selected with: 83 | node[i] = layer.inbound_nodes[node_indices[i]] 84 | """ 85 | # If the model has been copied, identify `layer` in the copied model. 86 | if self._copy: 87 | layer = self.model.get_layer(layer.name) 88 | # Check that layer is in the model 89 | if layer not in self.model.layers: 90 | raise ValueError('layer is not a valid Layer in model.') 91 | 92 | layer_node_indices = utils.find_nodes_in_model(self.model, layer) 93 | # If no nodes are specified, all of the layer's inbound nodes which are 94 | # in model are selected. 95 | if not node_indices: 96 | node_indices = layer_node_indices 97 | # Check for duplicate node indices 98 | elif len(node_indices) != len(set(node_indices)): 99 | raise ValueError('`node_indices` contains duplicate values.') 100 | # Check that all of the selected nodes are in the layer 101 | elif not set(node_indices).issubset(layer_node_indices): 102 | raise ValueError('One or more nodes specified by `layer` and ' 103 | '`node_indices` are not in `model`.') 104 | 105 | # Select the modification function and any keyword arguments. 106 | kwargs = {} 107 | if job == 'delete_channels': 108 | # If not all inbound_nodes are selected, the new layer is renamed 109 | # to avoid duplicate layer names. 110 | if set(node_indices) != set(layer_node_indices): 111 | kwargs['layer_name'] = layer.name + '_' + job 112 | kwargs['channels'] = channels 113 | mod_func = self._delete_channels 114 | 115 | elif job == 'delete_layer': 116 | mod_func = self._delete_layer 117 | 118 | elif job == 'insert_layer': 119 | kwargs['new_layer'] = new_layer 120 | mod_func = self._insert_layer 121 | 122 | elif job == 'replace_layer': 123 | kwargs['new_layer'] = new_layer 124 | mod_func = self._replace_layer 125 | 126 | else: 127 | raise ValueError(job + ' is not a recognised job. Valid jobs ' 128 | 'are:\n-', '\n- '.join(self.valid_jobs)) 129 | 130 | # Get nodes to be operated on for this job 131 | job_nodes = [] 132 | for node_index in node_indices: 133 | job_nodes.append(layer.inbound_nodes[node_index]) 134 | # Check that the nodes do not already have jobs assigned to them. 135 | if set(job_nodes).intersection(self.nodes): 136 | raise ValueError('Cannot apply several jobs to the same node.') 137 | 138 | # Add the modification function and keyword arguments to the 139 | # self._mod_func_map and _kwargs_map dictionaries for later retrieval. 140 | for node in job_nodes: 141 | self._mod_func_map[node] = mod_func 142 | self._kwargs_map[node] = kwargs 143 | self.nodes.extend(job_nodes) 144 | 145 | def operate(self): 146 | """Perform all jobs assigned to the surgeon. 147 | """ 148 | # Operate on each node in self.nodes by order of decreasing depth. 149 | sorted_nodes = sorted(self.nodes, reverse=True, 150 | key=lambda x: utils.get_node_depth(self.model, x)) 151 | for node in sorted_nodes: 152 | # Rebuild submodel up to this node 153 | sub_output_nodes = node_utils.parent_nodes(node) 154 | outputs, output_masks = self._rebuild_graph(self.model.inputs, 155 | sub_output_nodes) 156 | 157 | # Perform surgery at this node 158 | kwargs = self._kwargs_map[node] 159 | self._mod_func_map[node](node, outputs, output_masks, **kwargs) 160 | 161 | # Finish rebuilding model 162 | output_nodes = [] 163 | for output in self.model.outputs: 164 | layer, node_index, tensor_index = output._keras_history 165 | output_nodes.append(layer.inbound_nodes[node_index]) 166 | new_outputs, _ = self._rebuild_graph(self.model.inputs, output_nodes) 167 | new_model = Model(self.model.inputs, new_outputs) 168 | 169 | if self._copy: 170 | return utils.clean_copy(new_model) 171 | else: 172 | return new_model 173 | 174 | def _rebuild_graph(self, 175 | graph_inputs, 176 | output_nodes, 177 | graph_input_masks=None): 178 | """Rebuild the graph from graph_inputs to output_nodes. 179 | 180 | This does not return a model object, it re-creates the connections 181 | between layers and returns the output tensors and masks of the submodel 182 | This is a building block for the higher level surgery methods. 183 | See `Surgeon.operate` for details of how this method is used. 184 | 185 | Arguments: 186 | graph_inputs: List of the submodel's input tensor(s). 187 | output_nodes(list[Node]): List of the submodel's output node(s) 188 | graph_input_masks: Boolean mask for each submodel input. 189 | 190 | Returns: 191 | (tuple) containing : 192 | List of the output tensors of the rebuilt submodel 193 | List of the output masks of the rebuilt submodel 194 | tuple[submodel output tensors, output masks] 195 | 196 | """ 197 | if not graph_input_masks: 198 | graph_input_masks = [None] * len(graph_inputs) 199 | 200 | def _rebuild_rec(node): 201 | """Rebuild the graph up to `node` recursively. 202 | 203 | Args: 204 | node(Node): Node to rebuild up to. 205 | Returns: 206 | (tuple) containing : 207 | The output tensor of the rebuilt `node` 208 | The output mask of the rebuilt `node` 209 | 210 | """ 211 | layer = node.outbound_layer 212 | logging.debug('getting inputs for: {0}'.format(layer.name)) 213 | node_output = utils.single_element(node.output_tensors) 214 | # First check for conditions to bottom out the recursion 215 | # Check for replaced tensors before any other checks: 216 | # these are created by the surgery methods. 217 | if node_output in self._replace_tensors.keys(): 218 | logging.debug('bottomed out at replaced output: {0}'.format( 219 | node_output)) 220 | output, output_mask = self._replace_tensors[node_output] 221 | return output, output_mask 222 | # Next check if the current node has already been rebuilt. 223 | elif node in self._finished_nodes.keys(): 224 | logging.debug('reached finished node: {0}'.format(node)) 225 | return self._finished_nodes[node] 226 | # Next check if one of the graph_inputs has been reached. 227 | mask_map = TensorDict() 228 | for input, mask in zip(graph_inputs, graph_input_masks): 229 | mask_map[input] = mask 230 | 231 | try: 232 | output_mask = mask_map[node_output] 233 | logging.debug('bottomed out at a model input') 234 | return node_output, output_mask 235 | except KeyError: 236 | # Otherwise recursively call this method on the inbound nodes. 237 | inbound_nodes = node_utils.parent_nodes(node) 238 | logging.debug('inbound_layers: {0}'.format( 239 | [node.outbound_layer.name for node in inbound_nodes])) 240 | # Recursively rebuild the model up to `node`s inbound nodes to 241 | # obtain its inputs and input masks 242 | inputs, input_masks = zip( 243 | *[_rebuild_rec(n) for n in inbound_nodes]) 244 | 245 | if all(i is None for i in inputs): 246 | output = None 247 | try: 248 | assert len(node.output_tensors) <= 1 249 | except AssertionError as e: 250 | raise e 251 | except: 252 | pass 253 | 254 | output_mask = np.zeros(node.output_tensors.shape[1:], dtype=bool) 255 | elif any(i is None for i in inputs): 256 | if node.outbound_layer.__class__.__name__ != 'Concatenate': 257 | TypeError('Inputs can only be missing for concatenate layers.') 258 | # remove Nones from inputs list 259 | inputs = [i for i in inputs if i is not None] 260 | new_layer, output_mask = self._apply_delete_mask(node, input_masks) 261 | if len(inputs) == 1: 262 | output = utils.single_element(list(inputs)) 263 | else: 264 | output = new_layer(utils.single_element(list(inputs))) 265 | else: 266 | new_layer, output_mask = self._apply_delete_mask(node, input_masks) 267 | output = new_layer(utils.single_element(list(inputs))) 268 | 269 | # Record that this node has been rebuild 270 | self._finished_nodes[node] = (output, output_mask) 271 | logging.debug('layer complete: {0}'.format(layer.name)) 272 | return output, output_mask 273 | 274 | # Call the recursive _rebuild_rec method to rebuild the submodel up to 275 | # each output layer 276 | outputs, output_masks = zip(*[_rebuild_rec(n) for n in output_nodes]) 277 | return utils.single_element(outputs), output_masks 278 | 279 | def _delete_layer(self, node, inputs, input_masks): 280 | """Skip adding node.outbound_layer when building the graph.""" 281 | # Skip the deleted layer by replacing its outputs with it inputs 282 | # if not isinstance(inputs, tf.Tensor) and len(inputs) >= 2: 283 | # raise ValueError('Cannot insert new layer at node with multiple ' 284 | # 'inbound layers.') 285 | inputs = utils.single_element(inputs) 286 | input_masks = utils.single_element(input_masks) 287 | deleted_layer_output = utils.single_element(node.output_tensors) 288 | self._replace_tensors[deleted_layer_output] = (inputs, input_masks) 289 | 290 | def _insert_layer(self, node, inputs, input_masks, new_layer=None): 291 | """Insert new_layer into the graph before node.outbound_layer.""" 292 | # This will not work for nodes with multiple inbound layers 293 | if not isinstance(inputs, tf.Tensor) and len(inputs) >= 2: 294 | raise ValueError('Cannot insert new layer at node with multiple ' 295 | 'inbound layers.') 296 | # Call the new layer on the inbound layer's output 297 | new_output = new_layer(utils.single_element(inputs)) 298 | # Replace the inbound layer's output with the new layer's output 299 | old_output = utils.get_one_tensor(node.input_tensors) 300 | input_masks = utils.single_element(input_masks) 301 | self._replace_tensors[old_output] = (new_output, input_masks) 302 | 303 | def _replace_layer(self, node, inputs, input_masks, new_layer=None): 304 | """Replace node.outbound_layer with new_layer. Add it to the graph.""" 305 | # Call the new layer on the rebuild submodel's inputs 306 | new_output = new_layer(utils.single_element(inputs)) 307 | 308 | # Replace the original layer's output with the new layer's output 309 | replaced_layer_output = utils.single_element(node.output_tensors) 310 | input_masks = utils.single_element(input_masks) 311 | self._replace_tensors[replaced_layer_output] = (new_output, input_masks) 312 | 313 | def _delete_channels(self, node, inputs, input_masks, channels=None, layer_name=None): 314 | """Delete selected channels of node.outbound_layer. Add it to the graph. 315 | """ 316 | old_layer = node.outbound_layer 317 | old_layer_output = utils.single_element(node.output_tensors) 318 | # Create a mask to propagate the deleted channels to downstream layers 319 | new_delete_mask = self._make_delete_mask(old_layer, channels) 320 | 321 | if len(set(channels)) == getattr(old_layer, utils.get_channels_attr(old_layer)): 322 | self._replace_tensors[old_layer_output] = (None, new_delete_mask) 323 | return None 324 | 325 | # If this layer has already been operated on, use the cached copy of 326 | # the new layer. Otherwise, apply the inbound delete mask and 327 | # delete channels to obtain the new layer 328 | if old_layer in self._new_layers_map.keys(): 329 | new_layer = self._new_layers_map[old_layer] 330 | else: 331 | temp_layer, new_mask = self._apply_delete_mask(node, input_masks) 332 | # This call is needed to initialise input_shape and output_shape 333 | temp_layer(utils.single_element(inputs)) 334 | new_layer = self._delete_channel_weights(temp_layer, channels) 335 | if layer_name: 336 | new_layer.name = layer_name 337 | self._new_layers_map[old_layer] = new_layer 338 | new_output = new_layer(utils.single_element(inputs)) 339 | # Replace the original layer's output with the modified layer's output 340 | self._replace_tensors[old_layer_output] = (new_output, new_delete_mask) 341 | 342 | def _apply_delete_mask(self, node, inbound_masks): 343 | """Apply the inbound delete mask and return the outbound delete mask 344 | 345 | When specific channels in a layer or layer instance are deleted, the 346 | mask propagates information about which channels are affected to 347 | downstream layers. 348 | If the layer contains weights, those which were previously connected 349 | to the deleted channels are deleted and outbound masks are set to None 350 | since further downstream layers aren't affected. 351 | If the layer does not contain weights, its output mask is calculated to 352 | reflect any transformations performed by the layer to ensure that 353 | information about the deleted channels is propagated downstream. 354 | 355 | 356 | Arguments: 357 | node(Node): The node where the delete mask is applied. 358 | inbound_masks: Mask(s) from inbound node(s). 359 | 360 | Returns: 361 | new_layer: Pass through `layer` if it has no weights, otherwise a 362 | new `Layer` object with weights corresponding to the 363 | inbound mask deleted. 364 | outbound_mask: Mask corresponding to `new_layer`. 365 | """ 366 | 367 | # if delete_mask is None or all values are True, it does not affect 368 | # this layer or any layers above/downstream from it 369 | layer = node.outbound_layer 370 | new_layer = layer 371 | outbound_mask = None 372 | 373 | if all(mask is None for mask in inbound_masks): 374 | return new_layer, outbound_mask 375 | 376 | # If one or more of the masks are None, replace them with ones. 377 | if any(mask is None for mask in inbound_masks): 378 | inbound_masks = [np.ones(shape[1:], dtype=bool) 379 | if inbound_masks[i] is None else inbound_masks[i] 380 | for i, shape in enumerate(node.input_shapes)] 381 | 382 | # If the layer is shared and has already been affected by this 383 | # operation, use the cached new layer. 384 | if len(layer.inbound_nodes) > 1 \ 385 | and layer in self._replace_layers_map.keys(): 386 | return self._replace_layers_map[layer] 387 | 388 | output_shape = utils.single_element(node.output_shapes) 389 | input_shape = utils.single_element(node.input_shapes) 390 | data_format = getattr(layer, 'data_format', 'channels_last') 391 | inbound_masks = utils.single_element(inbound_masks) 392 | # otherwise, delete_mask.shape should be: layer.input_shape[1:] 393 | layer_class = layer.__class__.__name__ 394 | if layer_class == 'InputLayer': 395 | raise RuntimeError('This should never get here!') 396 | 397 | elif layer_class == 'Dense': 398 | if np.all(inbound_masks): 399 | new_layer = layer 400 | else: 401 | weights = layer.get_weights() 402 | weights[0] = weights[0][np.where(inbound_masks)[0], :] 403 | config = layer.get_config() 404 | config['weights'] = weights 405 | new_layer = type(layer).from_config(config) 406 | outbound_mask = None 407 | 408 | elif layer_class == 'Flatten': 409 | outbound_mask = np.reshape(inbound_masks, [-1, ]) 410 | new_layer = layer 411 | 412 | elif layer_class in ('Conv1D', 'Conv2D', 'Conv3D'): 413 | if np.all(inbound_masks): 414 | new_layer = layer 415 | else: 416 | if data_format == 'channels_first': 417 | inbound_masks = np.swapaxes(inbound_masks, 0, -1) 418 | # Conv layer: trim down inbound_masks to filter shape 419 | k_size = layer.kernel_size 420 | index = [slice(None, 1, None) for _ in k_size] 421 | inbound_masks = inbound_masks[tuple(index + [slice(None)])] 422 | weights = layer.get_weights() 423 | # Delete unused weights to obtain new_weights 424 | # Each deleted channel was connected to all of the channels 425 | # in layer; therefore, the mask must be repeated for each 426 | # channel. 427 | # `delete_mask`'s size: size(weights[0]) 428 | delete_mask = np.tile(inbound_masks[..., np.newaxis], list(k_size) + [1, weights[0].shape[-1]]) 429 | new_shape = list(weights[0].shape) 430 | new_shape[-2] = -1 # Weights always have channels_last 431 | weights[0] = np.reshape(weights[0][delete_mask], new_shape) 432 | # Instantiate new layer with new_weights 433 | config = layer.get_config() 434 | config['weights'] = weights 435 | new_layer = type(layer).from_config(config) 436 | outbound_mask = None 437 | 438 | elif layer_class in ('Cropping1D', 'Cropping2D', 'Cropping3D', 439 | 'MaxPooling1D', 'MaxPooling2D', 440 | 'MaxPooling3D', 441 | 'AveragePooling1D', 'AveragePooling2D', 442 | 'AveragePooling3D'): 443 | if output_shape is None: 444 | outbound_mask = None 445 | new_layer = layer 446 | else: 447 | index = [slice(None, x, None) for x in output_shape[1:]] 448 | if data_format == 'channels_first': 449 | index[0] = slice(None) 450 | elif data_format == 'channels_last': 451 | index[-1] = slice(None) 452 | else: 453 | raise ValueError('Invalid data format') 454 | outbound_mask = inbound_masks[tuple(index)] 455 | new_layer = layer 456 | 457 | elif layer_class in ('UpSampling1D', 458 | 'UpSampling2D', 459 | 'UpSampling3D', 460 | 'ZeroPadding1D', 461 | 'ZeroPadding2D', 462 | 'ZeroPadding3D'): 463 | 464 | # Get slice of mask with all singleton dimensions except 465 | # channels dimension 466 | index = [slice(1)] * (len(input_shape) - 1) 467 | tile_shape = list(output_shape[1:]) 468 | if data_format == 'channels_first': 469 | index[0] = slice(None) 470 | tile_shape[0] = 1 471 | elif data_format == 'channels_last': 472 | index[-1] = slice(None) 473 | tile_shape[-1] = 1 474 | else: 475 | raise ValueError('Invalid data format') 476 | channels_vector = inbound_masks[tuple(index)] 477 | # Tile this slice to create the outbound mask 478 | outbound_mask = np.tile(channels_vector, tile_shape) 479 | new_layer = layer 480 | 481 | elif layer_class in ('GlobalMaxPooling1D', 482 | 'GlobalMaxPooling2D', 483 | 'GlobalAveragePooling1D', 484 | 'GlobalAveragePooling2D'): 485 | # Get slice of mask with all singleton dimensions except 486 | # channels dimension 487 | index = [0] * (len(input_shape) - 1) 488 | if data_format == 'channels_first': 489 | index[0] = slice(None) 490 | elif data_format == 'channels_last': 491 | index[-1] = slice(None) 492 | else: 493 | raise ValueError('Invalid data format') 494 | channels_vector = inbound_masks[tuple(index)] 495 | # Tile this slice to create the outbound mask 496 | outbound_mask = channels_vector 497 | new_layer = layer 498 | 499 | elif layer_class in ('Dropout', 500 | 'Activation', 501 | 'SpatialDropout1D', 502 | 'SpatialDropout2D', 503 | 'SpatialDropout3D', 504 | 'ActivityRegularization', 505 | 'Masking', 506 | 'LeakyReLU', 507 | 'ELU', 508 | 'ThresholdedReLU', 509 | 'GaussianNoise', 510 | 'GaussianDropout', 511 | 'AlphaDropout'): 512 | # Pass-through layers 513 | outbound_mask = inbound_masks 514 | new_layer = layer 515 | 516 | elif layer_class == 'Reshape': 517 | outbound_mask = np.reshape(inbound_masks, layer.target_shape) 518 | new_layer = layer 519 | 520 | elif layer_class == 'Permute': 521 | outbound_mask = np.transpose(inbound_masks, 522 | [x-1 for x in layer.dims]) 523 | new_layer = layer 524 | 525 | elif layer_class == 'RepeatVector': 526 | outbound_mask = np.repeat( 527 | np.expand_dims(inbound_masks, 0), 528 | layer.n, 529 | axis=0) 530 | new_layer = layer 531 | 532 | elif layer_class == 'Embedding': 533 | # Embedding will always be the first layer so it doesn't need 534 | # to consider the inbound_delete_mask 535 | if inbound_masks is not None: 536 | raise ValueError('Channels cannot be deleted bedore Embedding ' 537 | 'layers because they change the number of ' 538 | 'channels.') 539 | outbound_mask = None 540 | new_layer = layer 541 | 542 | elif layer_class in ('Add', 'Multiply', 'Average', 'Maximum'): 543 | # The inputs must be the same size 544 | if not utils.all_equal(inbound_masks): 545 | ValueError( 546 | '{0} layers must have the same size inputs. All ' 547 | 'inbound nodes must have the same channels deleted' 548 | .format(layer_class)) 549 | outbound_mask = inbound_masks[1] 550 | new_layer = layer 551 | 552 | elif layer_class == 'Concatenate': 553 | axis = layer.axis 554 | if layer.axis < 0: 555 | axis = axis % len(layer.input_shape[0]) 556 | # Below: axis=axis-1 because the mask excludes the batch dimension 557 | outbound_mask = np.concatenate(inbound_masks, axis=axis-1) 558 | new_layer = layer 559 | 560 | elif layer_class in ('SimpleRNN', 'GRU', 'LSTM'): 561 | if np.all(inbound_masks): 562 | new_layer = layer 563 | else: 564 | weights = layer.get_weights() 565 | weights[0] = weights[0][np.where(inbound_masks[0, :])[0], :] 566 | config = layer.get_config() 567 | config['weights'] = weights 568 | new_layer = type(layer).from_config(config) 569 | outbound_mask = None 570 | 571 | elif layer_class == 'BatchNormalization': 572 | outbound_mask = inbound_masks 573 | # Get slice of mask with all singleton dimensions except 574 | # channels dimension 575 | index = [0] * (len(input_shape)) 576 | assert len(layer.axis) == 1 577 | index[layer.axis[0]] = slice(None) 578 | index = index[1:] 579 | channel_indices = np.where(inbound_masks[tuple(index)] == False)[0] 580 | weights = [np.delete(w, channel_indices, axis=-1) 581 | for w in layer.get_weights()] 582 | new_layer = BatchNormalization.from_config( 583 | layer.get_config()) 584 | new_input_shape = list(input_shape) 585 | assert len(new_layer.axis) == 1 586 | new_input_shape[new_layer.axis[0]] -= len(channel_indices) 587 | new_layer.build(new_input_shape) 588 | new_layer.set_weights(weights) 589 | 590 | else: 591 | # Not implemented: 592 | # - Lambda 593 | # - SeparableConv2D 594 | # - Conv2DTranspose 595 | # - LocallyConnected1D 596 | # - LocallyConnected2D 597 | # - TimeDistributed 598 | # - Bidirectional 599 | # - Dot 600 | # - PReLU 601 | # Warning/error needed for Reshape if channels axis is split 602 | raise ValueError('"{0}" layers are currently ' 603 | 'unsupported.'.format(layer_class)) 604 | 605 | if len(layer.inbound_nodes) > 1 and new_layer != layer: 606 | self._replace_layers_map[layer] = (new_layer, outbound_mask) 607 | 608 | return new_layer, outbound_mask 609 | 610 | def _delete_channel_weights(self, layer, channel_indices): 611 | """Delete channels from layer and remove the corresponding weights. 612 | 613 | Arguments: 614 | layer: A layer whose channels are to be deleted 615 | channel_indices: The indices of the channels to be deleted. 616 | 617 | Returns: 618 | A new layer with the channels and corresponding weights deleted. 619 | """ 620 | layer_config = layer.get_config() 621 | channels_attr = utils.get_channels_attr(layer) 622 | channel_count = layer_config[channels_attr] 623 | # Check inputs 624 | if any([i + 1 > channel_count for i in channel_indices]): 625 | raise ValueError('Channels_index value(s) out of range. ' 626 | 'This layer only has {0} channels.' 627 | .format(channel_count)) 628 | print('Deleting {0}/{1} channels from layer: {2}'.format( 629 | len(channel_indices), channel_count, layer.name)) 630 | # numpy.delete ignores negative indices in lists: wrap indices 631 | channel_indices = [i % channel_count for i in channel_indices] 632 | 633 | # Reduce layer channel count in config. 634 | layer_config[channels_attr] -= len(channel_indices) 635 | 636 | # Delete weights corresponding to deleted channels from config. 637 | # Except for recurrent layers, the weights' channels dimension is last. 638 | # Each recurrent layer type has a different internal weights layout. 639 | if layer.__class__.__name__ == 'SimpleRNN': 640 | weights = [np.delete(w, channel_indices, axis=-1) 641 | for w in layer.get_weights()] 642 | weights[1] = np.delete(weights[1], channel_indices, axis=0) 643 | elif layer.__class__.__name__ == 'GRU': 644 | # Repeat the channel indices for all internal GRU weights. 645 | channel_indices_gru = [layer.units * m + i for m in range(3) 646 | for i in channel_indices] 647 | weights = [np.delete(w, channel_indices_gru, axis=-1) 648 | for w in layer.get_weights()] 649 | weights[1] = np.delete(weights[1], channel_indices, axis=0) 650 | elif layer.__class__.__name__ == 'LSTM': 651 | # Repeat the channel indices for all interal LSTM weights. 652 | channel_indices_lstm = [layer.units * m + i for m in range(4) 653 | for i in channel_indices] 654 | weights = [np.delete(w, channel_indices_lstm, axis=-1) 655 | for w in layer.get_weights()] 656 | weights[1] = np.delete(weights[1], channel_indices, axis=0) 657 | else: 658 | weights = [np.delete(w, channel_indices, axis=-1) 659 | for w in layer.get_weights()] 660 | layer_config['weights'] = weights 661 | 662 | # Create new layer from the modified configuration and return it. 663 | return type(layer).from_config(layer_config) 664 | 665 | def _make_delete_mask(self, layer, channel_indices): 666 | """Make the boolean delete mask for layer's output deleting channels. 667 | 668 | The mask is used to remove the weights of the downstream layers which 669 | were connected to channels which have been deleted in this layer. 670 | The mask is a boolean array with the same size as the layer output 671 | excluding the first (batch) dimension. 672 | All elements of the mask corresponding to the removed channels are set 673 | to False. Other elements are set to True. 674 | 675 | Arguments: 676 | layer: A layer 677 | channel_indices: The indices of the channels to be deleted. 678 | 679 | Returns: 680 | A Numpy array of booleans of the same size as the output of layer 681 | excluding the batch dimension. 682 | """ 683 | data_format = getattr(layer, 'data_format', 'channels_last') 684 | new_delete_mask = np.ones(layer.output_shape[1:], dtype=bool) 685 | if data_format == 'channels_first': 686 | new_delete_mask[channel_indices, ...] = False 687 | elif data_format == 'channels_last': 688 | new_delete_mask[..., channel_indices] = False 689 | else: 690 | ValueError('Invalid data_format property value') 691 | return new_delete_mask 692 | -------------------------------------------------------------------------------- /tf2rust/utils/surgeon/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities used across other modules.""" 2 | import numpy as np 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras.activations import linear 5 | from tensorflow.python.keras.engine import keras_tensor 6 | import tensorflow as tf 7 | from ._utils import node as node_utils 8 | 9 | 10 | def clean_copy(model): 11 | """Returns a copy of the model without other model uses of its layers.""" 12 | weights = model.get_weights() 13 | new_model = model.__class__.from_config(model.get_config()) 14 | new_model.set_weights(weights) 15 | return new_model 16 | 17 | 18 | def get_channels_attr(layer): 19 | layer_config = layer.get_config() 20 | if 'units' in layer_config.keys(): 21 | channels_attr = 'units' 22 | elif 'filters' in layer_config.keys(): 23 | channels_attr = 'filters' 24 | else: 25 | raise ValueError('This layer has not got any channels.') 26 | return channels_attr 27 | 28 | 29 | def get_node_depth(model, node): 30 | """Get the depth of a node in a model. 31 | 32 | Arguments: 33 | model: Keras Model object 34 | node: Keras Node object 35 | 36 | Returns: 37 | The node depth as an integer. The model outputs are at depth 0. 38 | 39 | Raises: 40 | KeyError: if the node is not contained in the model. 41 | """ 42 | for (depth, nodes_at_depth) in model._nodes_by_depth.items(): 43 | if node in nodes_at_depth: 44 | return depth 45 | raise KeyError('The node is not contained in the model.') 46 | 47 | 48 | def check_for_layer_reuse(model, layers=None): 49 | """Returns True if any layers are reused, False if not.""" 50 | if layers is None: 51 | layers = model.layers 52 | return any([len(l.inbound_nodes) > 1 for l in layers]) 53 | 54 | 55 | def find_nodes_in_model(model, layer): 56 | """Find the indices of layer's inbound nodes which are in model""" 57 | model_nodes = get_model_nodes(model) 58 | node_indices = [] 59 | for i, node in enumerate(layer.inbound_nodes): 60 | if node in model_nodes: 61 | node_indices.append(i) 62 | return node_indices 63 | 64 | 65 | def check_nodes_in_model(model, nodes): 66 | """Check if nodes are in model""" 67 | model_nodes = get_model_nodes(model) 68 | nodes_in_model = [False] * len(nodes) 69 | for i, node in enumerate(nodes): 70 | if node in model_nodes: 71 | nodes_in_model[i] = True 72 | return nodes_in_model 73 | 74 | 75 | def get_model_nodes(model): 76 | """Return all nodes in the model""" 77 | return [node for v in model._nodes_by_depth.values() for node in v] 78 | 79 | 80 | def get_shallower_nodes(node): 81 | possible_nodes = node.outbound_layer.outbound_nodes 82 | next_nodes = [] 83 | for n in possible_nodes: 84 | if node in node_utils.parent_nodes(n): 85 | next_nodes.append(n) 86 | return next_nodes 87 | 88 | 89 | def get_node_index(node): 90 | for i, n in enumerate(node.outbound_layer.inbound_nodes): 91 | if node == n: 92 | return i 93 | 94 | 95 | def find_activation_layer(layer, node_index): 96 | """ 97 | Args: 98 | layer(Layer): 99 | node_index: 100 | """ 101 | output_shape = layer.get_output_shape_at(node_index) 102 | maybe_layer = layer 103 | node = maybe_layer.inbound_nodes[node_index] 104 | # Loop will be broken by an error if an output layer is encountered 105 | while True: 106 | # If maybe_layer has a nonlinear activation function return it and its index 107 | activation = getattr(maybe_layer, 'activation', linear) 108 | if activation.__name__ != 'linear': 109 | if maybe_layer.get_output_shape_at(node_index) != output_shape: 110 | ValueError('The activation layer ({0}), does not have the same' 111 | ' output shape as {1}'.format(maybe_layer.name, 112 | layer.name)) 113 | return maybe_layer, node_index 114 | 115 | # If not, move to the next layer in the datastream 116 | next_nodes = get_shallower_nodes(node) 117 | # test if node is a list of nodes with more than one item 118 | if len(next_nodes) > 1: 119 | ValueError('The model must not branch between the chosen layer' 120 | ' and the activation layer.') 121 | node = next_nodes[0] 122 | node_index = get_node_index(node) 123 | maybe_layer = node.outbound_layer 124 | 125 | # Check if maybe_layer has weights, no activation layer has been found 126 | if maybe_layer.weights and ( 127 | not maybe_layer.__class__.__name__.startswith('Global')): 128 | AttributeError('There is no nonlinear activation layer between {0}' 129 | ' and {1}'.format(layer.name, maybe_layer.name)) 130 | 131 | 132 | def sort_x_by_y(x, y): 133 | """Sort the iterable x by the order of iterable y""" 134 | x = [x for (_, x) in sorted(zip(y, x))] 135 | return x 136 | 137 | 138 | def single_element(x): 139 | """If x contains a single element, return it; otherwise return x""" 140 | if isinstance(x, (tf.Tensor, keras_tensor.KerasTensor)): 141 | return x 142 | 143 | if isinstance(x, list): 144 | if len(x) == 1: 145 | return x[0] 146 | 147 | if isinstance(x, tuple): 148 | return x[0] 149 | 150 | return x 151 | 152 | 153 | def get_one_tensor(x): 154 | if isinstance(x, (tf.Tensor, keras_tensor.KerasTensor)): 155 | return x 156 | 157 | if isinstance(x, list): 158 | if len(x) == 1: 159 | return x[0] 160 | 161 | return x 162 | 163 | 164 | def bool_to_index(x): 165 | return [i for i, v in enumerate(x) if v] 166 | 167 | 168 | def all_equal(iterator): 169 | try: 170 | iterator = iter(iterator) 171 | first = next(iterator) 172 | return all( 173 | np.array_equal(first, rest) for rest in iterator) 174 | except StopIteration: 175 | return True 176 | 177 | 178 | class MeanCalculator: 179 | def __init__(self, sum_axis): 180 | self.values = None 181 | self.n = 0 182 | self.sum_axis = sum_axis 183 | 184 | def add(self, v): 185 | if self.values is None: 186 | self.values = v.sum(axis=self.sum_axis) 187 | else: 188 | self.values += v.sum(axis=self.sum_axis) 189 | self.n += v.shape[self.sum_axis] 190 | 191 | def calculate(self): 192 | return self.values / self.n 193 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py39 3 | isolated_build = True 4 | 5 | [testenv] 6 | passenv = SSH_AUTH_SOCK RUSTUP_HOME CARGO_HOME 7 | wheel = true 8 | deps = -rrequirements.txt 9 | commands = pytest 10 | --------------------------------------------------------------------------------