├── .bazelrc ├── .bazelversion ├── .clang-format ├── .github ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml ├── pr_labels.yml ├── release-drafter.yml ├── tools │ └── release_linux.sh └── workflows │ ├── lint.yml │ ├── release-notes.yml │ ├── release.yml │ └── unittests.yml ├── .gitignore ├── .gitmodules ├── .tensorflow.bazelrc ├── BUILD ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── README.md ├── WORKSPACE ├── build_pip_pkg.sh ├── configure.py ├── docs └── README.md ├── examples ├── BUILD ├── converter_examples.py └── lce_minimal.cc ├── larq_compute_engine ├── BUILD ├── __init__.py ├── core │ ├── BUILD │ ├── bconv2d │ │ ├── BUILD │ │ ├── optimized_bgemm.h │ │ ├── optimized_indirect_bgemm.h │ │ ├── output_transform.h │ │ ├── params.h │ │ ├── reference.h │ │ └── zero_padding_correction.h │ ├── bgemm │ │ ├── BUILD │ │ ├── bgemm.h │ │ ├── kernels.h │ │ ├── kernels_aarch64.h │ │ ├── kernels_arm32.h │ │ ├── kernels_common.h │ │ ├── ruy_pack.h │ │ └── ruy_trmul_params.h │ ├── bitpacking │ │ ├── BUILD │ │ ├── bitpack.h │ │ ├── bitpack_aarch64.h │ │ ├── tests │ │ │ ├── BUILD │ │ │ ├── bitpack_aarch64_test.cc │ │ │ └── bitpack_test.cc │ │ └── utils.h │ ├── bmaxpool.h │ ├── indirect_bgemm │ │ ├── BUILD │ │ ├── kernel.h │ │ ├── kernel_4x2_portable.h │ │ ├── kernel_8x4x1_aarch64.h │ │ ├── kernel_8x4x2_aarch64.h │ │ ├── kernel_8x4x4_aarch64.h │ │ └── select_kernel.h │ └── types.h ├── mlir │ ├── BUILD │ ├── __init__.py │ ├── ir │ │ ├── lce_ops.cc │ │ ├── lce_ops.h │ │ └── lce_ops.td │ ├── lce_mlir_opt.cc │ ├── python │ │ ├── __init__.py │ │ ├── common.cc │ │ ├── common.h │ │ ├── converter.py │ │ ├── converter_test.py │ │ ├── graphdef_tfl_flatbuffer.cc │ │ ├── pybind_export.cc │ │ ├── saved_model_tfl_flatbuffer.cc │ │ └── util.py │ ├── tests │ │ ├── BUILD │ │ ├── bitpack-weights.mlir │ │ ├── const-fold.mlir │ │ ├── detection_postprocess.mlir │ │ ├── fuse_padding.mlir │ │ ├── lce_ops_options_test.cc │ │ ├── legalize-lce.mlir │ │ ├── lit_test.bzl │ │ ├── op-removal.mlir │ │ ├── optimize.mlir │ │ ├── prepare-tf.mlir │ │ ├── quantize.mlir │ │ ├── run_lit.sh │ │ └── set_batch_size.mlir │ ├── tf_tfl_passes.cc │ ├── tf_tfl_passes.h │ ├── tf_to_tfl_flatbuffer.cc │ ├── tf_to_tfl_flatbuffer.h │ └── transforms │ │ ├── bitpack.cc │ │ ├── bitpack.h │ │ ├── bitpack_activations_patterns.td │ │ ├── bitpack_weights.cc │ │ ├── bitpack_weights_patterns.td │ │ ├── common.h │ │ ├── detection_postprocess.cc │ │ ├── fuse_padding.cc │ │ ├── fuse_padding.td │ │ ├── legalize_tflite.cc │ │ ├── op_removal.cc │ │ ├── op_removal_patterns.td │ │ ├── optimize.cc │ │ ├── optimize_patterns_common.td │ │ ├── optimize_patterns_target_arm.td │ │ ├── padding.h │ │ ├── passes.h │ │ ├── prepare_patterns_common.td │ │ ├── prepare_patterns_target_arm.td │ │ ├── prepare_tf.cc │ │ ├── quantize.cc │ │ ├── quantize_patterns.td │ │ ├── set_batch_size.cc │ │ └── translate_tflite.cc ├── requirements.in ├── requirements.txt ├── tests │ ├── BUILD │ ├── convert_model.py │ ├── end2end_test.py │ ├── preprocess.py │ ├── qemu_test.bzl │ ├── strip_lcedequantize_test.py │ ├── test_aarch64_binary.sh │ └── test_arm32_binary.sh └── tflite │ ├── BUILD │ ├── __init__.py │ ├── benchmark │ ├── BUILD │ ├── README.md │ ├── lce_benchmark_main.cc │ ├── lce_benchmark_tflite_model.cc │ └── lce_benchmark_tflite_model.h │ ├── build_defs.bzl │ ├── java │ ├── BUILD │ ├── build_lce_aar.sh │ └── lce_ops_jni.cc │ ├── kernels │ ├── BUILD │ ├── bconv2d.cc │ ├── bmaxpool.cc │ ├── lce_ops_register.h │ ├── quantization.cc │ └── utils.h │ ├── python │ ├── BUILD │ ├── __init__.py │ ├── interpreter.py │ ├── interpreter_base.py │ ├── interpreter_wrapper_lite.cc │ └── interpreter_wrapper_utils.h │ └── tests │ ├── BUILD │ ├── bconv2d_op_model.h │ ├── bconv2d_test.cc │ ├── bmaxpool_test.cc │ ├── interpreter_test.py │ ├── quantization_test.cc │ └── utils.h ├── setup.cfg ├── setup.py ├── test-requirements.txt └── third_party ├── install_android.sh └── tensorflow_patches ├── BUILD └── disable_forced_mkl.patch /.bazelrc: -------------------------------------------------------------------------------- 1 | # Import TensorFlow's configuration first. 2 | try-import %workspace%/.tensorflow.bazelrc 3 | 4 | # Prevent invalid caching if input files are modified during a build. 5 | build --experimental_guard_against_concurrent_changes 6 | 7 | # Allow up to 10 Mb of logging 8 | build --experimental_ui_max_stdouterr_bytes=10485760 9 | 10 | # Disable visibility checks (works around some private deps in TensorFlow) 11 | build --nocheck_visibility 12 | 13 | # Disable framework_shared_object for all LCE builds and tests. 14 | build --config=monolithic 15 | 16 | # Make sure tests are quick and -DNDEBUG is *not* set 17 | test --compilation_mode=fastbuild 18 | test --cxxopt -DTF_LITE_DISABLE_X86_NEON 19 | 20 | # Enable Ruy 21 | build --copt=-DTFLITE_WITH_RUY 22 | 23 | # Disable XLA 24 | build --define=with_xla_support=false 25 | 26 | # Disable MKL 27 | build --define=enable_mkl=false --define=build_with_mkl=false --define=tensorflow_mkldnn_contraction_kernel=0 28 | 29 | # Config for a 32-bit Raspberry Pi - can be activated using --config=rpi3 30 | build:rpi3 --config=elinux_armhf 31 | build:rpi3 --copt=-march=armv7-a --copt=-mfpu=neon-vfpv4 --copt=-std=gnu++11 --copt=-DS_IREAD=S_IRUSR --copt=-DS_IWRITE=S_IWUSR --copt=-fno-tree-pre --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_1 --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_2 --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_8 --define=raspberry_pi_with_neon=true --copt=-funsafe-math-optimizations --copt=-ftree-vectorize --copt=-fomit-frame-pointer --verbose_failures 32 | 33 | # Config for 64-bit ARM - can be activated using --config=aarch64 34 | build:aarch64 --config=elinux_aarch64 35 | build:aarch64 --copt=-march=armv8-a --copt=-std=gnu++11 --copt=-DS_IREAD=S_IRUSR --copt=-DS_IWRITE=S_IWUSR --copt=-fno-tree-pre --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_1 --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_2 --copt=-U__GCC_HAVE_SYNC_COMPARE_AND_SWAP_8 --copt=-funsafe-math-optimizations --copt=-ftree-vectorize --copt=-fomit-frame-pointer --verbose_failures 36 | 37 | # Disable unnecessary features. 38 | build:linux --config=nohdfs --config=nonccl --config=noaws --config=nogcp 39 | build:macos --config=nohdfs --config=nonccl --config=noaws --config=nogcp 40 | build:windows --config=noaws --config=nogcp 41 | 42 | # Extra build options we need for windows. 43 | build:windows --experimental_strict_action_env=true 44 | 45 | # Disable certain warnings that come from TF code and pollute logs 46 | build:android --copt=-Wno-deprecated-declarations 47 | build:linux --copt=-Wno-deprecated-declarations 48 | build:linux --host_copt=-Wno-deprecated-declarations 49 | build:macos --copt=-Wno-deprecated-declarations 50 | 51 | # Windows generates *a lot* of warnings; disable them all. 52 | build:windows --copt=/W0 53 | 54 | # Config-specific options should come above this line. 55 | 56 | # Options from ./configure 57 | try-import %workspace%/.lce_configure.bazelrc 58 | 59 | # Put user-specific options in .bazelrc.user 60 | try-import %workspace%/.bazelrc.user 61 | -------------------------------------------------------------------------------- /.bazelversion: -------------------------------------------------------------------------------- 1 | 6.5.0 2 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | # Run manually to reformat a file: 2 | # clang-format -i --style=file 3 | BasedOnStyle: Google 4 | DerivePointerAlignment: false 5 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | ## What do these changes do? 6 | 7 | 8 | ## How Has This Been Tested? 9 | 10 | 11 | ## Benchmark Results 12 | 13 | 14 | ## Related issue number 15 | 16 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: github-actions 4 | directory: "/" 5 | schedule: 6 | interval: daily 7 | open-pull-requests-limit: 10 8 | -------------------------------------------------------------------------------- /.github/pr_labels.yml: -------------------------------------------------------------------------------- 1 | version: "1" 2 | invalidStatus: "pending" 3 | labelRule: 4 | values: 5 | - "bug" 6 | - "documentation" 7 | - "internal-improvement" 8 | - "performance" 9 | - "dependencies" 10 | - "feature" 11 | - "breaking-change" 12 | - "skip-changelog" 13 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | template: $CHANGES 2 | 3 | categories: 4 | - title: ":warning: Breaking Changes :warning:" 5 | label: "breaking-change" 6 | - title: ":tada: Features" 7 | label: "feature" 8 | - title: ":rocket: Performance" 9 | label: "performance" 10 | - title: ":bug: Bug Fixes" 11 | label: "bug" 12 | - title: ":book: Documentation" 13 | label: "documentation" 14 | - title: ":construction_worker_man: Internal Improvements" 15 | label: "internal-improvement" 16 | - title: ":arrow_up: Dependencies" 17 | label: "dependencies" 18 | 19 | exclude-labels: 20 | - "skip-changelog" 21 | -------------------------------------------------------------------------------- /.github/tools/release_linux.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e -x 3 | 4 | python configure.py 5 | 6 | # Inside the docker container on github actions there is not 7 | # enough space for the bazel cache, but a larger disk is mounted at /github_disk 8 | # so we tell bazel to store everything there 9 | 10 | # `release_cpu_linux` will activate absolute paths to files that only exist in the tensorflow/build:2.16-pythonXX docker container 11 | bazel --output_user_root=/github_disk/bazel_root \ 12 | build :build_pip_pkg \ 13 | -c opt \ 14 | --config=release_cpu_linux \ 15 | --copt=-fvisibility=hidden \ 16 | --verbose_failures 17 | 18 | # Package Whl 19 | bazel-bin/build_pip_pkg artifacts 20 | 21 | # Remove manylinux2014 config flags so that normal builds work as expected 22 | rm -f .lce_configure.bazelrc 23 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: {} 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Set up Python 3.9 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: 3.9 19 | - name: Install Lint Dependencies 20 | run: | 21 | pip install -r test-requirements.txt --no-cache-dir 22 | sudo wget -O /usr/local/bin/buildifier https://github.com/bazelbuild/buildtools/releases/download/4.0.1/buildifier-linux-amd64 23 | sudo chmod +x /usr/local/bin/buildifier 24 | - name: Run PyFlakes 25 | run: pyflakes larq_compute_engine 26 | - name: Black code style 27 | run: black larq_compute_engine --check --target-version py39 28 | - name: clang-format lint 29 | uses: DoozyX/clang-format-lint-action@v0.20 30 | with: 31 | clangFormatVersion: 12 32 | - name: Lint bazel files 33 | run: buildifier -mode=check -r ./ 34 | - name: Type check with PyType 35 | run: pytype --jobs auto 36 | -------------------------------------------------------------------------------- /.github/workflows/release-notes.yml: -------------------------------------------------------------------------------- 1 | name: Release Notes 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | update_draft_release: 10 | if: github.repository == 'larq/compute-engine' 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: toolmantim/release-drafter@v6.1.0 14 | env: 15 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .pytype 2 | .lce_configure.bazelrc 3 | .DS_Store 4 | .ipynb_checkpoints 5 | node_modules 6 | /.bazelrc.user 7 | /.tf_configure.bazelrc 8 | /bazel-* 9 | /bazel_pip 10 | /tools/python_bin_path.sh 11 | /tensorflow/tools/git/gen 12 | /pip_test 13 | /_python_build 14 | *.pyc 15 | *.so 16 | __pycache__ 17 | *.swp 18 | .vscode/ 19 | cmake_build* 20 | cmake-build* 21 | tensorflow/contrib/cmake/_build/ 22 | .idea/** 23 | /build/ 24 | [Bb]uild/ 25 | /tensorflow/core/util/version_info.cc 26 | /tensorflow/python/framework/fast_tensor_util.cpp 27 | /tensorflow/lite/gen/** 28 | /tensorflow/lite/tools/make/downloads/** 29 | /api_init_files_list.txt 30 | /estimator_api_init_files_list.txt 31 | *.whl 32 | 33 | # Android 34 | .gradle 35 | .idea 36 | *.iml 37 | local.properties 38 | gradleBuild 39 | 40 | # iOS 41 | *.pbxproj 42 | *.xcworkspace 43 | /*.podspec 44 | /tensorflow/lite/**/[ios|objc|swift]*/BUILD 45 | /tensorflow/lite/examples/ios/simple/data/*.tflite 46 | /tensorflow/lite/examples/ios/simple/data/*.txt 47 | Podfile.lock 48 | Pods 49 | xcuserdata 50 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tensorflow"] 2 | path = third_party/tensorflow 3 | url = https://github.com/tensorflow/tensorflow.git 4 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- 1 | sh_binary( 2 | name = "build_pip_pkg", 3 | srcs = ["build_pip_pkg.sh"], 4 | data = [ 5 | "LICENSE", 6 | "MANIFEST.in", 7 | "README.md", 8 | "setup.py", 9 | "//larq_compute_engine:compute_engine_py", 10 | ], 11 | ) 12 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | - Using welcoming and inclusive language 18 | - Being respectful of differing viewpoints and experiences 19 | - Gracefully accepting constructive criticism 20 | - Focusing on what is best for the community 21 | - Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | - The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | - Trolling, insulting/derogatory comments, and personal or political attacks 28 | - Public or private harassment 29 | - Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | - Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at roeland@plumerai.co.uk or lukas@plumerai.co.uk. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Larq Compute Engine 2 | 3 | 👍 🎉 First off, thanks for taking the time to contribute! 👍 🎉 4 | 5 | **Working on your first Pull Request?** You can learn how from this _free_ series 6 | [How to Contribute to an Open Source Project on GitHub](https://egghead.io/courses/how-to-contribute-to-an-open-source-project-on-github). 7 | 8 | ## Ask a question or raise an issue 9 | 10 | If you have questions about Larq Compute Engine or just want to say Hi you can [chat with us on Spectrum](https://spectrum.chat/larq). 11 | 12 | If something is not working as expected, if you run into problems with Larq Compute Engine or if you have ideas for missing features, please open a [new issue](https://github.com/larq/compute-engine/issues). 13 | 14 | ## Project setup 15 | 16 | See our [build guide](https://docs.larq.dev/compute-engine/build/) to get started. 17 | 18 | ## Code style 19 | 20 | We use [`clang-format`](https://clang.llvm.org/docs/ClangFormat.html), [`black`](https://black.readthedocs.io/en/stable/) and [`buildifier`](https://github.com/bazelbuild/buildtools/releases/tag/1.0.0) to format all of our code. 21 | 22 | ## Publish LCE converter `pip` package 23 | 24 | 1. Increment the version number in `setup.py`, and make a PR with that change. 25 | 26 | 2. Wait until your PR is reviewed and merged. 27 | 28 | 3. Go to the [GitHub releases](https://github.com/larq/compute-engine/releases), edit the release notes of the draft release, change the tag to the desired version (e.g. `v0.7.0`) and hit "Publish release". 29 | 30 | 4. A [GitHub action](https://github.com/larq/compute-engine/actions) will automatically publish a release to [PyPI](https://pypi.org/) based on the tag. 31 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | RUN apt-get update && apt-get install curl zip unzip git build-essential openjdk-8-jdk-headless python3-dev python3-pip qemu-user -y --no-install-recommends && rm -rf /var/lib/apt/lists/* 4 | RUN curl -L https://github.com/bazelbuild/bazelisk/releases/download/v1.7.4/bazelisk-linux-amd64 > /usr/local/bin/bazelisk && chmod +x /usr/local/bin/bazelisk 5 | RUN ln -s /usr/bin/python3 /usr/local/bin/python && ln -s /usr/bin/pip3 /usr/local/bin/pip 6 | RUN pip install six numpy --no-cache-dir 7 | 8 | WORKDIR /compute-engine 9 | COPY . . 10 | RUN ./third_party/install_android.sh 11 | RUN ./configure.py 12 | RUN bazelisk --version 13 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include larq_compute_engine *.so 2 | recursive-include larq_compute_engine *.pyd 3 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "larq_compute_engine") 2 | 3 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 4 | 5 | # To update TensorFlow to a new revision. 6 | # 1. Update the git hash in the urls and the 'strip_prefix' parameter. 7 | # 2. Get the sha256 hash of the archive with a command such as... 8 | # curl -L https://github.com/tensorflow/tensorflow/archive/.tar.gz | shasum -a 256 9 | # and update the 'sha256' arg with the result. 10 | # 3. Request the new archive to be mirrored on mirror.bazel.build for more 11 | # reliable downloads. 12 | http_archive( 13 | name = "org_tensorflow", 14 | patch_args = ["-p1"], 15 | patch_tool = "patch", 16 | patches = [ 17 | "//third_party/tensorflow_patches:disable_forced_mkl.patch", 18 | ], 19 | sha256 = "c729e56efc945c6df08efe5c9f5b8b89329c7c91b8f40ad2bb3e13900bd4876d", 20 | strip_prefix = "tensorflow-2.16.1", 21 | urls = [ 22 | "https://github.com/tensorflow/tensorflow/archive/v2.16.1.tar.gz", 23 | ], 24 | ) 25 | 26 | # We must initialize hermetic python first. 27 | http_archive( 28 | name = "bazel_skylib", 29 | sha256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506", 30 | urls = [ 31 | "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", 32 | "https://github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", 33 | ], 34 | ) 35 | 36 | http_archive( 37 | name = "rules_python", 38 | sha256 = "9d04041ac92a0985e344235f5d946f71ac543f1b1565f2cdbc9a2aaee8adf55b", 39 | strip_prefix = "rules_python-0.26.0", 40 | url = "https://github.com/bazelbuild/rules_python/releases/download/0.26.0/rules_python-0.26.0.tar.gz", 41 | ) 42 | 43 | load("@rules_python//python:repositories.bzl", "py_repositories", "python_register_toolchains") 44 | 45 | py_repositories() 46 | 47 | load( 48 | "@org_tensorflow//tensorflow/tools/toolchains/python:python_repo.bzl", 49 | "python_repository", 50 | ) 51 | 52 | python_repository(name = "python_version_repo") 53 | 54 | load("@python_version_repo//:py_version.bzl", "TF_PYTHON_VERSION") 55 | 56 | python_register_toolchains( 57 | name = "python", 58 | ignore_root_user_error = True, 59 | python_version = TF_PYTHON_VERSION, 60 | ) 61 | 62 | load("@python//:defs.bzl", "interpreter") 63 | load("@rules_python//python:pip.bzl", "package_annotation", "pip_parse") 64 | 65 | NUMPY_ANNOTATIONS = { 66 | "numpy": package_annotation( 67 | additive_build_content = """\ 68 | filegroup( 69 | name = "includes", 70 | srcs = glob(["site-packages/numpy/core/include/**/*.h"]), 71 | ) 72 | cc_library( 73 | name = "numpy_headers", 74 | hdrs = [":includes"], 75 | strip_include_prefix="site-packages/numpy/core/include/", 76 | ) 77 | """, 78 | ), 79 | } 80 | 81 | pip_parse( 82 | name = "pypi", 83 | annotations = NUMPY_ANNOTATIONS, 84 | python_interpreter_target = interpreter, 85 | requirements = "@org_tensorflow//:requirements_lock_" + TF_PYTHON_VERSION.replace(".", "_") + ".txt", 86 | ) 87 | 88 | load("@pypi//:requirements.bzl", tf_install_deps = "install_deps") 89 | 90 | tf_install_deps() 91 | 92 | pip_parse( 93 | name = "pypi_lce", 94 | python_interpreter_target = interpreter, 95 | requirements = "//larq_compute_engine:requirements.txt", 96 | ) 97 | 98 | load("@pypi_lce//:requirements.bzl", lce_install_deps = "install_deps") 99 | 100 | lce_install_deps() 101 | 102 | load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") 103 | 104 | tf_workspace3() 105 | 106 | load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") 107 | 108 | tf_workspace2() 109 | 110 | load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1") 111 | 112 | tf_workspace1() 113 | 114 | load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0") 115 | 116 | tf_workspace0() 117 | -------------------------------------------------------------------------------- /build_pip_pkg.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # Modifications copyright (C) 2020 Larq Contributors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | set -e 18 | set -x 19 | 20 | PLATFORM="$(uname -s | tr 'A-Z' 'a-z')" 21 | 22 | function is_linux() { 23 | [[ "${PLATFORM}" == "linux" ]] 24 | } 25 | 26 | function is_windows() { 27 | # On windows, the shell script is actually running in msys 28 | [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]] 29 | } 30 | 31 | if is_windows; then 32 | # On windows, the workspace name is lce to avoid the path length limit of MSVC 33 | PIP_FILE_PREFIX="bazel-bin/build_pip_pkg.exe.runfiles/lce/" 34 | else 35 | PIP_FILE_PREFIX="bazel-bin/build_pip_pkg.runfiles/larq_compute_engine/" 36 | fi 37 | 38 | function abspath() { 39 | cd "$(dirname $1)" 40 | echo "$PWD/$(basename $1)" 41 | cd "$OLDPWD" 42 | } 43 | 44 | function main() { 45 | DEST=${1} 46 | BUILD_FLAG=${@:2} 47 | 48 | if [[ -z ${DEST} ]]; then 49 | echo "No destination dir provided" 50 | exit 1 51 | fi 52 | 53 | mkdir -p ${DEST} 54 | DEST=$(abspath "${DEST}") 55 | echo "=== destination directory: ${DEST}" 56 | 57 | TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX) 58 | echo $(date) : "=== Using tmpdir: ${TMPDIR}" 59 | 60 | echo "=== Copy Larq Compute Engine files" 61 | 62 | cp ${PIP_FILE_PREFIX}setup.py "${TMPDIR}" 63 | cp ${PIP_FILE_PREFIX}MANIFEST.in "${TMPDIR}" 64 | cp ${PIP_FILE_PREFIX}README.md "${TMPDIR}" 65 | cp ${PIP_FILE_PREFIX}LICENSE "${TMPDIR}" 66 | if is_linux; then 67 | touch ${TMPDIR}/stub.cc 68 | fi 69 | if is_windows; then 70 | from=$(cygpath -w ${PIP_FILE_PREFIX}larq_compute_engine) 71 | to=$(cygpath -w "${TMPDIR}"/larq_compute_engine) 72 | start robocopy //S "${from}" "${to}" //xf *_test.py 73 | sleep 5 74 | else 75 | rsync -avm -L --exclude='*_test.py' ${PIP_FILE_PREFIX}larq_compute_engine "${TMPDIR}" 76 | fi 77 | pushd ${TMPDIR} 78 | 79 | if ! is_windows; then 80 | echo "=== Stripping symbols" 81 | chmod +w ${TMPDIR}/larq_compute_engine/mlir/*.so ${TMPDIR}/larq_compute_engine/tflite/python/*.so 82 | strip -x ${TMPDIR}/larq_compute_engine/mlir/*.so ${TMPDIR}/larq_compute_engine/tflite/python/*.so 83 | fi 84 | 85 | echo $(date) : "=== Building wheel" 86 | python setup.py bdist_wheel ${BUILD_FLAG} > /dev/null 87 | 88 | cp dist/*.whl "${DEST}" 89 | popd 90 | rm -rf ${TMPDIR} 91 | echo $(date) : "=== Output wheel file is in: ${DEST}" 92 | } 93 | 94 | main "$@" 95 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Larq Compute Engine Documentation 2 | 3 | Docs are available at https://docs.larq.dev/compute-engine 4 | -------------------------------------------------------------------------------- /examples/BUILD: -------------------------------------------------------------------------------- 1 | load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_linkopts") 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | licenses = ["notice"], # Apache 2.0 6 | ) 7 | 8 | cc_binary( 9 | name = "lce_minimal", 10 | srcs = [ 11 | "lce_minimal.cc", 12 | ], 13 | linkopts = tflite_linkopts() + select({ 14 | "@org_tensorflow//tensorflow:android": [ 15 | "-pie", # Android 5.0 and later supports only PIE 16 | "-lm", # some builtin ops, e.g., tanh, need -lm 17 | ], 18 | "//conditions:default": [], 19 | }), 20 | deps = [ 21 | "//larq_compute_engine/tflite/kernels:lce_op_kernels", 22 | "@org_tensorflow//tensorflow/lite:framework", 23 | "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", 24 | ], 25 | ) 26 | -------------------------------------------------------------------------------- /examples/converter_examples.py: -------------------------------------------------------------------------------- 1 | """Examples of TF lite model conversion.""" 2 | import tensorflow as tf 3 | import larq_compute_engine as lce 4 | import larq_zoo as lqz 5 | 6 | # Example of converting a model from Larq Zoo to TF lite 7 | model = lqz.sota.QuickNet(weights="imagenet") 8 | converted = lce.convert_keras_model(model) 9 | with open("quicknet.tflite", "wb") as f: 10 | f.write(converted) 11 | 12 | # Example of converting an h5 file 13 | model = tf.keras.models.load_model("my_model.h5") 14 | converted = lce.convert_keras_model(model) 15 | with open("my_model.tflite", "wb") as f: 16 | f.write(converted) 17 | -------------------------------------------------------------------------------- /examples/lce_minimal.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "larq_compute_engine/tflite/kernels/lce_ops_register.h" 4 | #include "tensorflow/lite/interpreter.h" 5 | #include "tensorflow/lite/kernels/register.h" 6 | #include "tensorflow/lite/model.h" 7 | #include "tensorflow/lite/optional_debug_tools.h" 8 | 9 | // This file is based on the TF lite minimal example where the 10 | // "BuiltinOpResolver" is modified to include the "Larq Compute Engine" custom 11 | // ops. Here we read a binary model from disk and perform inference by using the 12 | // C++ interface. See the BUILD file in this directory to see an example of 13 | // linking "Larq Compute Engine" cutoms ops to your inference binary. 14 | 15 | using namespace tflite; 16 | 17 | #define TFLITE_MINIMAL_CHECK(x) \ 18 | if (!(x)) { \ 19 | fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ 20 | exit(1); \ 21 | } 22 | 23 | int main(int argc, char* argv[]) { 24 | if (argc != 2) { 25 | fprintf(stderr, "lce_minimal \n"); 26 | return 1; 27 | } 28 | const char* filename = argv[1]; 29 | 30 | // Load model 31 | std::unique_ptr model = 32 | tflite::FlatBufferModel::BuildFromFile(filename); 33 | TFLITE_MINIMAL_CHECK(model != nullptr); 34 | 35 | // Build the interpreter 36 | tflite::ops::builtin::BuiltinOpResolver resolver; 37 | compute_engine::tflite::RegisterLCECustomOps(&resolver); 38 | 39 | InterpreterBuilder builder(*model, resolver); 40 | std::unique_ptr interpreter; 41 | builder(&interpreter); 42 | TFLITE_MINIMAL_CHECK(interpreter != nullptr); 43 | 44 | // Allocate tensor buffers. 45 | TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk); 46 | printf("=== Pre-invoke Interpreter State ===\n"); 47 | tflite::PrintInterpreterState(interpreter.get()); 48 | 49 | // Fill input buffers 50 | // TODO(user): Insert code to fill input tensors 51 | 52 | // Run inference 53 | TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk); 54 | 55 | printf("\n\n=== Post-invoke Interpreter State ===\n"); 56 | tflite::PrintInterpreterState(interpreter.get()); 57 | 58 | // Read output buffers 59 | // TODO(user): Insert getting data out code. 60 | 61 | return 0; 62 | } 63 | -------------------------------------------------------------------------------- /larq_compute_engine/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | config_setting( 6 | name = "arm32_build", 7 | values = {"cpu": "armeabi"}, 8 | ) 9 | 10 | config_setting( 11 | name = "aarch64_build", 12 | values = {"cpu": "aarch64"}, 13 | ) 14 | 15 | py_library( 16 | name = "compute_engine_py", 17 | srcs = [ 18 | "__init__.py", 19 | "//larq_compute_engine/mlir:__init__.py", 20 | "//larq_compute_engine/mlir:python/__init__.py", 21 | "//larq_compute_engine/tflite:__init__.py", 22 | ], 23 | deps = [ 24 | "//larq_compute_engine/mlir:converter", 25 | "//larq_compute_engine/tflite/python:interpreter", 26 | ], 27 | ) 28 | -------------------------------------------------------------------------------- /larq_compute_engine/__init__.py: -------------------------------------------------------------------------------- 1 | from larq_compute_engine.mlir.python.converter import ( 2 | convert_keras_model, 3 | convert_saved_model, 4 | ) 5 | from larq_compute_engine.tflite.python import interpreter as testing 6 | 7 | try: 8 | from importlib import metadata # type: ignore 9 | except ImportError: 10 | # Running on pre-3.8 Python; use importlib-metadata package 11 | import importlib_metadata as metadata # type: ignore 12 | 13 | try: 14 | __version__ = metadata.version("larq_compute_engine") 15 | except metadata.PackageNotFoundError: 16 | __version__ = "unknown" 17 | 18 | __all__ = ["convert_keras_model", "convert_saved_model", "testing"] 19 | -------------------------------------------------------------------------------- /larq_compute_engine/core/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_library( 6 | name = "types", 7 | hdrs = [ 8 | "types.h", 9 | ], 10 | deps = [ 11 | "@org_tensorflow//tensorflow/lite/kernels/internal:cppmath", 12 | ], 13 | ) 14 | 15 | cc_library( 16 | name = "bmaxpool", 17 | hdrs = [ 18 | "bmaxpool.h", 19 | ], 20 | deps = [ 21 | ":types", 22 | "@org_tensorflow//tensorflow/lite/kernels/internal:common", 23 | "@org_tensorflow//tensorflow/lite/kernels/internal:types", 24 | ], 25 | ) 26 | -------------------------------------------------------------------------------- /larq_compute_engine/core/bconv2d/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_library( 6 | name = "params", 7 | hdrs = [ 8 | "params.h", 9 | ], 10 | deps = [ 11 | "//larq_compute_engine/core:types", 12 | "@org_tensorflow//tensorflow/lite:builtin_op_data", 13 | ], 14 | ) 15 | 16 | cc_library( 17 | name = "output_transform", 18 | hdrs = [ 19 | "output_transform.h", 20 | ], 21 | deps = [ 22 | "//larq_compute_engine/core:types", 23 | "@org_tensorflow//tensorflow/lite/kernels/internal:common", 24 | "@org_tensorflow//tensorflow/lite/kernels/internal:cppmath", 25 | ], 26 | ) 27 | 28 | cc_library( 29 | name = "zero_padding_correction", 30 | hdrs = ["zero_padding_correction.h"], 31 | ) 32 | 33 | cc_library( 34 | name = "reference", 35 | hdrs = [ 36 | "reference.h", 37 | ], 38 | deps = [ 39 | ":output_transform", 40 | "@org_tensorflow//tensorflow/lite/kernels/internal:types", 41 | ], 42 | ) 43 | 44 | cc_library( 45 | name = "optimized_bgemm", 46 | hdrs = [ 47 | "optimized_bgemm.h", 48 | ], 49 | deps = [ 50 | ":zero_padding_correction", 51 | "//larq_compute_engine/core/bgemm", 52 | "@org_tensorflow//tensorflow/lite/kernels:cpu_backend_context", 53 | "@org_tensorflow//tensorflow/lite/kernels:cpu_backend_gemm", 54 | "@org_tensorflow//tensorflow/lite/kernels:padding", 55 | "@org_tensorflow//tensorflow/lite/kernels/internal:optimized_base", 56 | "@ruy//ruy/profiler:instrumentation", 57 | ], 58 | ) 59 | 60 | cc_library( 61 | name = "optimized_indirect_bgemm", 62 | hdrs = [ 63 | "optimized_indirect_bgemm.h", 64 | ], 65 | deps = [ 66 | ":zero_padding_correction", 67 | "//larq_compute_engine/core/indirect_bgemm:kernels", 68 | "@org_tensorflow//tensorflow/lite/kernels:cpu_backend_context", 69 | "@org_tensorflow//tensorflow/lite/kernels:cpu_backend_gemm", 70 | "@org_tensorflow//tensorflow/lite/kernels:padding", 71 | "@org_tensorflow//tensorflow/lite/kernels/internal:optimized_base", 72 | "@ruy//ruy/profiler:instrumentation", 73 | ], 74 | ) 75 | -------------------------------------------------------------------------------- /larq_compute_engine/core/bconv2d/optimized_indirect_bgemm.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_INDIRECT_BGEMM_H_ 2 | #define COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_INDIRECT_BGEMM_H_ 3 | 4 | #include "larq_compute_engine/core/bconv2d/zero_padding_correction.h" 5 | #include "larq_compute_engine/core/indirect_bgemm/kernel.h" 6 | #include "ruy/profiler/instrumentation.h" 7 | #include "tensorflow/lite/kernels/internal/types.h" 8 | 9 | namespace compute_engine { 10 | namespace core { 11 | namespace bconv2d { 12 | 13 | template 14 | inline void BConv2DOptimizedIndirectBGEMM( 15 | const indirect_bgemm::Kernel* kernel, const BConv2DParams* bconv2d_params, 16 | const RuntimeShape& bitpacked_input_shape, const RuntimeShape& output_shape, 17 | DstScalar* output_ptr, const float* padding_buffer, const int pad_value) { 18 | ruy::profiler::ScopeLabel label("BConv2D (optimized, indirect BGEMM)"); 19 | 20 | // If writing bitpacked output with a channel count that isn't a multiple of 21 | // 32 (i.e. where padding bits will be required in the output), fill the 22 | // output tensor with zeroes in advance so that the BGEMM doesn't have to 23 | // worry about doing the padding. 24 | if (std::is_same::value && 25 | (kernel->output_channels % bitpacking_bitwidth != 0)) { 26 | std::fill( 27 | output_ptr, 28 | output_ptr + kernel->num_output_pixels * 29 | bitpacking::GetBitpackedSize(kernel->output_channels), 30 | TBitpacked(0)); 31 | } 32 | 33 | kernel->Dispatch(reinterpret_cast(output_ptr)); 34 | 35 | if (std::is_same::value && 36 | bconv2d_params->padding_type == TfLitePadding::kTfLitePaddingSame && 37 | pad_value == 0) { 38 | ruy::profiler::ScopeLabel label("Zero padding correction"); 39 | 40 | const int stride_width = bconv2d_params->stride_width; 41 | const int stride_height = bconv2d_params->stride_height; 42 | const int dilation_width_factor = bconv2d_params->dilation_width_factor; 43 | const int dilation_height_factor = bconv2d_params->dilation_height_factor; 44 | const int batches = MatchingDim(bitpacked_input_shape, 0, output_shape, 0); 45 | const int input_depth_per_group = 46 | bconv2d_params->channels_in / bconv2d_params->groups; 47 | const int input_width = bitpacked_input_shape.Dims(2); 48 | const int input_height = bitpacked_input_shape.Dims(1); 49 | const int filter_height = bconv2d_params->filter_height; 50 | const int filter_width = bconv2d_params->filter_width; 51 | const int output_depth = output_shape.Dims(3); 52 | const int output_width = output_shape.Dims(2); 53 | const int output_height = output_shape.Dims(1); 54 | 55 | zero_padding_correction::ApplyCorrection( 56 | batches, input_height, input_width, input_depth_per_group, 57 | filter_height, filter_width, output_depth, stride_height, stride_width, 58 | dilation_height_factor, dilation_width_factor, 59 | reinterpret_cast(output_ptr), output_height, output_width, 60 | padding_buffer); 61 | } 62 | } 63 | 64 | } // namespace bconv2d 65 | } // namespace core 66 | } // namespace compute_engine 67 | 68 | #endif // COMPUTE_ENGINE_CORE_BCONV2D_OPTIMIZED_INDIRECT_BGEMM_H_ 69 | -------------------------------------------------------------------------------- /larq_compute_engine/core/bconv2d/params.h: -------------------------------------------------------------------------------- 1 | #ifndef LARQ_COMPUTE_ENGINE_CORE_BCONV2D_PARAMS 2 | #define LARQ_COMPUTE_ENGINE_CORE_BCONV2D_PARAMS 3 | 4 | #include 5 | 6 | #include "tensorflow/lite/c/builtin_op_data.h" 7 | 8 | namespace compute_engine { 9 | namespace core { 10 | namespace bconv2d { 11 | 12 | struct BConv2DParams { 13 | // Input and filter shapes 14 | std::int32_t filter_width; 15 | std::int32_t filter_height; 16 | std::int32_t channels_in; 17 | std::int32_t channels_out; 18 | std::int32_t groups; 19 | 20 | // Strides 21 | std::int32_t stride_height; 22 | std::int32_t stride_width; 23 | 24 | // Dilations 25 | std::int32_t dilation_height_factor; 26 | std::int32_t dilation_width_factor; 27 | 28 | // Padding 29 | TfLitePadding padding_type; 30 | TfLitePaddingValues padding_values; 31 | std::int32_t pad_value; // Must be 0 or 1 32 | }; 33 | 34 | } // namespace bconv2d 35 | } // namespace core 36 | } // namespace compute_engine 37 | 38 | #endif // LARQ_COMPUTE_ENGINE_CORE_BCONV2D_PARAMS 39 | -------------------------------------------------------------------------------- /larq_compute_engine/core/bgemm/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_library( 6 | name = "kernels_common", 7 | hdrs = [ 8 | "kernels_common.h", 9 | ], 10 | deps = [ 11 | "//larq_compute_engine/core/bconv2d:output_transform", 12 | ], 13 | ) 14 | 15 | cc_library( 16 | name = "kernels", 17 | hdrs = [ 18 | "kernels.h", 19 | "kernels_aarch64.h", 20 | "kernels_arm32.h", 21 | ], 22 | deps = [ 23 | ":kernels_common", 24 | "//larq_compute_engine/core/bitpacking:utils", 25 | "@ruy//ruy/profiler:instrumentation", 26 | ], 27 | ) 28 | 29 | cc_library( 30 | name = "bgemm", 31 | hdrs = [ 32 | "bgemm.h", 33 | "ruy_pack.h", 34 | "ruy_trmul_params.h", 35 | ], 36 | deps = [ 37 | ":kernels", 38 | "@org_tensorflow//tensorflow/lite/kernels:cpu_backend_context", 39 | "@org_tensorflow//tensorflow/lite/kernels:cpu_backend_gemm", 40 | "@ruy//ruy/profiler:instrumentation", 41 | ], 42 | ) 43 | -------------------------------------------------------------------------------- /larq_compute_engine/core/bgemm/bgemm.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTE_ENGINE_CORE_BGEMM_BGEMM_H_ 2 | #define COMPUTE_ENGINE_CORE_BGEMM_BGEMM_H_ 3 | 4 | #include "larq_compute_engine/core/bgemm/kernels_common.h" 5 | #include "larq_compute_engine/core/bgemm/ruy_trmul_params.h" 6 | #include "ruy/context.h" 7 | #include "ruy/context_get_ctx.h" 8 | #include "ruy/matrix.h" 9 | #include "ruy/platform.h" 10 | #include "ruy/prepare_packed_matrices.h" 11 | #include "ruy/profiler/instrumentation.h" 12 | #include "ruy/trmul.h" 13 | #include "tensorflow/lite/kernels/cpu_backend_context.h" 14 | #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h" 15 | #include "tensorflow/lite/kernels/cpu_backend_gemm_ruy.h" 16 | 17 | using namespace tflite; 18 | using namespace tflite::cpu_backend_gemm; 19 | 20 | namespace compute_engine { 21 | namespace core { 22 | namespace bgemm { 23 | 24 | template 25 | void BGemm(const MatrixParams& lhs_params, 26 | const TBitpacked* lhs_data, 27 | const MatrixParams& rhs_params, 28 | const TBitpacked* rhs_data, 29 | const MatrixParams& dst_params, DstScalar* dst_data, 30 | const OutputTransform& output_transform, 31 | CpuBackendContext* context) { 32 | ruy::profiler::ScopeLabel label("BGemm (Ruy)"); 33 | 34 | static_assert(std::is_signed::value, 35 | "The DstScalar should be signed."); 36 | 37 | // Get ruy context 38 | auto ruy_ctx = get_ctx(context->ruy_context()); 39 | 40 | // Set up the matrix layouts and mul_params. 41 | ruy::Matrix lhs; 42 | ruy::Matrix rhs; 43 | ruy::Matrix dst; 44 | // We allow these matrices to be cached. Note that this doesn't force them 45 | // to be cached; it means that the `cache_policy` of the MatrixParams will 46 | // be respected. 47 | cpu_backend_gemm::detail::MakeRuyMatrix(lhs_params, lhs_data, &lhs, 48 | /*use_caching=*/true); 49 | cpu_backend_gemm::detail::MakeRuyMatrix(rhs_params, rhs_data, &rhs, 50 | /*use_caching=*/true); 51 | cpu_backend_gemm::detail::MakeRuyMatrix(dst_params, dst_data, &dst); 52 | 53 | // We have to make this a `const` matrix because otherwise gcc will try to 54 | // use the non-const versions of `matrix.data()` 55 | ruy::Mat internal_lhs = 56 | ruy::ToInternal((const ruy::Matrix)lhs); 57 | ruy::Mat internal_rhs = 58 | ruy::ToInternal((const ruy::Matrix)rhs); 59 | ruy::Mat internal_dst = ruy::ToInternal(dst); 60 | 61 | BinaryMulParams mul_params; 62 | mul_params.output_transform = output_transform; 63 | 64 | #if RUY_PLATFORM_NEON 65 | constexpr bool HasOptimizedNeonKernel = 66 | std::is_same::value || 67 | std::is_same::value || 68 | std::is_same::value; 69 | constexpr auto SelectedPath = 70 | HasOptimizedNeonKernel ? ruy::Path::kNeon : ruy::Path::kStandardCpp; 71 | #else 72 | constexpr auto SelectedPath = ruy::Path::kStandardCpp; 73 | #endif 74 | 75 | ruy::TrMulParams bgemm_trmul_params; 76 | PopulateBGemmTrMulParams(ruy::Transpose(internal_lhs), 77 | internal_rhs, internal_dst, mul_params, 78 | &bgemm_trmul_params); 79 | 80 | ruy::PreparePackedMatrices(ruy_ctx, &bgemm_trmul_params); 81 | ruy::TrMul(ruy_ctx, &bgemm_trmul_params); 82 | 83 | ruy_ctx->GetMainAllocator()->FreeAll(); 84 | } 85 | 86 | } // namespace bgemm 87 | } // namespace core 88 | } // namespace compute_engine 89 | 90 | #endif // COMPUTE_ENGINE_CORE_BGEMM_BGEMM_H_ 91 | -------------------------------------------------------------------------------- /larq_compute_engine/core/bgemm/ruy_trmul_params.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTE_ENGINE_CORE_BGEMM_RUY_TRMUL_PARAMS_H_ 2 | #define COMPUTE_ENGINE_CORE_BGEMM_RUY_TRMUL_PARAMS_H_ 3 | 4 | #include "larq_compute_engine/core/bgemm/kernels.h" 5 | #include "larq_compute_engine/core/bgemm/ruy_pack.h" 6 | #include "ruy/create_trmul_params.h" 7 | #include "ruy/mul_params.h" 8 | #include "ruy/path.h" 9 | #include "ruy/trmul_params.h" 10 | 11 | namespace compute_engine { 12 | namespace core { 13 | namespace bgemm { 14 | 15 | inline bool IsColMajorTrMul(const ruy::TrMulParams& params) { 16 | return IsColMajor(params.src[Side::kLhs].layout) && 17 | IsColMajor(params.src[Side::kRhs].layout) && 18 | IsColMajor(params.dst.layout); 19 | } 20 | 21 | template 22 | void PopulateBGemmTrMulParams(const Mat& lhs, 23 | const Mat& rhs, Mat& dst, 24 | const MulParamsType& mul_params, 25 | ruy::TrMulParams* params) { 26 | params->src[Side::kLhs] = EraseType(lhs); 27 | params->src[Side::kRhs] = EraseType(rhs); 28 | params->dst = EraseType(dst); 29 | 30 | static_assert(alignof(MulParamsType) <= kMaxMulParamsAlignment, ""); 31 | static_assert(sizeof(MulParamsType) <= kMaxMulParamsSize, ""); 32 | static_assert(std::is_trivially_copyable::value, ""); 33 | auto* dst_mul_params = 34 | reinterpret_cast(params->mul_params_bytes); 35 | std::memcpy(dst_mul_params, &mul_params, sizeof(MulParamsType)); 36 | 37 | // Optimised code paths only support all matrices being column-major 38 | if (!IsColMajorTrMul(*params) && ThePath != ruy::Path::kStandardCpp) { 39 | PopulateBGemmTrMulParams(lhs, rhs, dst, mul_params, 40 | params); 41 | return; 42 | }; 43 | 44 | using Kernel = BGemmKernel; 45 | using LhsKernelLayout = typename Kernel::LhsLayout; 46 | using RhsKernelLayout = typename Kernel::RhsLayout; 47 | 48 | params->path = ThePath; 49 | 50 | ruy::detail::CreatePackedMatrix( 51 | Side::kLhs, ToKernelLayout(), params); 52 | ruy::detail::CreatePackedMatrix( 53 | Side::kRhs, ToKernelLayout(), params); 54 | params->run_pack[Side::kLhs] = &RunRuyPack; 55 | params->run_pack[Side::kRhs] = &RunRuyPack; 56 | params->run_kernel = &RunBGemmKernel; 57 | } 58 | 59 | } // namespace bgemm 60 | } // namespace core 61 | } // namespace compute_engine 62 | 63 | #endif // COMPUTE_ENGINE_CORE_BGEMM_RUY_TRMUL_PARAMS_H_ 64 | -------------------------------------------------------------------------------- /larq_compute_engine/core/bitpacking/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_library( 6 | name = "bitpack", 7 | hdrs = ["bitpack.h"] + select({ 8 | "//larq_compute_engine:aarch64_build": [ 9 | "bitpack_aarch64.h", 10 | ], 11 | "@org_tensorflow//tensorflow:android_arm64": [ 12 | "bitpack_aarch64.h", 13 | ], 14 | "@org_tensorflow//tensorflow:macos_arm64": [ 15 | "bitpack_aarch64.h", 16 | ], 17 | "//conditions:default": [], 18 | }), 19 | deps = [ 20 | "//larq_compute_engine/core:types", 21 | "@flatbuffers", 22 | "@org_tensorflow//tensorflow/lite/kernels/internal:types", 23 | "@ruy//ruy/profiler:instrumentation", 24 | ], 25 | ) 26 | 27 | cc_library( 28 | name = "utils", 29 | hdrs = ["utils.h"], 30 | deps = [ 31 | ":bitpack", 32 | ], 33 | ) 34 | -------------------------------------------------------------------------------- /larq_compute_engine/core/bitpacking/tests/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_test( 6 | name = "bitpack_test", 7 | size = "small", 8 | srcs = ["bitpack_test.cc"], 9 | linkopts = select({ 10 | "@org_tensorflow//tensorflow:windows": [], 11 | "@org_tensorflow//tensorflow:macos": [ 12 | "-lm", 13 | ], 14 | "//conditions:default": [ 15 | "-lm", 16 | "-lrt", 17 | ], 18 | }), 19 | deps = [ 20 | "//larq_compute_engine/core/bitpacking:bitpack", 21 | "@com_google_absl//absl/strings", 22 | "@com_google_googletest//:gtest_main", 23 | ], 24 | ) 25 | 26 | # Collection of all CC architecture independent tests. Each new cc test needs to 27 | # be added here. 28 | test_suite( 29 | name = "cc_tests", 30 | tests = [ 31 | "bitpack_test", 32 | ], 33 | ) 34 | 35 | cc_test( 36 | name = "bitpack_aarch64_test", 37 | size = "small", 38 | srcs = ["bitpack_aarch64_test.cc"], 39 | deps = [ 40 | "//larq_compute_engine/core/bitpacking:bitpack", 41 | "@com_google_googletest//:gtest_main", 42 | ], 43 | ) 44 | -------------------------------------------------------------------------------- /larq_compute_engine/core/bitpacking/tests/bitpack_aarch64_test.cc: -------------------------------------------------------------------------------- 1 | #include "larq_compute_engine/core/bitpacking/bitpack_aarch64.h" 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include "larq_compute_engine/core/bitpacking/bitpack.h" 10 | #include "larq_compute_engine/core/types.h" 11 | 12 | namespace compute_engine { 13 | namespace core { 14 | namespace bitpacking { 15 | 16 | template 17 | void test_bitpacking_order(const int num_4x32_blocks) { 18 | static_assert(std::is_same::value || 19 | std::is_same::value, 20 | ""); 21 | 22 | const int num_blocks = 4 * num_4x32_blocks; 23 | const int n = 32 * num_blocks; 24 | 25 | DstScalar input[n]; 26 | const DstScalar zero_point = 27 | std::is_same::value ? -42 : 0; 28 | TBitpacked output[num_blocks]; 29 | for (auto i = 0; i < n; ++i) { 30 | // Try to get the position of bit i by packing the one-hot vector e_i 31 | for (auto j = 0; j < n; ++j) { 32 | if (j == i) 33 | input[j] = zero_point - 5; 34 | else 35 | input[j] = zero_point + 5; 36 | } 37 | // Run bitpacking 38 | bitpack_aarch64_4x32(input, num_blocks, output, zero_point); 39 | // See where in the output the bit has popped up 40 | int bit_index = -1; 41 | int bits_found = 0; 42 | for (auto j = 0; j < num_blocks; ++j) { 43 | for (auto k = 0; k < 32; ++k) { 44 | if (output[j] & (TBitpacked(1) << k)) { 45 | bit_index = k + j * 32; 46 | bits_found++; 47 | } 48 | } 49 | } 50 | 51 | // We should have exactly one enabled bit... 52 | EXPECT_EQ(bits_found, 1); 53 | // ...and it should be in the i^th position. 54 | EXPECT_EQ(bit_index, i); 55 | } 56 | } 57 | 58 | TEST(BitpackingAarch64, Float_1x4x32) { test_bitpacking_order(1); } 59 | TEST(BitpackingAarch64, Float_2x4x32) { test_bitpacking_order(2); } 60 | TEST(BitpackingAarch64, Float_3x4x32) { test_bitpacking_order(3); } 61 | TEST(BitpackingAarch64, Float_11x4x32) { test_bitpacking_order(11); } 62 | TEST(BitpackingAarch64, Float_17x4x32) { test_bitpacking_order(17); } 63 | 64 | TEST(BitpackingAarch64, Int8_1x4x32) { test_bitpacking_order(1); } 65 | TEST(BitpackingAarch64, Int8_2x4x32) { test_bitpacking_order(2); } 66 | TEST(BitpackingAarch64, Int8_3x4x32) { test_bitpacking_order(3); } 67 | TEST(BitpackingAarch64, Int8_11x4x32) { 68 | test_bitpacking_order(11); 69 | } 70 | TEST(BitpackingAarch64, Int8_17x4x32) { 71 | test_bitpacking_order(17); 72 | } 73 | 74 | } // namespace bitpacking 75 | } // namespace core 76 | } // namespace compute_engine 77 | -------------------------------------------------------------------------------- /larq_compute_engine/core/bitpacking/tests/bitpack_test.cc: -------------------------------------------------------------------------------- 1 | #include "larq_compute_engine/core/bitpacking/bitpack.h" 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "absl/strings/str_cat.h" 11 | 12 | namespace compute_engine { 13 | namespace core { 14 | namespace bitpacking { 15 | 16 | class BitpackingTest 17 | : public ::testing::TestWithParam> {}; 18 | 19 | template 20 | void runBitpackingTest(const int num_rows, const int num_cols, 21 | const std::int32_t zero_point) { 22 | if (std::is_same::value && zero_point != 0) { 23 | GTEST_SKIP(); 24 | } 25 | 26 | const int num_packed_cols = GetBitpackedSize(num_cols); 27 | 28 | std::random_device rd; 29 | std::mt19937 gen(rd()); 30 | 31 | std::vector input_matrix(num_rows * num_cols); 32 | std::vector output_matrix( 33 | GetBitpackedMatrixSize(num_rows, num_cols)); 34 | 35 | // Generate some random data for the input. 36 | if (std::is_same::value) { 37 | std::generate(std::begin(input_matrix), std::end(input_matrix), [&gen]() { 38 | return std::uniform_real_distribution<>(-1.5, 1.5)(gen); 39 | }); 40 | } else if (std::is_same::value) { 41 | std::generate(std::begin(input_matrix), std::end(input_matrix), [&gen]() { 42 | return std::uniform_int_distribution<>(-128, 127)(gen); 43 | }); 44 | } else { 45 | EXPECT_FALSE(true); 46 | } 47 | 48 | // Perform the bitpacking. 49 | bitpack_matrix(input_matrix.data(), num_rows, num_cols, output_matrix.data(), 50 | zero_point); 51 | 52 | // Verify correctness of the results. 53 | for (auto i = 0; i < num_rows; i++) { 54 | for (auto j = 0; j < bitpacking_bitwidth * num_packed_cols; j++) { 55 | const bool packed_bit = 56 | output_matrix.at(i * num_packed_cols + j / bitpacking_bitwidth) & 57 | (TBitpacked(1) << (j % bitpacking_bitwidth)); 58 | 59 | if (j < num_cols) { 60 | // If this bit position corresponds to an actual value, compare against 61 | // the sign of that value... 62 | bool expected_bit; 63 | if (std::is_same::value) { 64 | assert(zero_point == 0); 65 | expected_bit = input_matrix.at(i * num_cols + j) < 0; 66 | } else { 67 | expected_bit = static_cast( 68 | input_matrix.at(i * num_cols + j)) < zero_point; 69 | } 70 | EXPECT_EQ(packed_bit, expected_bit); 71 | } else { 72 | // ...otherwise it's a 'padding' bit, and we expect it to be zero. 73 | EXPECT_EQ(packed_bit, 0); 74 | } 75 | } 76 | } 77 | } 78 | 79 | TEST_P(BitpackingTest, BitpackFloats) { 80 | runBitpackingTest(std::get<0>(GetParam()), std::get<1>(GetParam()), 81 | std::get<2>(GetParam())); 82 | } 83 | 84 | TEST_P(BitpackingTest, BitpackInt8) { 85 | runBitpackingTest(std::get<0>(GetParam()), 86 | std::get<1>(GetParam()), 87 | std::get<2>(GetParam())); 88 | } 89 | 90 | std::string TestName( 91 | const ::testing::TestParamInfo& info) { 92 | // We have to treat the zero point specially, because we can't have a 93 | // hyphen in the name, and so can't naturally represent negative numbers. 94 | std::string param_zp_value_str = 95 | absl::StrCat(std::get<2>(info.param) >= 0 ? "Pos" : "Neg", 96 | std::abs(std::get<2>(info.param))); 97 | return absl::StrCat("Rows_", std::get<0>(info.param), "__Cols_", 98 | std::get<1>(info.param), "__ZeroPoint_", 99 | param_zp_value_str); 100 | } 101 | 102 | INSTANTIATE_TEST_SUITE_P(Bitpacking, BitpackingTest, 103 | ::testing::Combine( 104 | // num_rows 105 | ::testing::Values(1, 2, 3, 8, 10, 15, 64), 106 | // num_cols 107 | ::testing::Values(1, 3, 16, 32, 33, 63, 64, 128), 108 | // zero_point 109 | ::testing::Values(-1000, -1, 0, 23, 127, 128)), 110 | TestName); 111 | 112 | } // namespace bitpacking 113 | } // namespace core 114 | } // namespace compute_engine 115 | -------------------------------------------------------------------------------- /larq_compute_engine/core/bitpacking/utils.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTE_ENGINE_CORE_BITPACKING_UTILS_H_ 2 | #define COMPUTE_ENGINE_CORE_BITPACKING_UTILS_H_ 3 | 4 | #include "larq_compute_engine/core/bitpacking/bitpack.h" 5 | #include "tensorflow/lite/kernels/internal/types.h" 6 | 7 | namespace compute_engine { 8 | namespace core { 9 | namespace bitpacking { 10 | 11 | using namespace tflite; 12 | 13 | inline int GetBitpackedTensorSize(const RuntimeShape& shape) { 14 | const int dims = shape.DimensionsCount(); 15 | // Pack the tensor along the last dimension 16 | const int rows = FlatSizeSkipDim(shape, dims - 1); 17 | const int cols = shape.Dims(dims - 1); 18 | return GetBitpackedMatrixSize(rows, cols); 19 | } 20 | 21 | // Convenience function for bitpacking a tensor along its last dimension 22 | // and updating the tensor shape 23 | template 24 | inline void bitpack_tensor(const RuntimeShape& in_shape, const T* in_data, 25 | const std::int32_t zero_point, 26 | TBitpacked* out_data) { 27 | const int dims = in_shape.DimensionsCount(); 28 | // Pack the tensor along the last dimension 29 | const int rows = FlatSizeSkipDim(in_shape, dims - 1); 30 | const int cols = in_shape.Dims(dims - 1); 31 | 32 | bitpack_matrix(in_data, rows, cols, out_data, zero_point); 33 | } 34 | 35 | // Convenience function for going from a shape to the packed shape 36 | inline RuntimeShape bitpacked_shape(const RuntimeShape& in_shape) { 37 | const int dims = in_shape.DimensionsCount(); 38 | RuntimeShape out_shape(in_shape); 39 | out_shape.SetDim(dims - 1, GetBitpackedSize(in_shape.Dims(dims - 1))); 40 | return out_shape; 41 | } 42 | 43 | } // namespace bitpacking 44 | } // namespace core 45 | } // namespace compute_engine 46 | 47 | #endif // COMPUTE_ENGINE_CORE_BITPACKING_UTILS_H_ 48 | -------------------------------------------------------------------------------- /larq_compute_engine/core/bmaxpool.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTE_ENGINE_CORE_BMAXPOOL_H_ 2 | #define COMPUTE_ENGINE_CORE_BMAXPOOL_H_ 3 | 4 | #include "larq_compute_engine/core/types.h" 5 | #include "tensorflow/lite/kernels/internal/common.h" 6 | #include "tensorflow/lite/kernels/internal/types.h" 7 | #include "tensorflow/lite/kernels/padding.h" 8 | 9 | namespace compute_engine { 10 | namespace core { 11 | 12 | using namespace tflite; 13 | 14 | struct BMaxPoolParams { 15 | std::int32_t filter_height{0}; 16 | std::int32_t filter_width{0}; 17 | std::int32_t stride_height{0}; 18 | std::int32_t stride_width{0}; 19 | TfLitePaddingValues padding{}; 20 | TfLitePadding padding_type{}; 21 | }; 22 | 23 | // Effectively takes the AND of everything in the filter region 24 | void BMaxPool(const BMaxPoolParams& params, const RuntimeShape& input_shape, 25 | const TBitpacked* input_data, const RuntimeShape& output_shape, 26 | TBitpacked* output_data) { 27 | TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); 28 | TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); 29 | const int batches = MatchingDim(input_shape, 0, output_shape, 0); 30 | const int input_height = input_shape.Dims(1); 31 | const int input_width = input_shape.Dims(2); 32 | const int output_height = output_shape.Dims(1); 33 | const int output_width = output_shape.Dims(2); 34 | const int filter_height = params.filter_height; 35 | const int filter_width = params.filter_width; 36 | const int stride_height = params.stride_height; 37 | const int stride_width = params.stride_width; 38 | const int channels = MatchingDim(input_shape, 3, output_shape, 3); 39 | 40 | for (int batch = 0; batch < batches; ++batch) { 41 | for (int out_y = 0; out_y < output_height; ++out_y) { 42 | for (int out_x = 0; out_x < output_width; ++out_x) { 43 | const int in_x_origin = (out_x * stride_width) - params.padding.width; 44 | const int in_y_origin = (out_y * stride_height) - params.padding.height; 45 | // Compute the boundaries of the filter region clamped so as to 46 | // ensure that the filter window fits in the input array. 47 | const int filter_x_start = std::max(0, -in_x_origin); 48 | const int filter_y_start = std::max(0, -in_y_origin); 49 | 50 | const int in_x = in_x_origin + filter_x_start; 51 | const int in_y = in_y_origin + filter_y_start; 52 | 53 | const int filter_x_count = 54 | std::min(filter_width - filter_x_start, input_width - in_x); 55 | const int filter_y_count = 56 | std::min(filter_height - filter_y_start, input_height - in_y); 57 | 58 | // How far to jump to the next input pixel in the x direction 59 | const int x_stride = channels; 60 | // How far to jump to the next input pixel in the y direction, and 61 | // 'back' to the first pixel in the x direction. 62 | const int y_stride = channels * (input_width - filter_x_count); 63 | 64 | // Get the pointer to the input pixel corresponding top-left filter 65 | // corner, channel 0 66 | const TBitpacked* in_base = 67 | &input_data[Offset(input_shape, batch, in_y, in_x, 0)]; 68 | TBitpacked* out_ptr = 69 | &output_data[Offset(output_shape, batch, out_y, out_x, 0)]; 70 | 71 | for (int channel = 0; channel < channels; ++channel) { 72 | const TBitpacked* in_ptr = in_base + channel; 73 | 74 | // Start with all ones 75 | TBitpacked max = ~TBitpacked(0); 76 | for (int y = 0; y < filter_y_count; ++y) { 77 | for (int x = 0; x < filter_x_count; ++x) { 78 | max &= *in_ptr; 79 | in_ptr += x_stride; 80 | } 81 | in_ptr += y_stride; 82 | } 83 | *out_ptr++ = max; 84 | } 85 | } 86 | } 87 | } 88 | } 89 | 90 | } // namespace core 91 | } // namespace compute_engine 92 | 93 | #endif // COMPUTE_ENGINE_CORE_MAXPOOL_H_ 94 | -------------------------------------------------------------------------------- /larq_compute_engine/core/indirect_bgemm/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # Apache 2.0 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_library( 6 | name = "kernels", 7 | srcs = [ 8 | "kernel_4x2_portable.h", 9 | "kernel_8x4x1_aarch64.h", 10 | "kernel_8x4x2_aarch64.h", 11 | "kernel_8x4x4_aarch64.h", 12 | ], 13 | hdrs = [ 14 | "kernel.h", 15 | "select_kernel.h", 16 | ], 17 | deps = [ 18 | "//larq_compute_engine/core:types", 19 | "//larq_compute_engine/core/bconv2d:output_transform", 20 | "//larq_compute_engine/core/bconv2d:params", 21 | "//larq_compute_engine/core/bitpacking:bitpack", 22 | "@org_tensorflow//tensorflow/lite/kernels/internal:types", 23 | "@ruy//ruy/profiler:instrumentation", 24 | ], 25 | ) 26 | -------------------------------------------------------------------------------- /larq_compute_engine/core/types.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTE_ENGINE_CORE_TYPES_H_ 2 | #define COMPUTE_ENGINE_CORE_TYPES_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "tensorflow/lite/kernels/internal/cppmath.h" 10 | 11 | namespace compute_engine { 12 | namespace core { 13 | 14 | #if defined(__GNUC__) 15 | #define LCE_LIKELY(condition) (__builtin_expect(!!(condition), 1)) 16 | #define LCE_UNLIKELY(condition) (__builtin_expect(condition, 0)) 17 | #else 18 | #define LCE_LIKELY(condition) (condition) 19 | #define LCE_UNLIKELY(condition) (condition) 20 | #endif 21 | 22 | #if defined(__GNUC__) 23 | #define FORCE_INLINE __attribute__((always_inline)) inline 24 | #else 25 | #define FORCE_INLINE inline 26 | #endif 27 | 28 | // Check that 0 <= index < limit using a single comparison, assuming 29 | // that 0 <= limit if Index is signed. Intended for use in performance 30 | // critical contexts where 0 <= index < limit is almost always true. 31 | inline bool FastBoundsCheck(const int index, const int limit) { 32 | return LCE_LIKELY((unsigned)index < (unsigned)limit); 33 | } 34 | 35 | // In our kernels we may occasionally read (but never write) beyond the end of 36 | // an array. This is the maximum number of extra bytes that will be read, and 37 | // should be added as padding to the end of tensor allocations. 38 | #define LCE_EXTRA_BYTES 16 39 | 40 | // Define these once here, so they can be included everywhere. 41 | using TBitpacked = std::int32_t; 42 | constexpr std::size_t bitpacking_bitwidth = 43 | std::numeric_limits::type>::digits; 44 | 45 | inline int xor_popcount(const TBitpacked& a, const TBitpacked& b) { 46 | return std::bitset(a ^ b).count(); 47 | } 48 | 49 | // Clamp an int32 value to int8 range 50 | inline std::int8_t saturate(std::int32_t x) { 51 | #ifdef __arm__ 52 | std::int8_t y; 53 | asm("ssat %[y], #8, %[x]\n" : [y] "=r"(y) : [x] "r"(x)); 54 | return y; 55 | #else 56 | x = std::min(x, std::numeric_limits::max()); 57 | x = std::max(x, std::numeric_limits::lowest()); 58 | return static_cast(x); 59 | #endif 60 | } 61 | 62 | // arithmetic right shift and clamp an int32 value to int8 range 63 | template 64 | inline std::int8_t shift_saturate(std::int32_t x) { 65 | #ifdef __arm__ 66 | std::int8_t y; 67 | asm("ssat %[y], #8, %[x], asr %[shift]\n" 68 | : [y] "=r"(y) 69 | : [x] "r"(x), [shift] "i"(shift)); 70 | return y; 71 | #else 72 | x = x >> shift; 73 | x = std::min(x, std::numeric_limits::max()); 74 | x = std::max(x, std::numeric_limits::lowest()); 75 | return static_cast(x); 76 | #endif 77 | } 78 | 79 | // Round-to-nearest. Handling of ties is allowed to be anything, as discussed in 80 | // https://github.com/tensorflow/tensorflow/issues/25087 81 | inline std::int32_t round(float x) { 82 | #if defined(__thumb__) && defined(__VFP_FP__) && !defined(__SOFTFP__) 83 | // The `vcvtr` instructions follows the IEEE 754 rounding standard which 84 | // rounds halfway points to the nearest *even* integer. 85 | std::int32_t y; 86 | asm("vcvtr.s32.f32 %[x], %[x] \n" 87 | "vmov %[y], %[x] \n" 88 | : [y] "=r"(y) 89 | : [x] "t"(x)); // The "t" means `x` will be in an FPU register 90 | return y; 91 | #else 92 | return ::tflite::TfLiteRound(x); 93 | #endif 94 | } 95 | 96 | template 97 | constexpr T CeilDiv(T a, S b) { 98 | return (a + b - 1) / b; 99 | } 100 | 101 | template 102 | constexpr T Ceil(T a, S b) { 103 | return CeilDiv(a, b) * b; 104 | } 105 | 106 | } // namespace core 107 | } // namespace compute_engine 108 | 109 | #endif // COMPUTE_ENGINE_CORE_TYPES_H_ 110 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larq/compute-engine/b8ec518b773302b2f32c2460cfa7e0267d9dbea0/larq_compute_engine/mlir/__init__.py -------------------------------------------------------------------------------- /larq_compute_engine/mlir/ir/lce_ops.cc: -------------------------------------------------------------------------------- 1 | #include "larq_compute_engine/mlir/ir/lce_ops.h" 2 | 3 | #include "flatbuffers/flexbuffers.h" 4 | #include "larq_compute_engine/core/bitpacking/bitpack.h" 5 | #include "larq_compute_engine/mlir/transforms/bitpack.h" 6 | #include "mlir/Dialect/Arith/IR/Arith.h" 7 | #include "tensorflow/lite/schema/schema_generated.h" 8 | 9 | // Generated dialect defs. 10 | #include "larq_compute_engine/mlir/ir/lce_dialect.cc.inc" 11 | 12 | static tflite::Padding ConvertPaddingAttr(llvm::StringRef str) { 13 | return llvm::StringSwitch(str) 14 | .Case("SAME", tflite::Padding_SAME) 15 | .Case("VALID", tflite::Padding_VALID); 16 | } 17 | 18 | static tflite::ActivationFunctionType ConvertActivationAttr( 19 | llvm::StringRef str) { 20 | return llvm::StringSwitch(str) 21 | .Case("NONE", tflite::ActivationFunctionType_NONE) 22 | .Case("RELU", tflite::ActivationFunctionType_RELU) 23 | .Case("RELU_N1_TO_1", tflite::ActivationFunctionType_RELU_N1_TO_1) 24 | .Case("RELU6", tflite::ActivationFunctionType_RELU6); 25 | } 26 | 27 | #define GET_OP_CLASSES 28 | #include "larq_compute_engine/mlir/ir/lce_ops.cc.inc" 29 | 30 | namespace mlir { 31 | namespace lq { 32 | 33 | std::vector QuantizeOp::buildCustomOptions() { return {}; } 34 | std::vector DequantizeOp::buildCustomOptions() { return {}; } 35 | 36 | std::vector Bconv2dOp::buildCustomOptions() { 37 | flexbuffers::Builder fbb; 38 | fbb.Map([&]() { 39 | fbb.Int("channels_in", getChannelsIn()); 40 | fbb.Int("dilation_height_factor", getDilationHeightFactor()); 41 | fbb.Int("dilation_width_factor", getDilationWidthFactor()); 42 | fbb.Int("fused_activation_function", 43 | (int)ConvertActivationAttr(getFusedActivationFunction())); 44 | fbb.Int("pad_values", getPadValues()); 45 | fbb.Int("padding", (int)ConvertPaddingAttr(getPadding())); 46 | fbb.Int("stride_height", getStrideHeight()); 47 | fbb.Int("stride_width", getStrideWidth()); 48 | }); 49 | fbb.Finish(); 50 | return fbb.GetBuffer(); 51 | } 52 | 53 | std::vector BMaxPool2dOp::buildCustomOptions() { 54 | flexbuffers::Builder fbb; 55 | fbb.Map([&]() { 56 | fbb.Int("padding", (int)ConvertPaddingAttr(getPadding())); 57 | fbb.Int("stride_width", getStrideWidth()); 58 | fbb.Int("stride_height", getStrideHeight()); 59 | fbb.Int("filter_width", getFilterWidth()); 60 | fbb.Int("filter_height", getFilterHeight()); 61 | }); 62 | fbb.Finish(); 63 | return fbb.GetBuffer(); 64 | } 65 | 66 | void QuantizeOp::build(OpBuilder& builder, OperationState& state, Value x) { 67 | state.addOperands(x); 68 | const auto existing_shape = x.getType().cast().getShape(); 69 | const auto channels = existing_shape[existing_shape.size() - 1]; 70 | std::vector shape = existing_shape.drop_back(); 71 | shape.push_back(compute_engine::core::bitpacking::GetBitpackedSize(channels)); 72 | state.addTypes(RankedTensorType::get(shape, builder.getIntegerType(32))); 73 | } 74 | 75 | OpFoldResult QuantizeOp::fold(FoldAdaptor adaptor) { 76 | auto operands = adaptor.getOperands(); 77 | mlir::OpBuilder builder(getOperation()); 78 | if (!operands[0]) return nullptr; 79 | return mlir::TFL::Bitpack(&builder, operands[0]); 80 | } 81 | 82 | OpFoldResult DequantizeOp::fold(FoldAdaptor adaptor) { 83 | auto operands = adaptor.getOperands(); 84 | auto result_type = getType().cast(); 85 | if (!operands[0]) return nullptr; 86 | return mlir::TFL::Unpack(operands[0], result_type); 87 | } 88 | 89 | void LarqDialect::initialize() { 90 | addOperations< 91 | #define GET_OP_LIST 92 | #include "larq_compute_engine/mlir/ir/lce_ops.cc.inc" 93 | >(); 94 | } 95 | 96 | Operation* LarqDialect::materializeConstant(OpBuilder& builder, Attribute value, 97 | Type type, Location loc) { 98 | if (arith::ConstantOp::isBuildableWith(value, type)) 99 | return builder.create(loc, type, 100 | cast(value)); 101 | return nullptr; 102 | } 103 | } // namespace lq 104 | } // namespace mlir 105 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/ir/lce_ops.h: -------------------------------------------------------------------------------- 1 | #ifndef LARQ_COMPUTE_ENGINE_MLIR_IR_LCE_OPS_H_ 2 | #define LARQ_COMPUTE_ENGINE_MLIR_IR_LCE_OPS_H_ 3 | 4 | #include "mlir/Bytecode/BytecodeOpInterface.h" 5 | #include "mlir/Dialect/Quant/QuantTypes.h" 6 | #include "mlir/Interfaces/SideEffectInterfaces.h" 7 | 8 | // clang-format off 9 | #include "larq_compute_engine/mlir/ir/lce_dialect.h.inc" 10 | // clang-format on 11 | 12 | #define GET_OP_CLASSES 13 | #include "larq_compute_engine/mlir/ir/lce_ops.h.inc" 14 | 15 | #endif // LARQ_COMPUTE_ENGINE_MLIR_IR_LCE_OPS_H_ 16 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/lce_mlir_opt.cc: -------------------------------------------------------------------------------- 1 | #include "larq_compute_engine/mlir/ir/lce_ops.h" 2 | #include "mlir/Dialect/Func/IR/FuncOps.h" 3 | #include "mlir/Dialect/Quant/QuantOps.h" 4 | #include "mlir/Tools/mlir-opt/MlirOptMain.h" 5 | #include "mlir/Transforms/Passes.h" 6 | #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" 7 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" 8 | 9 | int main(int argc, char** argv) { 10 | mlir::registerTransformsPasses(); 11 | mlir::DialectRegistry registry; 12 | registry.insert(); 15 | return failed(mlir::MlirOptMain( 16 | argc, argv, "Larq Compute Engine pass driver\n", registry)); 17 | } 18 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larq/compute-engine/b8ec518b773302b2f32c2460cfa7e0267d9dbea0/larq_compute_engine/mlir/python/__init__.py -------------------------------------------------------------------------------- /larq_compute_engine/mlir/python/common.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "larq_compute_engine/mlir/transforms/passes.h" 4 | #include "mlir/IR/BuiltinOps.h" 5 | #include "mlir/IR/MLIRContext.h" 6 | #include "mlir/Pass/Pass.h" 7 | #include "pybind11/pybind11.h" 8 | #include "tensorflow/core/platform/status.h" 9 | #include "tensorflow/core/public/session.h" 10 | 11 | namespace tensorflow { 12 | 13 | LCETarget GetLCETarget(const std::string& target_str); 14 | 15 | Status GetNumInputs(mlir::OwningOpRef* module, int* num_inputs); 16 | 17 | pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer( 18 | mlir::OwningOpRef* module, mlir::MLIRContext& context, 19 | const LCETarget target, const pybind11::object& default_ranges, 20 | const std::unordered_set& saved_model_tags, 21 | llvm::StringRef saved_model_dir, 22 | std::optional session, const int num_inputs, 23 | const bool should_quantize, const bool mark_as_post_training_quant); 24 | 25 | } // namespace tensorflow 26 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/python/converter_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import unittest 3 | from unittest import mock 4 | 5 | import tensorflow as tf 6 | import larq as lq 7 | from tensorflow.python.eager import context 8 | 9 | sys.modules["larq_compute_engine.mlir._tf_tfl_flatbuffer"] = mock.MagicMock() 10 | sys.modules["larq_compute_engine.tflite.python.interpreter_wrapper_lite"] = ( 11 | mock.MagicMock() 12 | ) 13 | sys.modules["larq_compute_engine.mlir.python.tflite_schema"] = mock.MagicMock() 14 | 15 | from larq_compute_engine.mlir.python.converter import convert_keras_model 16 | from larq_compute_engine.mlir._tf_tfl_flatbuffer import ( 17 | convert_saved_model_to_tflite_flatbuffer as mocked_saved_model_converter, 18 | ) 19 | 20 | 21 | def get_test_model(): 22 | """Model taken from https://docs.larq.dev/larq/tutorials/mnist/#create-the-model.""" 23 | 24 | # All quantized layers except the first will use the same options 25 | kwargs = { 26 | "input_quantizer": "ste_sign", 27 | "kernel_quantizer": "ste_sign", 28 | "kernel_constraint": "weight_clip", 29 | } 30 | 31 | model = tf.keras.models.Sequential() 32 | 33 | # In the first layer we only quantize the weights and not the input 34 | model.add( 35 | lq.layers.QuantConv2D( 36 | 32, 37 | (3, 3), 38 | kernel_quantizer="ste_sign", 39 | kernel_constraint="weight_clip", 40 | use_bias=False, 41 | input_shape=(28, 28, 1), 42 | ) 43 | ) 44 | model.add(tf.keras.layers.MaxPooling2D((2, 2))) 45 | model.add(tf.keras.layers.BatchNormalization(scale=False)) 46 | 47 | model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs)) 48 | model.add(tf.keras.layers.MaxPooling2D((2, 2))) 49 | model.add(tf.keras.layers.BatchNormalization(scale=False)) 50 | 51 | model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs)) 52 | model.add(tf.keras.layers.BatchNormalization(scale=False)) 53 | model.add(tf.keras.layers.Flatten()) 54 | 55 | model.add(lq.layers.QuantDense(64, use_bias=False, **kwargs)) 56 | model.add(tf.keras.layers.BatchNormalization(scale=False)) 57 | model.add(lq.layers.QuantDense(10, use_bias=False, **kwargs)) 58 | model.add(tf.keras.layers.BatchNormalization(scale=False)) 59 | model.add(tf.keras.layers.Activation("softmax")) 60 | return model 61 | 62 | 63 | class TestConverter(unittest.TestCase): 64 | def test_model(self): 65 | with context.eager_mode(): 66 | model = get_test_model() 67 | convert_keras_model(model) 68 | mocked_saved_model_converter.assert_called_once_with( 69 | mock.ANY, ["serve"], ["serving_default"], 1, "arm", None 70 | ) 71 | 72 | def test_wrong_arg(self): 73 | with self.assertRaises(ValueError): 74 | convert_keras_model("./model.h5") 75 | 76 | def test_target_arg(self): 77 | with context.eager_mode(): 78 | model = get_test_model() 79 | 80 | # These should work 81 | convert_keras_model(model, target="arm") 82 | convert_keras_model(model, target="xcore") 83 | 84 | # Anything else shouldn't 85 | with self.assertRaises( 86 | ValueError, msg='Expected `target` to be "arm" or "xcore"' 87 | ): 88 | convert_keras_model(model, target="x86") 89 | 90 | 91 | if __name__ == "__main__": 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "larq_compute_engine/mlir/python/common.h" 4 | #include "larq_compute_engine/mlir/tf_tfl_passes.h" 5 | #include "larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h" 6 | #include "mlir/IR/MLIRContext.h" 7 | #include "pybind11/pybind11.h" 8 | #include "tensorflow/compiler/mlir/lite/transforms/passes.h" 9 | #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" 10 | #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" 11 | #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" 12 | 13 | namespace tensorflow { 14 | 15 | pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer( 16 | const pybind11::bytes& graphdef_bytes, 17 | const std::vector& input_arrays, 18 | const std::vector& input_dtypes, 19 | const std::vector>& input_shapes, 20 | const std::vector& output_arrays, const bool should_quantize, 21 | const std::string& target_str, const pybind11::object& default_ranges) { 22 | GraphDef graphdef; 23 | if (!tensorflow::LoadProtoFromBuffer(std::string(graphdef_bytes), &graphdef) 24 | .ok()) { 25 | throw std::runtime_error("Could not load GraphDef."); 26 | } 27 | 28 | auto target = GetLCETarget(target_str); 29 | 30 | // Convert empty shapes to `None`. We could also do that on the python side. 31 | std::vector>> translated_input_shapes; 32 | for (auto x : input_shapes) { 33 | if (x.size() > 0) { 34 | translated_input_shapes.push_back(x); 35 | } else { 36 | translated_input_shapes.push_back(std::nullopt); 37 | } 38 | } 39 | 40 | GraphImportConfig specs; 41 | specs.prune_unused_nodes = true; 42 | specs.convert_legacy_fed_inputs = true; 43 | specs.graph_as_function = false; 44 | specs.upgrade_legacy = true; 45 | if (!ParseInputArrayInfo(input_arrays, input_dtypes, translated_input_shapes, 46 | &specs.inputs) 47 | .ok()) { 48 | throw std::runtime_error("Could not parse input arrays."); 49 | } 50 | if (!ParseOutputArrayInfo(output_arrays, &specs.outputs).ok()) { 51 | throw std::runtime_error("Could not parse output arrays."); 52 | } 53 | 54 | mlir::MLIRContext context; 55 | GraphDebugInfo debug_info; 56 | mlir::StatusScopedDiagnosticHandler statusHandler(&context, 57 | /*propagate=*/true); 58 | auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context); 59 | 60 | if (!module.ok()) { 61 | throw std::runtime_error("Could not convert GraphDef."); 62 | } 63 | 64 | return ConvertMLIRModuleToTFLiteFlatBuffer( 65 | &module.value(), context, target, default_ranges, 66 | /*saved_model_tags=*/{}, 67 | /*saved_model_dir=*/"", /*session=*/std::nullopt, input_arrays.size(), 68 | should_quantize, 69 | /*mark_as_post_training_quant=*/false); 70 | } 71 | 72 | } // namespace tensorflow 73 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/python/pybind_export.cc: -------------------------------------------------------------------------------- 1 | #include "pybind11/pybind11.h" 2 | #include "pybind11/pytypes.h" 3 | #include "pybind11/stl.h" 4 | 5 | namespace tensorflow { 6 | 7 | using std::string; 8 | 9 | pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer( 10 | const pybind11::bytes& graphdef_bytes, 11 | const std::vector& input_arrays, 12 | const std::vector& input_dtypes, 13 | const std::vector>& input_shapes, 14 | const std::vector& output_arrays, const bool should_quantize, 15 | const std::string& target_str, const pybind11::object& default_ranges); 16 | 17 | pybind11::bytes ConvertSavedModelToTFLiteFlatBuffer( 18 | const std::string& saved_model_dir, 19 | const std::vector& saved_model_tags, 20 | const std::vector& exported_names, 21 | const int saved_model_version, const std::string& target_str, 22 | const pybind11::object& default_ranges); 23 | } // namespace tensorflow 24 | 25 | PYBIND11_MODULE(_tf_tfl_flatbuffer, m) { 26 | m.def("convert_graphdef_to_tflite_flatbuffer", 27 | &tensorflow::ConvertGraphDefToTFLiteFlatBuffer); 28 | m.def("convert_saved_model_to_tflite_flatbuffer", 29 | &tensorflow::ConvertSavedModelToTFLiteFlatBuffer); 30 | }; 31 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/python/saved_model_tfl_flatbuffer.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | Modifications copyright (C) 2021 Larq Contributors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | #include "absl/types/span.h" 22 | #include "larq_compute_engine/mlir/python/common.h" 23 | #include "larq_compute_engine/mlir/tf_tfl_passes.h" 24 | #include "mlir/IR/MLIRContext.h" 25 | #include "mlir/Pass/Pass.h" 26 | #include "pybind11/pybind11.h" 27 | #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" 28 | #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" 29 | #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" 30 | #include "tensorflow/core/platform/status.h" 31 | 32 | namespace tensorflow { 33 | 34 | pybind11::bytes ConvertSavedModelToTFLiteFlatBuffer( 35 | const std::string& saved_model_dir, 36 | const std::vector& saved_model_tags, 37 | const std::vector& exported_names, 38 | const int saved_model_version, const std::string& target_str, 39 | const pybind11::object& default_ranges) { 40 | mlir::MLIRContext context; 41 | Status status; 42 | 43 | auto target = GetLCETarget(target_str); 44 | 45 | if (exported_names.empty()) { 46 | throw std::runtime_error("Need at least one exported name."); 47 | } 48 | 49 | tensorflow::GraphImportConfig specs; 50 | specs.upgrade_legacy = true; 51 | 52 | absl::Span custom_opdefs; 53 | 54 | // Register all custom ops, including user-specified custom ops. 55 | const toco::TocoFlags toco_flags; 56 | status = internal::RegisterAllCustomOps(toco_flags); 57 | if (!status.ok()) { 58 | throw std::runtime_error(std::string(status.message())); 59 | } 60 | 61 | // Some weirdness required to convert the vector to an 62 | // absl::Span 63 | auto exported_names_vector = 64 | std::vector(exported_names.begin(), exported_names.end()); 65 | absl::Span exported_names_span(exported_names_vector); 66 | 67 | std::unordered_set tags(saved_model_tags.begin(), 68 | saved_model_tags.end()); 69 | 70 | auto bundle = std::make_unique(); 71 | auto module = 72 | ImportSavedModel(saved_model_dir, saved_model_version, tags, 73 | custom_opdefs, exported_names_span, specs, 74 | /*enable_variable_lifting=*/true, &context, &bundle); 75 | 76 | if (!module.ok()) { 77 | throw std::runtime_error("Could not import SavedModel."); 78 | } 79 | 80 | int num_inputs = 0; 81 | status = GetNumInputs(&module.value(), &num_inputs); 82 | if (!status.ok()) { 83 | throw std::runtime_error(std::string(status.message())); 84 | } 85 | 86 | return ConvertMLIRModuleToTFLiteFlatBuffer( 87 | &module.value(), context, target, default_ranges, tags, saved_model_dir, 88 | bundle ? bundle->GetSession() : nullptr, num_inputs, 89 | /*should_quantize=*/true, 90 | /*mark_as_post_training_quant=*/true); 91 | } 92 | 93 | } // namespace tensorflow 94 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/tests/BUILD: -------------------------------------------------------------------------------- 1 | load("//larq_compute_engine/mlir/tests:lit_test.bzl", "lce_lit_test_suite") 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | licenses = ["notice"], # Apache 2.0 6 | ) 7 | 8 | lce_lit_test_suite( 9 | name = "lit", 10 | srcs = glob(["*.mlir"]), 11 | data = [ 12 | "//larq_compute_engine/mlir:lce-tf-opt", 13 | "@llvm-project//llvm:FileCheck", 14 | ], 15 | ) 16 | 17 | test_suite( 18 | name = "all", 19 | tests = [ 20 | ":lit", 21 | ], 22 | ) 23 | 24 | cc_test( 25 | name = "lce_ops_options_test", 26 | srcs = ["lce_ops_options_test.cc"], 27 | linkopts = ["-ldl"], 28 | deps = [ 29 | "//larq_compute_engine/mlir:larq_compute_engine", 30 | "@com_google_googletest//:gtest_main", 31 | "@flatbuffers", 32 | "@llvm-project//mlir:IR", 33 | "@org_tensorflow//tensorflow/lite/schema:schema_fbs", 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/tests/bitpack-weights.mlir: -------------------------------------------------------------------------------- 1 | // RUN: lce-tf-opt %s -tfl-lce-bitpack-weights -verify-diagnostics | FileCheck %s 2 | 3 | // CHECK-LABEL: @bitpack_bconv2d_filters 4 | func.func @bitpack_bconv2d_filters(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> { 5 | %cst = arith.constant dense<1.0> : tensor<16x3x3x3xf32> 6 | %0 = "lq.Bconv2d"(%arg0, %cst, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> 7 | return %0 : tensor<256x30x30x16xf32> 8 | 9 | // CHECK: %cst = arith.constant dense<0> : tensor<16x3x3x1xi32> 10 | // CHECK: %0 = "lq.Bconv2d"(%arg0, %cst, %arg1, %arg2, %arg3) <{channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32}> : (tensor<256x32x32x1xi32>, tensor<16x3x3x1xi32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> 11 | // CHECK-NEXT: return %0 12 | } 13 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/tests/const-fold.mlir: -------------------------------------------------------------------------------- 1 | // RUN: lce-tf-opt %s -canonicalize | FileCheck %s 2 | 3 | // CHECK-LABEL: @quantize 4 | func.func @quantize() -> (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) { 5 | %pos = arith.constant dense< 0.5> : tensor<1x1x2x32xf32> 6 | %neg = arith.constant dense<-0.5> : tensor<1x1x2x32xf32> 7 | %0 = "lq.Quantize"(%pos) {} : (tensor<1x1x2x32xf32>) -> tensor<1x1x2x1xi32> 8 | %1 = "lq.Quantize"(%neg) {} : (tensor<1x1x2x32xf32>) -> tensor<1x1x2x1xi32> 9 | return %0, %1 : tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32> 10 | 11 | // CHECK: %[[pos:.*]] = arith.constant dense<0> : tensor<1x1x2x1xi32> 12 | // CHECK: %[[neg:.*]] = arith.constant dense<-1> : tensor<1x1x2x1xi32> 13 | // CHECK: return %[[pos]], %[[neg]] : tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32> 14 | } 15 | 16 | // CHECK-LABEL: @dequantize 17 | func.func @dequantize() -> (tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32>) { 18 | %pos = arith.constant dense<0> : tensor<1x1x2x1xi32> 19 | %neg = arith.constant dense<-1> : tensor<1x1x2x1xi32> 20 | %0 = "lq.Dequantize"(%pos) {} : (tensor<1x1x2x1xi32>) -> tensor<1x1x2x32xf32> 21 | %1 = "lq.Dequantize"(%neg) {} : (tensor<1x1x2x1xi32>) -> tensor<1x1x2x32xf32> 22 | return %0, %1 : tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32> 23 | 24 | // CHECK: %[[pos:.*]] = arith.constant dense<1.000000e+00> : tensor<1x1x2x32xf32> 25 | // CHECK: %[[neg:.*]] = arith.constant dense<-1.000000e+00> : tensor<1x1x2x32xf32> 26 | // CHECK: return %[[pos]], %[[neg]] : tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32> 27 | } 28 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/tests/detection_postprocess.mlir: -------------------------------------------------------------------------------- 1 | // RUN: lce-tf-opt %s -detection-postprocess-int -verify-diagnostics | FileCheck %s 2 | 3 | // CHECK-LABEL: detection_postprocess_int 4 | func.func @detection_postprocess_int(%arg0: tensor<1x10x4x!quant.uniform>, %arg1: tensor<1x10x1x!quant.uniform>, %arg2: tensor<10x4x!quant.uniform>) -> (tensor<1x20x4xi32>, tensor<1x20xi32>, tensor<1x20xf32>, tensor<1xi32>) { 5 | %0 = "tfl.dequantize"(%arg0) : (tensor<1x10x4x!quant.uniform>) -> tensor<1x10x4xf32> 6 | %1 = "tfl.dequantize"(%arg1) : (tensor<1x10x1x!quant.uniform>) -> tensor<1x10x1xf32> 7 | %2 = "tfl.dequantize"(%arg2) : (tensor<10x4x!quant.uniform>) -> tensor<10x4xf32> 8 | %3:4 = "tfl.custom"(%0, %1, %2) {custom_code = "TFLite_Detection_PostProcess", custom_option = #tfl} : (tensor<1x10x4xf32>, tensor<1x10x1xf32>, tensor<10x4xf32>) -> (tensor<1x20x4xi32>, tensor<1x20xi32>, tensor<1x20xf32>, tensor<1xi32>) 9 | return %3#0, %3#1, %3#2, %3#3 : tensor<1x20x4xi32>, tensor<1x20xi32>, tensor<1x20xf32>, tensor<1xi32> // boxes, classes, scores, num_detections 10 | 11 | // CHECK: %0:4 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "TFLite_Detection_PostProcess", custom_option = #tfl} : (tensor<1x10x4x!quant.uniform>, tensor<1x10x1x!quant.uniform>, tensor<10x4x!quant.uniform>) -> (tensor<1x20x4xi32>, tensor<1x20xi32>, tensor<1x20x!quant.uniform>, tensor<1xi32>) 12 | // CHECK-NEXT: %1 = "tfl.dequantize"(%0#2) : (tensor<1x20x!quant.uniform>) -> tensor<1x20xf32> 13 | // CHECK-NEXT: return %0#0, %0#1, %1, %0#3 : tensor<1x20x4xi32>, tensor<1x20xi32>, tensor<1x20xf32>, tensor<1xi32> 14 | } 15 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/tests/lce_ops_options_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "flatbuffers/flexbuffers.h" 4 | #include "larq_compute_engine/mlir/ir/lce_ops.h" 5 | #include "mlir/IR/Builders.h" 6 | #include "mlir/IR/OperationSupport.h" 7 | #include "tensorflow/lite/schema/schema_generated.h" 8 | 9 | using namespace mlir; 10 | using namespace tflite; 11 | 12 | IntegerAttr getIntegerAttr(Builder builder, int value) { 13 | return builder.getIntegerAttr(builder.getIntegerType(32), value); 14 | } 15 | 16 | TEST(LCEOpsSerializationTest, QuantizeTest) { 17 | MLIRContext context; 18 | context.getOrLoadDialect(); 19 | OperationState state(UnknownLoc::get(&context), 20 | OperationName("lq.Quantize", &context)); 21 | mlir::Operation* op = Operation::create(state); 22 | 23 | ASSERT_EQ(cast(op).buildCustomOptions().size(), 0); 24 | } 25 | 26 | TEST(LCEOpsSerializationTest, DequantizeTest) { 27 | MLIRContext context; 28 | context.getOrLoadDialect(); 29 | OperationState state(UnknownLoc::get(&context), 30 | OperationName("lq.Dequantize", &context)); 31 | mlir::Operation* op = Operation::create(state); 32 | 33 | ASSERT_EQ(cast(op).buildCustomOptions().size(), 0); 34 | } 35 | 36 | TEST(LCEOpsSerializationTest, BConv2dTest) { 37 | MLIRContext context; 38 | context.getOrLoadDialect(); 39 | Builder builder(&context); 40 | OperationState state(UnknownLoc::get(&context), 41 | OperationName("lq.Bconv2d", &context)); 42 | mlir::Operation* op = Operation::create(state); 43 | 44 | op->setAttr("channels_in", getIntegerAttr(builder, 64)); 45 | op->setAttr("dilation_height_factor", getIntegerAttr(builder, 3)); 46 | op->setAttr("dilation_width_factor", getIntegerAttr(builder, 4)); 47 | op->setAttr("stride_height", getIntegerAttr(builder, 1)); 48 | op->setAttr("stride_width", getIntegerAttr(builder, 2)); 49 | op->setAttr("pad_values", getIntegerAttr(builder, 1)); 50 | 51 | op->setAttr("fused_activation_function", builder.getStringAttr("RELU")); 52 | op->setAttr("padding", builder.getStringAttr("SAME")); 53 | 54 | std::vector v = cast(op).buildCustomOptions(); 55 | const flexbuffers::Map& m = flexbuffers::GetRoot(v).AsMap(); 56 | 57 | ASSERT_EQ(m["channels_in"].AsInt32(), 64); 58 | ASSERT_EQ(m["dilation_height_factor"].AsInt32(), 3); 59 | ASSERT_EQ(m["dilation_width_factor"].AsInt32(), 4); 60 | ASSERT_EQ(m["stride_height"].AsInt32(), 1); 61 | ASSERT_EQ(m["stride_width"].AsInt32(), 2); 62 | ASSERT_EQ(m["pad_values"].AsInt32(), 1); 63 | ASSERT_EQ((ActivationFunctionType)m["fused_activation_function"].AsInt32(), 64 | ActivationFunctionType_RELU); 65 | ASSERT_EQ((Padding)m["padding"].AsInt32(), Padding_SAME); 66 | } 67 | 68 | TEST(LCEOpsSerializationTest, BMaxPool2dTest) { 69 | MLIRContext context; 70 | context.getOrLoadDialect(); 71 | Builder builder(&context); 72 | OperationState state(UnknownLoc::get(&context), 73 | OperationName("lq.BMaxPool2d", &context)); 74 | mlir::Operation* op = Operation::create(state); 75 | 76 | op->setAttr("padding", builder.getStringAttr("SAME")); 77 | op->setAttr("stride_width", getIntegerAttr(builder, 2)); 78 | op->setAttr("stride_height", getIntegerAttr(builder, 1)); 79 | op->setAttr("filter_width", getIntegerAttr(builder, 3)); 80 | op->setAttr("filter_height", getIntegerAttr(builder, 4)); 81 | 82 | std::vector v = cast(op).buildCustomOptions(); 83 | const flexbuffers::Map& m = flexbuffers::GetRoot(v).AsMap(); 84 | 85 | ASSERT_EQ((Padding)m["padding"].AsInt32(), Padding_SAME); 86 | ASSERT_EQ(m["stride_width"].AsInt32(), 2); 87 | ASSERT_EQ(m["stride_height"].AsInt32(), 1); 88 | ASSERT_EQ(m["filter_width"].AsInt32(), 3); 89 | ASSERT_EQ(m["filter_height"].AsInt32(), 4); 90 | } 91 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/tests/legalize-lce.mlir: -------------------------------------------------------------------------------- 1 | // RUN: lce-tf-opt %s -tfl-legalize-lce -verify-diagnostics | FileCheck %s 2 | // RUN: lce-tf-opt %s -tfl-legalize-lce -lce-translate-tfl -verify-diagnostics | FileCheck %s --check-prefix=TRANSLATE 3 | 4 | // CHECK-LABEL: @legalize_bconv2d 5 | func.func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: none) -> tensor<256x30x30x16xf32> { 6 | %0 = "lq.Bconv2d"(%arg0, %arg1, %arg2, %arg3, %arg4) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> 7 | return %0 : tensor<256x30x30x16xf32> 8 | 9 | // CHECK: %0 = "tfl.custom"(%arg0, %arg1, %arg2, %arg3, %arg4) {custom_code = "LceBconv2d", custom_option = #tfl} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> 10 | // CHECK-NEXT: return %0 11 | 12 | // TRANSLATE: %0 = "lq.Bconv2d"(%arg0, %arg1, %arg2, %arg3, %arg4) <{channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32}> : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> 13 | // TRANSLATE-NEXT: return %0 : tensor<256x30x30x16xf32> 14 | } 15 | 16 | // CHECK-LABEL: @legalize_bmax_pool2d 17 | func.func @legalize_bmax_pool2d(%arg0: tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32> { 18 | %0 = "lq.BMaxPool2d"(%arg0) {filter_height = 2 : i32, filter_width = 2 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32} : (tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32> 19 | return %0 : tensor<256x16x16x3xi32> 20 | 21 | // CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "LceBMaxPool2d", custom_option = #tfl} : (tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32> 22 | // CHECK-NEXT: return %0 23 | 24 | // TRANSLATE: %0 = "lq.BMaxPool2d"(%arg0) <{filter_height = 2 : i32, filter_width = 2 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32}> : (tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32> 25 | // TRANSLATE-NEXT: return %0 : tensor<256x16x16x3xi32> 26 | } 27 | 28 | // CHECK-LABEL: @legalize_quantize 29 | func.func @legalize_quantize(%arg0: tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32> { 30 | %0 = "lq.Quantize"(%arg0) {} : (tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32> 31 | return %0 : tensor<256x32x32x2xi32> 32 | 33 | // CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "LceQuantize", custom_option = #tfl} : (tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32> 34 | // CHECK-NEXT: return %0 35 | 36 | // TRANSLATE: %0 = "lq.Quantize"(%arg0) : (tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32> 37 | // TRANSLATE-NEXT: return %0 : tensor<256x32x32x2xi32> 38 | } 39 | 40 | // CHECK-LABEL: @legalize_dequantize 41 | func.func @legalize_dequantize(%arg0: tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32> { 42 | %0 = "lq.Dequantize"(%arg0) {} : (tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32> 43 | return %0 : tensor<256x32x32x64xf32> 44 | 45 | // CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "LceDequantize", custom_option = #tfl} : (tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32> 46 | // CHECK-NEXT: return %0 47 | 48 | // TRANSLATE: %0 = "lq.Dequantize"(%arg0) : (tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32> 49 | // TRANSLATE-NEXT: return %0 : tensor<256x32x32x64xf32> 50 | } 51 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/tests/lit_test.bzl: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # Modifications copyright (C) 2020 Larq Contributors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # https://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Copied from https://github.com/google/iree/blob/master/iree/lit_test.bzl 17 | 18 | """Bazel macros for running lit tests.""" 19 | 20 | def lce_lit_test( 21 | name, 22 | test_file, 23 | data, 24 | size = "small", 25 | driver = "//larq_compute_engine/mlir/tests:run_lit.sh", 26 | **kwargs): 27 | """Creates a lit test from the specified source file. 28 | 29 | Args: 30 | name: name of the generated test suite. 31 | test_file: the test file with the lit test 32 | data: binaries used in the lit tests. 33 | size: size of the tests. 34 | driver: the shell runner for the lit tests. 35 | **kwargs: Any additional arguments that will be passed to the underlying sh_test. 36 | """ 37 | native.sh_test( 38 | name = name, 39 | srcs = [driver], 40 | size = size, 41 | data = data + [test_file], 42 | shard_count = 2, 43 | args = ["$(location %s)" % (test_file,)], 44 | **kwargs 45 | ) 46 | 47 | def lce_lit_test_suite( 48 | name, 49 | data, 50 | srcs, 51 | size = "small", 52 | driver = "//larq_compute_engine/mlir/tests:run_lit.sh", 53 | tags = [], 54 | **kwargs): 55 | """Creates one lit test per source file and a test suite that bundles them. 56 | 57 | Args: 58 | name: name of the generated test suite. 59 | data: binaries used in the lit tests. 60 | srcs: test file sources. 61 | size: size of the tests. 62 | driver: the shell runner for the lit tests. 63 | tags: tags to apply to the test. Note that as in standard test suites, manual 64 | is treated specially and will also apply to the test suite itself. 65 | **kwargs: Any additional arguments that will be passed to the underlying tests and test_suite. 66 | """ 67 | tests = [] 68 | for test_file in srcs: 69 | # It's generally good practice to prefix any generated names with the 70 | # macro name, but we're trying to match the style of the names that are 71 | # used for LLVM internally. 72 | test_name = "%s.test" % (test_file) 73 | lce_lit_test( 74 | name = test_name, 75 | test_file = test_file, 76 | size = size, 77 | data = data, 78 | driver = driver, 79 | tags = tags, 80 | **kwargs 81 | ) 82 | tests.append(test_name) 83 | 84 | native.test_suite( 85 | name = name, 86 | tests = tests, 87 | # Note that only the manual tag really has any effect here. Others are 88 | # used for test suite filtering, but all tests are passed the same tags. 89 | tags = tags, 90 | # If there are kwargs that need to be passed here which only apply to 91 | # the generated tests and not to test_suite, they should be extracted 92 | # into separate named arguments. 93 | **kwargs 94 | ) 95 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/tests/op-removal.mlir: -------------------------------------------------------------------------------- 1 | // RUN: lce-tf-opt %s -lce-op-removal-tf -verify-diagnostics | FileCheck %s 2 | 3 | // CHECK-LABEL: @snapshot 4 | func.func @snapshot(%arg0: tensor<3xi32>) -> tensor<3xi32> { 5 | %0 = "tf.Snapshot"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> 6 | return %0 : tensor<3xi32> 7 | // Should be converted to Identity and then from Identity to value 8 | // CHECK-NEXT: return %arg0 : tensor<3xi32> 9 | } 10 | 11 | // CHECK-LABEL: @stop_gradient 12 | func.func @stop_gradient(%arg0: tensor<3xi32>) -> tensor<3xi32> { 13 | %0 = "tf.StopGradient"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> 14 | return %0 : tensor<3xi32> 15 | // Should be converted to Identity and then from Identity to value 16 | // CHECK-NEXT: return %arg0 : tensor<3xi32> 17 | } 18 | 19 | // CHECK-LABEL: @check_numerics 20 | func.func @check_numerics(%arg0: tensor<3xf32>) -> tensor<3xf32> { 21 | %0 = "tf.CheckNumerics"(%arg0) {message = ""}: (tensor<3xf32>) -> tensor<3xf32> 22 | return %0 : tensor<3xf32> 23 | // Should be converted to Identity and then from Identity to value 24 | // CHECK-NEXT: return %arg0 : tensor<3xf32> 25 | } 26 | 27 | // CHECK-LABEL: @placeholder_with_default 28 | func.func @placeholder_with_default(%arg0: tensor<3xf32>) -> tensor<3xf32> { 29 | %0 = "tf.PlaceholderWithDefault"(%arg0): (tensor<3xf32>) -> tensor<3xf32> 30 | return %0 : tensor<3xf32> 31 | // Should be converted to Identity and then from Identity to value 32 | // CHECK-NEXT: return %arg0 : tensor<3xf32> 33 | } 34 | 35 | // CHECK-LABEL: @identity 36 | func.func @identity(%arg0: tensor<10xi32>, %arg1: tensor<20xi32>, %arg2: tensor<30xi32>) -> (tensor<10xi32>, tensor<20xi32>, tensor<30xi32>) { 37 | %0 = "tf.Identity"(%arg0) : (tensor<10xi32>) -> tensor<10xi32> 38 | %1:2 = "tf.IdentityN"(%arg1,%arg2) : (tensor<20xi32>, tensor<30xi32>) -> (tensor<20xi32>, tensor<30xi32>) 39 | return %0, %1#0, %1#1: tensor<10xi32>, tensor<20xi32>, tensor<30xi32> 40 | 41 | // CHECK-NEXT: return %arg0, %arg1, %arg2 42 | } 43 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/tests/quantize.mlir: -------------------------------------------------------------------------------- 1 | // RUN: lce-tf-opt %s -lce-quantize -verify-diagnostics | FileCheck %s 2 | 3 | // CHECK-LABEL: quantize_bconv2d 4 | func.func @quantize_bconv2d(%arg0: tensor<1x224x224x1xi32>, %arg1: tensor<32x3x3x1xi32>, %arg2: none) -> tensor<1x112x112x32x!quant.uniform> { 5 | %cst0 = arith.constant dense<-1.23697901> : tensor<32xf32> 6 | %0 = "tfl.quantize"(%cst0) {qtype = tensor<32x!quant.uniform>} : (tensor<32xf32>) -> tensor<32x!quant.uniform> 7 | %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) -> tensor<32xf32> 8 | %cst1 = arith.constant dense<1.10976315> : tensor<32xf32> 9 | %2 = "tfl.quantize"(%cst1) {qtype = tensor<32x!quant.uniform>} : (tensor<32xf32>) -> tensor<32x!quant.uniform> 10 | %3 = "tfl.dequantize"(%2) : (tensor<32x!quant.uniform>) -> tensor<32xf32> 11 | %4 = "lq.Bconv2d"(%arg0, %arg1, %1, %3, %arg2) {channels_in = 3 : i32, dilation_height_factor = 2 : i32, dilation_width_factor = 3 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 4 : i32, stride_width = 5 : i32} : (tensor<1x224x224x1xi32>, tensor<32x3x3x1xi32>, tensor<32xf32>, tensor<32xf32>, none) -> tensor<1x112x112x32xf32> 12 | %5 = "tfl.quantize"(%4) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> 13 | return %5 : tensor<1x112x112x32x!quant.uniform> 14 | 15 | // CHECK: %[[cst1:.*]] = arith.constant dense<1.10976315> : tensor<32xf32> 16 | // CHECK: %[[cst0:.*]] = arith.constant dense<-1.23697901> : tensor<32xf32> 17 | // CHECK: %[[conv:.*]] = "lq.Bconv2d"(%arg0, %arg1, %[[cst0]], %[[cst1]], %arg2) 18 | // CHECK: return %[[conv]] : tensor<1x112x112x32x!quant.uniform> 19 | } 20 | 21 | // CHECK-LABEL: quantize_bitpacked_bconv2d 22 | func.func @quantize_bitpacked_bconv2d(%arg0: tensor<1x224x224x1xi32>, %arg1: tensor<32x3x3x1xi32>, %arg2: none, %arg3: none, %arg4: tensor<32xi32>) -> tensor<1x112x112x32x!quant.uniform> { 23 | %0 = "lq.Bconv2d"(%arg0, %arg1, %arg2, %arg3, %arg4) {channels_in = 3 : i32, dilation_height_factor = 2 : i32, dilation_width_factor = 3 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 4 : i32, stride_width = 5 : i32} : (tensor<1x224x224x1xi32>, tensor<32x3x3x1xi32>, none, none, tensor<32xi32>) -> tensor<1x112x112x32xf32> 24 | %1 = "tfl.quantize"(%0) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> 25 | return %1 : tensor<1x112x112x32x!quant.uniform> 26 | 27 | // CHECK-NEXT: %0 = "lq.Bconv2d"(%arg0, %arg1, %arg2, %arg3, %arg4) 28 | // CHECK-NEXT: return %0 : tensor<1x112x112x32x!quant.uniform> 29 | } 30 | 31 | // CHECK-LABEL: quantize_lce_dequantize 32 | func.func @quantize_lce_dequantize(%arg0: tensor<1x112x112x1xi32>) -> tensor<1x112x112x32x!quant.uniform> { 33 | %0 = "lq.Dequantize"(%arg0) : (tensor<1x112x112x1xi32>) -> tensor<1x112x112x32xf32> 34 | %1 = "tfl.quantize"(%0) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> 35 | return %1 : tensor<1x112x112x32x!quant.uniform> 36 | 37 | // CHECK-NEXT: %0 = "lq.Dequantize"(%arg0) : (tensor<1x112x112x1xi32>) -> tensor<1x112x112x32x!quant.uniform> 38 | // CHECK-NEXT: return %0 : tensor<1x112x112x32x!quant.uniform> 39 | } 40 | 41 | // CHECK-LABEL: dequantize_lce_quantize 42 | func.func @dequantize_lce_quantize(%arg0: tensor<1x112x112x32x!quant.uniform>) -> tensor<1x112x112x1xi32> { 43 | %0 = "tfl.dequantize"(%arg0) : (tensor<1x112x112x32x!quant.uniform>) -> tensor<1x112x112x32xf32> 44 | %1 = "lq.Quantize"(%0) : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x1xi32> 45 | return %1 : tensor<1x112x112x1xi32> 46 | 47 | // CHECK: %[[quant:.*]] = "lq.Quantize"(%arg0) : (tensor<1x112x112x32x!quant.uniform>) -> tensor<1x112x112x1xi32> 48 | // CHECK-NEXT: return %[[quant]] : tensor<1x112x112x1xi32> 49 | } 50 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/tests/run_lit.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2019 Google LLC 4 | # Modifications copyright (C) 2020 Larq Contributors. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | # Copied from https://github.com/google/iree/blob/master/iree/tools/run_lit.sh 19 | 20 | set -e 21 | set -o pipefail 22 | 23 | PLATFORM="$(uname -s | tr 'A-Z' 'a-z')" 24 | 25 | function is_macos() { 26 | [[ "${PLATFORM}" == "darwin" ]] 27 | } 28 | 29 | function is_windows() { 30 | # On windows, the shell script is actually running in msys 31 | [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]] 32 | } 33 | 34 | if [ -z "${RUNFILES_DIR}" ]; then 35 | # Some versions of bazel do not set RUNFILES_DIR. Instead they just cd 36 | # into the directory. 37 | RUNFILES_DIR="$PWD" 38 | fi 39 | 40 | function find_executables() { 41 | set -e 42 | local p="$1" 43 | if is_macos; then 44 | # For macOS use -type since the default find doesn't support -xtype 45 | find . -perm +111 -type f -or -type l -print 46 | elif is_windows; then 47 | # For windows, always use the newer -executable find predicate (which is 48 | # not supported by ancient versions of find). 49 | find "${p}" -xtype f -executable -print 50 | else 51 | # For linux, use the perm based executable check, which has been 52 | # supported by find for a very long time. 53 | find "${p}" -xtype f -perm /u=x,g=x,o=x -print 54 | fi 55 | } 56 | 57 | # Bazel helpfully puts all data deps in the ${RUNFILES_DIR}, but 58 | # it unhelpfully preserves the nesting with no way to reason about 59 | # it generically. run_lit expects that anything passed in the runfiles 60 | # can be found on the path for execution. So we just iterate over the 61 | # entries in the MANIFEST and extend the PATH. 62 | SUBPATH="" 63 | for runfile_path in $(find_executables "${RUNFILES_DIR}"); do 64 | # Prepend so that local things override. 65 | EXEDIR="$(dirname ${runfile_path})" 66 | if is_windows; then 67 | cygpath="$(which cygpath)" 68 | EXEDIR="$($cygpath -u "$EXEDIR")" 69 | fi 70 | SUBPATH="${EXEDIR}:$SUBPATH" 71 | done 72 | 73 | echo "run_lit.sh: $1" 74 | echo "PWD=$(pwd)" 75 | 76 | # For each "// RUN:" line, run the command. 77 | runline_matches="$(egrep "^// RUN: " "$1")" 78 | if [ -z "$runline_matches" ]; then 79 | echo "!!! No RUN lines found in test" 80 | exit 1 81 | fi 82 | 83 | echo "$runline_matches" | while read -r runline 84 | do 85 | echo "RUNLINE: $runline" 86 | match="${runline%%// RUN: *}" 87 | command="${runline##// RUN: }" 88 | if [ -z "${command}" ]; then 89 | echo "ERROR: Could not extract command from runline" 90 | exit 1 91 | fi 92 | 93 | # Substitute any embedded '%s' with the file name. 94 | full_command="${command//\%s/$1}" 95 | 96 | # Run it. 97 | export PATH="$SUBPATH" 98 | echo "RUNNING TEST: $full_command" 99 | echo "----------------" 100 | if eval "$full_command"; then 101 | echo "--- COMPLETE ---" 102 | else 103 | echo "!!! ERROR EVALUATING: $full_command" 104 | exit 1 105 | fi 106 | done 107 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/tf_tfl_passes.h: -------------------------------------------------------------------------------- 1 | #ifndef LARQ_COMPUTE_ENGINE_MLIR_TF_TFL_PASSES_H_ 2 | #define LARQ_COMPUTE_ENGINE_MLIR_TF_TFL_PASSES_H_ 3 | 4 | #include 5 | 6 | #include "larq_compute_engine/mlir/transforms/passes.h" 7 | #include "mlir/Pass/PassManager.h" 8 | #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" 9 | 10 | namespace tensorflow { 11 | 12 | void AddPreVariableFreezingTFToLCETFLConversionPasses( 13 | mlir::OpPassManager* pass_manager); 14 | 15 | void AddPostVariableFreezingTFToLCETFLConversionPasses( 16 | llvm::StringRef saved_model_dir, 17 | const mlir::quant::QuantizationSpecs& quant_specs, 18 | mlir::OpPassManager* pass_manager, const LCETarget target = LCETarget::ARM); 19 | 20 | } // namespace tensorflow 21 | 22 | #endif // LARQ_COMPUTE_ENGINE_MLIR_TF_TFL_PASSES_H_ 23 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h: -------------------------------------------------------------------------------- 1 | #ifndef LARQ_COMPUTE_ENGINE_MLIR_TF_TO_TFL_FLATBUFFER_H_ 2 | #define LARQ_COMPUTE_ENGINE_MLIR_TF_TO_TFL_FLATBUFFER_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include "larq_compute_engine/mlir/transforms/passes.h" 8 | #include "mlir/IR/BuiltinOps.h" 9 | #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" 10 | #include "tensorflow/core/public/session.h" 11 | #include "tsl/platform/statusor.h" 12 | 13 | namespace tensorflow { 14 | 15 | // This is a fork of ConvertTFExecutorToTFLOrFlatbuffer to enable custom 16 | // OpOrArgLocNameMapper 17 | // https://github.com/tensorflow/tensorflow/blob/v2.8.0/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h#L60-L78 18 | absl::Status ConvertTFExecutorToTFLOrFlatbuffer( 19 | mlir::ModuleOp module, bool export_to_mlir, const LCETarget target, 20 | mlir::quant::QuantizationSpecs quant_specs, 21 | const std::unordered_set& saved_model_tags, 22 | llvm::StringRef saved_model_dir, 23 | std::optional session, std::string* result); 24 | } // namespace tensorflow 25 | 26 | #endif // LARQ_COMPUTE_ENGINE_MLIR_TF_TO_TFL_FLATBUFFER_H_ 27 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/bitpack.cc: -------------------------------------------------------------------------------- 1 | #include "larq_compute_engine/mlir/transforms/bitpack.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "larq_compute_engine/core/bitpacking/bitpack.h" 7 | #include "larq_compute_engine/core/types.h" 8 | #include "mlir/Dialect/Quant/QuantTypes.h" 9 | 10 | namespace mlir { 11 | namespace TFL { 12 | 13 | using compute_engine::core::bitpacking_bitwidth; 14 | using compute_engine::core::round; 15 | using compute_engine::core::saturate; 16 | using compute_engine::core::TBitpacked; 17 | using namespace compute_engine::core::bitpacking; 18 | 19 | DenseElementsAttr Bitpack(mlir::Builder* builder, Attribute x) { 20 | if (!x) return nullptr; 21 | 22 | // ShapedType is something like tensor<1x2x3xf32> and element_type is f32 23 | auto shaped_type = x.cast().getType().cast(); 24 | auto shape = shaped_type.getShape(); 25 | auto element_type = shaped_type.getElementType(); 26 | 27 | int num_rows = shape[0] * shape[1] * shape[2]; 28 | int unpacked_channels = shape[3]; 29 | int packed_channels = GetBitpackedSize(unpacked_channels); 30 | 31 | std::vector new_values(num_rows * packed_channels); 32 | 33 | if (element_type.isF32()) { 34 | const auto& dense_elements_iter = 35 | x.cast().getValues(); 36 | 37 | std::vector old_values(num_rows * unpacked_channels); 38 | 39 | int i = 0; 40 | for (float x : dense_elements_iter) { 41 | old_values[i++] = x; 42 | } 43 | assert(i == num_rows * unpacked_channels); 44 | 45 | bitpack_matrix(old_values.data(), num_rows, unpacked_channels, 46 | new_values.data()); 47 | } else { 48 | // constant-fold bitpacking int8 tensors is currently not supported 49 | return nullptr; 50 | } 51 | 52 | RankedTensorType out_tensor_type = 53 | RankedTensorType::get({shape[0], shape[1], shape[2], packed_channels}, 54 | builder->getIntegerType(bitpacking_bitwidth)); 55 | 56 | return DenseElementsAttr::get(out_tensor_type, new_values); 57 | } 58 | 59 | DenseElementsAttr Unpack(Attribute x, ShapedType result_type) { 60 | if (!x) return nullptr; 61 | if (!result_type.hasStaticShape()) return nullptr; 62 | 63 | auto input_shape = 64 | x.cast().getType().cast().getShape(); 65 | auto output_shape = result_type.getShape(); 66 | auto output_type = result_type.getElementType(); 67 | 68 | int num_rows = output_shape[0] * output_shape[1] * output_shape[2]; 69 | int unpacked_channels = output_shape[3]; 70 | int packed_channels = GetBitpackedSize(unpacked_channels); 71 | if (input_shape[0] != output_shape[0] || input_shape[1] != output_shape[1] || 72 | input_shape[2] != output_shape[2] || input_shape[3] != packed_channels) { 73 | return nullptr; 74 | } 75 | 76 | std::vector old_values(num_rows * packed_channels); 77 | 78 | const auto& dense_elements_iter = 79 | x.cast().getValues(); 80 | 81 | int i = 0; 82 | for (TBitpacked x : dense_elements_iter) { 83 | old_values[i++] = x; 84 | } 85 | assert(i == num_rows * packed_channels); 86 | 87 | if (output_type.isF32()) { 88 | std::vector new_values(num_rows * unpacked_channels); 89 | 90 | unpack_matrix(old_values.data(), num_rows, unpacked_channels, 91 | new_values.data()); 92 | 93 | return DenseElementsAttr::get(result_type, new_values); 94 | } else { 95 | auto quant_type = output_type.cast(); 96 | const double scale = quant_type.getScale(); 97 | const int zero_point = quant_type.getZeroPoint(); 98 | 99 | std::int8_t zero_bit_result = saturate(zero_point + round(+1.0 / scale)); 100 | std::int8_t one_bit_result = saturate(zero_point + round(-1.0 / scale)); 101 | 102 | std::vector new_values(num_rows * unpacked_channels); 103 | 104 | unpack_matrix(old_values.data(), num_rows, unpacked_channels, 105 | new_values.data(), zero_bit_result, one_bit_result); 106 | 107 | return DenseElementsAttr::get(result_type, new_values); 108 | } 109 | } 110 | 111 | } // namespace TFL 112 | } // namespace mlir 113 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/bitpack.h: -------------------------------------------------------------------------------- 1 | #ifndef LARQ_COMPUTE_ENGINE_MLIR_TRANSFORMS_BITPACK_H_ 2 | #define LARQ_COMPUTE_ENGINE_MLIR_TRANSFORMS_BITPACK_H_ 3 | 4 | #include "mlir/IR/Attributes.h" 5 | #include "mlir/IR/Builders.h" 6 | #include "mlir/IR/BuiltinTypes.h" 7 | 8 | namespace mlir { 9 | namespace TFL { 10 | 11 | DenseElementsAttr Bitpack(mlir::Builder* builder, Attribute x); 12 | 13 | DenseElementsAttr Unpack(Attribute x, ShapedType result_type); 14 | 15 | } // namespace TFL 16 | } // namespace mlir 17 | 18 | #endif 19 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/bitpack_activations_patterns.td: -------------------------------------------------------------------------------- 1 | include "mlir/IR/PatternBase.td" 2 | include "mlir/Dialect/Func/IR/FuncOps.td" 3 | include "mlir/Dialect/Arith/IR/ArithOps.td" 4 | include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" 5 | include "larq_compute_engine/mlir/ir/lce_ops.td" 6 | 7 | 8 | def F32ElementsAttr : ElementsAttrBase< 9 | CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; 10 | 11 | // Checks if the value has only one user. 12 | def HasOneUse : Constraint>; 13 | 14 | def CreateNoneValue : NativeCodeCall< 15 | "$_builder.create($0.getLoc(), $_builder.getUnitAttr())">; 16 | def GetSignsOfVectorAndBroadcast4D : NativeCodeCall<"GetSignsOfVectorAndBroadcast4D($0)">; 17 | def GetBitpackedOutputThresholds : NativeCodeCall<"GetBitpackedOutputThresholds($_builder, $0, $1, $2, $3)">; 18 | 19 | class WriteBitpackedActivationsPat : 20 | Pat<(LQ_QuantizeOp 21 | (LQ_Bconv2dOp:$output 22 | $input, 23 | (Arith_ConstantOp F32ElementsAttr:$filter), 24 | (Arith_ConstantOp F32ElementsAttr:$post_activation_multiplier), 25 | (Arith_ConstantOp F32ElementsAttr:$post_activation_bias), 26 | (TFL_NoValueOp UnitAttr), 27 | $channels_in, 28 | $dilation_height_factor, 29 | $dilation_width_factor, 30 | $fused_activation_function, 31 | ConstantAttr, 32 | padding_type, 33 | $stride_height, 34 | $stride_width)), 35 | (LQ_Bconv2dOp 36 | $input, 37 | (TFL_MulOp 38 | (Arith_ConstantOp $filter), 39 | (Arith_ConstantOp 40 | (GetSignsOfVectorAndBroadcast4D $post_activation_multiplier)), 41 | TFL_AF_None), 42 | (CreateNoneValue $input), 43 | (CreateNoneValue $input), 44 | (Arith_ConstantOp 45 | (GetBitpackedOutputThresholds 46 | $filter, 47 | $post_activation_multiplier, 48 | $post_activation_bias, 49 | $fused_activation_function)), 50 | $channels_in, 51 | $dilation_height_factor, 52 | $dilation_width_factor, 53 | $fused_activation_function, 54 | ConstantAttr, 55 | padding_type, 56 | $stride_height, 57 | $stride_width), 58 | [(HasOneUse $output)]>; 59 | def : WriteBitpackedActivationsPat; 60 | def : WriteBitpackedActivationsPat; 61 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/bitpack_weights.cc: -------------------------------------------------------------------------------- 1 | #include "larq_compute_engine/mlir/ir/lce_ops.h" 2 | #include "larq_compute_engine/mlir/transforms/bitpack.h" 3 | #include "mlir/Dialect/Arith/IR/Arith.h" 4 | #include "mlir/Dialect/Func/IR/FuncOps.h" 5 | #include "mlir/IR/PatternMatch.h" 6 | #include "mlir/Pass/Pass.h" 7 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 8 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" 9 | 10 | namespace mlir { 11 | namespace TFL { 12 | 13 | struct BitpackWeightsLCE 14 | : public PassWrapper> { 15 | llvm::StringRef getArgument() const final { 16 | return "tfl-lce-bitpack-weights"; 17 | } 18 | llvm::StringRef getDescription() const final { 19 | return "Bitpack binary weights"; 20 | } 21 | void runOnOperation() override; 22 | }; 23 | 24 | bool IsConv2DFilter(Attribute filter) { 25 | if (!filter.isa()) return false; 26 | auto filter_type = filter.cast().getType().cast(); 27 | return filter_type.getElementType().isF32() && 28 | filter_type.getShape().size() == 4; 29 | } 30 | 31 | namespace bitpackweights { 32 | #include "larq_compute_engine/mlir/transforms/generated_bitpack_weights.inc" 33 | } // namespace bitpackweights 34 | 35 | void BitpackWeightsLCE::runOnOperation() { 36 | RewritePatternSet patterns(&getContext()); 37 | auto func = getOperation(); 38 | 39 | bitpackweights::populateWithGenerated(patterns); 40 | (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 41 | } 42 | 43 | // Creates an instance of the TensorFlow dialect BitpackWeights pass. 44 | std::unique_ptr> 45 | CreateBitpackWeightsLCEPass() { 46 | return std::make_unique(); 47 | } 48 | 49 | static PassRegistration pass; 50 | 51 | } // namespace TFL 52 | } // namespace mlir 53 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/bitpack_weights_patterns.td: -------------------------------------------------------------------------------- 1 | include "mlir/IR/PatternBase.td" 2 | include "mlir/Dialect/Func/IR/FuncOps.td" 3 | include "mlir/Dialect/Arith/IR/ArithOps.td" 4 | include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" 5 | include "larq_compute_engine/mlir/ir/lce_ops.td" 6 | include "larq_compute_engine/mlir/transforms/op_removal_patterns.td" 7 | 8 | // Bitpack weights 9 | def Conv2DFilter : AttrConstraint>; 10 | def Bitpack : NativeCodeCall<"Bitpack(&$_builder, $0)">; 11 | 12 | def : Pat<(LQ_Bconv2dOp 13 | $input, 14 | (Arith_ConstantOp Conv2DFilter:$filter), 15 | $post_activation_multiplier, 16 | $post_activation_bias, 17 | $output_threshold, 18 | $channels_in, 19 | $dilation_height_factor, 20 | $dilation_width_factor, 21 | $fused_activation_function, 22 | $pad_values, 23 | $padding, 24 | $stride_height, 25 | $stride_width), 26 | (LQ_Bconv2dOp 27 | $input, 28 | (Arith_ConstantOp (Bitpack $filter)), 29 | $post_activation_multiplier, 30 | $post_activation_bias, 31 | $output_threshold, 32 | $channels_in, 33 | $dilation_height_factor, 34 | $dilation_width_factor, 35 | $fused_activation_function, 36 | $pad_values, 37 | $padding, 38 | $stride_height, 39 | $stride_width), 40 | []>; 41 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "mlir/IR/Attributes.h" 4 | #include "mlir/IR/BuiltinAttributes.h" 5 | 6 | namespace mlir { 7 | namespace TFL { 8 | 9 | inline bool IsConstantValue(Attribute values, float expected_value) { 10 | if (!values.isa()) return false; 11 | 12 | for (auto value : values.cast().getValues()) { 13 | if (value != expected_value) return false; 14 | } 15 | return true; 16 | } 17 | 18 | } // namespace TFL 19 | } // namespace mlir 20 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/fuse_padding.cc: -------------------------------------------------------------------------------- 1 | #include "larq_compute_engine/mlir/transforms/padding.h" 2 | #include "mlir/Pass/Pass.h" 3 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 4 | #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" 5 | 6 | namespace mlir { 7 | namespace TFL { 8 | 9 | bool NoBatchAndChannelPadding(Attribute paddings_attr) { 10 | auto paddings = GetValidPadAttr(paddings_attr); 11 | if (!paddings) return false; 12 | return IsNoPadding(paddings, 0) && IsNoPadding(paddings, 3); 13 | } 14 | 15 | // The TFLite op has `stride_height` and `stride_width` as separate attributes. 16 | // Due to a TableGen limitation we can't pass them both in a single call. 17 | bool IsSamePaddingPartial(Attribute paddings_attr, Value input, Value output, 18 | Attribute strides_attr, uint64_t dimension) { 19 | auto paddings = GetValidPadAttr(paddings_attr); 20 | if (!paddings) return false; 21 | auto input_shape = GetShape4D(input); 22 | if (input_shape.empty()) return false; 23 | auto output_shape = GetShape4D(output); 24 | if (output_shape.empty()) return false; 25 | 26 | if (!strides_attr.isa()) return false; 27 | const int stride = strides_attr.cast().getInt(); 28 | 29 | // Check that there is no padding in the batch and channel dimensions 30 | return IsSamePadding1D(paddings, dimension, input_shape[dimension], 31 | output_shape[dimension], stride); 32 | } 33 | 34 | namespace fuse_padding { 35 | #include "larq_compute_engine/mlir/transforms/generated_fuse_padding.inc" 36 | } 37 | 38 | // Prepare LCE operations in functions for subsequent legalization. 39 | struct FusePadding 40 | : public PassWrapper> { 41 | llvm::StringRef getArgument() const final { return "tfl-fuse-padding"; } 42 | llvm::StringRef getDescription() const final { 43 | return "Fuse padding ops into (Depthwise)Convs"; 44 | } 45 | FusePadding() = default; 46 | FusePadding(const FusePadding& pass) {} 47 | 48 | void runOnOperation() override { 49 | auto* ctx = &getContext(); 50 | RewritePatternSet patterns(ctx); 51 | auto func = getOperation(); 52 | fuse_padding::populateWithGenerated(patterns); 53 | (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 54 | } 55 | void getDependentDialects(DialectRegistry& registry) const override { 56 | registry.insert<::mlir::TFL::TensorFlowLiteDialect>(); 57 | } 58 | }; 59 | 60 | // Creates an instance of the TensorFlow dialect FusePadding pass. 61 | std::unique_ptr> CreateFusePaddingPass() { 62 | return std::make_unique(); 63 | } 64 | 65 | static PassRegistration pass; 66 | 67 | } // namespace TFL 68 | } // namespace mlir 69 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/legalize_tflite.cc: -------------------------------------------------------------------------------- 1 | #include "larq_compute_engine/mlir/ir/lce_ops.h" 2 | #include "mlir/IR/PatternMatch.h" 3 | #include "mlir/Pass/Pass.h" 4 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 5 | #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" 6 | 7 | namespace mlir { 8 | namespace TFL { 9 | 10 | struct LegalizeLCE 11 | : public PassWrapper> { 12 | llvm::StringRef getArgument() const final { return "tfl-legalize-lce"; } 13 | llvm::StringRef getDescription() const final { 14 | return "Legalize LCE ops in TensorFlow Lite dialect"; 15 | } 16 | void getDependentDialects(DialectRegistry& registry) const override { 17 | registry.insert(); 18 | } 19 | void runOnOperation() override; 20 | }; 21 | 22 | template 23 | struct LegalizeToCustomOp : public OpRewritePattern { 24 | using OpRewritePattern::OpRewritePattern; 25 | 26 | LogicalResult matchAndRewrite(LarqOp larq_op, 27 | PatternRewriter& rewriter) const override { 28 | std::vector options = larq_op.buildCustomOptions(); 29 | Operation* op = larq_op.getOperation(); 30 | ShapedType type = RankedTensorType::get( 31 | {static_cast(options.size())}, rewriter.getIntegerType(8)); 32 | 33 | std::string options_bytes(options.begin(), options.end()); 34 | auto attr = ConstBytesAttr::get(op->getContext(), options_bytes); 35 | 36 | rewriter.replaceOpWithNewOp( 37 | op, op->getResultTypes(), op->getOperands(), 38 | "Lce" + std::string(LarqOp::getOperationName().drop_front(3)), attr); 39 | return success(); 40 | } 41 | }; 42 | 43 | void LegalizeLCE::runOnOperation() { 44 | RewritePatternSet patterns(&getContext()); 45 | auto* ctx = &getContext(); 46 | auto func = getOperation(); 47 | 48 | patterns.add< 49 | LegalizeToCustomOp, LegalizeToCustomOp, 50 | LegalizeToCustomOp, LegalizeToCustomOp>( 51 | ctx); 52 | 53 | (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 54 | } 55 | 56 | // Creates an instance of the LegalizeLCE pass. 57 | std::unique_ptr> CreateLegalizeLCEPass() { 58 | return std::make_unique(); 59 | } 60 | 61 | static PassRegistration pass; 62 | 63 | } // namespace TFL 64 | } // namespace mlir 65 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/op_removal.cc: -------------------------------------------------------------------------------- 1 | #include "mlir/Dialect/Arith/IR/Arith.h" 2 | #include "mlir/Dialect/Func/IR/FuncOps.h" 3 | #include "mlir/IR/PatternMatch.h" 4 | #include "mlir/Pass/Pass.h" 5 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 6 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" 7 | 8 | namespace mlir { 9 | namespace TFL { 10 | 11 | // Op removal of pass through ops to make following patterns easier and enable 12 | // early constant folding 13 | struct OpRemoval 14 | : public PassWrapper> { 15 | llvm::StringRef getArgument() const final { return "lce-op-removal-tf"; } 16 | llvm::StringRef getDescription() const final { 17 | return "Remove pass through TensorFlow ops"; 18 | } 19 | void runOnOperation() override; 20 | }; 21 | 22 | namespace op_removal { 23 | #include "larq_compute_engine/mlir/transforms/generated_op_removal.inc" 24 | } // namespace op_removal 25 | 26 | void OpRemoval::runOnOperation() { 27 | RewritePatternSet patterns(&getContext()); 28 | auto func = getOperation(); 29 | 30 | op_removal::populateWithGenerated(patterns); 31 | (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 32 | } 33 | 34 | // Creates an instance of the TensorFlow dialect OpRemoval pass. 35 | std::unique_ptr> CreateOpRemovalPass() { 36 | return std::make_unique(); 37 | } 38 | 39 | static PassRegistration pass; 40 | 41 | } // namespace TFL 42 | } // namespace mlir 43 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/op_removal_patterns.td: -------------------------------------------------------------------------------- 1 | include "mlir/IR/PatternBase.td" 2 | include "mlir/Dialect/Func/IR/FuncOps.td" 3 | include "mlir/Dialect/Arith/IR/ArithOps.td" 4 | include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" 5 | 6 | def DenseElementsAttr : ElementsAttrBase< 7 | CPred<"$_self.isa()">, 8 | "non-opaque constant tensor">; 9 | 10 | // Convert to std constant for statically shaped, non-opaque constants. 11 | def : Pat<(TF_ConstOp:$res DenseElementsAttr:$value), (Arith_ConstantOp $value), 12 | [(AnyStaticShapeTensor $res)]>; 13 | 14 | // Partially supported in TFLite, treated as passthrough IdentityOp 15 | def : Pat<(TF_CheckNumericsOp $arg, $msg), (TF_IdentityOp $arg)>; 16 | def : Pat<(TF_SnapshotOp $arg), (TF_IdentityOp $arg)>; 17 | def : Pat<(TF_StopGradientOp $arg), (TF_IdentityOp $arg)>; 18 | def : Pat<(TF_PlaceholderWithDefaultOp $arg), (TF_IdentityOp $arg)>; 19 | 20 | //===----------------------------------------------------------------------===// 21 | // Op removal patterns. 22 | //===----------------------------------------------------------------------===// 23 | def : Pat<(TF_IdentityOp $arg), (replaceWithValue $arg)>; 24 | def : Pat<(TF_IdentityNOp $arg), (replaceWithValue $arg)>; 25 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/optimize_patterns_target_arm.td: -------------------------------------------------------------------------------- 1 | include "larq_compute_engine/mlir/transforms/optimize_patterns_common.td" 2 | 3 | // Insert a binary maxpool if a maxpool is followed by a sign op. 4 | def : Pat<(LQ_QuantizeOp 5 | (TFL_MaxPool2DOp:$pool_output 6 | $input, 7 | $padding, 8 | $stride_w, 9 | $stride_h, 10 | $filter_width, 11 | $filter_height, 12 | $fused_activation_function)), 13 | (LQ_BMaxPool2dOp 14 | (LQ_QuantizeOp $input), 15 | $padding, 16 | $stride_w, 17 | $stride_h, 18 | $filter_width, 19 | $filter_height), 20 | [(HasOneUse $pool_output)]>; 21 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/padding.h: -------------------------------------------------------------------------------- 1 | #ifndef LARQ_COMPUTE_ENGINE_MLIR_TRANSFORMS_PADDING_H_ 2 | #define LARQ_COMPUTE_ENGINE_MLIR_TRANSFORMS_PADDING_H_ 3 | 4 | #include "larq_compute_engine/core/types.h" 5 | #include "mlir/IR/BuiltinAttributes.h" 6 | #include "mlir/IR/BuiltinTypes.h" 7 | #include "mlir/IR/Value.h" 8 | 9 | namespace mlir { 10 | namespace TFL { 11 | 12 | inline DenseElementsAttr GetValidPadAttr(Attribute paddings_attr) { 13 | if (!paddings_attr.isa()) return nullptr; 14 | auto paddings = paddings_attr.cast(); 15 | // The shape should be [4,2] 16 | auto pad_type = paddings.getType(); 17 | if (pad_type.getRank() != 2) return nullptr; 18 | auto pad_shape = pad_type.getShape(); 19 | if (pad_shape[0] != 4 || pad_shape[1] != 2) return nullptr; 20 | return paddings; 21 | } 22 | 23 | using ShapeRefType = ::llvm::ArrayRef; 24 | 25 | inline ShapeRefType GetShape4D(Value tensor) { 26 | auto tensor_type = tensor.getType().dyn_cast(); 27 | if (!tensor_type) return ShapeRefType(); 28 | ShapeRefType tensor_shape = tensor_type.getShape(); 29 | if (tensor_shape.size() != 4) return ShapeRefType(); 30 | return tensor_shape; 31 | } 32 | 33 | inline bool IsSamePadding1D(DenseElementsAttr paddings, uint64_t dimension, 34 | int input_size, int output_size, int stride) { 35 | using compute_engine::core::CeilDiv; 36 | int pad_before = paddings.getValues()[{dimension, 0}]; 37 | int pad_after = paddings.getValues()[{dimension, 1}]; 38 | const int pad_total = pad_before + pad_after; 39 | return (output_size == CeilDiv(input_size, stride)) && 40 | (pad_before == (pad_total / 2)) && 41 | (pad_after == ((pad_total + 1) / 2)); 42 | } 43 | 44 | inline bool IsNoPadding(DenseElementsAttr paddings, uint64_t dimension) { 45 | return paddings.getValues()[{dimension, 0}] == 0 && 46 | paddings.getValues()[{dimension, 1}] == 0; 47 | } 48 | 49 | } // namespace TFL 50 | } // namespace mlir 51 | 52 | #endif 53 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/passes.h: -------------------------------------------------------------------------------- 1 | #ifndef LARQ_COMPUTE_ENGINE_MLIR_PASSES_H_ 2 | #define LARQ_COMPUTE_ENGINE_MLIR_PASSES_H_ 3 | 4 | #include "mlir/Dialect/Func/IR/FuncOps.h" 5 | #include "mlir/Pass/Pass.h" 6 | 7 | enum LCETarget { ARM = 0, XCORE = 1 }; 8 | 9 | namespace mlir { 10 | namespace TFL { 11 | 12 | // Creates an instance of the TensorFlow dialect OpRemoval pass. 13 | std::unique_ptr> CreateOpRemovalPass(); 14 | 15 | // Creates an instance of the TensorFlow dialect PrepareLCE pass. 16 | std::unique_ptr> CreatePrepareLCEPass( 17 | LCETarget target); 18 | 19 | // Creates an instance of the TensorFlow dialect OptimizeLCE pass. 20 | std::unique_ptr> CreateOptimizeLCEPass( 21 | LCETarget target); 22 | 23 | // Creates an instance of the TensorFlow dialect BitpackWeightsLCE pass. 24 | std::unique_ptr> CreateBitpackWeightsLCEPass(); 25 | 26 | // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass. 27 | std::unique_ptr> CreateLCEQuantizePass(); 28 | 29 | // Creates an instance of LegalizeLCE pass. 30 | std::unique_ptr> CreateLegalizeLCEPass(); 31 | 32 | // Creates an instance of the TensorFlow dialect DetectionPostProcess pass. 33 | std::unique_ptr> QuantizeDetectionPostProcessPass(); 34 | 35 | // Creates an instance of the FusePadding pass. 36 | std::unique_ptr> CreateFusePaddingPass(); 37 | 38 | // Creates an instance of TranslateToLCE pass. 39 | std::unique_ptr> CreateTranslateToLCEPass(); 40 | 41 | } // namespace TFL 42 | 43 | // Creates an instance of the TensorFlow dialect SetBatchSize pass 44 | std::unique_ptr> CreateSetBatchSizePass(); 45 | 46 | } // namespace mlir 47 | 48 | #endif // LARQ_COMPUTE_ENGINE_MLIR_PASSES_H_ 49 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/prepare_patterns_target_arm.td: -------------------------------------------------------------------------------- 1 | include "larq_compute_engine/mlir/transforms/prepare_patterns_common.td" 2 | 3 | // On ARM we support 'same-zero' padding. 4 | def : PrepareBConvPadValue0Pat; 5 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/quantize.cc: -------------------------------------------------------------------------------- 1 | // This transformation pass applies quantization on Larq dialect. 2 | 3 | #include "larq_compute_engine/mlir/ir/lce_ops.h" 4 | #include "mlir/IR/PatternMatch.h" 5 | #include "mlir/Pass/Pass.h" 6 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 7 | #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" 8 | 9 | namespace mlir { 10 | namespace TFL { 11 | 12 | //===----------------------------------------------------------------------===// 13 | // The actual Quantize Pass. 14 | // 15 | // Applies quantization on the model in TFL dialect. 16 | struct LCEQuantizePass 17 | : public PassWrapper> { 18 | llvm::StringRef getArgument() const final { return "lce-quantize"; } 19 | llvm::StringRef getDescription() const final { 20 | return "Apply hybrid quantization on models in TensorFlow Lite dialect"; 21 | } 22 | void runOnOperation() override; 23 | }; 24 | 25 | namespace lce_quantize { 26 | #include "larq_compute_engine/mlir/transforms/generated_quantize.inc" 27 | } 28 | 29 | void LCEQuantizePass::runOnOperation() { 30 | RewritePatternSet patterns(&getContext()); 31 | auto func = getOperation(); 32 | lce_quantize::populateWithGenerated(patterns); 33 | (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 34 | } 35 | 36 | // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass. 37 | std::unique_ptr> CreateLCEQuantizePass() { 38 | return std::make_unique(); 39 | } 40 | 41 | static PassRegistration pass; 42 | 43 | } // namespace TFL 44 | } // namespace mlir 45 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/quantize_patterns.td: -------------------------------------------------------------------------------- 1 | include "mlir/IR/PatternBase.td" 2 | include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" 3 | include "larq_compute_engine/mlir/ir/lce_ops.td" 4 | 5 | 6 | def : Pat<(LQ_Bconv2dOp 7 | $input, 8 | $filter, 9 | (TFL_DequantizeOp 10 | (TFL_QuantizeOp $post_activation_multiplier, $qtype1)), 11 | (TFL_DequantizeOp 12 | (TFL_QuantizeOp $post_activation_bias, $qtype2)), 13 | $output_threshold, 14 | $channels_in, 15 | $dilation_height_factor, 16 | $dilation_width_factor, 17 | $fused_activation_function, 18 | $pad_values, 19 | $padding, 20 | $stride_height, 21 | $stride_width), 22 | (LQ_Bconv2dOp $input, 23 | $filter, 24 | $post_activation_multiplier, 25 | $post_activation_bias, 26 | $output_threshold, 27 | $channels_in, 28 | $dilation_height_factor, 29 | $dilation_width_factor, 30 | $fused_activation_function, 31 | $pad_values, 32 | $padding, 33 | $stride_height, 34 | $stride_width)>; 35 | 36 | // Checks if the value has only one user. 37 | def HasOneUse : Constraint>; 38 | 39 | def : Pat<(TFL_QuantizeOp 40 | (LQ_Bconv2dOp:$output 41 | $input, 42 | $filter, 43 | $post_activation_multiplier, 44 | $post_activation_bias, 45 | $output_threshold, 46 | $channels_in, 47 | $dilation_height_factor, 48 | $dilation_width_factor, 49 | $fused_activation_function, 50 | $pad_values, 51 | $padding, 52 | $stride_height, 53 | $stride_width), 54 | $qtype), 55 | (LQ_Bconv2dOp $input, 56 | $filter, 57 | $post_activation_multiplier, 58 | $post_activation_bias, 59 | $output_threshold, 60 | $channels_in, 61 | $dilation_height_factor, 62 | $dilation_width_factor, 63 | $fused_activation_function, 64 | $pad_values, 65 | $padding, 66 | $stride_height, 67 | $stride_width), 68 | [(HasOneUse $output)]>; 69 | 70 | def : Pat<(TFL_QuantizeOp (LQ_DequantizeOp:$output $input), $qtype), 71 | (LQ_DequantizeOp $input), 72 | [(HasOneUse $output)]>; 73 | 74 | def : Pat<(LQ_QuantizeOp (TFL_DequantizeOp:$output $input)), 75 | (LQ_QuantizeOp $input), 76 | [(HasOneUse $output)]>; 77 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/set_batch_size.cc: -------------------------------------------------------------------------------- 1 | #include "mlir/Dialect/Func/IR/FuncOps.h" 2 | #include "mlir/Pass/Pass.h" 3 | 4 | // This pass will set the batch dimension of all inputs of the outermost 5 | // function to 1, leaving it to shape inference to do the rest. 6 | 7 | namespace mlir { 8 | 9 | mlir::Type SetBatchSize(mlir::Type type) { 10 | auto tensor_type = type.dyn_cast(); 11 | if (tensor_type && tensor_type.hasRank()) { 12 | auto shape = tensor_type.getShape(); 13 | if (shape.size() > 0 && shape[0] == ShapedType::kDynamic) { 14 | // Create a new shape but set the first dimension to 1 15 | llvm::SmallVector shape_new(shape.begin(), shape.end()); 16 | shape_new[0] = 1; 17 | 18 | return tensor_type.clone(shape_new); 19 | } 20 | } 21 | return nullptr; 22 | } 23 | 24 | struct SetBatchSizePass 25 | : public PassWrapper> { 26 | llvm::StringRef getArgument() const final { return "mlir-setbatchsize"; } 27 | llvm::StringRef getDescription() const final { return "Set batch size to 1"; } 28 | void runOnOperation() override { 29 | mlir::func::FuncOp func = getOperation(); 30 | 31 | // We have to edit both the function signature (mlir::Type) *and* the 32 | // function arguments (mlir::Value) 33 | 34 | // mlir::FunctionType is a TableGen-autogenerated MLIR type 35 | mlir::FunctionType signature = func.getFunctionType(); 36 | 37 | // Create a mutable copy of the input types, since getInputs returns an 38 | // immutable llvm::ArrayRef 39 | std::vector signature_inputs(signature.getInputs()); 40 | 41 | for (auto& input_type : signature_inputs) { 42 | auto new_type = SetBatchSize(input_type); 43 | if (new_type) input_type = new_type; 44 | } 45 | 46 | auto signature_new = mlir::FunctionType::get( 47 | signature.getContext(), signature_inputs, signature.getResults()); 48 | // Set the new signature 49 | func.setFunctionTypeAttr(mlir::TypeAttr::get(signature_new)); 50 | 51 | // Now apply the same change to the mlir::Value objects 52 | for (mlir::BlockArgument arg : func.getArguments()) { 53 | // mlir::BlockArgument is a sublcass of mlir::Value 54 | auto new_type = SetBatchSize(arg.getType()); 55 | if (new_type) arg.setType(new_type); 56 | } 57 | } 58 | }; 59 | 60 | // Creates an instance of the ZeroPointCompatibility pass. 61 | std::unique_ptr> CreateSetBatchSizePass() { 62 | return std::make_unique(); 63 | } 64 | 65 | static PassRegistration pass; 66 | 67 | } // namespace mlir 68 | -------------------------------------------------------------------------------- /larq_compute_engine/mlir/transforms/translate_tflite.cc: -------------------------------------------------------------------------------- 1 | #include "flatbuffers/flexbuffers.h" 2 | #include "larq_compute_engine/mlir/ir/lce_ops.h" 3 | #include "mlir/IR/PatternMatch.h" 4 | #include "mlir/Pass/Pass.h" 5 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 6 | #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" 7 | 8 | static llvm::StringRef ConvertActivationAttr( 9 | tflite::ActivationFunctionType af_type) { 10 | if (af_type == tflite::ActivationFunctionType_NONE) return "NONE"; 11 | if (af_type == tflite::ActivationFunctionType_RELU) return "RELU"; 12 | if (af_type == tflite::ActivationFunctionType_RELU_N1_TO_1) 13 | return "RELU_N1_TO_1"; 14 | if (af_type == tflite::ActivationFunctionType_RELU6) return "RELU6"; 15 | } 16 | 17 | static llvm::StringRef ConvertPaddingAttr(tflite::Padding padding_type) { 18 | if (padding_type == tflite::Padding_SAME) return "SAME"; 19 | if (padding_type == tflite::Padding_VALID) return "VALID"; 20 | } 21 | 22 | namespace mlir { 23 | namespace TFL { 24 | 25 | struct TranslateToLCE 26 | : public PassWrapper> { 27 | llvm::StringRef getArgument() const final { return "lce-translate-tfl"; } 28 | llvm::StringRef getDescription() const final { 29 | return "Translate TFL custom ops to LCE ops"; 30 | } 31 | void getDependentDialects(DialectRegistry& registry) const override { 32 | registry.insert(); 33 | } 34 | void runOnOperation() override; 35 | }; 36 | 37 | struct TranslateToLCEPattern : public OpRewritePattern { 38 | using OpRewritePattern::OpRewritePattern; 39 | 40 | LogicalResult matchAndRewrite(TFL::CustomOp custom_op, 41 | PatternRewriter& rewriter) const override { 42 | auto stringData = custom_op.getCustomOption().getValue(); 43 | 44 | // Replace CustomOp with relevant LarqOp 45 | if (custom_op.getCustomCode() == "LceQuantize") { 46 | rewriter.replaceOpWithNewOp( 47 | custom_op, custom_op->getResultTypes(), custom_op->getOperands()); 48 | } else if (custom_op.getCustomCode() == "LceDequantize") { 49 | rewriter.replaceOpWithNewOp( 50 | custom_op, custom_op->getResultTypes(), custom_op->getOperands()); 51 | } else if (custom_op.getCustomCode() == "LceBMaxPool2d") { 52 | auto map = 53 | flexbuffers::GetRoot((uint8_t*)stringData.data(), stringData.size()) 54 | .AsMap(); 55 | rewriter.replaceOpWithNewOp( 56 | custom_op, custom_op->getResultTypes(), custom_op->getOperand(0), 57 | ConvertPaddingAttr( 58 | static_cast(map["padding"].AsInt32())), 59 | map["stride_width"].AsInt32(), map["stride_height"].AsInt32(), 60 | map["filter_width"].AsInt32(), map["filter_height"].AsInt32()); 61 | } else if (custom_op.getCustomCode() == "LceBconv2d") { 62 | auto map = 63 | flexbuffers::GetRoot((uint8_t*)stringData.data(), stringData.size()) 64 | .AsMap(); 65 | rewriter.replaceOpWithNewOp( 66 | custom_op, custom_op->getResultTypes(), custom_op->getOperand(0), 67 | custom_op->getOperand(1), custom_op->getOperand(2), 68 | custom_op->getOperand(3), custom_op->getOperand(4), 69 | map["channels_in"].AsInt32(), map["dilation_height_factor"].AsInt32(), 70 | map["dilation_width_factor"].AsInt32(), 71 | ConvertActivationAttr(static_cast( 72 | map["fused_activation_function"].AsInt32())), 73 | map["pad_values"].AsInt32(), 74 | ConvertPaddingAttr( 75 | static_cast(map["padding"].AsInt32())), 76 | map["stride_height"].AsInt32(), map["stride_width"].AsInt32()); 77 | } 78 | 79 | return success(); 80 | } 81 | }; 82 | 83 | void TranslateToLCE::runOnOperation() { 84 | RewritePatternSet patterns(&getContext()); 85 | auto* ctx = &getContext(); 86 | auto func = getOperation(); 87 | patterns.add(ctx); 88 | (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 89 | } 90 | 91 | // Creates an instance of the TranslateToLCE pass. 92 | std::unique_ptr> CreateTranslateToLCEPass() { 93 | return std::make_unique(); 94 | } 95 | 96 | static PassRegistration pass; 97 | 98 | } // namespace TFL 99 | } // namespace mlir 100 | -------------------------------------------------------------------------------- /larq_compute_engine/requirements.in: -------------------------------------------------------------------------------- 1 | tensorflow==2.16.1 2 | tf-keras==2.16.0 3 | tensorflow-datasets 4 | larq 5 | tqdm 6 | pytest 7 | googleapis-common-protos<2,>=1.52.0 # dependency of tensorflow-datasets, somehow not picked up by pip-compile 8 | -------------------------------------------------------------------------------- /larq_compute_engine/requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.12 3 | # by the following command: 4 | # 5 | # pip-compile --allow-unsafe --no-emit-index-url --strip-extras larq_compute_engine/requirements.in 6 | # 7 | absl-py==2.1.0 8 | # via 9 | # array-record 10 | # etils 11 | # keras 12 | # tensorboard 13 | # tensorflow 14 | # tensorflow-datasets 15 | # tensorflow-metadata 16 | array-record==0.5.1 17 | # via tensorflow-datasets 18 | astunparse==1.6.3 19 | # via tensorflow 20 | certifi==2024.7.4 21 | # via requests 22 | charset-normalizer==3.3.2 23 | # via requests 24 | click==8.1.7 25 | # via tensorflow-datasets 26 | dm-tree==0.1.8 27 | # via tensorflow-datasets 28 | docstring-parser==0.16 29 | # via simple-parsing 30 | etils==1.9.2 31 | # via 32 | # array-record 33 | # tensorflow-datasets 34 | flatbuffers==24.3.25 35 | # via tensorflow 36 | fsspec==2024.6.0 37 | # via etils 38 | gast==0.5.4 39 | # via tensorflow 40 | google-pasta==0.2.0 41 | # via tensorflow 42 | googleapis-common-protos==1.63.1 43 | # via 44 | # -r requirements.in 45 | # tensorflow-metadata 46 | grpcio==1.64.1 47 | # via 48 | # tensorboard 49 | # tensorflow 50 | h5py==3.11.0 51 | # via 52 | # keras 53 | # tensorflow 54 | idna==3.7 55 | # via requests 56 | immutabledict==4.2.0 57 | # via tensorflow-datasets 58 | importlib-resources==6.4.0 59 | # via etils 60 | iniconfig==2.0.0 61 | # via pytest 62 | keras==3.3.3 63 | # via tensorflow 64 | larq==0.13.3 65 | # via -r requirements.in 66 | libclang==18.1.1 67 | # via tensorflow 68 | markdown==3.6 69 | # via tensorboard 70 | markdown-it-py==3.0.0 71 | # via rich 72 | markupsafe==2.1.5 73 | # via werkzeug 74 | mdurl==0.1.2 75 | # via markdown-it-py 76 | ml-dtypes==0.3.2 77 | # via 78 | # keras 79 | # tensorflow 80 | namex==0.0.8 81 | # via keras 82 | numpy==1.26.4 83 | # via 84 | # etils 85 | # h5py 86 | # keras 87 | # larq 88 | # ml-dtypes 89 | # opt-einsum 90 | # pyarrow 91 | # tensorboard 92 | # tensorflow 93 | # tensorflow-datasets 94 | opt-einsum==3.3.0 95 | # via tensorflow 96 | optree==0.11.0 97 | # via keras 98 | packaging==24.1 99 | # via 100 | # larq 101 | # pytest 102 | # tensorflow 103 | pluggy==1.5.0 104 | # via pytest 105 | promise==2.3 106 | # via tensorflow-datasets 107 | protobuf==4.25.3 108 | # via 109 | # googleapis-common-protos 110 | # tensorboard 111 | # tensorflow 112 | # tensorflow-datasets 113 | # tensorflow-metadata 114 | psutil==5.9.8 115 | # via tensorflow-datasets 116 | pyarrow==16.1.0 117 | # via tensorflow-datasets 118 | pygments==2.18.0 119 | # via rich 120 | pytest==8.2.2 121 | # via -r requirements.in 122 | requests==2.32.3 123 | # via 124 | # tensorflow 125 | # tensorflow-datasets 126 | rich==13.7.1 127 | # via keras 128 | simple-parsing==0.1.5 129 | # via tensorflow-datasets 130 | six==1.16.0 131 | # via 132 | # astunparse 133 | # google-pasta 134 | # promise 135 | # tensorboard 136 | # tensorflow 137 | tensorboard==2.16.2 138 | # via tensorflow 139 | tensorboard-data-server==0.7.2 140 | # via tensorboard 141 | tensorflow==2.16.1 142 | # via 143 | # -r requirements.in 144 | # tf-keras 145 | tensorflow-datasets==4.9.6 146 | # via -r requirements.in 147 | tensorflow-metadata==1.15.0 148 | # via tensorflow-datasets 149 | termcolor==2.4.0 150 | # via 151 | # tensorflow 152 | # tensorflow-datasets 153 | terminaltables==3.1.10 154 | # via larq 155 | tf-keras==2.16.0 156 | # via -r requirements.in 157 | toml==0.10.2 158 | # via tensorflow-datasets 159 | tqdm==4.66.4 160 | # via 161 | # -r requirements.in 162 | # etils 163 | # tensorflow-datasets 164 | typing-extensions==4.12.2 165 | # via 166 | # etils 167 | # optree 168 | # simple-parsing 169 | # tensorflow 170 | urllib3==2.2.2 171 | # via requests 172 | werkzeug==3.0.6 173 | # via tensorboard 174 | wheel==0.43.0 175 | # via astunparse 176 | wrapt==1.16.0 177 | # via 178 | # tensorflow 179 | # tensorflow-datasets 180 | zipp==3.19.2 181 | # via etils 182 | 183 | # The following packages are considered to be unsafe in a requirements file: 184 | setuptools==70.0.0 185 | # via 186 | # tensorboard 187 | # tensorflow 188 | -------------------------------------------------------------------------------- /larq_compute_engine/tests/BUILD: -------------------------------------------------------------------------------- 1 | load("@pypi//:requirements.bzl", tf_requirement = "requirement") 2 | load("@pypi_lce//:requirements.bzl", lce_requirement = "requirement") 3 | load("//larq_compute_engine/tests:qemu_test.bzl", "lce_qemu_test_suite") 4 | 5 | package( 6 | default_visibility = ["//visibility:public"], 7 | licenses = ["notice"], # Apache 2.0 8 | ) 9 | 10 | exports_files([ 11 | "test_arm32_binary.sh", 12 | "test_aarch64_binary.sh", 13 | ]) 14 | 15 | py_test( 16 | name = "end2end_test", 17 | size = "large", 18 | srcs = [ 19 | "end2end_test.py", 20 | "preprocess.py", 21 | ], 22 | deps = [ 23 | "//larq_compute_engine/mlir:converter", 24 | "//larq_compute_engine/tflite/python:interpreter", 25 | tf_requirement("numpy"), 26 | lce_requirement("larq"), 27 | lce_requirement("pytest"), 28 | lce_requirement("tensorflow"), 29 | lce_requirement("tensorflow_datasets"), 30 | lce_requirement("tf-keras"), 31 | lce_requirement("importlib_resources"), 32 | ], 33 | ) 34 | 35 | py_test( 36 | name = "strip_lcedequantize_test", 37 | srcs = ["strip_lcedequantize_test.py"], 38 | deps = [ 39 | "//larq_compute_engine/mlir:converter", 40 | lce_requirement("larq"), 41 | lce_requirement("pytest"), 42 | lce_requirement("tensorflow"), 43 | lce_requirement("tf-keras"), 44 | ], 45 | ) 46 | 47 | py_test( 48 | name = "convert_model", 49 | srcs = ["convert_model.py"], 50 | deps = [ 51 | "//larq_compute_engine/mlir:converter", 52 | ], 53 | ) 54 | 55 | test_suite( 56 | name = "cc_tests", 57 | tests = [ 58 | "//larq_compute_engine/core/bitpacking/tests:cc_tests", 59 | "//larq_compute_engine/tflite/tests:cc_tests", 60 | ], 61 | ) 62 | 63 | lce_qemu_test_suite( 64 | name = "arm32_tests", 65 | platform = "arm32", 66 | tests = [ 67 | "//larq_compute_engine/tflite/tests:bconv2d_test", 68 | "//larq_compute_engine/tflite/tests:bmaxpool_test", 69 | "//larq_compute_engine/tflite/tests:quantization_test", 70 | ], 71 | ) 72 | 73 | lce_qemu_test_suite( 74 | name = "aarch64_tests", 75 | platform = "aarch64", 76 | tests = [ 77 | "//larq_compute_engine/core/bitpacking/tests:bitpack_aarch64_test", 78 | "//larq_compute_engine/core/bitpacking/tests:bitpack_test", 79 | "//larq_compute_engine/tflite/tests:bconv2d_test", 80 | "//larq_compute_engine/tflite/tests:bmaxpool_test", 81 | "//larq_compute_engine/tflite/tests:quantization_test", 82 | ], 83 | ) 84 | -------------------------------------------------------------------------------- /larq_compute_engine/tests/convert_model.py: -------------------------------------------------------------------------------- 1 | """Convert model script. 2 | 3 | This can be used to test model conversion via the CLI. Callers should overwrite 4 | `model_fn` to return a Keras model to be converted. 5 | 6 | Usage Examples: 7 | - bazelisk test larq_compute_engine/tests:convert_model 8 | - bazelisk test larq_compute_engine/tests:convert_model --test_arg="--outfile=/tmp/model.tflite" 9 | - bazelisk run larq_compute_engine/tests:convert_model -- --outfile=/tmp/model.tflite 10 | """ 11 | 12 | import click 13 | 14 | from larq_compute_engine.mlir.python.converter import convert_keras_model 15 | 16 | 17 | def model_fn(): 18 | raise NotImplementedError( 19 | "No model defined. This function should be overwritten by caller." 20 | ) 21 | 22 | 23 | @click.command() 24 | @click.option( 25 | "--outfile", 26 | default="model.tflite", 27 | help="Destination (relative to bazel rundir) used to save converted TFLite flatbuffer.", 28 | type=click.Path(writable=True, resolve_path=True), 29 | ) 30 | def convert_model(outfile): 31 | model_lce = convert_keras_model(model_fn()) 32 | with open(outfile, "wb") as f: 33 | f.write(model_lce) 34 | 35 | click.secho(f"TFLite flatbuffer saved to '{outfile}'.") 36 | 37 | 38 | if __name__ == "__main__": 39 | convert_model() 40 | -------------------------------------------------------------------------------- /larq_compute_engine/tests/preprocess.py: -------------------------------------------------------------------------------- 1 | """From larq_zoo/training/data.py""" 2 | 3 | import tensorflow as tf 4 | 5 | IMAGE_SIZE = 224 6 | CROP_PADDING = 32 7 | 8 | MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] 9 | STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] 10 | 11 | 12 | def _center_crop(image, image_size): 13 | """Crops to center of image with padding then scales image_size.""" 14 | shape = tf.shape(image) 15 | image_height = shape[0] 16 | image_width = shape[1] 17 | 18 | padded_center_crop_size = tf.cast( 19 | ( 20 | (image_size / (image_size + CROP_PADDING)) 21 | * tf.cast(tf.minimum(image_height, image_width), tf.float32) 22 | ), 23 | tf.int32, 24 | ) 25 | 26 | offset_height = ((image_height - padded_center_crop_size) + 1) // 2 27 | offset_width = ((image_width - padded_center_crop_size) + 1) // 2 28 | 29 | image = tf.image.crop_to_bounding_box( 30 | image, 31 | offset_height, 32 | offset_width, 33 | padded_center_crop_size, 34 | padded_center_crop_size, 35 | ) 36 | image = tf.compat.v1.image.resize_bicubic([image], [image_size, image_size])[0] 37 | return image 38 | 39 | 40 | def _normalize(image, mean_rgb=MEAN_RGB, stddev_rgb=STDDEV_RGB): 41 | """Normalizes images to variance 1 and mean 0 over the whole dataset""" 42 | 43 | image -= tf.broadcast_to(mean_rgb, tf.shape(image)) 44 | image /= tf.broadcast_to(stddev_rgb, tf.shape(image)) 45 | 46 | return image 47 | 48 | 49 | def preprocess_image_tensor(image_tensor, image_size=IMAGE_SIZE): 50 | """Preprocesses the given image Tensor. 51 | 52 | Args: 53 | image_tensor: `Tensor` representing an image array arbitrary size. 54 | image_size: image size. 55 | 56 | Returns: 57 | A preprocessed and normalized image `Tensor`. 58 | """ 59 | image_tensor = _center_crop(image_tensor, image_size) 60 | image_tensor = tf.reshape(image_tensor, [image_size, image_size, 3]) 61 | image_tensor = tf.cast(image_tensor, dtype=tf.float32) 62 | image_tensor = _normalize(image_tensor, mean_rgb=MEAN_RGB, stddev_rgb=STDDEV_RGB) 63 | return image_tensor 64 | -------------------------------------------------------------------------------- /larq_compute_engine/tests/qemu_test.bzl: -------------------------------------------------------------------------------- 1 | def lce_qemu_test_suite( 2 | name, 3 | platform, 4 | tests): 5 | """Test a set of C/C++ binaries using qemu. 6 | 7 | Args: 8 | name: a unique name for this rule. 9 | platform: either "arm32" or "aarch64" 10 | tests: list of cc_test targets 11 | """ 12 | if platform == "arm32": 13 | src = "//larq_compute_engine/tests:test_arm32_binary.sh" 14 | qemu_data = "@local_config_embedded_arm//:armhf_toolchain_all_files" 15 | elif platform == "aarch64": 16 | src = "//larq_compute_engine/tests:test_aarch64_binary.sh" 17 | qemu_data = "@local_config_embedded_arm//:aarch64_toolchain_all_files" 18 | else: 19 | fail("Invalid platform name in lce_qemu_test_suite", platform) 20 | 21 | sh_tests = [] 22 | for test in tests: 23 | # `test` is a Bazel target name 24 | # From this we extract a path to the compiled binary 25 | test_path = test 26 | if test_path.startswith("//"): 27 | test_path = test_path[2:] 28 | else: 29 | test_path = native.package_name() + "/" + test_path 30 | test_path = test_path.replace(":", "/") 31 | 32 | # We also have to create a unique identifier for this sh_test target 33 | test_suffix = test.split(":", None)[-1] 34 | sh_name = name + "_" + test_suffix 35 | 36 | # Finally create a new sh_test target 37 | native.sh_test( 38 | name = sh_name, 39 | size = "large", 40 | srcs = [src], 41 | args = [test_path], 42 | data = [test, qemu_data], 43 | shard_count = 2, 44 | ) 45 | 46 | # And add that sh_test target to the list 47 | sh_tests = sh_tests + [sh_name] 48 | 49 | # Collect the newly created targets in a regular test_suite 50 | native.test_suite(name = name, tests = sh_tests) 51 | -------------------------------------------------------------------------------- /larq_compute_engine/tests/strip_lcedequantize_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import larq as lq 4 | import pytest 5 | import tensorflow as tf 6 | 7 | from larq_compute_engine.mlir.python.converter import convert_keras_model 8 | from larq_compute_engine.mlir.python.util import strip_lcedequantize_ops 9 | 10 | 11 | def toy_model_sign(**kwargs): 12 | img = tf.keras.layers.Input(shape=(224, 224, 3)) 13 | x = lq.layers.QuantConv2D( 14 | 256, 15 | kernel_size=3, 16 | strides=1, 17 | padding="same", 18 | pad_values=1, 19 | input_quantizer="ste_sign", 20 | kernel_quantizer="ste_sign", 21 | kernel_constraint="weight_clip", 22 | )(img) 23 | x = lq.quantizers.SteSign()(x) 24 | return tf.keras.Model(inputs=img, outputs=x) 25 | 26 | 27 | def quant(x): 28 | return tf.quantization.fake_quant_with_min_max_vars(x, -3.0, 3.0) 29 | 30 | 31 | def toy_model_int8_sign(**kwargs): 32 | img = tf.keras.layers.Input(shape=(224, 224, 3)) 33 | x = quant(img) 34 | x = lq.layers.QuantConv2D( 35 | 256, 36 | kernel_size=3, 37 | strides=1, 38 | padding="same", 39 | pad_values=1, 40 | input_quantizer="ste_sign", 41 | kernel_quantizer="ste_sign", 42 | kernel_constraint="weight_clip", 43 | )(x) 44 | x = lq.quantizers.SteSign()(x) 45 | x = quant(x) 46 | return tf.keras.Model(inputs=img, outputs=x) 47 | 48 | 49 | @pytest.mark.parametrize("model_cls", [toy_model_sign, toy_model_int8_sign]) 50 | @pytest.mark.parametrize("inference_input_type", [tf.float32, tf.int8]) 51 | @pytest.mark.parametrize("inference_output_type", [tf.float32, tf.int8]) 52 | def test_strip_lcedequantize_ops( 53 | model_cls, inference_input_type, inference_output_type 54 | ): 55 | model_lce = convert_keras_model( 56 | model_cls(), 57 | inference_input_type=inference_input_type, 58 | inference_output_type=inference_output_type, 59 | ) 60 | model_lce = strip_lcedequantize_ops(model_lce) 61 | interpreter = tf.lite.Interpreter(model_content=model_lce) 62 | output_details = interpreter.get_output_details() 63 | assert len(output_details) == 1 64 | assert output_details[0]["dtype"] == tf.int32.as_numpy_dtype 65 | 66 | 67 | if __name__ == "__main__": 68 | sys.exit(pytest.main([__file__, "-s"])) 69 | -------------------------------------------------------------------------------- /larq_compute_engine/tests/test_aarch64_binary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | export QEMU_LD_PREFIX="external/aarch64_linux_toolchain/aarch64-none-linux-gnu/libc" 6 | 7 | qemu-aarch64 "$1" 8 | -------------------------------------------------------------------------------- /larq_compute_engine/tests/test_arm32_binary.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | export QEMU_LD_PREFIX="external/armhf_linux_toolchain/arm-none-linux-gnueabihf/libc" 6 | 7 | qemu-arm "$1" 8 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/BUILD: -------------------------------------------------------------------------------- 1 | package( 2 | default_visibility = ["//visibility:public"], 3 | licenses = ["notice"], # Apache 2.0 4 | ) 5 | 6 | exports_files([ 7 | "__init__.py", 8 | ]) 9 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larq/compute-engine/b8ec518b773302b2f32c2460cfa7e0267d9dbea0/larq_compute_engine/tflite/__init__.py -------------------------------------------------------------------------------- /larq_compute_engine/tflite/benchmark/BUILD: -------------------------------------------------------------------------------- 1 | load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_binary") 2 | load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_copts_warnings", "tflite_linkopts") 3 | 4 | package( 5 | default_visibility = [ 6 | "//visibility:public", 7 | ], 8 | licenses = ["notice"], # Apache 2.0 9 | ) 10 | 11 | tf_cc_binary( 12 | name = "lce_benchmark_model", 13 | srcs = [ 14 | "lce_benchmark_main.cc", 15 | ], 16 | copts = tflite_copts() + tflite_copts_warnings(), 17 | linkopts = tflite_linkopts() + select({ 18 | "@org_tensorflow//tensorflow:android": [ 19 | "-pie", # Android 5.0 and later supports only PIE 20 | "-lm", # some builtin ops, e.g., tanh, need -lm 21 | "-Wl,--rpath=/data/local/tmp/", # Hexagon delegate libraries should be in /data/local/tmp 22 | ], 23 | "//conditions:default": [], 24 | }), 25 | deps = [ 26 | "//larq_compute_engine/tflite/benchmark:lce_benchmark_tflite_model_lib", 27 | "//larq_compute_engine/tflite/kernels:lce_op_kernels", 28 | "@org_tensorflow//tensorflow/lite/tools:logging", 29 | ], 30 | ) 31 | 32 | cc_library( 33 | name = "lce_benchmark_tflite_model_lib", 34 | srcs = ["lce_benchmark_tflite_model.cc"], 35 | hdrs = ["lce_benchmark_tflite_model.h"], 36 | copts = tflite_copts() + select({ 37 | "@org_tensorflow//tensorflow:ios": [ 38 | "-xobjective-c++", 39 | ], 40 | "//conditions:default": [], 41 | }), 42 | deps = [ 43 | "@org_tensorflow//tensorflow/lite/tools/benchmark:benchmark_tflite_model_lib", 44 | ], 45 | ) 46 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/benchmark/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarking TF Lite 2 | 3 | ## Building the benchmark program 4 | 5 | See the [LCE build guide](../../../docs/build.md) on how to configure bazel 6 | and then build the `//larq_compute_engine/tflite/benchmark:lce_benchmark_model` 7 | bazel target. 8 | 9 | ## Running the benchmark 10 | 11 | Simply run 12 | ```bash 13 | ./lce_benchmark_model --graph=path_to_your_model.tflite 14 | ``` 15 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/benchmark/lce_benchmark_main.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | Modifications copyright (C) 2020 Larq Contributors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #include 18 | #include 19 | 20 | #include "larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.h" 21 | #include "larq_compute_engine/tflite/kernels/lce_ops_register.h" 22 | #include "tensorflow/lite/tools/logging.h" 23 | 24 | bool use_reference_bconv = false; 25 | bool use_indirect_bgemm = false; 26 | 27 | void RegisterSelectedOps(::tflite::MutableOpResolver* resolver) { 28 | compute_engine::tflite::RegisterLCECustomOps(resolver, use_reference_bconv, 29 | use_indirect_bgemm); 30 | } 31 | 32 | namespace tflite { 33 | namespace benchmark { 34 | 35 | int Main(int argc, char** argv) { 36 | TFLITE_LOG(INFO) << "STARTING!"; 37 | LceBenchmarkTfLiteModel benchmark(LceBenchmarkTfLiteModel::DefaultParams(), 38 | use_reference_bconv, use_indirect_bgemm); 39 | if (benchmark.Run(argc, argv) != kTfLiteOk) { 40 | TFLITE_LOG(ERROR) << "Benchmarking failed."; 41 | return EXIT_FAILURE; 42 | } 43 | return EXIT_SUCCESS; 44 | } 45 | } // namespace benchmark 46 | } // namespace tflite 47 | 48 | int main(int argc, char** argv) { return tflite::benchmark::Main(argc, argv); } 49 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | Modifications copyright (C) 2022 Larq Contributors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #include "larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.h" 18 | 19 | #include "tensorflow/lite/tools/logging.h" 20 | 21 | namespace tflite { 22 | namespace benchmark { 23 | 24 | BenchmarkParams LceBenchmarkTfLiteModel::DefaultParams() { 25 | BenchmarkParams default_params = BenchmarkTfLiteModel::DefaultParams(); 26 | default_params.AddParam("use_reference_bconv", 27 | BenchmarkParam::Create(false)); 28 | default_params.AddParam("use_indirect_bgemm", 29 | BenchmarkParam::Create(false)); 30 | 31 | return default_params; 32 | } 33 | 34 | LceBenchmarkTfLiteModel::LceBenchmarkTfLiteModel(BenchmarkParams params, 35 | bool& use_reference_bconv, 36 | bool& use_indirect_bgemm) 37 | : BenchmarkTfLiteModel(std::move(params)), 38 | use_reference_bconv(use_reference_bconv), 39 | use_indirect_bgemm(use_indirect_bgemm) {} 40 | 41 | std::vector LceBenchmarkTfLiteModel::GetFlags() { 42 | std::vector flags = BenchmarkTfLiteModel::GetFlags(); 43 | std::vector lce_flags = { 44 | CreateFlag( 45 | "use_reference_bconv", ¶ms_, 46 | "When true, uses the reference implementation of LceBconv2d."), 47 | CreateFlag("use_indirect_bgemm", ¶ms_, 48 | "When true, uses the optimized indirect BGEMM kernel of" 49 | "LceBconv2d.")}; 50 | 51 | flags.insert(flags.end(), lce_flags.begin(), lce_flags.end()); 52 | 53 | return flags; 54 | } 55 | 56 | void LceBenchmarkTfLiteModel::LogParams() { 57 | BenchmarkTfLiteModel::LogParams(); 58 | const bool verbose = params_.Get("verbose"); 59 | LOG_BENCHMARK_PARAM(bool, "use_reference_bconv", "Use reference Bconv", 60 | verbose); 61 | LOG_BENCHMARK_PARAM(bool, "use_indirect_bgemm", "Use indirect BGEMM", 62 | verbose); 63 | } 64 | 65 | TfLiteStatus LceBenchmarkTfLiteModel::Run(int argc, char** argv) { 66 | TF_LITE_ENSURE_STATUS(ParseFlags(argc, argv)); 67 | use_reference_bconv = params_.Get("use_reference_bconv"); 68 | use_indirect_bgemm = params_.Get("use_indirect_bgemm"); 69 | 70 | return BenchmarkTfLiteModel::Run(); 71 | } 72 | 73 | } // namespace benchmark 74 | } // namespace tflite 75 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | Modifications copyright (C) 2022 Larq Contributors. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | ==============================================================================*/ 16 | 17 | #ifndef COMPUTE_ENGINE_TFLITE_BENCHMARK_LCE_BENCHMARK_TFLITE_MODEL_H_ 18 | #define COMPUTE_ENGINE_TFLITE_BENCHMARK_LCE_BENCHMARK_TFLITE_MODEL_H_ 19 | 20 | #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h" 21 | 22 | namespace tflite { 23 | namespace benchmark { 24 | 25 | // Benchmarks a TFLite model by running tflite interpreter. 26 | class LceBenchmarkTfLiteModel : public BenchmarkTfLiteModel { 27 | public: 28 | explicit LceBenchmarkTfLiteModel(BenchmarkParams params, 29 | bool& use_reference_bconv, 30 | bool& use_indirect_bgemm); 31 | 32 | std::vector GetFlags() override; 33 | void LogParams() override; 34 | static BenchmarkParams DefaultParams(); 35 | 36 | using BenchmarkTfLiteModel::Run; 37 | TfLiteStatus Run(int argc, char** argv); 38 | 39 | private: 40 | bool& use_reference_bconv; 41 | bool& use_indirect_bgemm; 42 | }; 43 | 44 | } // namespace benchmark 45 | } // namespace tflite 46 | 47 | #endif // COMPUTE_ENGINE_TFLITE_BENCHMARK_LCE_BENCHMARK_TFLITE_MODEL_H_ 48 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/build_defs.bzl: -------------------------------------------------------------------------------- 1 | """Build definitions for Ruy.""" 2 | 3 | def ruy_visibility(): 4 | return [ 5 | "//tensorflow/lite/kernels:__subpackages__", 6 | ] 7 | 8 | # 1. Enable -mfpu=neon unconditionally on ARM32. If it turns out that we need to support 9 | # ARM32 without NEON then we'll implement runtime detection and dispatch at that point. 10 | # 2. Explicitly pass -O3 on optimization configs where just "-c opt" means "optimize for code size". 11 | 12 | def ruy_copts_base(): 13 | return select({ 14 | "//tensorflow:android_arm": [ 15 | "-mfpu=neon", 16 | ], 17 | "//conditions:default": [], 18 | }) + select({ 19 | ":optimized": ["-O3"], 20 | "//conditions:default": [], 21 | }) 22 | 23 | # Used for targets that are compiled with extra features that are skipped at runtime if unavailable. 24 | def ruy_copts_skylake(): 25 | return [] 26 | 27 | # Used for targets that are compiled with extra features that are skipped at runtime if unavailable. 28 | def ruy_copts_avx2(): 29 | return [] 30 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/java/BUILD: -------------------------------------------------------------------------------- 1 | load("@build_bazel_rules_android//android:rules.bzl", "android_library") 2 | load("@org_tensorflow//tensorflow/java:build_defs.bzl", "JAVACOPTS") 3 | load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_jni_binary") 4 | load("@org_tensorflow//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") 5 | load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_copts") 6 | 7 | package( 8 | default_visibility = ["//visibility:public"], 9 | licenses = ["notice"], # Apache 2.0 10 | ) 11 | 12 | # Building tensorflow-lite-lce.aar including 4 variants of .so 13 | # To build the LCE compatible TFLite AAR, run the following command: 14 | # bazel build -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ 15 | # larq_compute_engine/tflite/java:tensorflow-lite-lce 16 | aar_with_jni( 17 | name = "tensorflow-lite-lce", 18 | android_library = ":tensorflowlite_lce", 19 | ) 20 | 21 | android_library( 22 | name = "tensorflowlite_lce", 23 | javacopts = JAVACOPTS, 24 | manifest = "@org_tensorflow//tensorflow/lite/java:AndroidManifest.xml", 25 | proguard_specs = ["@org_tensorflow//tensorflow/lite/java:proguard.flags"], 26 | deps = [ 27 | ":tensorflowlite_lce_native", 28 | "@org_checkerframework_qual", 29 | ], 30 | ) 31 | 32 | java_library( 33 | name = "tensorflowlite_java", 34 | runtime_deps = [ 35 | "@org_tensorflow//tensorflow/lite/java:tensorflowlite", 36 | ], 37 | ) 38 | 39 | # This includes all builtin and LCE ops. 40 | # If you want a smaller binary, you should copy and 41 | # modify lce_ops_jni.cc. 42 | cc_library( 43 | name = "lce_ops_jni", 44 | srcs = [ 45 | "lce_ops_jni.cc", 46 | ], 47 | copts = tflite_copts(), 48 | deps = [ 49 | "//larq_compute_engine/tflite/kernels:lce_op_kernels", 50 | "@org_tensorflow//tensorflow/lite/java/src/main/native:native_framework_only", 51 | "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", 52 | ], 53 | alwayslink = 1, 54 | ) 55 | 56 | cc_library( 57 | name = "tensorflowlite_lce_native", 58 | srcs = ["libtensorflowlite_jni.so"], 59 | visibility = ["//visibility:private"], 60 | ) 61 | 62 | tflite_jni_binary( 63 | name = "libtensorflowlite_jni.so", 64 | deps = [ 65 | ":lce_ops_jni", 66 | "@org_tensorflow//tensorflow/lite/delegates/nnapi/java/src/main/native", 67 | ], 68 | ) 69 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/java/build_lce_aar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This is a bash script based on TFLite android AAR build script. 4 | # Additionally, we extract the TFLite Java API sources from the 5 | # `libtensorflowlite_java.jar` target and replace the `classes.jar` file of 6 | # LCE AAR with them. 7 | 8 | set -e 9 | set -x 10 | 11 | TMPDIR=`mktemp -d` 12 | trap "rm -rf $TMPDIR" EXIT 13 | 14 | AAR_NAME=lce-lite 15 | VERSION=$(git describe --tags) 16 | 17 | BUILDER="${BUILDER:-bazel}" 18 | BASEDIR=larq_compute_engine/tflite 19 | 20 | BUILD_OPTS="-c opt --config=android_arm64 --fat_apk_cpu=x86,x86_64,arm64-v8a" 21 | 22 | test -d $BASEDIR || (echo "Aborting: not at top-level build directory"; exit 1) 23 | 24 | function build_lce_aar() { 25 | local OUTDIR=$1 26 | $BUILDER build $BUILD_OPTS $BASEDIR/java:tensorflow-lite-lce.aar 27 | unzip -d $OUTDIR bazel-bin/$BASEDIR/java/tensorflow-lite-lce.aar 28 | # targetSdkVersion is here to prevent the app from requesting spurious 29 | # permissions, such as permission to make phone calls. It worked for v1.0, 30 | # but minSdkVersion might be the preferred way to handle this. 31 | sed -i -e 's///' $OUTDIR/AndroidManifest.xml 32 | 33 | $BUILDER build $BUILD_OPTS $BASEDIR/java:tensorflowlite_java 34 | # override the classes.jar with the Java sources from TF Lite Java API 35 | cp bazel-bin/$BASEDIR/java/libtensorflowlite_java.jar $OUTDIR/classes.jar 36 | } 37 | 38 | rm -rf $TMPDIR 39 | mkdir -p $TMPDIR/jni 40 | 41 | build_lce_aar $TMPDIR 42 | 43 | if [[ "$OSTYPE" == "linux-gnu" ]]; then 44 | AAR_FILE=`realpath $AAR_NAME-${VERSION}.aar` 45 | elif [[ "$OSTYPE" == "darwin"* ]]; then 46 | # on macOS get 'grealpath' by installing 'coreutils' package: 47 | # "brew install coreutils" 48 | AAR_FILE=`grealpath $AAR_NAME-${VERSION}.aar` 49 | else 50 | # Unknown. 51 | echo "ERROR: could not detect the OS." 52 | exit 1 53 | fi 54 | 55 | (cd $TMPDIR && zip $AAR_FILE -r *) 56 | echo "New AAR file is $AAR_FILE" 57 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/java/lce_ops_jni.cc: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #include "larq_compute_engine/tflite/kernels/lce_ops_register.h" 17 | #include "tensorflow/lite/kernels/register.h" 18 | 19 | namespace tflite { 20 | 21 | // The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in 22 | // the tflite namespace. This one instantiates a BuiltinOpResolver, with all the 23 | // builtin ops. For smaller binary sizes users should avoid linking this in, and 24 | // should provide a custom make CreateOpResolver() instead. 25 | std::unique_ptr CreateOpResolver() { // NOLINT 26 | auto resolver = new tflite::ops::builtin::BuiltinOpResolver(); 27 | compute_engine::tflite::RegisterLCECustomOps(resolver); 28 | return std::unique_ptr(resolver); 29 | } 30 | 31 | } // namespace tflite 32 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/kernels/BUILD: -------------------------------------------------------------------------------- 1 | load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_linkopts") 2 | load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_opts_nortti_if_android") 3 | load("//larq_compute_engine/tflite:build_defs.bzl", "ruy_copts_base") 4 | 5 | package( 6 | default_visibility = ["//visibility:public"], 7 | licenses = ["notice"], # Apache 2.0 8 | ) 9 | 10 | cc_library( 11 | name = "utils", 12 | hdrs = [ 13 | "utils.h", 14 | ], 15 | deps = [ 16 | "@org_tensorflow//tensorflow/lite/c:common", 17 | "@org_tensorflow//tensorflow/lite/schema:schema_fbs", 18 | ], 19 | ) 20 | 21 | cc_library( 22 | name = "lce_op_kernels", 23 | srcs = [ 24 | "bconv2d.cc", 25 | "bmaxpool.cc", 26 | "quantization.cc", 27 | ], 28 | hdrs = [ 29 | "lce_ops_register.h", 30 | ], 31 | copts = tflite_copts() + tf_opts_nortti_if_android(), 32 | deps = [ 33 | ":utils", 34 | "//larq_compute_engine/core:bmaxpool", 35 | "//larq_compute_engine/core/bconv2d:optimized_bgemm", 36 | "//larq_compute_engine/core/bconv2d:optimized_indirect_bgemm", 37 | "//larq_compute_engine/core/bconv2d:params", 38 | "//larq_compute_engine/core/bconv2d:reference", 39 | "//larq_compute_engine/core/bitpacking:bitpack", 40 | "//larq_compute_engine/core/bitpacking:utils", 41 | "//larq_compute_engine/core/indirect_bgemm:kernels", 42 | "@flatbuffers", 43 | "@org_tensorflow//tensorflow/lite:framework", 44 | "@org_tensorflow//tensorflow/lite:type_to_tflitetype", 45 | "@org_tensorflow//tensorflow/lite/kernels/internal:kernel_utils", 46 | "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", 47 | "@org_tensorflow//tensorflow/lite/tools:logging", 48 | "@ruy//ruy/profiler:instrumentation", 49 | ], 50 | alwayslink = 1, 51 | ) 52 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/kernels/bmaxpool.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "larq_compute_engine/core/bmaxpool.h" 3 | 4 | #include "flatbuffers/flexbuffers.h" 5 | #include "larq_compute_engine/tflite/kernels/utils.h" 6 | #include "ruy/profiler/instrumentation.h" 7 | #include "tensorflow/lite/c/builtin_op_data.h" 8 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 9 | #include "tensorflow/lite/kernels/kernel_util.h" 10 | 11 | using namespace tflite; 12 | 13 | namespace ce = compute_engine; 14 | 15 | namespace compute_engine { 16 | namespace tflite { 17 | namespace bmaxpool { 18 | 19 | using ce::core::TBitpacked; 20 | 21 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { 22 | auto* poolparams = new core::BMaxPoolParams{}; 23 | 24 | const std::uint8_t* buffer_t = reinterpret_cast(buffer); 25 | const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); 26 | 27 | poolparams->filter_height = m["filter_height"].AsInt32(); 28 | poolparams->filter_width = m["filter_width"].AsInt32(); 29 | poolparams->stride_height = m["stride_height"].AsInt32(); 30 | poolparams->stride_width = m["stride_width"].AsInt32(); 31 | poolparams->padding_type = ConvertPadding((Padding)m["padding"].AsInt32()); 32 | 33 | return poolparams; 34 | } 35 | 36 | void Free(TfLiteContext* context, void* buffer) { 37 | delete reinterpret_cast(buffer); 38 | } 39 | 40 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 41 | core::BMaxPoolParams* poolparams = 42 | reinterpret_cast(node->user_data); 43 | 44 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); 45 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 46 | TfLiteTensor* output = GetOutput(context, node, 0); 47 | const TfLiteTensor* input = GetInput(context, node, 0); 48 | TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); 49 | TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt32); 50 | TF_LITE_ENSURE_EQ(context, output->type, kTfLiteInt32); 51 | TF_LITE_ENSURE(context, poolparams->stride_height != 0); 52 | TF_LITE_ENSURE(context, poolparams->stride_width != 0); 53 | TF_LITE_ENSURE(context, poolparams->filter_height != 0); 54 | TF_LITE_ENSURE(context, poolparams->filter_width != 0); 55 | 56 | int height = SizeOfDimension(input, 1); 57 | int width = SizeOfDimension(input, 2); 58 | 59 | // Matching GetWindowedOutputSize in TensorFlow. 60 | int out_width, out_height; 61 | 62 | poolparams->padding = ComputePaddingHeightWidth( 63 | poolparams->stride_height, poolparams->stride_width, 1, 1, height, width, 64 | poolparams->filter_height, poolparams->filter_width, 65 | poolparams->padding_type, &out_height, &out_width); 66 | 67 | TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); 68 | output_size->data[0] = SizeOfDimension(input, 0); 69 | output_size->data[1] = out_height; 70 | output_size->data[2] = out_width; 71 | output_size->data[3] = SizeOfDimension(input, 3); 72 | return context->ResizeTensor(context, output, output_size); 73 | } 74 | 75 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 76 | ruy::profiler::ScopeLabel label("Binary MaxPool"); 77 | 78 | core::BMaxPoolParams* poolparams = 79 | reinterpret_cast(node->user_data); 80 | 81 | TfLiteTensor* output = GetOutput(context, node, 0); 82 | const TfLiteTensor* input = GetInput(context, node, 0); 83 | 84 | core::BMaxPool(*poolparams, GetTensorShape(input), 85 | GetTensorData(input), GetTensorShape(output), 86 | GetTensorData(output)); 87 | return kTfLiteOk; 88 | } 89 | 90 | } // namespace bmaxpool 91 | 92 | TfLiteRegistration* Register_BMAXPOOL_2D() { 93 | static TfLiteRegistration r = {bmaxpool::Init, bmaxpool::Free, 94 | bmaxpool::Prepare, bmaxpool::Eval}; 95 | return &r; 96 | } 97 | 98 | } // namespace tflite 99 | } // namespace compute_engine 100 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/kernels/lce_ops_register.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTE_ENGINE_TFLITE_KERNELS_LCE_OPS_REGISTER_H_ 2 | #define COMPUTE_ENGINE_TFLITE_KERNELS_LCE_OPS_REGISTER_H_ 3 | 4 | #include "tensorflow/lite/context.h" 5 | #include "tensorflow/lite/op_resolver.h" 6 | #include "tensorflow/lite/tools/logging.h" 7 | 8 | // This file contains forward declaration of all custom ops 9 | // implemented in LCE which can be used to link against LCE library. 10 | 11 | namespace compute_engine { 12 | namespace tflite { 13 | 14 | using namespace ::tflite; 15 | 16 | TfLiteRegistration* Register_QUANTIZE(); 17 | TfLiteRegistration* Register_DEQUANTIZE(); 18 | TfLiteRegistration* Register_BCONV_2D(); 19 | TfLiteRegistration* Register_BCONV_2D_REF(); 20 | TfLiteRegistration* Register_BCONV_2D_OPT_INDIRECT_BGEMM(); 21 | TfLiteRegistration* Register_BMAXPOOL_2D(); 22 | 23 | // By calling this function on TF lite mutable op resolver, all LCE custom ops 24 | // will be registerd to the op resolver. 25 | inline void RegisterLCECustomOps(::tflite::MutableOpResolver* resolver, 26 | const bool use_reference_bconv = false, 27 | const bool use_indirect_bgemm = false) { 28 | if (use_reference_bconv && use_indirect_bgemm) { 29 | TFLITE_LOG(WARN) 30 | << "WARNING: 'use_reference_bconv' and `use_indirect_bgemm` " 31 | "are both set to true. use_indirect_bgemm==true " 32 | "will have no effect."; 33 | } 34 | resolver->AddCustom("LceQuantize", 35 | compute_engine::tflite::Register_QUANTIZE()); 36 | resolver->AddCustom("LceDequantize", 37 | compute_engine::tflite::Register_DEQUANTIZE()); 38 | if (use_reference_bconv) { 39 | resolver->AddCustom("LceBconv2d", 40 | compute_engine::tflite::Register_BCONV_2D_REF()); 41 | } else { 42 | if (use_indirect_bgemm) { 43 | resolver->AddCustom( 44 | "LceBconv2d", 45 | compute_engine::tflite::Register_BCONV_2D_OPT_INDIRECT_BGEMM()); 46 | } else { 47 | resolver->AddCustom("LceBconv2d", 48 | compute_engine::tflite::Register_BCONV_2D()); 49 | } 50 | } 51 | resolver->AddCustom("LceBMaxPool2d", 52 | compute_engine::tflite::Register_BMAXPOOL_2D()); 53 | }; 54 | 55 | } // namespace tflite 56 | } // namespace compute_engine 57 | 58 | #endif // COMPUTE_ENGINE_TFLITE_KERNELS_LCE_OPS_REGISTER_H_ 59 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/kernels/utils.h: -------------------------------------------------------------------------------- 1 | #ifndef COMPUTE_ENGINE_TFLITE_KERNEL_UTILS_H 2 | #define COMPUTE_ENGINE_TFLITE_KERNEL_UTILS_H 3 | 4 | #include "tensorflow/lite/c/builtin_op_data.h" 5 | #include "tensorflow/lite/schema/schema_generated.h" 6 | 7 | namespace tflite { 8 | 9 | // Converts the flatbuffer activation to what is used at runtime. 10 | inline TfLiteFusedActivation ConvertActivation( 11 | ActivationFunctionType activation) { 12 | switch (activation) { 13 | case ActivationFunctionType_NONE: 14 | return kTfLiteActNone; 15 | case ActivationFunctionType_RELU: 16 | return kTfLiteActRelu; 17 | case ActivationFunctionType_RELU_N1_TO_1: 18 | return kTfLiteActReluN1To1; 19 | case ActivationFunctionType_RELU6: 20 | return kTfLiteActRelu6; 21 | default: 22 | return kTfLiteActNone; 23 | } 24 | } 25 | 26 | // Converts the flatbuffer padding enum to TFLite padding 27 | inline TfLitePadding ConvertPadding(Padding padding) { 28 | switch (padding) { 29 | case Padding_SAME: 30 | return kTfLitePaddingSame; 31 | case Padding_VALID: 32 | return kTfLitePaddingValid; 33 | } 34 | return kTfLitePaddingUnknown; 35 | } 36 | 37 | // Converts the TFLite padding enum to what is used at runtime. 38 | inline PaddingType RuntimePaddingType(TfLitePadding padding) { 39 | switch (padding) { 40 | case TfLitePadding::kTfLitePaddingSame: 41 | return PaddingType::kSame; 42 | case TfLitePadding::kTfLitePaddingValid: 43 | return PaddingType::kValid; 44 | case TfLitePadding::kTfLitePaddingUnknown: 45 | default: 46 | return PaddingType::kNone; 47 | } 48 | } 49 | 50 | } // namespace tflite 51 | 52 | #endif // COMPUTE_ENGINE_TFLITE_KERNEL_UTILS_H 53 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/python/BUILD: -------------------------------------------------------------------------------- 1 | load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") 2 | load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_linkopts") 3 | load("@pypi//:requirements.bzl", tf_requirement = "requirement") 4 | load("@pypi_lce//:requirements.bzl", lce_requirement = "requirement") 5 | 6 | package( 7 | default_visibility = ["//visibility:public"], 8 | licenses = ["notice"], # Apache 2.0 9 | ) 10 | 11 | cc_library( 12 | name = "interpreter_wrapper_utils", 13 | hdrs = ["interpreter_wrapper_utils.h"], 14 | deps = [ 15 | "@org_tensorflow//tensorflow/lite/c:common", 16 | "@pybind11", 17 | ], 18 | ) 19 | 20 | pybind_extension( 21 | name = "interpreter_wrapper_lite", 22 | srcs = ["interpreter_wrapper_lite.cc"], 23 | linkopts = tflite_linkopts(), 24 | module_name = "interpreter_wrapper_lite", 25 | deps = [ 26 | ":interpreter_wrapper_utils", 27 | "//larq_compute_engine/tflite/kernels:lce_op_kernels", 28 | "@org_tensorflow//tensorflow/lite:framework", 29 | "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", 30 | "@pybind11", 31 | ], 32 | ) 33 | 34 | py_library( 35 | name = "interpreter_base", 36 | srcs = [ 37 | "__init__.py", 38 | "interpreter_base.py", 39 | ], 40 | deps = [ 41 | tf_requirement("numpy"), 42 | lce_requirement("tqdm"), 43 | ], 44 | ) 45 | 46 | py_library( 47 | name = "interpreter", 48 | srcs = [ 49 | "__init__.py", 50 | "interpreter.py", 51 | ], 52 | deps = [ 53 | ":interpreter_base", 54 | ":interpreter_wrapper_lite", 55 | ], 56 | ) 57 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larq/compute-engine/b8ec518b773302b2f32c2460cfa7e0267d9dbea0/larq_compute_engine/tflite/python/__init__.py -------------------------------------------------------------------------------- /larq_compute_engine/tflite/python/interpreter.py: -------------------------------------------------------------------------------- 1 | from larq_compute_engine.tflite.python.interpreter_base import InterpreterBase 2 | 3 | __all__ = ["Interpreter"] 4 | 5 | 6 | class Interpreter(InterpreterBase): 7 | """Interpreter interface for Larq Compute Engine Models. 8 | 9 | !!! warning 10 | The Larq Compute Engine is optimized for ARM v8, which is used by devices 11 | such as a Raspberry Pi or Android phones. Running this interpreter on any 12 | other architecture (e.g. x86) will fall back on the reference kernels, meaning 13 | it will return correct outputs, but is not optimized for speed in any way! 14 | 15 | !!! example 16 | ```python 17 | lce_model = convert_keras_model(model) 18 | interpreter = Interpreter(lce_model) 19 | interpreter.predict(input_data, verbose=1) 20 | ``` 21 | 22 | # Arguments 23 | flatbuffer_model: A serialized Larq Compute Engine model in the flatbuffer format. 24 | num_threads: The number of threads used by the interpreter. 25 | use_reference_bconv: When True, uses the reference implementation of LceBconv2d. 26 | use_indirect_bgemm: When True, uses the optimized indirect BGEMM kernel of LceBconv2d. 27 | use_xnnpack: When True, uses the XNNPack delegate of TFLite. 28 | 29 | # Attributes 30 | input_types: Returns a list of input types. 31 | input_shapes: Returns a list of input shapes. 32 | input_scales: Returns a list of input scales. 33 | input_zero_points: Returns a list of input zero points. 34 | output_types: Returns a list of output types. 35 | output_shapes: Returns a list of output shapes. 36 | output_scales: Returns a list of input scales. 37 | output_zero_points: Returns a list of input zero points. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | flatbuffer_model: bytes, 43 | num_threads: int = 1, 44 | use_reference_bconv: bool = False, 45 | use_indirect_bgemm: bool = False, 46 | use_xnnpack: bool = False, 47 | ): 48 | from larq_compute_engine.tflite.python import interpreter_wrapper_lite 49 | 50 | super().__init__( 51 | interpreter_wrapper_lite.LiteInterpreter( 52 | flatbuffer_model, 53 | num_threads, 54 | use_reference_bconv, 55 | use_indirect_bgemm, 56 | use_xnnpack, 57 | ) 58 | ) 59 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/python/interpreter_base.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterator 2 | from typing import Union, Optional 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | Data = Union[np.ndarray, list[np.ndarray]] 8 | 9 | 10 | def data_generator(x: Union[Data, Iterator[Data]]) -> Iterator[list[np.ndarray]]: 11 | if isinstance(x, np.ndarray): 12 | for inputs in x: 13 | yield [np.expand_dims(inputs, axis=0)] 14 | elif isinstance(x, list): 15 | for inputs in zip(*x): 16 | yield [np.expand_dims(inp, axis=0) for inp in inputs] 17 | elif hasattr(x, "__next__") and hasattr(x, "__iter__"): 18 | for inputs in x: 19 | if isinstance(inputs, np.ndarray): 20 | yield [np.expand_dims(inputs, axis=0)] 21 | else: 22 | yield [np.expand_dims(inp, axis=0) for inp in inputs] 23 | else: 24 | raise ValueError( 25 | "Expected either a list of inputs or a Numpy array with implicit initial " 26 | f"batch dimension or an iterator yielding one of the above. Received: {x}" 27 | ) 28 | 29 | 30 | class InterpreterBase: 31 | def __init__(self, interpreter): 32 | self.interpreter = interpreter 33 | 34 | @property 35 | def input_types(self) -> list: 36 | """Returns a list of input types.""" 37 | return self.interpreter.input_types 38 | 39 | @property 40 | def input_shapes(self) -> list[tuple[int]]: 41 | """Returns a list of input shapes.""" 42 | return self.interpreter.input_shapes 43 | 44 | @property 45 | def input_scales(self) -> list[Optional[Union[float, list[float]]]]: 46 | """Returns a list of input scales.""" 47 | return self.interpreter.input_scales 48 | 49 | @property 50 | def input_zero_points(self) -> list[Optional[int]]: 51 | """Returns a list of input zero points.""" 52 | return self.interpreter.input_zero_points 53 | 54 | @property 55 | def output_types(self) -> list: 56 | """Returns a list of output types.""" 57 | return self.interpreter.output_types 58 | 59 | @property 60 | def output_shapes(self) -> list[tuple[int]]: 61 | """Returns a list of output shapes.""" 62 | return self.interpreter.output_shapes 63 | 64 | @property 65 | def output_scales(self) -> list[Optional[Union[float, list[float]]]]: 66 | """Returns a list of input scales.""" 67 | return self.interpreter.output_scales 68 | 69 | @property 70 | def output_zero_points(self) -> list[Optional[int]]: 71 | """Returns a list of input zero points.""" 72 | return self.interpreter.output_zero_points 73 | 74 | def predict(self, x: Union[Data, Iterator[Data]], verbose: int = 0) -> Data: 75 | """Generates output predictions for the input samples. 76 | 77 | # Arguments 78 | x: Input samples. A Numpy array, or a list of arrays in case the model has 79 | multiple inputs. 80 | verbose: Verbosity mode, 0 or 1. 81 | 82 | # Returns 83 | Numpy array(s) of output predictions. 84 | """ 85 | 86 | data_iterator = data_generator(x) 87 | if verbose >= 1: 88 | data_iterator = tqdm(data_iterator) 89 | 90 | prediction_iter = (self.interpreter.predict(inputs) for inputs in data_iterator) 91 | outputs = [np.concatenate(batches) for batches in zip(*prediction_iter)] 92 | 93 | if len(self.output_shapes) == 1: 94 | return outputs[0] 95 | return outputs 96 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/python/interpreter_wrapper_lite.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "larq_compute_engine/tflite/kernels/lce_ops_register.h" 4 | #include "larq_compute_engine/tflite/python/interpreter_wrapper_utils.h" 5 | #include "tensorflow/lite/interpreter.h" 6 | #include "tensorflow/lite/kernels/register.h" 7 | #include "tensorflow/lite/model.h" 8 | #include "tensorflow/lite/optional_debug_tools.h" 9 | 10 | class LiteInterpreterWrapper 11 | : public InterpreterWrapperBase { 12 | public: 13 | LiteInterpreterWrapper(const pybind11::bytes& flatbuffer, 14 | const int num_threads = 1, 15 | const bool use_reference_bconv = false, 16 | const bool use_indirect_bgemm = false, 17 | const bool use_xnnpack = false); 18 | ~LiteInterpreterWrapper(){}; 19 | 20 | private: 21 | std::string flatbuffer_; // Copy of the flatbuffer because the pybind version 22 | // is destroyed 23 | 24 | std::unique_ptr model_; 25 | std::unique_ptr resolver_; 26 | }; 27 | 28 | LiteInterpreterWrapper::LiteInterpreterWrapper( 29 | const pybind11::bytes& flatbuffer, const int num_threads, 30 | const bool use_reference_bconv, const bool use_indirect_bgemm, 31 | const bool use_xnnpack) { 32 | // Make a copy of the flatbuffer because it can get deallocated after the 33 | // constructor is done 34 | flatbuffer_ = static_cast(flatbuffer); 35 | 36 | model_ = tflite::FlatBufferModel::BuildFromBuffer(flatbuffer_.data(), 37 | flatbuffer_.size()); 38 | if (!model_) { 39 | PY_ERROR("Invalid model."); 40 | } 41 | 42 | // Build the interpreter 43 | if (use_xnnpack) { 44 | resolver_ = std::make_unique(); 45 | } else { 46 | resolver_ = std::make_unique< 47 | tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>(); 48 | } 49 | compute_engine::tflite::RegisterLCECustomOps( 50 | resolver_.get(), use_reference_bconv, use_indirect_bgemm); 51 | 52 | tflite::InterpreterBuilder builder(*model_, *resolver_); 53 | builder(&interpreter_, num_threads); 54 | MINIMAL_CHECK(interpreter_ != nullptr); 55 | 56 | // Allocate tensor buffers. 57 | MINIMAL_CHECK(interpreter_->AllocateTensors() == kTfLiteOk); 58 | } 59 | 60 | PYBIND11_MODULE(interpreter_wrapper_lite, m) { 61 | pybind11::class_(m, "LiteInterpreter") 62 | .def(pybind11::init()) 64 | .def_property("input_types", &LiteInterpreterWrapper::get_input_types, 65 | nullptr) 66 | .def_property("output_types", &LiteInterpreterWrapper::get_output_types, 67 | nullptr) 68 | .def_property("input_shapes", &LiteInterpreterWrapper::get_input_shapes, 69 | nullptr) 70 | .def_property("output_shapes", &LiteInterpreterWrapper::get_output_shapes, 71 | nullptr) 72 | .def_property("input_zero_points", 73 | &LiteInterpreterWrapper::get_input_zero_points, nullptr) 74 | .def_property("output_zero_points", 75 | &LiteInterpreterWrapper::get_output_zero_points, nullptr) 76 | .def_property("input_scales", &LiteInterpreterWrapper::get_input_scales, 77 | nullptr) 78 | .def_property("output_scales", &LiteInterpreterWrapper::get_output_scales, 79 | nullptr) 80 | .def("predict", &LiteInterpreterWrapper::predict); 81 | }; 82 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/tests/BUILD: -------------------------------------------------------------------------------- 1 | load("@pypi//:requirements.bzl", tf_requirement = "requirement") 2 | load("@pypi_lce//:requirements.bzl", lce_requirement = "requirement") 3 | 4 | package( 5 | default_visibility = ["//visibility:public"], 6 | licenses = ["notice"], # Apache 2.0 7 | ) 8 | 9 | cc_library( 10 | name = "utils", 11 | hdrs = [ 12 | "utils.h", 13 | ], 14 | deps = [ 15 | "//larq_compute_engine/core:types", 16 | "@org_tensorflow//tensorflow/lite/kernels/internal:types", 17 | ], 18 | ) 19 | 20 | cc_library( 21 | name = "bconv2d_op_model", 22 | hdrs = [ 23 | "bconv2d_op_model.h", 24 | ], 25 | deps = [ 26 | ":utils", 27 | "//larq_compute_engine/tflite/kernels:lce_op_kernels", 28 | "@flatbuffers", 29 | ], 30 | ) 31 | 32 | cc_test( 33 | name = "bconv2d_test", 34 | size = "large", 35 | srcs = ["bconv2d_test.cc"], 36 | deps = [ 37 | ":bconv2d_op_model", 38 | ":utils", 39 | "//larq_compute_engine/core/bitpacking:bitpack", 40 | "//larq_compute_engine/core/bitpacking:utils", 41 | "@com_google_googletest//:gtest", 42 | "@flatbuffers", 43 | "@org_tensorflow//tensorflow/lite:framework", 44 | "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", 45 | "@org_tensorflow//tensorflow/lite/kernels:test_main", 46 | "@org_tensorflow//tensorflow/lite/kernels:test_util", 47 | ], 48 | ) 49 | 50 | cc_test( 51 | name = "bmaxpool_test", 52 | size = "small", 53 | srcs = ["bmaxpool_test.cc"], 54 | deps = [ 55 | ":utils", 56 | "//larq_compute_engine/core/bitpacking:utils", 57 | "//larq_compute_engine/tflite/kernels:lce_op_kernels", 58 | "@com_google_googletest//:gtest", 59 | "@flatbuffers", 60 | "@org_tensorflow//tensorflow/lite:framework", 61 | "@org_tensorflow//tensorflow/lite/kernels:test_main", 62 | "@org_tensorflow//tensorflow/lite/kernels:test_util", 63 | ], 64 | ) 65 | 66 | cc_test( 67 | name = "quantization_test", 68 | size = "small", 69 | srcs = ["quantization_test.cc"], 70 | deps = [ 71 | ":utils", 72 | "//larq_compute_engine/tflite/kernels:lce_op_kernels", 73 | "@com_google_googletest//:gtest", 74 | "@flatbuffers", 75 | "@org_tensorflow//tensorflow/lite:framework", 76 | "@org_tensorflow//tensorflow/lite/kernels:test_main", 77 | "@org_tensorflow//tensorflow/lite/kernels:test_util", 78 | ], 79 | ) 80 | 81 | py_test( 82 | name = "interpreter_test", 83 | size = "small", 84 | srcs = ["interpreter_test.py"], 85 | deps = [ 86 | "//larq_compute_engine/tflite/python:interpreter", 87 | tf_requirement("numpy"), 88 | lce_requirement("pytest"), 89 | lce_requirement("tensorflow"), 90 | lce_requirement("tf-keras"), 91 | ], 92 | ) 93 | 94 | # COLLECTION OF ALL TFLITE CC TESTS 95 | # each new cc test needs to be added here 96 | test_suite( 97 | name = "cc_tests", 98 | tests = [ 99 | ":bconv2d_test", 100 | ":bmaxpool_test", 101 | ":quantization_test", 102 | ], 103 | ) 104 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/tests/bconv2d_op_model.h: -------------------------------------------------------------------------------- 1 | #ifndef LARQ_COMPUTE_ENGINE_TFLITE_TESTS_BCONV2D_OP 2 | #define LARQ_COMPUTE_ENGINE_TFLITE_TESTS_BCONV2D_OP 3 | 4 | #include 5 | 6 | #include "flatbuffers/flexbuffers.h" // TF:flatbuffers 7 | #include "larq_compute_engine/core/types.h" 8 | #include "larq_compute_engine/tflite/tests/utils.h" 9 | #include "tensorflow/lite/kernels/test_util.h" 10 | 11 | using namespace tflite; 12 | 13 | namespace compute_engine { 14 | namespace tflite { 15 | 16 | TfLiteRegistration* Register_BCONV_2D_OPT(); 17 | 18 | namespace testing { 19 | 20 | using compute_engine::core::TBitpacked; 21 | 22 | typedef TfLiteRegistration* (*register_function)(void); 23 | 24 | class BaseBConv2DOpModel : public SingleOpModel { 25 | public: 26 | BaseBConv2DOpModel( 27 | register_function registration, const TensorData& input, 28 | const TensorData& filter, const TensorData& output, 29 | const TensorData& post_activation_multiplier, 30 | const TensorData& post_activation_bias, const TensorData& thresholds, 31 | int channels_in, int stride_width = 1, int stride_height = 1, 32 | enum Padding padding = Padding_VALID, int pad_values = 0, 33 | enum ActivationFunctionType activation = ActivationFunctionType_NONE, 34 | int dilation_width_factor = 1, int dilation_height_factor = 1, 35 | int num_threads = -1) { 36 | input_ = AddInput(input); 37 | filter_ = AddInput(filter); 38 | post_activation_multiplier_ = AddInput(post_activation_multiplier); 39 | post_activation_bias_ = AddInput(post_activation_bias); 40 | thresholds_ = AddInput(thresholds); 41 | output_ = AddOutput(output); 42 | 43 | flexbuffers::Builder fbb; 44 | fbb.Map([&]() { 45 | fbb.Int("channels_in", channels_in); 46 | fbb.Int("stride_height", stride_height); 47 | fbb.Int("stride_width", stride_width); 48 | fbb.Int("dilation_height_factor", dilation_height_factor); 49 | fbb.Int("dilation_width_factor", dilation_width_factor); 50 | fbb.Int("padding", (int)padding); 51 | fbb.Int("pad_values", pad_values); 52 | fbb.Int("fused_activation_function", (int)activation); 53 | }); 54 | fbb.Finish(); 55 | SetCustomOp("LceBconv2d", fbb.GetBuffer(), registration); 56 | BuildInterpreter({GetShape(input_), GetShape(filter_)}, num_threads, 57 | /*allow_fp32_relax_to_fp16=*/false, 58 | /*apply_delegate=*/true); 59 | } 60 | 61 | protected: 62 | int input_; 63 | int filter_; 64 | int output_; 65 | int post_activation_multiplier_; 66 | int post_activation_bias_; 67 | int thresholds_; 68 | }; 69 | 70 | template 71 | class BConv2DOpModel : public BaseBConv2DOpModel { 72 | public: 73 | using BaseBConv2DOpModel::BaseBConv2DOpModel; 74 | 75 | void SetFilter(const std::vector& f) { 76 | PopulateTensor(filter_, f); 77 | } 78 | 79 | void SetInput(const std::vector& data) { 80 | PopulateTensor(input_, data); 81 | } 82 | 83 | void SetPostActivationMultiplier(const std::vector& f) { 84 | PopulateTensor(post_activation_multiplier_, f); 85 | } 86 | 87 | void SetPostActivationBias(const std::vector& f) { 88 | PopulateTensor(post_activation_bias_, f); 89 | } 90 | 91 | void SetThresholds(const std::vector& f) { 92 | PopulateTensor(thresholds_, f); 93 | } 94 | 95 | std::vector GetOutput() { return ExtractVector(output_); } 96 | std::vector GetOutputShape() { return GetTensorShape(output_); } 97 | }; 98 | 99 | } // namespace testing 100 | } // namespace tflite 101 | } // namespace compute_engine 102 | 103 | #endif // LARQ_COMPUTE_ENGINE_TFLITE_TESTS_BCONV2D_OP 104 | -------------------------------------------------------------------------------- /larq_compute_engine/tflite/tests/interpreter_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import pytest 5 | import tensorflow as tf 6 | 7 | from larq_compute_engine.tflite.python.interpreter import Interpreter 8 | 9 | 10 | @pytest.mark.parametrize("use_iterator", [True, False]) 11 | def test_interpreter(use_iterator): 12 | input_shape = (24, 24, 3) 13 | x = tf.keras.Input(input_shape) 14 | model = tf.keras.Model(x, tf.keras.layers.Flatten()(x)) 15 | converter = tf.lite.TFLiteConverter.from_keras_model(model) 16 | 17 | inputs = np.random.uniform(-1, 1, size=(16, *input_shape)).astype(np.float32) 18 | expected_outputs = inputs.reshape(16, -1) 19 | 20 | interpreter = Interpreter(converter.convert()) 21 | assert interpreter.input_types == [np.float32] 22 | assert interpreter.output_types == [np.float32] 23 | assert interpreter.input_shapes == [(1, *input_shape)] 24 | assert interpreter.output_shapes == [(1, np.product(input_shape))] 25 | 26 | def input_fn(): 27 | if use_iterator: 28 | return (input for input in inputs) 29 | return inputs 30 | 31 | outputs = interpreter.predict(input_fn(), 1) 32 | np.testing.assert_allclose(outputs, expected_outputs) 33 | 34 | 35 | @pytest.mark.parametrize("use_iterator", [True, False]) 36 | def test_interpreter_multi_input(use_iterator): 37 | x = tf.keras.Input((24, 24, 2)) 38 | y = tf.keras.Input((24, 24, 1)) 39 | model = tf.keras.Model( 40 | [x, y], [tf.keras.layers.Flatten()(x), tf.keras.layers.Flatten()(y)] 41 | ) 42 | converter = tf.lite.TFLiteConverter.from_keras_model(model) 43 | 44 | x_np = np.random.uniform(-1, 1, size=(16, 24, 24, 2)).astype(np.float32) 45 | y_np = np.random.uniform(-1, 1, size=(16, 24, 24, 1)).astype(np.float32) 46 | expected_output_x = x_np.reshape(16, -1) 47 | expected_output_y = y_np.reshape(16, -1) 48 | 49 | interpreter = Interpreter(converter.convert(), num_threads=2) 50 | assert interpreter.input_types == [np.float32, np.float32] 51 | assert interpreter.output_types == [np.float32, np.float32] 52 | assert sorted(interpreter.input_shapes) == [(1, 24, 24, 1), (1, 24, 24, 2)] 53 | assert sorted(interpreter.output_shapes) == [(1, 24 * 24 * 1), (1, 24 * 24 * 2)] 54 | 55 | # Input order is not deterministic, decide based on shape 56 | if interpreter.input_shapes[0][-1] == 1: 57 | x_np, y_np = y_np, x_np 58 | 59 | def input_fn(): 60 | if use_iterator: 61 | return ([x, y] for x, y in zip(x_np, y_np)) 62 | return [x_np, y_np] 63 | 64 | output_x, output_y = interpreter.predict(input_fn()) 65 | # Output order is not deterministic, decide based on shape 66 | if output_y.shape == expected_output_x.shape: 67 | output_x, output_y = output_y, output_x 68 | np.testing.assert_allclose(output_x, expected_output_x) 69 | np.testing.assert_allclose(output_y, expected_output_y) 70 | 71 | 72 | if __name__ == "__main__": 73 | sys.exit(pytest.main([__file__, "-s"])) 74 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pytype] 2 | 3 | inputs = larq_compute_engine 4 | # Keep going past errors to analyse as many files as possible. 5 | keep_going = True 6 | # Don't check use of imported modules, because we have no type-stubs for TF. 7 | strict_import = True 8 | # Disable import errors since our pybind modules are not available during type check and we don't supply type stubs 9 | disable = import-error 10 | python_version = 3.9 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup for pip package.""" 2 | 3 | import os 4 | from sys import platform 5 | 6 | from setuptools import Extension, dist, find_packages, setup 7 | 8 | 9 | def readme(): 10 | with open("README.md", "r") as f: 11 | return f.read() 12 | 13 | 14 | class BinaryDistribution(dist.Distribution): 15 | """This class is needed in order to create OS specific wheels.""" 16 | 17 | def has_ext_modules(self): 18 | return True 19 | 20 | 21 | def get_version_number(default): 22 | # The `or default` is because on CI the `getenv` can return the empty string. 23 | version = os.getenv("LCE_RELEASE_VERSION", default) or default 24 | if "." not in version: 25 | raise ValueError(f"Invalid version: {version}") 26 | return version 27 | 28 | 29 | ext_modules = [Extension("_foo", ["stub.cc"])] if platform.startswith("linux") else [] 30 | 31 | setup( 32 | name="larq-compute-engine", 33 | version=get_version_number(default="0.16.0"), 34 | python_requires=">=3.10", 35 | description="Highly optimized inference engine for binarized neural networks.", 36 | long_description=readme(), 37 | long_description_content_type="text/markdown", 38 | author="Plumerai", 39 | author_email="opensource@plumerai.com", 40 | packages=find_packages(), 41 | ext_modules=ext_modules, 42 | url="https://larq.dev/", 43 | install_requires=["flatbuffers>=2.0", "tqdm>=4"], 44 | extras_require={ 45 | "tensorflow": ["tensorflow>=2.8"], 46 | "tensorflow_gpu": ["tensorflow-gpu>=2.8"], 47 | }, 48 | include_package_data=True, 49 | zip_safe=False, 50 | distclass=BinaryDistribution, 51 | classifiers=[ 52 | "Development Status :: 4 - Beta", 53 | "Intended Audience :: Developers", 54 | "Intended Audience :: Education", 55 | "Intended Audience :: Science/Research", 56 | "License :: OSI Approved :: Apache Software License", 57 | "Programming Language :: Python :: 3", 58 | "Programming Language :: Python :: 3 :: Only", 59 | "Programming Language :: Python :: 3.10", 60 | "Programming Language :: Python :: 3.11", 61 | "Programming Language :: Python :: 3.12", 62 | "Topic :: Scientific/Engineering", 63 | "Topic :: Scientific/Engineering :: Mathematics", 64 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 65 | "Topic :: Software Development", 66 | "Topic :: Software Development :: Libraries", 67 | "Topic :: Software Development :: Libraries :: Python Modules", 68 | ], 69 | license="Apache 2.0", 70 | keywords="binarized neural networks", 71 | ) 72 | -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- 1 | black==24.3.0 2 | pyflakes==2.4.0 3 | pytype==2022.3.29 4 | -------------------------------------------------------------------------------- /third_party/install_android.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # **NOTE**: This requires Java 8 and won't work on never versions. See: 5 | # https://stackoverflow.com/questions/46402772/failed-to-install-android-sdk-java-lang-noclassdeffounderror-javax-xml-bind-a 6 | 7 | # Taken from tensorflow/lite/tools/tflite-android.Dockerfile 8 | 9 | # default LCE Android Env. variables 10 | export ANDROID_SDK_URL="https://dl.google.com/android/repository/commandlinetools-linux-6858069_latest.zip" 11 | export ANDROID_HOME="/tmp/lce_android" 12 | export ANDROID_API_LEVEL=30 13 | export ANDROID_BUILD_TOOLS_VERSION=31.0.0 14 | export ANDROID_NDK_VERSION=25.2.9519653 15 | export ANDROID_NDK_API_LEVEL=30 16 | 17 | 18 | # download android SDK 19 | mkdir -p $ANDROID_HOME; cd $ANDROID_HOME; 20 | 21 | echo -e "Downloading Android SDK ... " 22 | curl -o lce_android_sdk.zip $ANDROID_SDK_URL; 23 | echo -e "DONE.\n\n" 24 | 25 | echo -e "Unpacking Android SDK ... " 26 | unzip lce_android_sdk.zip -d /tmp 27 | mkdir -p ${ANDROID_HOME}/cmdline-tools 28 | mv /tmp/cmdline-tools ${ANDROID_HOME}/cmdline-tools/latest 29 | echo -e "DONE.\n\n" 30 | 31 | rm lce_android_sdk.zip; 32 | 33 | # install android platform and build tools 34 | echo -e "Updating SDK manager ... " 35 | yes | $ANDROID_HOME/cmdline-tools/latest/bin/sdkmanager --licenses 36 | $ANDROID_HOME/cmdline-tools/latest/bin/sdkmanager --update 37 | echo -e "DONE.\n\n" 38 | 39 | echo -e "Installing Android SDK Platform and Build Tools ... " 40 | $ANDROID_HOME/cmdline-tools/latest/bin/sdkmanager \ 41 | "build-tools;${ANDROID_BUILD_TOOLS_VERSION}" \ 42 | "platforms;android-${ANDROID_API_LEVEL}" \ 43 | "platform-tools" 44 | echo -e "DONE.\n\n" 45 | 46 | echo -e "Installing Android NDK ... " 47 | $ANDROID_HOME/cmdline-tools/latest/bin/sdkmanager \ 48 | "ndk;${ANDROID_NDK_VERSION}" 49 | echo -e "DONE.\n\n" 50 | -------------------------------------------------------------------------------- /third_party/tensorflow_patches/BUILD: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/larq/compute-engine/b8ec518b773302b2f32c2460cfa7e0267d9dbea0/third_party/tensorflow_patches/BUILD -------------------------------------------------------------------------------- /third_party/tensorflow_patches/disable_forced_mkl.patch: -------------------------------------------------------------------------------- 1 | diff --git a/third_party/xla/third_party/tsl/tsl/mkl/build_defs.bzl b/third_party/xla/third_party/tsl/tsl/mkl/build_defs.bzl 2 | index 90030a39744..489ebaa5aa7 100644 3 | --- a/third_party/xla/third_party/tsl/tsl/mkl/build_defs.bzl 4 | +++ b/third_party/xla/third_party/tsl/tsl/mkl/build_defs.bzl 5 | @@ -33,8 +33,9 @@ def if_mkl(if_true, if_false = []): 6 | """ 7 | return select({ 8 | "@local_tsl//tsl/mkl:build_with_mkl_aarch64": if_true, 9 | - "@local_tsl//tsl:linux_x86_64": if_true, 10 | - "@local_tsl//tsl:windows": if_true, 11 | + "@local_tsl//tsl/mkl:build_with_mkl_lnx_x64": if_true, 12 | + "@local_tsl//tsl/mkl:build_with_mkl_lnx_openmp": if_true, 13 | + "@local_tsl//tsl/mkl:build_with_mkl_windows_openmp": if_true, 14 | "//conditions:default": if_false, 15 | }) 16 | 17 | @@ -102,8 +103,8 @@ def mkl_deps(): 18 | """ 19 | return select({ 20 | "@local_tsl//tsl/mkl:build_with_mkl_aarch64": ["@mkl_dnn_acl_compatible//:mkl_dnn_acl"], 21 | - "@local_tsl//tsl:linux_x86_64": ["@onednn//:mkl_dnn"], 22 | - "@local_tsl//tsl:windows": ["@onednn//:mkl_dnn"], 23 | + "@local_tsl//tsl/mkl:build_with_mkl_lnx_x64": ["@onednn//:mkl_dnn"], 24 | + "@local_tsl//tsl/mkl:build_with_mkl_windows_openmp": ["@onednn//:mkl_dnn"], 25 | "//conditions:default": [], 26 | }) 27 | 28 | --------------------------------------------------------------------------------