├── .github ├── ISSUE_TEMPLATE │ ├── bug-performance-issue.md │ ├── feature_request.md │ ├── operators.md │ └── question.md ├── actions │ ├── keras_application_test │ │ └── action.yml │ ├── keras_unit_test │ │ └── action.yml │ ├── pretrained_model_test │ │ └── action.yml │ └── unit_test │ │ └── action.yml └── workflows │ ├── keras_application_test_ci.yml │ ├── keras_unit_test_ci.yml │ ├── pretrained_model_test_ci.yml │ ├── pylint.yml │ └── unit_test_ci.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── Troubleshooting.md ├── VERSION_NUMBER ├── build.bat ├── build.sh ├── examples ├── benchmark_tfmodel_ort.py ├── call_converter_via_python.py ├── custom_op_via_python.py ├── end2end_tfhub.py ├── end2end_tfkeras.py ├── getting_started.py ├── rnn_tips.md └── tf_custom_op │ ├── custom_op.md │ ├── double_and_add_one.cc │ ├── double_and_add_one.so │ └── double_and_add_one_custom_op.py ├── setup.cfg ├── setup.py ├── support_status.md ├── tests ├── ade20k.jpg ├── backend_test_base.py ├── beach.jpg ├── car.JPEG ├── cobalt_group_a_perf_testing_models.yaml ├── cobalt_group_c_perf_testing_models.yaml ├── common.py ├── completed_perf_testing_models.yaml ├── conftest.py ├── huggingface.py ├── in_progress_perf_testing_models.yaml ├── keras2onnx_applications │ ├── .gitignore │ ├── data │ │ └── street.jpg │ ├── lpcnet │ │ ├── README.md │ │ └── convert_lpcnet_to_onnx.py │ ├── mask_rcnn │ │ ├── README.md │ │ └── mask_rcnn.py │ ├── model_source │ │ ├── densenet_1 │ │ │ ├── densenet_1.py │ │ │ ├── subpixel.py │ │ │ └── tensorflow_backend.py │ │ └── densenet_2 │ │ │ └── densenet_2.py │ ├── nightly_build │ │ ├── run_all.py │ │ ├── run_all_v2.py │ │ ├── test_aae.py │ │ ├── test_acgan.py │ │ ├── test_autoencoder.py │ │ ├── test_bgan.py │ │ ├── test_bigan.py │ │ ├── test_ccgan.py │ │ ├── test_chatbot.py │ │ ├── test_cogan.py │ │ ├── test_craft.py │ │ ├── test_crnn.py │ │ ├── test_cyclegan.py │ │ ├── test_dcgan.py │ │ ├── test_deep_rl.py │ │ ├── test_deep_speaker.py │ │ ├── test_deep_speech.py │ │ ├── test_deepface.py │ │ ├── test_deeplab_v3.py │ │ ├── test_densenet_1.py │ │ ├── test_densenet_2.py │ │ ├── test_discogan.py │ │ ├── test_dual_path_network.py │ │ ├── test_dualgan.py │ │ ├── test_efn.py │ │ ├── test_fcn.py │ │ ├── test_gan.py │ │ ├── test_image_outpainting.py │ │ ├── test_inception_v4.py │ │ ├── test_infogan.py │ │ ├── test_keras_applications.py │ │ ├── test_keras_applications_v2.py │ │ ├── test_lipnet.py │ │ ├── test_lsgan.py │ │ ├── test_mask_rcnn.py │ │ ├── test_mlstm_fcn.py │ │ ├── test_music_generation.py │ │ ├── test_name_entity_recognition.py │ │ ├── test_nasnet_mobile.py │ │ ├── test_nbeats.py │ │ ├── test_nlp.py │ │ ├── test_nonlocal.py │ │ ├── test_ocr.py │ │ ├── test_open_face.py │ │ ├── test_pix2pix.py │ │ ├── test_pixelda.py │ │ ├── test_prn.py │ │ ├── test_pspnet.py │ │ ├── test_resnext.py │ │ ├── test_resume_parser.py │ │ ├── test_se_densenet.py │ │ ├── test_se_inc_resnet.py │ │ ├── test_segnet.py │ │ ├── test_segnet_2.py │ │ ├── test_semantic_embeddings.py │ │ ├── test_semantic_segmentation.py │ │ ├── test_series_net.py │ │ ├── test_sgan.py │ │ ├── test_srgan.py │ │ ├── test_ssrnet.py │ │ ├── test_super_resolution.py │ │ ├── test_transformers.py │ │ ├── test_unet.py │ │ ├── test_unet_plus_plus.py │ │ ├── test_wavenet.py │ │ ├── test_wgan.py │ │ ├── test_wgan_gp.py │ │ └── test_yolov3.py │ └── yolov3 │ │ ├── README.md │ │ └── yolov3.py ├── keras2onnx_unit_tests │ ├── conftest.py │ ├── mock_keras2onnx │ │ ├── __init__.py │ │ ├── ke2onnx │ │ │ └── batch_norm.py │ │ └── proto │ │ │ ├── __init__.py │ │ │ └── tfcompat.py │ ├── test_cgan.py │ ├── test_layers.py │ ├── test_subclassing.py │ └── test_utils.py ├── models │ ├── ae0 │ │ └── frozen.pb │ ├── conv-layers │ │ └── frozen.pb │ ├── fc-layers │ │ └── frozen.pb │ ├── gru │ │ └── frozen.pb │ ├── lstm │ │ └── frozen.pb │ ├── regression │ │ ├── checkpoint │ │ │ ├── checkpoint │ │ │ ├── model.data-00000-of-00001 │ │ │ ├── model.index │ │ │ └── model.meta │ │ ├── graphdef │ │ │ └── frozen.pb │ │ ├── saved_model │ │ │ ├── saved_model.pb │ │ │ └── variables │ │ │ │ ├── variables.data-00000-of-00001 │ │ │ │ └── variables.index │ │ └── tflite │ │ │ ├── model.tflite │ │ │ └── test_api_model.tflite │ └── saved_model_with_redundant_inputs │ │ └── saved_model.pb ├── run_pretrained_models.py ├── run_pretrained_models.yaml ├── run_tfjs.js ├── test_api.py ├── test_backend.py ├── test_cond.py ├── test_const_fold.py ├── test_convert.py ├── test_cudnn.py ├── test_cudnn_compatible_gru.py ├── test_custom_rnncell.py ├── test_einsum_helper.py ├── test_einsum_ml.py ├── test_einsum_optimizers.py ├── test_example.py ├── test_gru.py ├── test_grublock.py ├── test_internals.py ├── test_issue_2025.py ├── test_loops.py ├── test_lstm.py ├── test_lstmblock.py ├── test_onnx_shape_inference.py ├── test_optimizers.py ├── test_profile.py ├── test_seq2seq.py ├── test_stacked_lstm.py ├── test_string_ops.py ├── test_tf_shape_inference.py ├── test_tfjs_runner.py ├── test_tflite_postprocess.py ├── test_tflite_utils.py ├── tfhub │ ├── _tools.py │ ├── tfhub_albert_en_xlarge.py │ ├── tfhub_albert_en_xlarge_keras.py │ ├── tfhub_bert_en_wwm_uncased.py │ ├── tfhub_blazeposedetector.py │ ├── tfhub_enformer.py │ ├── tfhub_esrgan.py │ ├── tfhub_humpback_whale.py │ ├── tfhub_inception_v3.py │ ├── tfhub_lambert_en_uncased_L-24_H-1024_A-16.py │ ├── tfhub_mobile_food_segmenter_V1.py │ ├── tfhub_mobilebert_en_uncased.py │ ├── tfhub_mobilenet_v3_large_075_224.py │ ├── tfhub_mobilenet_v3_small_075_224.py │ ├── tfhub_nasnet_large.py │ ├── tfhub_resnet_v1_101.py │ ├── tfhub_resnet_v1_101_keras.py │ ├── tfhub_resnet_v2_101.py │ ├── tfhub_resnet_v2_101_classification.py │ ├── tfhub_resnet_v2_152.py │ ├── tfhub_spam_detection.py │ ├── tfhub_talkheads_ggelu_bert_en_large.py │ ├── tfhub_thunder.py │ ├── tfhub_yamnet_coral.py │ ├── tfhub_yamnet_tf.py │ └── tfhub_yamnet_tflite.py ├── tfjs_runner.py ├── unity.yaml └── utils │ └── setup_test_env.sh ├── tf2onnx ├── __init__.py ├── constants.py ├── convert.py ├── custom_opsets │ ├── __init__.py │ ├── ms.py │ ├── onnx_ml.py │ └── string_ops.py ├── flexbuffers.py ├── graph.py ├── graph_builder.py ├── graph_matcher.py ├── handler.py ├── keras2onnx_api.py ├── late_rewriters │ ├── __init__.py │ └── channel_order_rewriters.py ├── onnx_opset │ ├── __init__.py │ ├── common.py │ ├── controlflow.py │ ├── generator.py │ ├── logical.py │ ├── math.py │ ├── misc.py │ ├── nn.py │ ├── quantize.py │ ├── reduction.py │ ├── rnn.py │ ├── signal.py │ ├── tensor.py │ └── traditionalml.py ├── optimizer │ ├── __init__.py │ ├── back_to_back_optimizer.py │ ├── const_dequantize_optimizer.py │ ├── const_fold_optimizer.py │ ├── einsum_optimizer.py │ ├── global_pool_optimizer.py │ ├── identity_optimizer.py │ ├── loop_optimizer.py │ ├── merge_duplicated_nodes_optimizer.py │ ├── optimizer_base.py │ ├── q_dq_optimizer.py │ ├── reshape_optimizer.py │ ├── transpose_optimizer.py │ └── upsample_optimizer.py ├── rewriter │ ├── __init__.py │ ├── bigru_rewriter.py │ ├── bilstm_rewriter.py │ ├── cond_rewriter.py │ ├── conv2d_with_add_rewriter.py │ ├── conv2d_with_pad_rewriter.py │ ├── conv_dilations_rewriter.py │ ├── custom_rnn_rewriter.py │ ├── dropout_rewriter.py │ ├── eye_rewriter.py │ ├── flatten_rewriter.py │ ├── fused_op_rewriter.py │ ├── gemm_rewriter.py │ ├── gru_rewriter.py │ ├── gru_tf2_rewriter.py │ ├── layer_normalization_rewriter.py │ ├── leakyrelu_rewriter.py │ ├── loop_rewriter.py │ ├── loop_rewriter_base.py │ ├── lstm_rewriter.py │ ├── lstm_rewriter_base.py │ ├── lstm_tf2_rewriter.py │ ├── quantization_ops_rewriter.py │ ├── ragged_variant_shape_rewriter.py │ ├── random_normal_rewriter.py │ ├── random_uniform.py │ ├── rnn.py │ ├── rnn_utils.py │ ├── thresholded_relu_rewriter.py │ ├── transpose_rewriter.py │ └── unit_rnn_rewriter_base.py ├── schemas.py ├── shape_inference.py ├── symbolic_executor.py ├── tf_loader.py ├── tf_utils.py ├── tfjs_utils.py ├── tflite │ ├── ATan2Options.py │ ├── AbsOptions.py │ ├── ActivationFunctionType.py │ ├── AddNOptions.py │ ├── AddOptions.py │ ├── ArgMaxOptions.py │ ├── ArgMinOptions.py │ ├── AssignVariableOptions.py │ ├── BatchMatMulOptions.py │ ├── BatchToSpaceNDOptions.py │ ├── BidirectionalSequenceLSTMOptions.py │ ├── BidirectionalSequenceRNNOptions.py │ ├── BitcastOptions.py │ ├── BitwiseXorOptions.py │ ├── BroadcastToOptions.py │ ├── BucketizeOptions.py │ ├── Buffer.py │ ├── BuiltinOperator.py │ ├── BuiltinOptions.py │ ├── CallOnceOptions.py │ ├── CallOptions.py │ ├── CastOptions.py │ ├── CombinerType.py │ ├── ConcatEmbeddingsOptions.py │ ├── ConcatenationOptions.py │ ├── Conv2DOptions.py │ ├── Conv3DOptions.py │ ├── CosOptions.py │ ├── CumsumOptions.py │ ├── CustomOptionsFormat.py │ ├── CustomQuantization.py │ ├── DensifyOptions.py │ ├── DepthToSpaceOptions.py │ ├── DepthwiseConv2DOptions.py │ ├── DequantizeOptions.py │ ├── DimensionMetadata.py │ ├── DimensionType.py │ ├── DivOptions.py │ ├── DynamicUpdateSliceOptions.py │ ├── EmbeddingLookupSparseOptions.py │ ├── EqualOptions.py │ ├── ExpOptions.py │ ├── ExpandDimsOptions.py │ ├── FakeQuantOptions.py │ ├── FillOptions.py │ ├── FloorDivOptions.py │ ├── FloorModOptions.py │ ├── FullyConnectedOptions.py │ ├── FullyConnectedOptionsWeightsFormat.py │ ├── GatherNdOptions.py │ ├── GatherOptions.py │ ├── GeluOptions.py │ ├── GreaterEqualOptions.py │ ├── GreaterOptions.py │ ├── HardSwishOptions.py │ ├── HashtableFindOptions.py │ ├── HashtableImportOptions.py │ ├── HashtableOptions.py │ ├── HashtableSizeOptions.py │ ├── IfOptions.py │ ├── Int32Vector.py │ ├── L2NormOptions.py │ ├── LSHProjectionOptions.py │ ├── LSHProjectionType.py │ ├── LSTMKernelType.py │ ├── LSTMOptions.py │ ├── LeakyReluOptions.py │ ├── LessEqualOptions.py │ ├── LessOptions.py │ ├── LocalResponseNormalizationOptions.py │ ├── LogSoftmaxOptions.py │ ├── LogicalAndOptions.py │ ├── LogicalNotOptions.py │ ├── LogicalOrOptions.py │ ├── MatrixDiagOptions.py │ ├── MatrixSetDiagOptions.py │ ├── MaximumMinimumOptions.py │ ├── Metadata.py │ ├── MirrorPadMode.py │ ├── MirrorPadOptions.py │ ├── Model.py │ ├── MulOptions.py │ ├── NegOptions.py │ ├── NonMaxSuppressionV4Options.py │ ├── NonMaxSuppressionV5Options.py │ ├── NotEqualOptions.py │ ├── OneHotOptions.py │ ├── Operator.py │ ├── OperatorCode.py │ ├── PackOptions.py │ ├── PadOptions.py │ ├── PadV2Options.py │ ├── Padding.py │ ├── Pool2DOptions.py │ ├── PowOptions.py │ ├── QuantizationDetails.py │ ├── QuantizationParameters.py │ ├── QuantizeOptions.py │ ├── RNNOptions.py │ ├── RandomOptions.py │ ├── RangeOptions.py │ ├── RankOptions.py │ ├── ReadVariableOptions.py │ ├── ReducerOptions.py │ ├── ReshapeOptions.py │ ├── ResizeBilinearOptions.py │ ├── ResizeNearestNeighborOptions.py │ ├── ReverseSequenceOptions.py │ ├── ReverseV2Options.py │ ├── Rfft2dOptions.py │ ├── RightShiftOptions.py │ ├── SVDFOptions.py │ ├── ScatterNdOptions.py │ ├── SegmentSumOptions.py │ ├── SelectOptions.py │ ├── SelectV2Options.py │ ├── SequenceRNNOptions.py │ ├── ShapeOptions.py │ ├── SignOptions.py │ ├── SignatureDef.py │ ├── SkipGramOptions.py │ ├── SliceOptions.py │ ├── SoftmaxOptions.py │ ├── SpaceToBatchNDOptions.py │ ├── SpaceToDepthOptions.py │ ├── SparseIndexVector.py │ ├── SparseToDenseOptions.py │ ├── SparsityParameters.py │ ├── SplitOptions.py │ ├── SplitVOptions.py │ ├── SquareOptions.py │ ├── SquaredDifferenceOptions.py │ ├── SqueezeOptions.py │ ├── StridedSliceOptions.py │ ├── SubGraph.py │ ├── SubOptions.py │ ├── Tensor.py │ ├── TensorMap.py │ ├── TensorType.py │ ├── TileOptions.py │ ├── TopKV2Options.py │ ├── TransposeConvOptions.py │ ├── TransposeOptions.py │ ├── Uint16Vector.py │ ├── Uint8Vector.py │ ├── UnidirectionalSequenceLSTMOptions.py │ ├── UniqueOptions.py │ ├── UnpackOptions.py │ ├── UnsortedSegmentMaxOptions.py │ ├── UnsortedSegmentMinOptions.py │ ├── UnsortedSegmentProdOptions.py │ ├── UnsortedSegmentSumOptions.py │ ├── VarHandleOptions.py │ ├── VariantSubType.py │ ├── WhereOptions.py │ ├── WhileOptions.py │ ├── ZerosLikeOptions.py │ └── __init__.py ├── tflite_handlers │ ├── __init__.py │ ├── tfl_controlflow.py │ ├── tfl_direct.py │ ├── tfl_math.py │ ├── tfl_nn.py │ ├── tfl_postprocess.py │ └── tfl_tensor.py ├── tflite_rewriters │ ├── __init__.py │ ├── tfl_qdq_rewriter.py │ ├── tfl_rfft_rewriter.py │ ├── tfl_scan_output_rewriter.py │ └── tfl_select_zero_rewriter.py ├── tflite_utils.py ├── tfonnx.py ├── utils.py ├── verbose_logging.py └── version.py ├── tools ├── aggregate-patterns.py ├── dump-onnx.py ├── example.bat ├── example.sh ├── gen_doc.py ├── gen_tflite_flatbuffer.py ├── graphtool.py ├── make_regression_test_models.py ├── onnx-experiments.py ├── onnx-optimize.py ├── profile_conversion_time.py ├── pylintrc ├── save_pretrained_model.py ├── tf_graph_tool.py └── tfgraph.py └── tutorials ├── BertTutorial.ipynb ├── ConvertingSSDMobilenetToONNX.ipynb ├── README.md ├── efficientdet.ipynb ├── efficientnet-edge.ipynb ├── efficientnet-lite.ipynb ├── huggingface-bert.ipynb ├── keras-resnet50.ipynb └── mobiledet-tflite.ipynb /.github/ISSUE_TEMPLATE/bug-performance-issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug/Performance issue 3 | about: Use this template for reporting a bug or a performance issue. 4 | title: '' 5 | labels: 'bug' 6 | assignees: '' 7 | --- 8 | 9 | 16 | 17 | **Describe the bug** 18 | 19 | 20 | **Urgency** 21 | 22 | 23 | **System information** 24 | - OS Platform and Distribution (e.g., Linux Ubuntu 18.04*): 25 | - TensorFlow Version: 26 | - Python version: 27 | - ONNX version (if applicable, e.g. 1.11*): 28 | - ONNXRuntime version (if applicable, e.g. 1.11*): 29 | 30 | 31 | **To Reproduce** 32 | 33 | 34 | **Screenshots** 35 | 36 | 37 | **Additional context** 38 | 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Requests for new tf2onnx features 4 | title: '' 5 | labels: 'enhancement' 6 | assignees: '' 7 | 8 | --- 9 | 10 | Before submitting your request, please review past submissions to ensure that it is not a duplicate of a known feature request. 11 | 12 | ### Describe the feature request 13 | 14 | 15 | ### Describe scenario use case -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/operators.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Missing Operator 3 | about: Operator that does not currently support in tf2onnx. 4 | title: '' 5 | labels: 'unsupported ops' 6 | assignees: '' 7 | 8 | 9 | 10 | --- 11 | # New Operator 12 | 13 | ### Describe the operator 14 | 15 | 16 | ### Do you know this operator be constructed using existing ONNX operators? 17 | 18 | 19 | ### Is this operator used by any model currently? Which one? 20 | 21 | ### Are you willing to contribute it? (Y/N) 22 | 23 | ### Notes 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Ask a question about the tf2onnx. 4 | title: '' 5 | labels: 'question' 6 | assignees: '' 7 | 8 | 9 | 10 | --- 11 | # Ask a Question 12 | 13 | ### Question 14 | 15 | 16 | ### Further information 17 | - Is this issue related to a specific model? 18 | **Model name**: 19 | 20 | **Model opset**: 21 | 22 | ### Notes 23 | -------------------------------------------------------------------------------- /.github/actions/keras_unit_test/action.yml: -------------------------------------------------------------------------------- 1 | name: Keras2onnx Unit Test Run 2 | 3 | inputs: 4 | tf_version: 5 | description: 'TensorFlow version' 6 | python_version: 7 | description: 'Python version' 8 | ort_version: 9 | description: 'ONNXRuntime version' 10 | onnx_version: 11 | description: 'ONNX version' 12 | 13 | runs: 14 | using: "composite" 15 | steps: 16 | - name: Set up Python (${{ inputs.python_version }}) 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ inputs.python_version }} 20 | 21 | - name: Install dependencies (TF-v${{ inputs.tf_version }}) 22 | shell: bash 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install onnxconverter-common 26 | pip install onnx==${{ inputs.onnx_version }} 27 | pip install h5py==3.7.0 28 | pip install parameterized 29 | pip install timeout-decorator 30 | pip install coloredlogs flatbuffers 31 | pip install tensorflow==${{ inputs.tf_version }} 32 | pip install pytest pytest-cov pytest-runner 33 | pip install onnxruntime==${{ inputs.ort_version }} 34 | pip uninstall -y protobuf 35 | pip install "protobuf~=3.20" 36 | if [[ ${{ inputs.tf_version }} == 1.* ]]; then 37 | pip install numpy==1.19.0 38 | else 39 | pip install "numpy<2" 40 | fi 41 | 42 | pip install -e . 43 | 44 | echo "----- List all of depdencies:" 45 | pip freeze --all 46 | 47 | - name: Run keras_unit_test (Linux) 48 | shell: bash 49 | if: runner.os == 'Linux' 50 | run: | 51 | python -c "import onnxruntime" 52 | python -c "import onnxconverter_common" 53 | pytest tests/keras2onnx_unit_tests --doctest-modules --junitxml=junit/test-results.xml 54 | -------------------------------------------------------------------------------- /.github/actions/pretrained_model_test/action.yml: -------------------------------------------------------------------------------- 1 | name: Pretrained Model Test Run 2 | 3 | inputs: 4 | tf_version: 5 | description: 'TensorFlow version' 6 | python_version: 7 | description: 'Python version' 8 | ort_version: 9 | description: 'ONNXRuntime version' 10 | onnx_version: 11 | description: 'ONNX version' 12 | opset_version: 13 | description: 'OPSET version' 14 | skip_tflite: 15 | description: 'Skip TFLite tests' 16 | 17 | runs: 18 | using: "composite" 19 | steps: 20 | - name: Set up Python (${{ inputs.python_version }}) 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: ${{ inputs.python_version }} 24 | 25 | - name: Install dependencies (TF-v${{ inputs.tf_version }}) 26 | shell: bash 27 | run: | 28 | chmod +x ./tests/utils/setup_test_env.sh 29 | ./tests/utils/setup_test_env.sh ${{ inputs.tf_version }} ${{ inputs.ort_version }} ${{ inputs.onnx_version }} 30 | 31 | - name: Fix Paths (Windows only) 32 | shell: pwsh 33 | if: runner.os == 'Windows' 34 | run: | 35 | $site_dir = python -c "import site; print(site.getsitepackages()[1])" 36 | echo "##vso[task.prependpath]$site_dir\numpy\.libs" 37 | $base_dir = python -c "import site; print(site.getsitepackages()[0])" 38 | echo "##vso[task.prependpath]$base_dir/Library/bin" 39 | 40 | - name: Run pretrained_model_test 41 | shell: bash 42 | run: | 43 | # TODO: fix unity model path 44 | # python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --config tests/unity.yaml || status=$? 45 | python tests/run_pretrained_models.py --backend onnxruntime --opset ${{ inputs.opset_version }} --skip_tf_tests False --skip_tflite_tests ${{ inputs.skip_tflite }} --skip_tfjs_tests True --config tests/run_pretrained_models.yaml || status=$? 46 | ls 47 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | push: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | enforce-style: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout code 16 | uses: actions/checkout@v4 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: 3.9 # Specify the desired Python version (e.g., 3.8, 3.9) 22 | 23 | - name: Install dependencies 24 | run: pip install pylint==2.4.4 25 | 26 | - name: Run pylint 27 | run: | 28 | pip freeze 29 | pylint --rcfile=tools/pylintrc --ignore=version.py,tflite --disable=cyclic-import tf2onnx tests/*.py tools -j 0 30 | 31 | # Add other jobs or steps as needed 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .idea 3 | build 4 | dist 5 | bin 6 | obj 7 | .ipynb_checkpoints 8 | __pycache__ 9 | *.pyc 10 | *.swp 11 | .cache 12 | .eggs 13 | *.egg-info 14 | *.onnx 15 | run.sh 16 | node_modules/* 17 | tests/tfhub/*/*.onnx 18 | tests/tfhub/*/*.tar.gz 19 | tests/tfhub/*/*.tflite 20 | tests/tfhub/*/** 21 | 22 | # Unit test / coverage reports 23 | .coverage* 24 | .cache 25 | .hypothesis/ 26 | .pytest_cache/ 27 | test-output.xml 28 | 29 | # VSCode 30 | .vscode/ 31 | !.vscode/extensions.json 32 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Contributing 4 | 5 | We're always looking for your help to fix bugs and improve the product. Create a pull request and we'll be happy to take a look. 6 | 7 | # Checkin procedure 8 | 1. Fork the repo 9 | 2. git clone your fork 10 | 3. Create feature branch 11 | 4. Make and checkin your changes along with unit tests 12 | 5. git commit your changes 13 | 6. git push origin HEAD 14 | 7. To request merge into main send a pull request from the web ui 15 | https://github.com/onnx/tensorflow-onnx. 16 | 17 | 18 | New code *must* be accompanied by unit tests. 19 | 20 | *Note*: After creating a pull request, you will see a build getting triggered right away. You may check if style check and unit tests are passing. 21 | 22 | 23 | # Coding guidelines 24 | Please see [Coding Conventions and Standards](http://google.github.io/styleguide/pyguide.html) 25 | 26 | # Licensing guidelines 27 | This project welcomes contributions and suggestions. The contributions require you to 28 | agree the Developer Certificate of Origin (DCO) declaring that you have the right to, 29 | and actually do, grant us the rights to use your contribution. 30 | 31 | When you submit a pull request, a DCO-bot will automatically determine whether you need 32 | to provide a DCO and decorate the PR appropriately. 33 | 34 | You are ready to sign your code by using the `-s` flag during your commits. 35 | 36 | ```sh 37 | git commit -s 38 | ``` 39 | 40 | 41 | # Code of conduct 42 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 43 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 44 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 45 | -------------------------------------------------------------------------------- /VERSION_NUMBER: -------------------------------------------------------------------------------- 1 | 1.16.1 2 | -------------------------------------------------------------------------------- /build.bat: -------------------------------------------------------------------------------- 1 | rem SPDX-License-Identifier: Apache-2.0 2 | 3 | python -m pytest --cov=tf2onnx 4 | python setup.py bdist_wheel 5 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | set -x 7 | 8 | apt-get install -y protobuf-compiler libprotoc-dev 9 | pip install setuptools 10 | pip install onnx pytest-cov 11 | 12 | python setup.py test 13 | python setup.py bdist_wheel 14 | -------------------------------------------------------------------------------- /examples/call_converter_via_python.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | """ 4 | A simple example how to call tensorflow-onnx via python. 5 | """ 6 | 7 | import tensorflow as tf 8 | import tf2onnx 9 | 10 | with tf.Session() as sess: 11 | x = tf.placeholder(tf.float32, [2, 3], name="input") 12 | x_ = tf.add(x, x) 13 | _ = tf.identity(x_, name="output") 14 | onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph, input_names=["input:0"], output_names=["output:0"]) 15 | model_proto = onnx_graph.make_model("test") 16 | with open("/tmp/model.onnx", "wb") as f: 17 | f.write(model_proto.SerializeToString()) 18 | -------------------------------------------------------------------------------- /examples/custom_op_via_python.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | """ 4 | A simple example how to map a custom op in python. 5 | """ 6 | import tensorflow as tf 7 | import tf2onnx 8 | from onnx import helper 9 | 10 | _TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow" 11 | 12 | 13 | def print_handler(ctx, node, name, args): 14 | # replace tf.Print() with Identity 15 | # T output = Print(T input, data, @list(type) U, @string message, @int first_n, @int summarize) 16 | # becomes: 17 | # T output = Identity(T Input) 18 | node.type = "Identity" 19 | node.domain = _TENSORFLOW_DOMAIN 20 | del node.input[1:] 21 | return node 22 | 23 | 24 | with tf.Session() as sess: 25 | x = tf.placeholder(tf.float32, [2, 3], name="input") 26 | x_ = tf.add(x, x) 27 | x_ = tf.Print(x_, [x_], "hello") 28 | _ = tf.identity(x_, name="output") 29 | onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph, 30 | custom_op_handlers={"Print": (print_handler, [])}, 31 | extra_opset=[helper.make_opsetid(_TENSORFLOW_DOMAIN, 1)], 32 | input_names=["input:0"], 33 | output_names=["output:0"]) 34 | model_proto = onnx_graph.make_model("test") 35 | with open("/tmp/model.onnx", "wb") as f: 36 | f.write(model_proto.SerializeToString()) 37 | -------------------------------------------------------------------------------- /examples/getting_started.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | """ 4 | This example shows how to convert tf functions and keras models using the Python API. 5 | It also demonstrates converting saved_models from the command line. 6 | """ 7 | 8 | import tensorflow as tf 9 | import tf2onnx 10 | import numpy as np 11 | import onnxruntime as ort 12 | import os 13 | 14 | ##################### tf function ##################### 15 | 16 | @tf.function 17 | def f(a, b): 18 | return a + b 19 | 20 | input_signature = [tf.TensorSpec([2, 3], tf.float32), tf.TensorSpec([2, 3], tf.float32)] 21 | onnx_model, _ = tf2onnx.convert.from_function(f, input_signature, opset=13) 22 | 23 | a_val = np.ones([2, 3], np.float32) 24 | b_val = np.zeros([2, 3], np.float32) 25 | 26 | print("Tensorflow result") 27 | print(f(a_val, b_val).numpy()) 28 | 29 | print("ORT result") 30 | sess = ort.InferenceSession(onnx_model.SerializeToString()) 31 | res = sess.run(None, {'a': a_val, 'b': b_val}) 32 | print(res[0]) 33 | 34 | 35 | ##################### Keras Model ##################### 36 | 37 | model = tf.keras.Sequential() 38 | model.add(tf.keras.layers.Dense(4, activation="relu")) 39 | 40 | input_signature = [tf.TensorSpec([3, 3], tf.float32, name='x')] 41 | onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=13) 42 | 43 | x_val = np.ones((3, 3), np.float32) 44 | 45 | print("Keras result") 46 | print(model(x_val).numpy()) 47 | 48 | print("ORT result") 49 | sess = ort.InferenceSession(onnx_model.SerializeToString()) 50 | res = sess.run(None, {'x': x_val}) 51 | print(res[0]) 52 | 53 | 54 | ##################### Saved Model ##################### 55 | 56 | model.save("savedmodel") 57 | os.system("python -m tf2onnx.convert --saved-model savedmodel --output model.onnx --opset 13") 58 | 59 | print("ORT result") 60 | sess = ort.InferenceSession("model.onnx") 61 | res = sess.run(None, {'dense_input': x_val}) 62 | print(res[0]) 63 | 64 | print("Conversion succeeded") -------------------------------------------------------------------------------- /examples/tf_custom_op/double_and_add_one.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/examples/tf_custom_op/double_and_add_one.so -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | [aliases] 4 | test=pytest 5 | 6 | [tool:pytest] 7 | addopts=--cov=tf2onnx 8 | norecursedirs=tests/keras2onnx_applications tests/keras2onnx_unit_tests 9 | -------------------------------------------------------------------------------- /tests/ade20k.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/ade20k.jpg -------------------------------------------------------------------------------- /tests/beach.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/beach.jpg -------------------------------------------------------------------------------- /tests/car.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/car.JPEG -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | """ print pytest config.""" 5 | 6 | from common import get_test_config 7 | from tf2onnx import logging 8 | 9 | 10 | def pytest_configure(): 11 | config = get_test_config() 12 | logging.basicConfig(level=config.log_level) 13 | with logging.set_scope_level(logging.INFO) as logger: 14 | logger.info(config) 15 | -------------------------------------------------------------------------------- /tests/keras2onnx_applications/.gitignore: -------------------------------------------------------------------------------- 1 | *.h5 2 | *.onnx 3 | -------------------------------------------------------------------------------- /tests/keras2onnx_applications/data/street.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/keras2onnx_applications/data/street.jpg -------------------------------------------------------------------------------- /tests/keras2onnx_applications/lpcnet/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Introduction 4 | This tool converts the lpcnet model to onnx. 5 | To run this code, we need first install the original lpcnet model from . 6 | Note that lpcnet is not a package, so please add its directory to the path. 7 | Then run 8 | ``` 9 | python convert_lpcnet_to_onnx.py [model_file] 10 | ``` 11 | model_file is the model with trained weights, it is a *.h5 file. 12 | -------------------------------------------------------------------------------- /tests/keras2onnx_applications/lpcnet/convert_lpcnet_to_onnx.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | import lpcnet 4 | import sys 5 | 6 | model, enc, dec = lpcnet.new_lpcnet_model(use_gpu=False) 7 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) 8 | model_file = sys.argv[1] 9 | model.load_weights(model_file) 10 | 11 | import mock_keras2onnx 12 | oxml_enc = mock_keras2onnx.convert_keras(enc, 'lpcnet_enc') 13 | oxml_dec = mock_keras2onnx.convert_keras(dec, 'lpcnet_dec') 14 | 15 | import onnx 16 | onnx.save(oxml_enc, "lpcnet_enc.onnx") 17 | onnx.save(oxml_dec, "lpcnet_dec.onnx") 18 | -------------------------------------------------------------------------------- /tests/keras2onnx_applications/mask_rcnn/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Introduction 4 | The original Keras project of Masked RCNN is: . Follow the 'Installation' section in README.md to set up the model. 5 | There is also a good [tutorial](https://github.com/matterport/Mask_RCNN#step-by-step-detection) to learn about the object detection. 6 | 7 | The conversion supports since opset 11, And the converted model need to be working with ONNXRuntime latest version which supports ONNX opset 11 and contrib ops needed by this model. 8 | 9 | # Convert and Run the model. 10 | ``` 11 | cd 12 | pip install -e . 13 | cd /applications/mask_rcnn 14 | # convert the model to onnx and test it with an image. 15 | python mask_rcnn.py 16 | ``` 17 | The unit test is added in our nightly build, see [here](https://github.com/onnx/keras-onnx/blob/master/applications/nightly_build/test_mask_rcnn.py) 18 | -------------------------------------------------------------------------------- /tests/keras2onnx_applications/model_source/densenet_1/tensorflow_backend.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # From https://github.com/titu1994/DenseNet/blob/master/tensorflow_backend.py 4 | # Modifications Copyright (c) Microsoft. 5 | 6 | import tensorflow as tf 7 | 8 | from mock_keras2onnx.proto import keras 9 | from keras.backend import tensorflow_backend as KTF 10 | from keras.backend.common import image_data_format 11 | 12 | py_all = all 13 | 14 | def depth_to_space(input, scale, data_format=None): 15 | ''' Uses phase shift algorithm to convert channels/depth for spatial resolution ''' 16 | if data_format is None: 17 | data_format = image_data_format() 18 | 19 | if data_format == 'channels_first': 20 | data_format = 'NCHW' 21 | else: 22 | data_format = 'NHWC' 23 | 24 | data_format = data_format.lower() 25 | out = tf.depth_to_space(input, scale, data_format=data_format) 26 | return out 27 | -------------------------------------------------------------------------------- /tests/keras2onnx_applications/nightly_build/run_all.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | import os 4 | from os import listdir 5 | from os.path import isfile, join 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--exclude') 10 | args = parser.parse_args() 11 | exclude_set = set(args.exclude.split()) if args.exclude is not None else set() 12 | 13 | os.environ["PYTHONPATH"] = \ 14 | os.environ.get("PYTHONPATH", "") + os.pathsep + "../../keras2onnx_unit_tests" + os.pathsep + "../../../" 15 | os.environ["TF2ONNX_CATCH_ERRORS"] = "FALSE" 16 | 17 | mypath = '.' 18 | files = [f for f in listdir(mypath) if isfile(join(mypath, f)) and f.find("test_") == 0] 19 | files.sort() 20 | 21 | res_final = True 22 | for f_ in files: 23 | if f_ not in exclude_set: 24 | res = os.system("pytest " + f_ + " --no-cov " 25 | "--doctest-modules --junitxml=junit/test-results-" + f_[5:-3] + ".xml") 26 | if res > 0: 27 | res_final = False 28 | 29 | if res_final: 30 | assert(True) 31 | else: 32 | assert(False) 33 | -------------------------------------------------------------------------------- /tests/keras2onnx_applications/nightly_build/run_all_v2.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | import os 4 | os.environ["PYTHONPATH"] = \ 5 | os.environ.get("PYTHONPATH", "") + os.pathsep + "../../keras2onnx_unit_tests" + os.pathsep + "../../../" 6 | os.environ["TF2ONNX_CATCH_ERRORS"] = "FALSE" 7 | 8 | files = ['test_keras_applications_v2.py', 'test_transformers.py', 'test_chatbot.py', 'test_efn.py', \ 9 | 'test_resnext.py'] 10 | files.sort() 11 | 12 | res_final = True 13 | for f_ in files: 14 | res = os.system("pytest " + f_ + " --no-cov " 15 | "--doctest-modules --junitxml=junit/test-results-" + f_[5:-3] + ".xml") 16 | if res > 0: 17 | res_final = False 18 | 19 | if res_final: 20 | assert(True) 21 | else: 22 | assert(False) 23 | -------------------------------------------------------------------------------- /tests/keras2onnx_applications/nightly_build/test_densenet_1.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | import os 4 | import sys 5 | import unittest 6 | import keras_segmentation 7 | from os.path import dirname, abspath 8 | 9 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/')) 10 | from test_utils import run_image 11 | 12 | img_path = os.path.join(os.path.dirname(__file__), '../data', 'street.jpg') 13 | 14 | from mock_keras2onnx.proto import is_keras_older_than 15 | 16 | class TestDenseNet_1(unittest.TestCase): 17 | 18 | def setUp(self): 19 | self.model_files = [] 20 | 21 | def tearDown(self): 22 | for fl in self.model_files: 23 | os.remove(fl) 24 | 25 | @unittest.skipIf(is_keras_older_than("2.2.3"), 26 | "Cannot import normalize_data_format from keras.backend") 27 | def test_densenet(self): 28 | # From https://github.com/titu1994/DenseNet/blob/master/densenet.py 29 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../model_source/densenet_1/')) 30 | import densenet_1 31 | image_dim = (224, 224, 3) 32 | model = densenet_1.DenseNetImageNet121(input_shape=image_dim) 33 | res = run_image(model, self.model_files, img_path, target_size=(224, 224)) 34 | self.assertTrue(*res) 35 | 36 | 37 | if __name__ == "__main__": 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /tests/keras2onnx_applications/nightly_build/test_densenet_2.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | import os 4 | import sys 5 | import unittest 6 | import keras_segmentation 7 | from os.path import dirname, abspath 8 | 9 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/')) 10 | from test_utils import run_image 11 | 12 | img_path = os.path.join(os.path.dirname(__file__), '../data', 'street.jpg') 13 | 14 | from mock_keras2onnx.proto import is_keras_older_than 15 | 16 | class TestDenseNet_2(unittest.TestCase): 17 | 18 | def setUp(self): 19 | self.model_files = [] 20 | 21 | def tearDown(self): 22 | for fl in self.model_files: 23 | os.remove(fl) 24 | 25 | def test_densenet(self): 26 | # From https://github.com/tdeboissiere/DeepLearningImplementations/blob/master/DenseNet/densenet.py 27 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../model_source/densenet_2/')) 28 | import densenet_2 29 | model = densenet_2.DenseNet(20, 30 | (224, 224, 3), 31 | 4, 32 | 1, 33 | 1, 34 | nb_filter=10) 35 | res = run_image(model, self.model_files, img_path, target_size=(224, 224)) 36 | self.assertTrue(*res) 37 | 38 | 39 | if __name__ == "__main__": 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /tests/keras2onnx_applications/nightly_build/test_efn.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | import os 4 | import sys 5 | import unittest 6 | from os.path import dirname, abspath 7 | from mock_keras2onnx.proto import keras, is_tensorflow_older_than 8 | 9 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/')) 10 | from test_utils import run_image 11 | 12 | img_path = os.path.join(os.path.dirname(__file__), '../data', 'street.jpg') 13 | 14 | 15 | @unittest.skipIf(is_tensorflow_older_than('2.1.0'), "efficientnet needs tensorflow >= 2.1.0") 16 | class TestEfn(unittest.TestCase): 17 | 18 | def setUp(self): 19 | self.model_files = [] 20 | 21 | def tearDown(self): 22 | for fl in self.model_files: 23 | os.remove(fl) 24 | 25 | @unittest.skip("TODO: model discrepancy") 26 | def test_custom(self): 27 | from efficientnet import tfkeras as efn 28 | keras.backend.set_learning_phase(0) 29 | base_model = efn.EfficientNetB0(input_shape=(600, 600, 3), weights=None) 30 | backbone = keras.Model(base_model.input, base_model.get_layer("top_activation").output) 31 | res = run_image(backbone, self.model_files, img_path, target_size=(600, 600), 32 | rtol=1e-2, atol=1e-1) 33 | self.assertTrue(*res) 34 | 35 | def test_efn(self): 36 | from efficientnet import tfkeras as efn 37 | keras.backend.set_learning_phase(0) 38 | model = efn.EfficientNetB0(weights=None) 39 | res = run_image(model, self.model_files, img_path, target_size=(224, 224), rtol=1e-2) 40 | self.assertTrue(*res) 41 | 42 | 43 | if __name__ == "__main__": 44 | unittest.main() 45 | -------------------------------------------------------------------------------- /tests/keras2onnx_applications/nightly_build/test_fcn.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | import os 4 | import sys 5 | import unittest 6 | import keras_segmentation 7 | from os.path import dirname, abspath 8 | 9 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/')) 10 | from test_utils import run_image 11 | img_path = os.path.join(os.path.dirname(__file__), '../data', 'street.jpg') 12 | 13 | 14 | class TestFCN(unittest.TestCase): 15 | 16 | def setUp(self): 17 | self.model_files = [] 18 | 19 | def tearDown(self): 20 | for fl in self.model_files: 21 | os.remove(fl) 22 | 23 | def test_fcn(self): 24 | # From https://github.com/divamgupta/image-segmentation-keras/models/fcn.py 25 | model = keras_segmentation.models.fcn.fcn_8(101) 26 | res = run_image(model, self.model_files, img_path, target_size=(416, 608)) 27 | self.assertTrue(*res) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() 32 | -------------------------------------------------------------------------------- /tests/keras2onnx_applications/nightly_build/test_segnet.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | import os 4 | import sys 5 | import unittest 6 | import keras_segmentation 7 | from os.path import dirname, abspath 8 | 9 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/')) 10 | from test_utils import run_image 11 | img_path = os.path.join(os.path.dirname(__file__), '../data', 'street.jpg') 12 | 13 | 14 | class TestSegNet(unittest.TestCase): 15 | 16 | def setUp(self): 17 | self.model_files = [] 18 | 19 | def tearDown(self): 20 | for fl in self.model_files: 21 | os.remove(fl) 22 | 23 | def test_segnet(self): 24 | # From https://github.com/divamgupta/image-segmentation-keras/blob/master/keras_segmentation/models/segnet.py 25 | model = keras_segmentation.models.segnet.segnet(101) 26 | res = run_image(model, self.model_files, img_path, target_size=(416, 608)) 27 | self.assertTrue(*res) 28 | 29 | def test_vgg_segnet(self): 30 | # From https://github.com/divamgupta/image-segmentation-keras/blob/master/keras_segmentation/models/segnet.py 31 | model = keras_segmentation.models.segnet.vgg_segnet(101) 32 | res = run_image(model, self.model_files, img_path, rtol=3.e-3, target_size=(416, 608)) 33 | self.assertTrue(*res) 34 | 35 | def test_mobilenet_segnet(self): 36 | # From https://github.com/divamgupta/image-segmentation-keras/blob/master/keras_segmentation/models/segnet.py 37 | model = keras_segmentation.models.segnet.mobilenet_segnet(101) 38 | res = run_image(model, self.model_files, img_path, target_size=(224, 224)) 39 | self.assertTrue(*res) 40 | 41 | if __name__ == "__main__": 42 | unittest.main() 43 | -------------------------------------------------------------------------------- /tests/keras2onnx_applications/yolov3/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Introduction 4 | The original keras model was coming from: , clone the project and follow the 'Quick Start' to get the pre-trained model. 5 | 6 | We have converted yolov3 model successfully and uploaded to the model zoo 7 | 8 | The model supports `batch_size = 1`. 9 | 10 | # Convert 11 | ``` 12 | export PYTHONPATH=$(the keras-yolo3 path) 13 | # run object detection, convert the model to onnx first if the onnx model does not exist 14 | python yolov3.py 15 | ``` 16 | The unit test is added in our nightly build, see [here](https://github.com/onnx/keras-onnx/blob/master/applications/nightly_build/test_yolov3.py) 17 | -------------------------------------------------------------------------------- /tests/keras2onnx_unit_tests/conftest.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | import os 4 | import pytest 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from mock_keras2onnx.proto import keras, is_tf_keras 10 | from test_utils import run_onnx_runtime 11 | from mock_keras2onnx.proto.tfcompat import is_tf2 12 | 13 | K = keras.backend 14 | 15 | 16 | @pytest.fixture(scope='function') 17 | def runner(): 18 | np.random.seed(42) 19 | if is_tf2: 20 | tf.random.set_seed(42) 21 | else: 22 | tf.random.set_random_seed(42) 23 | model_files = [] 24 | 25 | def runner_func(*args, **kwargs): 26 | return run_onnx_runtime(*args, model_files, **kwargs) 27 | 28 | # Ensure Keras layer naming is reset for each function 29 | if hasattr(K, "reset_uids"): 30 | # see https://github.com/onnx/tensorflow-onnx/issues/2370 31 | K.reset_uids() 32 | # Reset the TensorFlow session to avoid resource leaking between tests 33 | K.clear_session() 34 | 35 | # Provide wrapped run_onnx_runtime function 36 | yield runner_func 37 | # Remove model files 38 | for fl in model_files: 39 | os.remove(fl) 40 | -------------------------------------------------------------------------------- /tests/keras2onnx_unit_tests/mock_keras2onnx/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from tf2onnx.keras2onnx_api import convert_keras 3 | 4 | def set_converter(*args, **kwargs): 5 | pass 6 | -------------------------------------------------------------------------------- /tests/keras2onnx_unit_tests/mock_keras2onnx/ke2onnx/batch_norm.py: -------------------------------------------------------------------------------- 1 | convert_keras_batch_normalization = None -------------------------------------------------------------------------------- /tests/keras2onnx_unit_tests/mock_keras2onnx/proto/tfcompat.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | import os 4 | import tensorflow as _tf 5 | 6 | from packaging.version import Version 7 | 8 | is_tf2 = Version(_tf.__version__.split('-')[0]) >= Version("2.0.0") 9 | 10 | 11 | def normalize_tensor_shape(tensor_shape): 12 | if is_tf2: 13 | return [d for d in tensor_shape] 14 | else: 15 | return [d.value for d in tensor_shape] 16 | 17 | 18 | def dump_graph_into_tensorboard(tf_graph): 19 | # type: (_tf.Graph) -> None 20 | _tb_log_dir = os.environ.get('TB_LOG_DIR') 21 | if _tb_log_dir: 22 | if is_tf2: 23 | from tensorflow.python.ops.summary_ops_v2 import graph as write_graph 24 | pb_visual_writer = _tf.summary.create_file_writer(_tb_log_dir) 25 | with pb_visual_writer.as_default(): 26 | write_graph(tf_graph) 27 | else: 28 | from tensorflow.python.summary import summary 29 | pb_visual_writer = summary.FileWriter(_tb_log_dir) 30 | pb_visual_writer.add_graph(tf_graph) 31 | 32 | 33 | if is_tf2: 34 | tensorflow = _tf.compat.v1 35 | 36 | def is_subclassed(layer): 37 | """Returns True if the object is a subclassed layer or subclassed model.""" 38 | return (layer.__module__.find('keras.engine') == -1 and 39 | layer.__module__.find('keras.layers') == -1) 40 | else: 41 | tensorflow = _tf 42 | 43 | def is_subclassed(layer): 44 | return False 45 | -------------------------------------------------------------------------------- /tests/models/ae0/frozen.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/ae0/frozen.pb -------------------------------------------------------------------------------- /tests/models/conv-layers/frozen.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/conv-layers/frozen.pb -------------------------------------------------------------------------------- /tests/models/fc-layers/frozen.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/fc-layers/frozen.pb -------------------------------------------------------------------------------- /tests/models/gru/frozen.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/gru/frozen.pb -------------------------------------------------------------------------------- /tests/models/lstm/frozen.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/lstm/frozen.pb -------------------------------------------------------------------------------- /tests/models/regression/checkpoint/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model" 2 | all_model_checkpoint_paths: "model" 3 | -------------------------------------------------------------------------------- /tests/models/regression/checkpoint/model.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/checkpoint/model.data-00000-of-00001 -------------------------------------------------------------------------------- /tests/models/regression/checkpoint/model.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/checkpoint/model.index -------------------------------------------------------------------------------- /tests/models/regression/checkpoint/model.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/checkpoint/model.meta -------------------------------------------------------------------------------- /tests/models/regression/graphdef/frozen.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/graphdef/frozen.pb -------------------------------------------------------------------------------- /tests/models/regression/saved_model/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/saved_model/saved_model.pb -------------------------------------------------------------------------------- /tests/models/regression/saved_model/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/saved_model/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /tests/models/regression/saved_model/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/saved_model/variables/variables.index -------------------------------------------------------------------------------- /tests/models/regression/tflite/model.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/tflite/model.tflite -------------------------------------------------------------------------------- /tests/models/regression/tflite/test_api_model.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/tflite/test_api_model.tflite -------------------------------------------------------------------------------- /tests/models/saved_model_with_redundant_inputs/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/saved_model_with_redundant_inputs/saved_model.pb -------------------------------------------------------------------------------- /tests/test_cudnn.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | """Unit Tests for cudnn.""" 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from tensorflow.python.ops import init_ops 10 | from backend_test_base import Tf2OnnxBackendTestBase 11 | from common import check_tf_max_version, skip_tf_cpu, check_opset_min_version, unittest_main 12 | 13 | 14 | class CudnnTests(Tf2OnnxBackendTestBase): 15 | """ test cudnn cases """ 16 | @check_tf_max_version("1.15.0", "not supported in tf-2.0") 17 | @skip_tf_cpu("only tf_gpu can run CudnnGPU") 18 | @check_opset_min_version(10, "CudnnGRU") 19 | def test_cudnngru(self): 20 | """ test contrib cudnn gru """ 21 | seq_length = 3 22 | batch_size = 5 23 | input_size = 2 24 | num_layers = 2 25 | num_units = 2 26 | num_dirs = 2 27 | x_val = np.random.randint(0, 100, [seq_length, batch_size, input_size]).astype(np.float32) 28 | h_val = np.random.randint(0, 100, [num_layers * num_dirs, batch_size, num_units]).astype(np.float32).reshape( 29 | [num_layers * num_dirs, batch_size, num_units]) 30 | 31 | def func(x, h): 32 | initializer = init_ops.constant_initializer(0.5) 33 | cudnngru = tf.contrib.cudnn_rnn.CudnnGRU(num_layers, num_units, 'linear_input', 'bidirectional', 34 | kernel_initializer=initializer, bias_initializer=initializer) 35 | cudnngru.build([seq_length, batch_size, input_size]) 36 | outputs = cudnngru.call(x, tuple([h])) 37 | _ = tf.identity(outputs[0], name='output') 38 | 39 | feed_dict = {"input_1:0": x_val, "input_2:0": h_val} 40 | input_names_with_port = ["input_1:0", "input_2:0"] 41 | output_names_with_port = ["output:0"] 42 | self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-05, atol=1e-04) 43 | 44 | 45 | if __name__ == '__main__': 46 | unittest_main() 47 | -------------------------------------------------------------------------------- /tests/test_profile.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | """Unit Tests for Benchmarks.""" 5 | import os 6 | import subprocess 7 | from backend_test_base import Tf2OnnxBackendTestBase 8 | from common import ( 9 | check_opset_min_version, check_tf_min_version, 10 | unittest_main, check_onnxruntime_min_version 11 | ) 12 | 13 | # pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop 14 | # pylint: disable=invalid-name 15 | # pylint: enable=invalid-name 16 | 17 | class ProfileTests(Tf2OnnxBackendTestBase): 18 | 19 | folder = os.path.join(os.path.dirname(__file__), '..', 'tools') 20 | 21 | @check_tf_min_version("2.0") 22 | @check_opset_min_version(12) 23 | @check_onnxruntime_min_version('1.4.0') 24 | def test_profile_conversion_time(self): 25 | filename = os.path.join(ProfileTests.folder, 'profile_conversion_time.py') 26 | proc = subprocess.Popen( 27 | ["python", filename], stdout=subprocess.PIPE, stderr=subprocess.PIPE) 28 | try: 29 | outs = proc.communicate(timeout=15)[0] 30 | except subprocess.TimeoutExpired: 31 | proc.kill() 32 | return 33 | assert b"Profile complete." in outs or outs == b'' 34 | 35 | 36 | if __name__ == '__main__': 37 | unittest_main() 38 | -------------------------------------------------------------------------------- /tests/test_tfjs_runner.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | """Unit Tests for tfjs_runer.py""" 5 | 6 | import json 7 | import unittest 8 | 9 | import numpy as np 10 | 11 | from tfjs_runner import numpy_to_json, json_to_numpy 12 | 13 | # pylint: disable=missing-docstring 14 | 15 | 16 | class TestTfjsRunner(unittest.TestCase): 17 | def test_tfjs_runner(self): 18 | float_array = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32) 19 | int_array = np.array([[1, 2], [3, 4]], np.int32) 20 | bool_array = np.array([[True, False], [True, True]], bool) 21 | string_array = np.array([['Hello world', ''], ['π', 'Tensor']], str) 22 | complex_array = np.array([[1 + 0.1j, 2 + 0.2j], [3 + 0.3j, 4 + 0.4j]], np.complex64) 23 | 24 | arrays = [float_array, int_array, bool_array, string_array, complex_array] 25 | for arr in arrays: 26 | array_enc = json.dumps(numpy_to_json(arr)) 27 | print(array_enc) 28 | array_dec = json_to_numpy(json.loads(array_enc)) 29 | np.testing.assert_equal(arr, array_dec) 30 | 31 | 32 | if __name__ == '__main__': 33 | unittest.main() 34 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_albert_en_xlarge.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | import numpy.random as rnd 5 | from collections import OrderedDict 6 | from _tools import generate_text_inputs, benchmark 7 | 8 | 9 | def main(opset=13): 10 | url = "https://tfhub.dev/tensorflow/albert_en_xlarge/3?tf-hub-format=compressed" 11 | dest = "tf-albert-en-xlarge" 12 | name = "albert-en-xlarge" 13 | onnx_name = os.path.join(dest, "%s-%d.zip" % (name, opset)) 14 | 15 | inputs = generate_text_inputs() 16 | benchmark(url, dest, onnx_name, opset, inputs, output_name="pooled_output") 17 | 18 | inputs = [OrderedDict([ 19 | ('input_word_ids', numpy.array([rnd.randint(0, 1000) for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))), 20 | ('input_mask', numpy.array([rnd.randint(0, 1) for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))), 21 | ('input_type_ids', numpy.array([i//5 for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))) 22 | ]) for i in range(0, 10)] 23 | 24 | benchmark(url, dest, onnx_name, opset, inputs, output_name="pooled_output") 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_bert_en_wwm_uncased.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | from collections import OrderedDict 4 | import numpy 5 | import numpy.random as rnd 6 | from _tools import generate_random_images, benchmark 7 | 8 | 9 | def main(opset=13): 10 | url = "https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/4?tf-hub-format=compressed" 11 | dest = "tf-bert-en-wwm-uncased-L-24-H-1024-A-16" 12 | name = "bert-en-wwm-uncased-L-24-H-1024-A-16" 13 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 14 | 15 | inputs = [OrderedDict([ 16 | ('input_word_ids', numpy.array([rnd.randint(0, 1000) for i in range(0, 32)], dtype=numpy.int32).reshape((1, -1))), 17 | ('input_mask', numpy.array([rnd.randint(0, 1) for i in range(0, 32)], dtype=numpy.int32).reshape((1, -1))), 18 | ('input_type_ids', numpy.array([i//5 for i in range(0, 32)], dtype=numpy.int32).reshape((1, -1))) 19 | ]) for i in range(0, 10)] 20 | 21 | benchmark(url, dest, onnx_name, opset, inputs, output_name="pooled_output") 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_blazeposedetector.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/mediapipe/tfjs-model/blazeposedetector/1/default/1?tfjs-format=compressed" 9 | dest = "tf-blazeposedetector" 10 | name = "blazeposedetector" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | 13 | imgs = generate_random_images(shape=(1, 513, 513, 3), scale=1.) 14 | 15 | benchmark(url, dest, onnx_name, opset, imgs) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_esrgan.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/captain-pool/esrgan-tf2/1?tf-hub-format=compressed" 9 | dest = "tf-esrgan-tf2" 10 | name = "esrgan-tf2" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | 13 | imgs = generate_random_images() 14 | 15 | benchmark(url, dest, onnx_name, opset, imgs) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_inception_v3.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/google/inaturalist/inception_v3/feature_vector/5?tf-hub-format=compressed" 9 | dest = "tf-inception_v3" 10 | name = "inception_v3" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | 13 | imgs = generate_random_images(shape=(1, 299, 299, 3)) 14 | 15 | benchmark(url, dest, onnx_name, opset, imgs) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_lambert_en_uncased_L-24_H-1024_A-16.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | from collections import OrderedDict 4 | import numpy 5 | import numpy.random as rnd 6 | from _tools import generate_text_inputs, benchmark 7 | 8 | 9 | def main(opset=13): 10 | 11 | if False: 12 | import tensorflow as tf 13 | import tensorflow_text 14 | import tensorflow_hub as hub 15 | sentences = tf.constant(["Hi I'm some text"]) 16 | text_input = tf.keras.layers.Input(shape=(), dtype=tf.string) 17 | encoder = hub.KerasLayer( 18 | "https://tfhub.dev/tensorflow/lambert_en_uncased_L-24_H-1024_A-16/2", trainable=True) 19 | preprocessor = hub.KerasLayer( 20 | "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3") 21 | encoder_inputs = preprocessor(text_input) 22 | embedded_inputs = {k: v.numpy() for k, v in preprocessor(sentences).items()} 23 | for k, v in embedded_inputs.items(): 24 | print(k, v.dtype, v.shape) 25 | 26 | url = "https://tfhub.dev/tensorflow/lambert_en_uncased_L-24_H-1024_A-16/2?tf-hub-format=compressed" 27 | dest = "tf-lambert_en_uncased_L-24_H-1024_A-16" 28 | name = "lambert_en_uncased_L-24_H-1024_A-16" 29 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 30 | 31 | inputs = [OrderedDict([ 32 | ('input_word_ids', numpy.array([rnd.randint(0, 1000) for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))), 33 | ('input_mask', numpy.array([rnd.randint(0, 1) for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))), 34 | ('input_type_ids', numpy.array([i//5 for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))) 35 | ]) for i in range(0, 10)] 36 | 37 | benchmark(url, dest, onnx_name, opset, inputs, output_name="pooled_output") 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_mobile_food_segmenter_V1.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark 5 | import tf2onnx 6 | import onnxruntime as ort 7 | 8 | 9 | def main(opset=13): 10 | url = "https://tfhub.dev/google/seefood/segmenter/mobile_food_segmenter_V1/1?tf-hub-format=compressed" 11 | dest = "tf-mobile_food_segmenter_V1" 12 | name = "mobile_food_segmenter_V1" 13 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 14 | 15 | imgs = generate_random_images(shape=(1, 513, 513, 3), scale=1.) 16 | 17 | if True: 18 | benchmark(url, dest, onnx_name, opset, imgs, tag='') 19 | # The conversion works but tensorflow fails with 20 | # TypeError: 'AutoTrackable' object is not callable 21 | 22 | if True: 23 | import tensorflow.compat.v2 as tf 24 | import tensorflow_hub as hub 25 | 26 | m = hub.KerasLayer('https://tfhub.dev/google/seefood/segmenter/mobile_food_segmenter_V1/1') 27 | inputs = { 28 | "X": tf.keras.Input(shape=[1, 513, 513, 3], dtype="float32", name="X"), 29 | } 30 | outputs = m(inputs)["default"] 31 | # TypeError: pruned(images) missing required arguments: images 32 | print(outputs) 33 | model = tf.keras.Model(inputs, outputs) 34 | 35 | if not os.path.exists(dest): 36 | os.makedirs(dest) 37 | 38 | # This model is a large model. 39 | tf2onnx.convert.from_keras(model, opset=13, output_path=onnx_name) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_mobilebert_en_uncased.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | from collections import OrderedDict 4 | import numpy 5 | import numpy.random as rnd 6 | from _tools import generate_text_inputs, benchmark 7 | 8 | 9 | def main(opset=13): 10 | 11 | if False: 12 | import tensorflow as tf 13 | import tensorflow_text 14 | import tensorflow_hub as hub 15 | sentences = tf.constant(["Hi I'm some text"]) 16 | text_input = tf.keras.layers.Input(shape=(), dtype=tf.string) 17 | preprocessor = hub.KerasLayer( 18 | "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3") 19 | encoder_inputs = preprocessor(text_input) 20 | embedded_inputs = {k: v.numpy() for k, v in preprocessor(sentences).items()} 21 | for k, v in embedded_inputs.items(): 22 | print(k, v.dtype, v.shape) 23 | 24 | url = "https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1?tf-hub-format=compressed" 25 | dest = "tf-mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT" 26 | name = "mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT" 27 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 28 | 29 | inputs = generate_text_inputs() 30 | benchmark(url, dest, onnx_name, opset, inputs, 31 | output_name="attention_scores") #, ort_name="mobile_bert_encoder_50") 32 | 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_mobilenet_v3_large_075_224.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/google/imagenet/mobilenet_v3_large_075_224/classification/5?tf-hub-format=compressed" 9 | dest = "tf-mobilenet-v3-large-075-224" 10 | name = "mobilenet-v3-large-075-224" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | 13 | imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.) 14 | 15 | benchmark(url, dest, onnx_name, opset, imgs) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_mobilenet_v3_small_075_224.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/google/imagenet/mobilenet_v3_small_075_224/feature_vector/5?tf-hub-format=compressed" 9 | dest = "tf-mobilenet-v3-small-075-224" 10 | name = "mobilenet-v3-small-075-224" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | 13 | imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.) 14 | 15 | benchmark(url, dest, onnx_name, opset, imgs) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_nasnet_large.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/google/imagenet/nasnet_large/feature_vector/5?tf-hub-format=compressed" 9 | dest = "tf-nasnet-large" 10 | name = "nasnet-large" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | 13 | imgs = generate_random_images(shape=(1, 331, 331, 3)) 14 | 15 | benchmark(url, dest, onnx_name, opset, imgs) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_resnet_v1_101.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/google/imagenet/resnet_v1_101/feature_vector/5?tf-hub-format=compressed" 9 | dest = "tf-resnet_v1_101" 10 | name = "resnet_v1_101" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | 13 | imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.) 14 | 15 | benchmark(url, dest, onnx_name, opset, imgs) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_resnet_v1_101_keras.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | import onnxruntime as ort 5 | import tensorflow as tf 6 | import tensorflow_hub as hub 7 | import tf2onnx 8 | from _tools import generate_random_images, check_discrepencies 9 | 10 | imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.) 11 | 12 | model = tf.keras.Sequential([ 13 | hub.KerasLayer("https://tfhub.dev/google/imagenet/resnet_v1_101/feature_vector/5", 14 | trainable=False)]) 15 | model.build([None, 224, 224, 3]) 16 | 17 | expected_output = model(imgs[0]) 18 | 19 | dest = "tf-resnet_v1_101" 20 | if not os.path.exists(dest): 21 | os.makedirs(dest) 22 | dest_name = os.path.join(dest, "resnet_v1_101-13-keras.onnx") 23 | if not os.path.exists(dest_name): 24 | tf2onnx.convert.from_keras(model, opset=13, output_path=dest_name) 25 | 26 | sess = ort.InferenceSession(dest_name) 27 | print('inputs', [_.name for _ in sess.get_inputs()]) 28 | ort_output = sess.run(None, {"keras_layer_input": imgs[0]}) 29 | 30 | print("Actual") 31 | print(ort_output) 32 | print("Expected") 33 | print(expected_output) 34 | 35 | diff = expected_output.numpy() - ort_output[0] 36 | max_diff = numpy.abs(diff).max() 37 | rel_diff = (numpy.abs(diff) / (expected_output.numpy() + 1e-5)).max() 38 | print(max_diff, rel_diff, [ort_output[0].min(), ort_output[0].max()]) 39 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_resnet_v2_101.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/google/imagenet/resnet_v2_101/feature_vector/5?tf-hub-format=compressed" 9 | dest = "tf-resnet_v2_101" 10 | name = "resnet_v2_101" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | 13 | imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.) 14 | 15 | benchmark(url, dest, onnx_name, opset, imgs) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_resnet_v2_101_classification.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/google/imagenet/resnet_v1_101/classification/5?tf-hub-format=compressed" 9 | dest = "tf-resnet_v2_101_classification" 10 | name = "resnet_v2_101_classification" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | 13 | imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.) 14 | 15 | benchmark(url, dest, onnx_name, opset, imgs) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_resnet_v2_152.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/google/imagenet/resnet_v2_152/classification/5?tf-hub-format=compressed" 9 | dest = "tf-resnet_v2_152" 10 | name = "resnet_v2_152" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | 13 | imgs = generate_random_images(shape=(1, 224, 224, 3)) 14 | 15 | benchmark(url, dest, onnx_name, opset, imgs) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_spam_detection.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import random 4 | import numpy 5 | from _tools import generate_random_images, benchmark 6 | 7 | 8 | def main(opset=13): 9 | url = "https://tfhub.dev/tensorflow/tutorials/spam-detection/1?tf-hub-format=compressed" 10 | dest = "tf-spam-detection" 11 | name = "spam-detection" 12 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 13 | 14 | imgs = generate_random_images((1, 20), dtype=numpy.int32) 15 | 16 | benchmark(url, dest, onnx_name, opset, imgs) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_talkheads_ggelu_bert_en_large.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | from collections import OrderedDict 4 | import numpy 5 | import numpy.random as rnd 6 | from _tools import generate_random_images, benchmark 7 | 8 | 9 | def main(opset=13): 10 | url = "https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_large/2?tf-hub-format=compressed" 11 | dest = "tf-talkheads_ggelu_bert_en_large" 12 | name = "talkheads_ggelu_bert_en_large" 13 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 14 | 15 | inputs = [OrderedDict([ 16 | ('input_word_ids', numpy.array([rnd.randint(0, 1000) for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))), 17 | ('input_mask', numpy.array([rnd.randint(0, 1) for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))), 18 | ('input_type_ids', numpy.array([i//5 for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))) 19 | ]) for i in range(0, 10)] 20 | 21 | benchmark(url, dest, onnx_name, opset, inputs, output_name="pooled_output") 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_thunder.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/google/movenet/singlepose/thunder/3?tf-hub-format=compressed" 9 | dest = "tf-thunder" 10 | name = "thunder" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | 13 | imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32) 14 | 15 | benchmark(url, dest, onnx_name, opset, imgs, 16 | signature='serving_default') 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_yamnet_coral.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark_tflite 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/google/coral-model/yamnet/classification/coral/1?coral-format=tflite" 9 | dest = "tf-yamnet-coral" 10 | name = "yamnet" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | 13 | imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32) 14 | 15 | benchmark_tflite(url, dest, onnx_name, opset, imgs) 16 | # WARNING - Error loading model into tflite interpreter: Encountered unresolved custom op: edgetpu-custom-op.Node number 14 (edgetpu-custom-op) failed to prepare. 17 | # WARNING - Could not parse attributes for custom op 'TFL_edgetpu-custom-op': 'utf-8' codec can't decode byte 0xc8 in position 0: invalid continuation byte 18 | # WARNING - For now, onnxruntime only support float32 type for Gemm rewriter 19 | # ERROR - Tensorflow op [tower0/network/layer32/final_output1_prequant: TFL_edgetpu-custom-op] is not supported 20 | # ERROR - Unsupported ops: Counter({'TFL_edgetpu-custom-op': 1}) 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_yamnet_tf.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark, benchmark_tflite 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/google/yamnet/1?tf-hub-format=compressed" 9 | dest = "tf-yamnet-tf" 10 | name = "yamnet" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | tfl = os.path.join(dest, 'model.tflite') 13 | 14 | imgs = generate_random_images(shape=(16000, ), dtype=numpy.float32, scale=0.) 15 | 16 | # benchmark(url, dest, onnx_name, opset, imgs, convert_tflite=tfl) 17 | 18 | onnx_name = os.path.join(dest, "%s-tfl-%d.onnx" % (name, opset)) 19 | benchmark_tflite(tfl, dest, onnx_name, opset, imgs) 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /tests/tfhub/tfhub_yamnet_tflite.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | import os 3 | import numpy 4 | from _tools import generate_random_images, benchmark_tflite 5 | 6 | 7 | def main(opset=13): 8 | url = "https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1?lite-format=tflite" 9 | dest = "tf-yamnet-tflite" 10 | name = "yamnet" 11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset)) 12 | 13 | imgs = generate_random_images(shape=(15600, ), dtype=numpy.float32, scale=0.) 14 | 15 | benchmark_tflite(url, dest, onnx_name, opset, imgs, names=[ 16 | ('stft/rfft3', 'FFT_stft/rfft4_reshape__190:0'), 17 | ('magnitude_spectrogram', 'ComplexAbsmagnitude_spectrogram__206:0'), 18 | ('log_mel_spectrogram', 'log_mel_spectrogram'), 19 | ]) 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /tests/utils/setup_test_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # # Check if the argument is provided 4 | if [ "$#" -ne 3 ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | # Assign the argument to a variable 10 | TF_VERSION=$1 11 | ORT_VERSION=$2 12 | ONNX_VERSION=$3 13 | 14 | echo "==== TensorFlow version: $TF_VERSION" 15 | echo "==== ONNXRuntime version: $ORT_VERSION" 16 | echo "==== ONNX version: $ONNX_VERSION" 17 | 18 | pip install pytest pytest-cov pytest-runner coverage graphviz requests pyyaml pillow pandas parameterized sympy coloredlogs flatbuffers timeout-decorator 19 | pip install onnx==$ONNX_VERSION 20 | pip install onnxruntime==$ORT_VERSION 21 | pip install "numpy<2" 22 | 23 | pip install onnxruntime-extensions 24 | pip install "tensorflow-text<=$TF_VERSION" 25 | 26 | pip uninstall -y tensorflow 27 | pip install tensorflow==$TF_VERSION 28 | pip uninstall -y protobuf 29 | pip install "protobuf~=3.20" 30 | 31 | python setup.py install 32 | 33 | echo "----- List all of depdencies:" 34 | pip freeze --all -------------------------------------------------------------------------------- /tf2onnx/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | """tf2onnx package.""" 4 | 5 | __all__ = ["utils", "graph_matcher", "graph", "graph_builder", 6 | "tfonnx", "shape_inference", "schemas", "tf_utils", "tf_loader", "convert"] 7 | 8 | import onnx 9 | from .version import git_version, version as __version__ 10 | from . import verbose_logging as logging 11 | from tf2onnx import tfonnx, utils, graph, graph_builder, graph_matcher, shape_inference, schemas, convert # pylint: disable=wrong-import-order 12 | -------------------------------------------------------------------------------- /tf2onnx/custom_opsets/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | """ custom tf2onnx mapping functions. """ 4 | 5 | from . import ms 6 | from . import onnx_ml 7 | from . import string_ops 8 | -------------------------------------------------------------------------------- /tf2onnx/late_rewriters/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | """tf2onnx.late_rewriters module.""" 4 | 5 | from tf2onnx.late_rewriters.channel_order_rewriters import rewrite_channels_first, rewrite_channels_last 6 | 7 | 8 | __all__ = [ 9 | "rewrite_channels_first", 10 | "rewrite_channels_last", 11 | ] 12 | -------------------------------------------------------------------------------- /tf2onnx/onnx_opset/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | """tf2onnx.onnx_opset module""" 4 | 5 | from . import ( 6 | common, 7 | controlflow, 8 | generator, 9 | logical, 10 | math, 11 | misc, 12 | nn, 13 | quantize, 14 | reduction, 15 | rnn, 16 | signal, 17 | tensor, 18 | traditionalml 19 | ) 20 | -------------------------------------------------------------------------------- /tf2onnx/onnx_opset/misc.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | """ 5 | misc 6 | """ 7 | 8 | import logging 9 | 10 | from tf2onnx.handler import tf_op 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | # pylint: disable=unused-argument,missing-docstring 16 | 17 | @tf_op(["CheckNumerics", "StopGradient"]) 18 | class MoveToIdent: 19 | @classmethod 20 | def version_1(cls, ctx, node, **kwargs): 21 | node.type = "Identity" 22 | if node.inputs[0].is_const(): 23 | # should not remove the identity node if it is output of the graph 24 | if node.output[0] in ctx.outputs: 25 | return 26 | # if identity has a const as input, remove it 27 | input_name = node.input[0] 28 | output_name = node.output[0] 29 | ctx.replace_all_inputs(output_name, input_name) # ops=ctx.get_nodes() 30 | ctx.remove_node(node.name) 31 | 32 | 33 | @tf_op(["Placeholder", "PlaceholderV2", "PlaceholderWithDefault"]) 34 | class DirectOp: 35 | @classmethod 36 | def version_1(cls, ctx, node, **kwargs): 37 | pass 38 | 39 | 40 | @tf_op("NoOp") 41 | class NukeNode: 42 | @classmethod 43 | def version_1(cls, ctx, node, **kwargs): 44 | ctx.remove_node(node.name) 45 | -------------------------------------------------------------------------------- /tf2onnx/onnx_opset/traditionalml.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | """ 5 | traditional ml 6 | """ 7 | 8 | import logging 9 | 10 | logger = logging.getLogger(__name__) 11 | -------------------------------------------------------------------------------- /tf2onnx/rewriter/fused_op_rewriter.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | """ 5 | tf2onnx.rewriter.fused_op_rewriter - rewrite tensorflow _Fused ops from grappler into other tf ops 6 | """ 7 | 8 | 9 | # pylint: disable=missing-docstring 10 | 11 | 12 | def rewrite_fused_ops(g, ops): 13 | for node in ops: 14 | if node.type in ["_FusedConv2D", "_FusedMatMul", "_FusedDepthwiseConv2dNative"]: 15 | op_types = [op.decode() for op in node.get_attr_value("fused_ops")] 16 | extra_inputs = node.input[2:] 17 | g.replace_inputs(node, node.input[:2]) 18 | last_output = node.output[0] 19 | node.type = node.type.replace("_Fused", "") 20 | dtype = g.get_dtype(node.output[0]) 21 | shape = g.get_shape(node.output[0]) 22 | first_node = None 23 | for op in op_types: 24 | num_inputs = {"BiasAdd": 2, "FusedBatchNorm": 5}.get(op, 1 + len(extra_inputs)) 25 | my_inputs = [last_output] + extra_inputs[:num_inputs - 1] 26 | new_node = g.make_node(op, my_inputs, skip_conversion=False, 27 | op_name_scope=node.name, dtypes=[dtype], shapes=[shape]) 28 | last_output = new_node.output[0] 29 | extra_inputs = extra_inputs[num_inputs - 1:] 30 | if first_node is None: 31 | first_node = new_node 32 | 33 | consumers = [n for n in g.find_output_consumers(node.output[0]) if n != first_node] 34 | g.replace_all_inputs(node.output[0], last_output, consumers) 35 | 36 | return g.get_nodes() 37 | -------------------------------------------------------------------------------- /tf2onnx/rewriter/leakyrelu_rewriter.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | """ 5 | tf2onnx.rewriter - rewrite tensorflow subgraph to onnx leakyrelu op 6 | """ 7 | 8 | from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher 9 | 10 | 11 | # pylint: disable=missing-docstring 12 | 13 | 14 | def rewrite_leakyrelu(g, ops): 15 | if g.opset < 6: 16 | return ops 17 | 18 | pattern = \ 19 | OpTypePattern('Maximum', name='max', inputs=[ 20 | OpTypePattern('Mul', name='mul', inputs=[ 21 | OpTypePattern('Const', name='alpha'), 22 | OpTypePattern('*', name='mul_input'), 23 | ]), 24 | OpTypePattern('*', name='max_input'), 25 | ]) 26 | 27 | matcher = GraphMatcher(pattern, allow_reorder=True) 28 | match_results = list(matcher.match_ops(ops)) 29 | for match in match_results: 30 | max_node = match.get_op('max') 31 | mul_node = match.get_op("mul") 32 | 33 | max_input_edge_name = match.get_tensor('max_input') 34 | mul_input_edge_name = match.get_tensor('mul_input') 35 | if max_input_edge_name == mul_input_edge_name: 36 | alpha = match.get_op("alpha").get_tensor_value() 37 | if alpha >= 1: 38 | continue 39 | leakyrelu = g.make_node("LeakyRelu", inputs=[max_input_edge_name], attr={"alpha": alpha}, 40 | shapes=[g.get_shape(max_node.output[0])], dtypes=[g.get_dtype(max_node.output[0])]) 41 | ops.append(leakyrelu) 42 | g.replace_all_inputs(max_node.output[0], leakyrelu.output[0], ops=ops) 43 | to_delete = [max_node, mul_node] 44 | g.safe_remove_nodes(to_delete) 45 | 46 | return ops 47 | -------------------------------------------------------------------------------- /tf2onnx/rewriter/ragged_variant_shape_rewriter.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | """ 5 | tf2onnx.rewriter - RaggedTensorToVariant -> Shape pattern 6 | """ 7 | 8 | import numpy as np 9 | from tf2onnx import utils 10 | from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher 11 | 12 | 13 | # pylint: disable=missing-docstring 14 | 15 | 16 | def rewrite_ragged_variant_shape(g, ops): 17 | pattern1 = \ 18 | OpTypePattern('Shape', name='shape', inputs=[ 19 | OpTypePattern('RaggedTensorToVariant', name='raggedtovariant') 20 | ]) 21 | 22 | pattern_list = [pattern1] 23 | for pattern in pattern_list: 24 | matcher = GraphMatcher(pattern) 25 | match_results = list(matcher.match_ops(ops)) 26 | for match in match_results: 27 | shape = match.get_op('shape') 28 | raggedtovariant = match.get_op('raggedtovariant') 29 | if raggedtovariant.get_attr_value("batched_input") != 1: 30 | continue 31 | if raggedtovariant.get_attr_value("RAGGED_RANK") != 1: 32 | continue 33 | # Shape of batched variant from ragged is same as number of splits minus 1 34 | g.replace_inputs(shape, [raggedtovariant.input[0]]) 35 | np_dtype = utils.map_onnx_to_numpy_type(g.get_dtype(shape.output[0])) 36 | const_one = g.make_const(utils.make_name("const_one"), np.array(1, np_dtype)).output[0] 37 | g.insert_new_node_on_output("Sub", shape.output[0], inputs=[shape.output[0], const_one]) 38 | 39 | return ops 40 | -------------------------------------------------------------------------------- /tf2onnx/rewriter/rnn.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | """ 5 | tf2onnx.rewriter.rnn - lstm support 6 | """ 7 | 8 | import logging 9 | 10 | from tf2onnx.rewriter.bilstm_rewriter import rewrite_bidirectional_lstms 11 | from tf2onnx.rewriter.bigru_rewriter import rewrite_bidirectional_grus 12 | from tf2onnx.rewriter.custom_rnn_rewriter import CustomRnnRewriter 13 | from tf2onnx.rewriter.loop_rewriter import LoopRewriter 14 | from tf2onnx.rewriter.lstm_rewriter import LSTMRewriter 15 | from tf2onnx.rewriter.gru_rewriter import GRUUnitRewriter 16 | 17 | # pylint: disable=invalid-name,unused-argument,missing-docstring 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def rewrite_single_direction_lstm(g, ops): 24 | r = LSTMRewriter(g) 25 | return r.run() 26 | 27 | 28 | def rewrite_bi_direction_lstm(g, ops): 29 | return rewrite_bidirectional_lstms(g, ops) 30 | 31 | 32 | def rewrite_single_direction_gru(g, ops): 33 | r = GRUUnitRewriter(g) 34 | return r.run() 35 | 36 | 37 | def rewrite_bi_direction_gru(g, ops): 38 | return rewrite_bidirectional_grus(g, ops) 39 | 40 | 41 | def rewrite_custom_rnn_cell(g, ops): 42 | return CustomRnnRewriter(g).run() 43 | 44 | 45 | def rewrite_generic_loop(g, ops): 46 | return LoopRewriter(g).run() 47 | -------------------------------------------------------------------------------- /tf2onnx/rewriter/thresholded_relu_rewriter.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | """ 5 | tf2onnx.rewriter - rewrite tensorflow subgraph to onnx ThresholdedRelu op 6 | """ 7 | 8 | from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher 9 | 10 | 11 | # pylint: disable=missing-docstring 12 | 13 | 14 | def rewrite_thresholded_relu(g, ops): 15 | if g.opset < 10: 16 | return ops 17 | 18 | pattern = \ 19 | OpTypePattern('Mul', name='mul', inputs=[ 20 | OpTypePattern('Cast', name='cast', inputs=[ 21 | OpTypePattern('Greater', name='greater', inputs=[ 22 | OpTypePattern('*', name='greater_input'), 23 | OpTypePattern('Const', name='theta') 24 | ]) 25 | ]), 26 | OpTypePattern('*', name='mul_input') 27 | ]) 28 | matcher = GraphMatcher(pattern, allow_reorder=True) 29 | match_results = list(matcher.match_ops(ops)) 30 | 31 | for match in match_results: 32 | mul_node = match.get_op("mul") 33 | cast_node = match.get_op('cast') 34 | 35 | greater_input_edge_name = match.get_tensor('greater_input') 36 | mul_input_edge_name = match.get_tensor('mul_input') 37 | if greater_input_edge_name == mul_input_edge_name: 38 | theta = match.get_op('theta').get_tensor_value() 39 | thresholded_relu = g.make_node("ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta}, 40 | shapes=[g.get_shape(mul_node.output[0])], 41 | dtypes=[g.get_dtype(mul_node.output[0])]) 42 | g.replace_all_inputs(mul_node.output[0], thresholded_relu.output[0], ops=ops) 43 | to_delete = [cast_node, mul_node] 44 | g.safe_remove_nodes(to_delete) 45 | return ops 46 | -------------------------------------------------------------------------------- /tf2onnx/rewriter/transpose_rewriter.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | """ 5 | tf2onnx.rewriter - rewrite tensorflow transpose op 6 | """ 7 | 8 | from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher 9 | 10 | 11 | # pylint: disable=missing-docstring 12 | 13 | 14 | def rewrite_transpose(g, ops): 15 | pattern = \ 16 | OpTypePattern('Transpose', name='output', inputs=[ 17 | OpTypePattern(None), 18 | OpTypePattern('Sub', inputs=[ 19 | OpTypePattern('Sub', inputs=["*", "*"]), 20 | OpTypePattern('Range', inputs=["*", "*", "*"]), 21 | ]), 22 | ]) 23 | 24 | matcher = GraphMatcher(pattern) 25 | match_results = list(matcher.match_ops(ops)) 26 | for match in match_results: 27 | output = match.get_op('output') 28 | shape = g.get_shape(output.input[0]) 29 | dims = range(len(shape) - 1, -1, -1) 30 | output.set_attr("perm", dims) 31 | g.remove_input(output, output.input[1], 1) 32 | to_delete = [n for n in match.get_nodes() if n != output] 33 | g.safe_remove_nodes(to_delete) 34 | return ops 35 | -------------------------------------------------------------------------------- /tf2onnx/tflite/ATan2Options.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class ATan2Options(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = ATan2Options() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsATan2Options(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def ATan2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # ATan2Options 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def ATan2OptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def ATan2OptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/AbsOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class AbsOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = AbsOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsAbsOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def AbsOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # AbsOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def AbsOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def AbsOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/ActivationFunctionType.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | class ActivationFunctionType(object): 8 | NONE = 0 9 | RELU = 1 10 | RELU_N1_TO_1 = 2 11 | RELU6 = 3 12 | TANH = 4 13 | SIGN_BIT = 5 14 | 15 | -------------------------------------------------------------------------------- /tf2onnx/tflite/AddNOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class AddNOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = AddNOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsAddNOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def AddNOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # AddNOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def AddNOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def AddNOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/ArgMaxOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class ArgMaxOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = ArgMaxOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsArgMaxOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def ArgMaxOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # ArgMaxOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | # ArgMaxOptions 34 | def OutputType(self): 35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 36 | if o != 0: 37 | return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) 38 | return 0 39 | 40 | def Start(builder): builder.StartObject(1) 41 | def ArgMaxOptionsStart(builder): 42 | """This method is deprecated. Please switch to Start.""" 43 | return Start(builder) 44 | def AddOutputType(builder, outputType): builder.PrependInt8Slot(0, outputType, 0) 45 | def ArgMaxOptionsAddOutputType(builder, outputType): 46 | """This method is deprecated. Please switch to AddOutputType.""" 47 | return AddOutputType(builder, outputType) 48 | def End(builder): return builder.EndObject() 49 | def ArgMaxOptionsEnd(builder): 50 | """This method is deprecated. Please switch to End.""" 51 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/ArgMinOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class ArgMinOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = ArgMinOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsArgMinOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def ArgMinOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # ArgMinOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | # ArgMinOptions 34 | def OutputType(self): 35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 36 | if o != 0: 37 | return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) 38 | return 0 39 | 40 | def Start(builder): builder.StartObject(1) 41 | def ArgMinOptionsStart(builder): 42 | """This method is deprecated. Please switch to Start.""" 43 | return Start(builder) 44 | def AddOutputType(builder, outputType): builder.PrependInt8Slot(0, outputType, 0) 45 | def ArgMinOptionsAddOutputType(builder, outputType): 46 | """This method is deprecated. Please switch to AddOutputType.""" 47 | return AddOutputType(builder, outputType) 48 | def End(builder): return builder.EndObject() 49 | def ArgMinOptionsEnd(builder): 50 | """This method is deprecated. Please switch to End.""" 51 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/AssignVariableOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class AssignVariableOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = AssignVariableOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsAssignVariableOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def AssignVariableOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # AssignVariableOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def AssignVariableOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def AssignVariableOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/BatchToSpaceNDOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class BatchToSpaceNDOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = BatchToSpaceNDOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsBatchToSpaceNDOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def BatchToSpaceNDOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # BatchToSpaceNDOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def BatchToSpaceNDOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def BatchToSpaceNDOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/BitcastOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class BitcastOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = BitcastOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsBitcastOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def BitcastOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # BitcastOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def BitcastOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def BitcastOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/BitwiseXorOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class BitwiseXorOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = BitwiseXorOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsBitwiseXorOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def BitwiseXorOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # BitwiseXorOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def BitwiseXorOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def BitwiseXorOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/BroadcastToOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class BroadcastToOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = BroadcastToOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsBroadcastToOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def BroadcastToOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # BroadcastToOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def BroadcastToOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def BroadcastToOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/CallOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class CallOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = CallOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsCallOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def CallOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # CallOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | # CallOptions 34 | def Subgraph(self): 35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 36 | if o != 0: 37 | return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos) 38 | return 0 39 | 40 | def Start(builder): builder.StartObject(1) 41 | def CallOptionsStart(builder): 42 | """This method is deprecated. Please switch to Start.""" 43 | return Start(builder) 44 | def AddSubgraph(builder, subgraph): builder.PrependUint32Slot(0, subgraph, 0) 45 | def CallOptionsAddSubgraph(builder, subgraph): 46 | """This method is deprecated. Please switch to AddSubgraph.""" 47 | return AddSubgraph(builder, subgraph) 48 | def End(builder): return builder.EndObject() 49 | def CallOptionsEnd(builder): 50 | """This method is deprecated. Please switch to End.""" 51 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/CombinerType.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | class CombinerType(object): 8 | SUM = 0 9 | MEAN = 1 10 | SQRTN = 2 11 | 12 | -------------------------------------------------------------------------------- /tf2onnx/tflite/CosOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class CosOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = CosOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsCosOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def CosOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # CosOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def CosOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def CosOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/CustomOptionsFormat.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | class CustomOptionsFormat(object): 8 | FLEXBUFFERS = 0 9 | 10 | -------------------------------------------------------------------------------- /tf2onnx/tflite/DensifyOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class DensifyOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = DensifyOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsDensifyOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def DensifyOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # DensifyOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def DensifyOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def DensifyOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/DequantizeOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class DequantizeOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = DequantizeOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsDequantizeOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def DequantizeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # DequantizeOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def DequantizeOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def DequantizeOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/DimensionType.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | class DimensionType(object): 8 | DENSE = 0 9 | SPARSE_CSR = 1 10 | 11 | -------------------------------------------------------------------------------- /tf2onnx/tflite/DynamicUpdateSliceOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class DynamicUpdateSliceOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = DynamicUpdateSliceOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsDynamicUpdateSliceOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def DynamicUpdateSliceOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # DynamicUpdateSliceOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def DynamicUpdateSliceOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def DynamicUpdateSliceOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/EqualOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class EqualOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = EqualOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsEqualOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def EqualOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # EqualOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def EqualOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def EqualOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/ExpOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class ExpOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = ExpOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsExpOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def ExpOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # ExpOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def ExpOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def ExpOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/ExpandDimsOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class ExpandDimsOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = ExpandDimsOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsExpandDimsOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def ExpandDimsOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # ExpandDimsOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def ExpandDimsOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def ExpandDimsOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/FillOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class FillOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = FillOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsFillOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def FillOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # FillOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def FillOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def FillOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/FloorDivOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class FloorDivOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = FloorDivOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsFloorDivOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def FloorDivOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # FloorDivOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def FloorDivOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def FloorDivOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/FloorModOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class FloorModOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = FloorModOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsFloorModOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def FloorModOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # FloorModOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def FloorModOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def FloorModOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/FullyConnectedOptionsWeightsFormat.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | class FullyConnectedOptionsWeightsFormat(object): 8 | DEFAULT = 0 9 | SHUFFLED4x16INT8 = 1 10 | 11 | -------------------------------------------------------------------------------- /tf2onnx/tflite/GatherNdOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class GatherNdOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = GatherNdOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsGatherNdOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def GatherNdOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # GatherNdOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def GatherNdOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def GatherNdOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/GeluOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class GeluOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = GeluOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsGeluOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def GeluOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # GeluOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | # GeluOptions 34 | def Approximate(self): 35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 36 | if o != 0: 37 | return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) 38 | return False 39 | 40 | def Start(builder): builder.StartObject(1) 41 | def GeluOptionsStart(builder): 42 | """This method is deprecated. Please switch to Start.""" 43 | return Start(builder) 44 | def AddApproximate(builder, approximate): builder.PrependBoolSlot(0, approximate, 0) 45 | def GeluOptionsAddApproximate(builder, approximate): 46 | """This method is deprecated. Please switch to AddApproximate.""" 47 | return AddApproximate(builder, approximate) 48 | def End(builder): return builder.EndObject() 49 | def GeluOptionsEnd(builder): 50 | """This method is deprecated. Please switch to End.""" 51 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/GreaterEqualOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class GreaterEqualOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = GreaterEqualOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsGreaterEqualOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def GreaterEqualOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # GreaterEqualOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def GreaterEqualOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def GreaterEqualOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/GreaterOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class GreaterOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = GreaterOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsGreaterOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def GreaterOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # GreaterOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def GreaterOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def GreaterOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/HardSwishOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class HardSwishOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = HardSwishOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsHardSwishOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def HardSwishOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # HardSwishOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def HardSwishOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def HardSwishOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/HashtableFindOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class HashtableFindOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = HashtableFindOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsHashtableFindOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def HashtableFindOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # HashtableFindOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def HashtableFindOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def HashtableFindOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/HashtableImportOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class HashtableImportOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = HashtableImportOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsHashtableImportOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def HashtableImportOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # HashtableImportOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def HashtableImportOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def HashtableImportOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/HashtableSizeOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class HashtableSizeOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = HashtableSizeOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsHashtableSizeOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def HashtableSizeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # HashtableSizeOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def HashtableSizeOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def HashtableSizeOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/LSHProjectionType.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | class LSHProjectionType(object): 8 | UNKNOWN = 0 9 | SPARSE = 1 10 | DENSE = 2 11 | 12 | -------------------------------------------------------------------------------- /tf2onnx/tflite/LSTMKernelType.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | class LSTMKernelType(object): 8 | FULL = 0 9 | BASIC = 1 10 | 11 | -------------------------------------------------------------------------------- /tf2onnx/tflite/LeakyReluOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class LeakyReluOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = LeakyReluOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsLeakyReluOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def LeakyReluOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # LeakyReluOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | # LeakyReluOptions 34 | def Alpha(self): 35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 36 | if o != 0: 37 | return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) 38 | return 0.0 39 | 40 | def Start(builder): builder.StartObject(1) 41 | def LeakyReluOptionsStart(builder): 42 | """This method is deprecated. Please switch to Start.""" 43 | return Start(builder) 44 | def AddAlpha(builder, alpha): builder.PrependFloat32Slot(0, alpha, 0.0) 45 | def LeakyReluOptionsAddAlpha(builder, alpha): 46 | """This method is deprecated. Please switch to AddAlpha.""" 47 | return AddAlpha(builder, alpha) 48 | def End(builder): return builder.EndObject() 49 | def LeakyReluOptionsEnd(builder): 50 | """This method is deprecated. Please switch to End.""" 51 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/LessEqualOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class LessEqualOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = LessEqualOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsLessEqualOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def LessEqualOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # LessEqualOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def LessEqualOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def LessEqualOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/LessOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class LessOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = LessOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsLessOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def LessOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # LessOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def LessOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def LessOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/LogSoftmaxOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class LogSoftmaxOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = LogSoftmaxOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsLogSoftmaxOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def LogSoftmaxOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # LogSoftmaxOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def LogSoftmaxOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def LogSoftmaxOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/LogicalAndOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class LogicalAndOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = LogicalAndOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsLogicalAndOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def LogicalAndOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # LogicalAndOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def LogicalAndOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def LogicalAndOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/LogicalNotOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class LogicalNotOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = LogicalNotOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsLogicalNotOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def LogicalNotOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # LogicalNotOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def LogicalNotOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def LogicalNotOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/LogicalOrOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class LogicalOrOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = LogicalOrOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsLogicalOrOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def LogicalOrOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # LogicalOrOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def LogicalOrOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def LogicalOrOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/MatrixDiagOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class MatrixDiagOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = MatrixDiagOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsMatrixDiagOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def MatrixDiagOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # MatrixDiagOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def MatrixDiagOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def MatrixDiagOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/MatrixSetDiagOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class MatrixSetDiagOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = MatrixSetDiagOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsMatrixSetDiagOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def MatrixSetDiagOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # MatrixSetDiagOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def MatrixSetDiagOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def MatrixSetDiagOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/MaximumMinimumOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class MaximumMinimumOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = MaximumMinimumOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsMaximumMinimumOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def MaximumMinimumOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # MaximumMinimumOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def MaximumMinimumOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def MaximumMinimumOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/MirrorPadMode.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | class MirrorPadMode(object): 8 | REFLECT = 0 9 | SYMMETRIC = 1 10 | 11 | -------------------------------------------------------------------------------- /tf2onnx/tflite/MirrorPadOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class MirrorPadOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = MirrorPadOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsMirrorPadOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def MirrorPadOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # MirrorPadOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | # MirrorPadOptions 34 | def Mode(self): 35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 36 | if o != 0: 37 | return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) 38 | return 0 39 | 40 | def Start(builder): builder.StartObject(1) 41 | def MirrorPadOptionsStart(builder): 42 | """This method is deprecated. Please switch to Start.""" 43 | return Start(builder) 44 | def AddMode(builder, mode): builder.PrependInt8Slot(0, mode, 0) 45 | def MirrorPadOptionsAddMode(builder, mode): 46 | """This method is deprecated. Please switch to AddMode.""" 47 | return AddMode(builder, mode) 48 | def End(builder): return builder.EndObject() 49 | def MirrorPadOptionsEnd(builder): 50 | """This method is deprecated. Please switch to End.""" 51 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/NegOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class NegOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = NegOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsNegOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def NegOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # NegOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def NegOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def NegOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/NonMaxSuppressionV4Options.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class NonMaxSuppressionV4Options(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = NonMaxSuppressionV4Options() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsNonMaxSuppressionV4Options(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def NonMaxSuppressionV4OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # NonMaxSuppressionV4Options 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def NonMaxSuppressionV4OptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def NonMaxSuppressionV4OptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/NonMaxSuppressionV5Options.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class NonMaxSuppressionV5Options(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = NonMaxSuppressionV5Options() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsNonMaxSuppressionV5Options(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def NonMaxSuppressionV5OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # NonMaxSuppressionV5Options 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def NonMaxSuppressionV5OptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def NonMaxSuppressionV5OptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/NotEqualOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class NotEqualOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = NotEqualOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsNotEqualOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def NotEqualOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # NotEqualOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def NotEqualOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def NotEqualOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/OneHotOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class OneHotOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = OneHotOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsOneHotOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def OneHotOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # OneHotOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | # OneHotOptions 34 | def Axis(self): 35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 36 | if o != 0: 37 | return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) 38 | return 0 39 | 40 | def Start(builder): builder.StartObject(1) 41 | def OneHotOptionsStart(builder): 42 | """This method is deprecated. Please switch to Start.""" 43 | return Start(builder) 44 | def AddAxis(builder, axis): builder.PrependInt32Slot(0, axis, 0) 45 | def OneHotOptionsAddAxis(builder, axis): 46 | """This method is deprecated. Please switch to AddAxis.""" 47 | return AddAxis(builder, axis) 48 | def End(builder): return builder.EndObject() 49 | def OneHotOptionsEnd(builder): 50 | """This method is deprecated. Please switch to End.""" 51 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/PadOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class PadOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = PadOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsPadOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def PadOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # PadOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def PadOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def PadOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/PadV2Options.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class PadV2Options(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = PadV2Options() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsPadV2Options(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def PadV2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # PadV2Options 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def PadV2OptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def PadV2OptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/Padding.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | class Padding(object): 8 | SAME = 0 9 | VALID = 1 10 | 11 | -------------------------------------------------------------------------------- /tf2onnx/tflite/PowOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class PowOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = PowOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsPowOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def PowOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # PowOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def PowOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def PowOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/QuantizationDetails.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | class QuantizationDetails(object): 8 | NONE = 0 9 | CustomQuantization = 1 10 | 11 | -------------------------------------------------------------------------------- /tf2onnx/tflite/QuantizeOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class QuantizeOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = QuantizeOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsQuantizeOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def QuantizeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # QuantizeOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def QuantizeOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def QuantizeOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/RangeOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class RangeOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = RangeOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsRangeOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def RangeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # RangeOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def RangeOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def RangeOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/RankOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class RankOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = RankOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsRankOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def RankOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # RankOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def RankOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def RankOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/ReadVariableOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class ReadVariableOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = ReadVariableOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsReadVariableOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def ReadVariableOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # ReadVariableOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def ReadVariableOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def ReadVariableOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/ReducerOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class ReducerOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = ReducerOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsReducerOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def ReducerOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # ReducerOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | # ReducerOptions 34 | def KeepDims(self): 35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 36 | if o != 0: 37 | return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) 38 | return False 39 | 40 | def Start(builder): builder.StartObject(1) 41 | def ReducerOptionsStart(builder): 42 | """This method is deprecated. Please switch to Start.""" 43 | return Start(builder) 44 | def AddKeepDims(builder, keepDims): builder.PrependBoolSlot(0, keepDims, 0) 45 | def ReducerOptionsAddKeepDims(builder, keepDims): 46 | """This method is deprecated. Please switch to AddKeepDims.""" 47 | return AddKeepDims(builder, keepDims) 48 | def End(builder): return builder.EndObject() 49 | def ReducerOptionsEnd(builder): 50 | """This method is deprecated. Please switch to End.""" 51 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/ReverseV2Options.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class ReverseV2Options(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = ReverseV2Options() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsReverseV2Options(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def ReverseV2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # ReverseV2Options 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def ReverseV2OptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def ReverseV2OptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/Rfft2dOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class Rfft2dOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = Rfft2dOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsRfft2dOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def Rfft2dOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # Rfft2dOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def Rfft2dOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def Rfft2dOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/RightShiftOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class RightShiftOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = RightShiftOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsRightShiftOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def RightShiftOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # RightShiftOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def RightShiftOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def RightShiftOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/ScatterNdOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class ScatterNdOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = ScatterNdOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsScatterNdOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def ScatterNdOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # ScatterNdOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def ScatterNdOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def ScatterNdOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/SegmentSumOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class SegmentSumOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = SegmentSumOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsSegmentSumOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def SegmentSumOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # SegmentSumOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def SegmentSumOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def SegmentSumOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/SelectOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class SelectOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = SelectOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsSelectOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def SelectOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # SelectOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def SelectOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def SelectOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/SelectV2Options.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class SelectV2Options(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = SelectV2Options() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsSelectV2Options(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def SelectV2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # SelectV2Options 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def SelectV2OptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def SelectV2OptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/ShapeOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class ShapeOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = ShapeOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsShapeOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def ShapeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # ShapeOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | # ShapeOptions 34 | def OutType(self): 35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 36 | if o != 0: 37 | return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) 38 | return 0 39 | 40 | def Start(builder): builder.StartObject(1) 41 | def ShapeOptionsStart(builder): 42 | """This method is deprecated. Please switch to Start.""" 43 | return Start(builder) 44 | def AddOutType(builder, outType): builder.PrependInt8Slot(0, outType, 0) 45 | def ShapeOptionsAddOutType(builder, outType): 46 | """This method is deprecated. Please switch to AddOutType.""" 47 | return AddOutType(builder, outType) 48 | def End(builder): return builder.EndObject() 49 | def ShapeOptionsEnd(builder): 50 | """This method is deprecated. Please switch to End.""" 51 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/SignOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class SignOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = SignOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsSignOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def SignOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # SignOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def SignOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def SignOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/SliceOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class SliceOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = SliceOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsSliceOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def SliceOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # SliceOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def SliceOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def SliceOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/SoftmaxOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class SoftmaxOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = SoftmaxOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsSoftmaxOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def SoftmaxOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # SoftmaxOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | # SoftmaxOptions 34 | def Beta(self): 35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 36 | if o != 0: 37 | return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) 38 | return 0.0 39 | 40 | def Start(builder): builder.StartObject(1) 41 | def SoftmaxOptionsStart(builder): 42 | """This method is deprecated. Please switch to Start.""" 43 | return Start(builder) 44 | def AddBeta(builder, beta): builder.PrependFloat32Slot(0, beta, 0.0) 45 | def SoftmaxOptionsAddBeta(builder, beta): 46 | """This method is deprecated. Please switch to AddBeta.""" 47 | return AddBeta(builder, beta) 48 | def End(builder): return builder.EndObject() 49 | def SoftmaxOptionsEnd(builder): 50 | """This method is deprecated. Please switch to End.""" 51 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/SpaceToBatchNDOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class SpaceToBatchNDOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = SpaceToBatchNDOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsSpaceToBatchNDOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def SpaceToBatchNDOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # SpaceToBatchNDOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def SpaceToBatchNDOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def SpaceToBatchNDOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/SparseIndexVector.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | class SparseIndexVector(object): 8 | NONE = 0 9 | Int32Vector = 1 10 | Uint16Vector = 2 11 | Uint8Vector = 3 12 | 13 | -------------------------------------------------------------------------------- /tf2onnx/tflite/SplitOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class SplitOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = SplitOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsSplitOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def SplitOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # SplitOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | # SplitOptions 34 | def NumSplits(self): 35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 36 | if o != 0: 37 | return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) 38 | return 0 39 | 40 | def Start(builder): builder.StartObject(1) 41 | def SplitOptionsStart(builder): 42 | """This method is deprecated. Please switch to Start.""" 43 | return Start(builder) 44 | def AddNumSplits(builder, numSplits): builder.PrependInt32Slot(0, numSplits, 0) 45 | def SplitOptionsAddNumSplits(builder, numSplits): 46 | """This method is deprecated. Please switch to AddNumSplits.""" 47 | return AddNumSplits(builder, numSplits) 48 | def End(builder): return builder.EndObject() 49 | def SplitOptionsEnd(builder): 50 | """This method is deprecated. Please switch to End.""" 51 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/SplitVOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class SplitVOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = SplitVOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsSplitVOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def SplitVOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # SplitVOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | # SplitVOptions 34 | def NumSplits(self): 35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 36 | if o != 0: 37 | return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) 38 | return 0 39 | 40 | def Start(builder): builder.StartObject(1) 41 | def SplitVOptionsStart(builder): 42 | """This method is deprecated. Please switch to Start.""" 43 | return Start(builder) 44 | def AddNumSplits(builder, numSplits): builder.PrependInt32Slot(0, numSplits, 0) 45 | def SplitVOptionsAddNumSplits(builder, numSplits): 46 | """This method is deprecated. Please switch to AddNumSplits.""" 47 | return AddNumSplits(builder, numSplits) 48 | def End(builder): return builder.EndObject() 49 | def SplitVOptionsEnd(builder): 50 | """This method is deprecated. Please switch to End.""" 51 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/SquareOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class SquareOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = SquareOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsSquareOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def SquareOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # SquareOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def SquareOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def SquareOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/SquaredDifferenceOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class SquaredDifferenceOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = SquaredDifferenceOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsSquaredDifferenceOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def SquaredDifferenceOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # SquaredDifferenceOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def SquaredDifferenceOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def SquaredDifferenceOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/TensorType.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | class TensorType(object): 8 | FLOAT32 = 0 9 | FLOAT16 = 1 10 | INT32 = 2 11 | UINT8 = 3 12 | INT64 = 4 13 | STRING = 5 14 | BOOL = 6 15 | INT16 = 7 16 | COMPLEX64 = 8 17 | INT8 = 9 18 | FLOAT64 = 10 19 | COMPLEX128 = 11 20 | UINT64 = 12 21 | RESOURCE = 13 22 | VARIANT = 14 23 | UINT32 = 15 24 | UINT16 = 16 25 | INT4 = 17 26 | 27 | -------------------------------------------------------------------------------- /tf2onnx/tflite/TileOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class TileOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = TileOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsTileOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def TileOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # TileOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def TileOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def TileOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/TopKV2Options.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class TopKV2Options(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = TopKV2Options() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsTopKV2Options(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def TopKV2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # TopKV2Options 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def TopKV2OptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def TopKV2OptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/TransposeOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class TransposeOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = TransposeOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsTransposeOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def TransposeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # TransposeOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def TransposeOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def TransposeOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/UniqueOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class UniqueOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = UniqueOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsUniqueOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def UniqueOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # UniqueOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | # UniqueOptions 34 | def IdxOutType(self): 35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) 36 | if o != 0: 37 | return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) 38 | return 2 39 | 40 | def Start(builder): builder.StartObject(1) 41 | def UniqueOptionsStart(builder): 42 | """This method is deprecated. Please switch to Start.""" 43 | return Start(builder) 44 | def AddIdxOutType(builder, idxOutType): builder.PrependInt8Slot(0, idxOutType, 2) 45 | def UniqueOptionsAddIdxOutType(builder, idxOutType): 46 | """This method is deprecated. Please switch to AddIdxOutType.""" 47 | return AddIdxOutType(builder, idxOutType) 48 | def End(builder): return builder.EndObject() 49 | def UniqueOptionsEnd(builder): 50 | """This method is deprecated. Please switch to End.""" 51 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/UnsortedSegmentMaxOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class UnsortedSegmentMaxOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = UnsortedSegmentMaxOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsUnsortedSegmentMaxOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def UnsortedSegmentMaxOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # UnsortedSegmentMaxOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def UnsortedSegmentMaxOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def UnsortedSegmentMaxOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/UnsortedSegmentMinOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class UnsortedSegmentMinOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = UnsortedSegmentMinOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsUnsortedSegmentMinOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def UnsortedSegmentMinOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # UnsortedSegmentMinOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def UnsortedSegmentMinOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def UnsortedSegmentMinOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/UnsortedSegmentProdOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class UnsortedSegmentProdOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = UnsortedSegmentProdOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsUnsortedSegmentProdOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def UnsortedSegmentProdOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # UnsortedSegmentProdOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def UnsortedSegmentProdOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def UnsortedSegmentProdOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/UnsortedSegmentSumOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class UnsortedSegmentSumOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = UnsortedSegmentSumOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsUnsortedSegmentSumOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def UnsortedSegmentSumOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # UnsortedSegmentSumOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def UnsortedSegmentSumOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def UnsortedSegmentSumOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/WhereOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class WhereOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = WhereOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsWhereOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def WhereOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # WhereOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def WhereOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def WhereOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/ZerosLikeOptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # automatically generated by the FlatBuffers compiler, do not modify 4 | 5 | # namespace: tflite 6 | 7 | import flatbuffers 8 | from flatbuffers.compat import import_numpy 9 | np = import_numpy() 10 | 11 | class ZerosLikeOptions(object): 12 | __slots__ = ['_tab'] 13 | 14 | @classmethod 15 | def GetRootAs(cls, buf, offset=0): 16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) 17 | x = ZerosLikeOptions() 18 | x.Init(buf, n + offset) 19 | return x 20 | 21 | @classmethod 22 | def GetRootAsZerosLikeOptions(cls, buf, offset=0): 23 | """This method is deprecated. Please switch to GetRootAs.""" 24 | return cls.GetRootAs(buf, offset) 25 | @classmethod 26 | def ZerosLikeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False): 27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed) 28 | 29 | # ZerosLikeOptions 30 | def Init(self, buf, pos): 31 | self._tab = flatbuffers.table.Table(buf, pos) 32 | 33 | def Start(builder): builder.StartObject(0) 34 | def ZerosLikeOptionsStart(builder): 35 | """This method is deprecated. Please switch to Start.""" 36 | return Start(builder) 37 | def End(builder): return builder.EndObject() 38 | def ZerosLikeOptionsEnd(builder): 39 | """This method is deprecated. Please switch to End.""" 40 | return End(builder) -------------------------------------------------------------------------------- /tf2onnx/tflite/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | -------------------------------------------------------------------------------- /tf2onnx/tflite_handlers/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | """tf2onnx.tflite_handlers module""" 4 | 5 | from . import ( 6 | tfl_math, 7 | tfl_nn, 8 | tfl_controlflow, 9 | tfl_direct, 10 | tfl_tensor, 11 | tfl_postprocess, 12 | ) 13 | -------------------------------------------------------------------------------- /tf2onnx/tflite_rewriters/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | """tf2onnx.tflite_rewriters module""" 4 | 5 | from tf2onnx.tflite_rewriters.tfl_scan_output_rewriter import rewrite_tfl_scan_outputs 6 | from tf2onnx.tflite_rewriters.tfl_qdq_rewriter import rewrite_tfl_qdq 7 | from tf2onnx.tflite_rewriters.tfl_select_zero_rewriter import rewrite_tfl_select_zero 8 | from tf2onnx.tflite_rewriters.tfl_rfft_rewriter import rewrite_tfl_rfft 9 | 10 | __all__ = [ 11 | "rewrite_tfl_scan_outputs", 12 | "rewrite_tfl_qdq", 13 | "rewrite_tfl_select_zero", 14 | "rewrite_tfl_rfft", 15 | ] 16 | -------------------------------------------------------------------------------- /tf2onnx/version.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | version = '1.16.1' 5 | git_version = '13bab8a91e17ccd87541b2f361ab60e8e38359d3' 6 | -------------------------------------------------------------------------------- /tools/example.bat: -------------------------------------------------------------------------------- 1 | rem SPDX-License-Identifier: Apache-2.0 2 | 3 | set frozen=tests/models/fc-layers/frozen.pb 4 | set output=/tmp/model.onnx 5 | set input_names=X:0 6 | set output_names=output:0 7 | set output_names1=output 8 | 9 | python tf2onnx\convert.py --input %frozen% --inputs %input_names% --outputs %output_names% --output %output% %1 %2 10 | -------------------------------------------------------------------------------- /tools/example.sh: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | frozen=tests/models/fc-layers/frozen.pb 4 | output=/tmp/model.onnx 5 | input_names=X:0 6 | output_names=output:0 7 | output_names1=output 8 | 9 | python -m tf2onnx.convert --input $frozen --inputs $input_names --outputs $output_names --output $output $@ 10 | -------------------------------------------------------------------------------- /tools/gen_tflite_flatbuffer.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | """ 4 | Generates the files in tf2onnx/tflite used for parsing tflite flatbuffer 5 | WARNING: this script will overwrite all files in tf2onnx/tflite 6 | Before running, download the flatc executable from https://github.com/google/flatbuffers/releases and add it to PATH 7 | This script only tested on Windows 8 | """ 9 | 10 | import os 11 | import subprocess 12 | import tempfile 13 | import wget 14 | 15 | SCHEMA_URL = "https://github.com/tensorflow/tensorflow/raw/master/tensorflow/lite/schema/schema.fbs" 16 | 17 | FILE_HEADER = "# SPDX-License-Identifier: Apache-2.0\n\n" 18 | 19 | def main(): 20 | tmpdir = os.path.join(tempfile.gettempdir(), "tflite_flatbuffer") 21 | os.makedirs(tmpdir, exist_ok=True) 22 | repodir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 23 | dstpath = os.path.join(repodir, "tf2onnx", "tflite") 24 | os.makedirs(dstpath, exist_ok=True) 25 | # Remove existing flatbuffer bindings 26 | for file in os.listdir(dstpath): 27 | os.remove(os.path.join(dstpath, file)) 28 | schema_path = os.path.join(tmpdir, "schema.fbs") 29 | # Download schema file 30 | if os.path.exists(schema_path): 31 | os.remove(schema_path) 32 | wget.download(SCHEMA_URL, schema_path) 33 | print() 34 | # Generate flatbuffer code 35 | subprocess.call(["flatc", "-p", "-o", tmpdir, schema_path]) 36 | tmp_result_path = os.path.join(tmpdir, "tflite") 37 | for file in os.listdir(tmp_result_path): 38 | with open(os.path.join(tmp_result_path, file), "rt") as f: 39 | content = f.read() 40 | content = FILE_HEADER + content.replace("from tflite.", "from tf2onnx.tflite.") 41 | with open(os.path.join(dstpath, file), "wt") as f: 42 | f.write(content) 43 | print("Generated", file) 44 | 45 | main() 46 | -------------------------------------------------------------------------------- /tools/graphtool.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | 4 | """ 5 | simple tool to convert .meta to .pb. 6 | """ 7 | 8 | import argparse 9 | import os 10 | 11 | import tensorflow as tf 12 | 13 | 14 | def get_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("infile", nargs="*", help="event files") 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | def to_pb(src): 22 | """Convert .meta to .pb.""" 23 | _ = tf.train.import_meta_graph(src) 24 | graph = tf.get_default_graph() 25 | 26 | fname = os.path.basename(src)[:-5] 27 | tf.train.write_graph(graph, os.path.dirname(src), fname + '.pb', as_text=False) 28 | tf.train.write_graph(graph, os.path.dirname(src), fname + '.pbtxt', as_text=True) 29 | 30 | writer = tf.summary.FileWriter(os.path.dirname(src)) 31 | writer.add_graph(graph) 32 | writer.close() 33 | 34 | 35 | def main(): 36 | args = get_args() 37 | for src in args.infile: 38 | to_pb(src) 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # tf2onnx tutorials 4 | 5 | The following tutorials show how to convert various models to ONNX. 6 | 7 | ## Image Classifiers 8 | [efficientnet-edge](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/efficientnet-edge.ipynb) 9 | 10 | [efficientnet-lite](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/efficientnet-lite.ipynb) 11 | 12 | [keras-resnet50](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/keras-resnet50.ipynb) - shows how to convert a keras model via python api 13 | 14 | ## Object Detectors 15 | [ssd-mobilenet](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/ConvertingSSDMobilenetToONNX.ipynb) 16 | 17 | [efficientdet](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/efficientdet.ipynb) 18 | 19 | [mobiledet](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/mobiledet-tflite.ipynb) - shows how to convert a tflite model 20 | 21 | ## Nlp 22 | [Huggingface Bert Example](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/huggingface-bert.ipynb) 23 | 24 | [The original Tensorflow Bert model](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/BertTutorial.ipynb) - depreciated, use huggingface 25 | --------------------------------------------------------------------------------