├── .github
├── ISSUE_TEMPLATE
│ ├── bug-performance-issue.md
│ ├── feature_request.md
│ ├── operators.md
│ └── question.md
├── actions
│ ├── keras_application_test
│ │ └── action.yml
│ ├── keras_unit_test
│ │ └── action.yml
│ ├── pretrained_model_test
│ │ └── action.yml
│ └── unit_test
│ │ └── action.yml
└── workflows
│ ├── keras_application_test_ci.yml
│ ├── keras_unit_test_ci.yml
│ ├── pretrained_model_test_ci.yml
│ ├── pylint.yml
│ └── unit_test_ci.yml
├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── Troubleshooting.md
├── VERSION_NUMBER
├── build.bat
├── build.sh
├── examples
├── benchmark_tfmodel_ort.py
├── call_converter_via_python.py
├── custom_op_via_python.py
├── end2end_tfhub.py
├── end2end_tfkeras.py
├── getting_started.py
├── rnn_tips.md
└── tf_custom_op
│ ├── custom_op.md
│ ├── double_and_add_one.cc
│ ├── double_and_add_one.so
│ └── double_and_add_one_custom_op.py
├── setup.cfg
├── setup.py
├── support_status.md
├── tests
├── ade20k.jpg
├── backend_test_base.py
├── beach.jpg
├── car.JPEG
├── cobalt_group_a_perf_testing_models.yaml
├── cobalt_group_c_perf_testing_models.yaml
├── common.py
├── completed_perf_testing_models.yaml
├── conftest.py
├── huggingface.py
├── in_progress_perf_testing_models.yaml
├── keras2onnx_applications
│ ├── .gitignore
│ ├── data
│ │ └── street.jpg
│ ├── lpcnet
│ │ ├── README.md
│ │ └── convert_lpcnet_to_onnx.py
│ ├── mask_rcnn
│ │ ├── README.md
│ │ └── mask_rcnn.py
│ ├── model_source
│ │ ├── densenet_1
│ │ │ ├── densenet_1.py
│ │ │ ├── subpixel.py
│ │ │ └── tensorflow_backend.py
│ │ └── densenet_2
│ │ │ └── densenet_2.py
│ ├── nightly_build
│ │ ├── run_all.py
│ │ ├── run_all_v2.py
│ │ ├── test_aae.py
│ │ ├── test_acgan.py
│ │ ├── test_autoencoder.py
│ │ ├── test_bgan.py
│ │ ├── test_bigan.py
│ │ ├── test_ccgan.py
│ │ ├── test_chatbot.py
│ │ ├── test_cogan.py
│ │ ├── test_craft.py
│ │ ├── test_crnn.py
│ │ ├── test_cyclegan.py
│ │ ├── test_dcgan.py
│ │ ├── test_deep_rl.py
│ │ ├── test_deep_speaker.py
│ │ ├── test_deep_speech.py
│ │ ├── test_deepface.py
│ │ ├── test_deeplab_v3.py
│ │ ├── test_densenet_1.py
│ │ ├── test_densenet_2.py
│ │ ├── test_discogan.py
│ │ ├── test_dual_path_network.py
│ │ ├── test_dualgan.py
│ │ ├── test_efn.py
│ │ ├── test_fcn.py
│ │ ├── test_gan.py
│ │ ├── test_image_outpainting.py
│ │ ├── test_inception_v4.py
│ │ ├── test_infogan.py
│ │ ├── test_keras_applications.py
│ │ ├── test_keras_applications_v2.py
│ │ ├── test_lipnet.py
│ │ ├── test_lsgan.py
│ │ ├── test_mask_rcnn.py
│ │ ├── test_mlstm_fcn.py
│ │ ├── test_music_generation.py
│ │ ├── test_name_entity_recognition.py
│ │ ├── test_nasnet_mobile.py
│ │ ├── test_nbeats.py
│ │ ├── test_nlp.py
│ │ ├── test_nonlocal.py
│ │ ├── test_ocr.py
│ │ ├── test_open_face.py
│ │ ├── test_pix2pix.py
│ │ ├── test_pixelda.py
│ │ ├── test_prn.py
│ │ ├── test_pspnet.py
│ │ ├── test_resnext.py
│ │ ├── test_resume_parser.py
│ │ ├── test_se_densenet.py
│ │ ├── test_se_inc_resnet.py
│ │ ├── test_segnet.py
│ │ ├── test_segnet_2.py
│ │ ├── test_semantic_embeddings.py
│ │ ├── test_semantic_segmentation.py
│ │ ├── test_series_net.py
│ │ ├── test_sgan.py
│ │ ├── test_srgan.py
│ │ ├── test_ssrnet.py
│ │ ├── test_super_resolution.py
│ │ ├── test_transformers.py
│ │ ├── test_unet.py
│ │ ├── test_unet_plus_plus.py
│ │ ├── test_wavenet.py
│ │ ├── test_wgan.py
│ │ ├── test_wgan_gp.py
│ │ └── test_yolov3.py
│ └── yolov3
│ │ ├── README.md
│ │ └── yolov3.py
├── keras2onnx_unit_tests
│ ├── conftest.py
│ ├── mock_keras2onnx
│ │ ├── __init__.py
│ │ ├── ke2onnx
│ │ │ └── batch_norm.py
│ │ └── proto
│ │ │ ├── __init__.py
│ │ │ └── tfcompat.py
│ ├── test_cgan.py
│ ├── test_layers.py
│ ├── test_subclassing.py
│ └── test_utils.py
├── models
│ ├── ae0
│ │ └── frozen.pb
│ ├── conv-layers
│ │ └── frozen.pb
│ ├── fc-layers
│ │ └── frozen.pb
│ ├── gru
│ │ └── frozen.pb
│ ├── lstm
│ │ └── frozen.pb
│ ├── regression
│ │ ├── checkpoint
│ │ │ ├── checkpoint
│ │ │ ├── model.data-00000-of-00001
│ │ │ ├── model.index
│ │ │ └── model.meta
│ │ ├── graphdef
│ │ │ └── frozen.pb
│ │ ├── saved_model
│ │ │ ├── saved_model.pb
│ │ │ └── variables
│ │ │ │ ├── variables.data-00000-of-00001
│ │ │ │ └── variables.index
│ │ └── tflite
│ │ │ ├── model.tflite
│ │ │ └── test_api_model.tflite
│ └── saved_model_with_redundant_inputs
│ │ └── saved_model.pb
├── run_pretrained_models.py
├── run_pretrained_models.yaml
├── run_tfjs.js
├── test_api.py
├── test_backend.py
├── test_cond.py
├── test_const_fold.py
├── test_convert.py
├── test_cudnn.py
├── test_cudnn_compatible_gru.py
├── test_custom_rnncell.py
├── test_einsum_helper.py
├── test_einsum_ml.py
├── test_einsum_optimizers.py
├── test_example.py
├── test_gru.py
├── test_grublock.py
├── test_internals.py
├── test_issue_2025.py
├── test_loops.py
├── test_lstm.py
├── test_lstmblock.py
├── test_onnx_shape_inference.py
├── test_optimizers.py
├── test_profile.py
├── test_seq2seq.py
├── test_stacked_lstm.py
├── test_string_ops.py
├── test_tf_shape_inference.py
├── test_tfjs_runner.py
├── test_tflite_postprocess.py
├── test_tflite_utils.py
├── tfhub
│ ├── _tools.py
│ ├── tfhub_albert_en_xlarge.py
│ ├── tfhub_albert_en_xlarge_keras.py
│ ├── tfhub_bert_en_wwm_uncased.py
│ ├── tfhub_blazeposedetector.py
│ ├── tfhub_enformer.py
│ ├── tfhub_esrgan.py
│ ├── tfhub_humpback_whale.py
│ ├── tfhub_inception_v3.py
│ ├── tfhub_lambert_en_uncased_L-24_H-1024_A-16.py
│ ├── tfhub_mobile_food_segmenter_V1.py
│ ├── tfhub_mobilebert_en_uncased.py
│ ├── tfhub_mobilenet_v3_large_075_224.py
│ ├── tfhub_mobilenet_v3_small_075_224.py
│ ├── tfhub_nasnet_large.py
│ ├── tfhub_resnet_v1_101.py
│ ├── tfhub_resnet_v1_101_keras.py
│ ├── tfhub_resnet_v2_101.py
│ ├── tfhub_resnet_v2_101_classification.py
│ ├── tfhub_resnet_v2_152.py
│ ├── tfhub_spam_detection.py
│ ├── tfhub_talkheads_ggelu_bert_en_large.py
│ ├── tfhub_thunder.py
│ ├── tfhub_yamnet_coral.py
│ ├── tfhub_yamnet_tf.py
│ └── tfhub_yamnet_tflite.py
├── tfjs_runner.py
├── unity.yaml
└── utils
│ └── setup_test_env.sh
├── tf2onnx
├── __init__.py
├── constants.py
├── convert.py
├── custom_opsets
│ ├── __init__.py
│ ├── ms.py
│ ├── onnx_ml.py
│ └── string_ops.py
├── flexbuffers.py
├── graph.py
├── graph_builder.py
├── graph_matcher.py
├── handler.py
├── keras2onnx_api.py
├── late_rewriters
│ ├── __init__.py
│ └── channel_order_rewriters.py
├── onnx_opset
│ ├── __init__.py
│ ├── common.py
│ ├── controlflow.py
│ ├── generator.py
│ ├── logical.py
│ ├── math.py
│ ├── misc.py
│ ├── nn.py
│ ├── quantize.py
│ ├── reduction.py
│ ├── rnn.py
│ ├── signal.py
│ ├── tensor.py
│ └── traditionalml.py
├── optimizer
│ ├── __init__.py
│ ├── back_to_back_optimizer.py
│ ├── const_dequantize_optimizer.py
│ ├── const_fold_optimizer.py
│ ├── einsum_optimizer.py
│ ├── global_pool_optimizer.py
│ ├── identity_optimizer.py
│ ├── loop_optimizer.py
│ ├── merge_duplicated_nodes_optimizer.py
│ ├── optimizer_base.py
│ ├── q_dq_optimizer.py
│ ├── reshape_optimizer.py
│ ├── transpose_optimizer.py
│ └── upsample_optimizer.py
├── rewriter
│ ├── __init__.py
│ ├── bigru_rewriter.py
│ ├── bilstm_rewriter.py
│ ├── cond_rewriter.py
│ ├── conv2d_with_add_rewriter.py
│ ├── conv2d_with_pad_rewriter.py
│ ├── conv_dilations_rewriter.py
│ ├── custom_rnn_rewriter.py
│ ├── dropout_rewriter.py
│ ├── eye_rewriter.py
│ ├── flatten_rewriter.py
│ ├── fused_op_rewriter.py
│ ├── gemm_rewriter.py
│ ├── gru_rewriter.py
│ ├── gru_tf2_rewriter.py
│ ├── layer_normalization_rewriter.py
│ ├── leakyrelu_rewriter.py
│ ├── loop_rewriter.py
│ ├── loop_rewriter_base.py
│ ├── lstm_rewriter.py
│ ├── lstm_rewriter_base.py
│ ├── lstm_tf2_rewriter.py
│ ├── quantization_ops_rewriter.py
│ ├── ragged_variant_shape_rewriter.py
│ ├── random_normal_rewriter.py
│ ├── random_uniform.py
│ ├── rnn.py
│ ├── rnn_utils.py
│ ├── thresholded_relu_rewriter.py
│ ├── transpose_rewriter.py
│ └── unit_rnn_rewriter_base.py
├── schemas.py
├── shape_inference.py
├── symbolic_executor.py
├── tf_loader.py
├── tf_utils.py
├── tfjs_utils.py
├── tflite
│ ├── ATan2Options.py
│ ├── AbsOptions.py
│ ├── ActivationFunctionType.py
│ ├── AddNOptions.py
│ ├── AddOptions.py
│ ├── ArgMaxOptions.py
│ ├── ArgMinOptions.py
│ ├── AssignVariableOptions.py
│ ├── BatchMatMulOptions.py
│ ├── BatchToSpaceNDOptions.py
│ ├── BidirectionalSequenceLSTMOptions.py
│ ├── BidirectionalSequenceRNNOptions.py
│ ├── BitcastOptions.py
│ ├── BitwiseXorOptions.py
│ ├── BroadcastToOptions.py
│ ├── BucketizeOptions.py
│ ├── Buffer.py
│ ├── BuiltinOperator.py
│ ├── BuiltinOptions.py
│ ├── CallOnceOptions.py
│ ├── CallOptions.py
│ ├── CastOptions.py
│ ├── CombinerType.py
│ ├── ConcatEmbeddingsOptions.py
│ ├── ConcatenationOptions.py
│ ├── Conv2DOptions.py
│ ├── Conv3DOptions.py
│ ├── CosOptions.py
│ ├── CumsumOptions.py
│ ├── CustomOptionsFormat.py
│ ├── CustomQuantization.py
│ ├── DensifyOptions.py
│ ├── DepthToSpaceOptions.py
│ ├── DepthwiseConv2DOptions.py
│ ├── DequantizeOptions.py
│ ├── DimensionMetadata.py
│ ├── DimensionType.py
│ ├── DivOptions.py
│ ├── DynamicUpdateSliceOptions.py
│ ├── EmbeddingLookupSparseOptions.py
│ ├── EqualOptions.py
│ ├── ExpOptions.py
│ ├── ExpandDimsOptions.py
│ ├── FakeQuantOptions.py
│ ├── FillOptions.py
│ ├── FloorDivOptions.py
│ ├── FloorModOptions.py
│ ├── FullyConnectedOptions.py
│ ├── FullyConnectedOptionsWeightsFormat.py
│ ├── GatherNdOptions.py
│ ├── GatherOptions.py
│ ├── GeluOptions.py
│ ├── GreaterEqualOptions.py
│ ├── GreaterOptions.py
│ ├── HardSwishOptions.py
│ ├── HashtableFindOptions.py
│ ├── HashtableImportOptions.py
│ ├── HashtableOptions.py
│ ├── HashtableSizeOptions.py
│ ├── IfOptions.py
│ ├── Int32Vector.py
│ ├── L2NormOptions.py
│ ├── LSHProjectionOptions.py
│ ├── LSHProjectionType.py
│ ├── LSTMKernelType.py
│ ├── LSTMOptions.py
│ ├── LeakyReluOptions.py
│ ├── LessEqualOptions.py
│ ├── LessOptions.py
│ ├── LocalResponseNormalizationOptions.py
│ ├── LogSoftmaxOptions.py
│ ├── LogicalAndOptions.py
│ ├── LogicalNotOptions.py
│ ├── LogicalOrOptions.py
│ ├── MatrixDiagOptions.py
│ ├── MatrixSetDiagOptions.py
│ ├── MaximumMinimumOptions.py
│ ├── Metadata.py
│ ├── MirrorPadMode.py
│ ├── MirrorPadOptions.py
│ ├── Model.py
│ ├── MulOptions.py
│ ├── NegOptions.py
│ ├── NonMaxSuppressionV4Options.py
│ ├── NonMaxSuppressionV5Options.py
│ ├── NotEqualOptions.py
│ ├── OneHotOptions.py
│ ├── Operator.py
│ ├── OperatorCode.py
│ ├── PackOptions.py
│ ├── PadOptions.py
│ ├── PadV2Options.py
│ ├── Padding.py
│ ├── Pool2DOptions.py
│ ├── PowOptions.py
│ ├── QuantizationDetails.py
│ ├── QuantizationParameters.py
│ ├── QuantizeOptions.py
│ ├── RNNOptions.py
│ ├── RandomOptions.py
│ ├── RangeOptions.py
│ ├── RankOptions.py
│ ├── ReadVariableOptions.py
│ ├── ReducerOptions.py
│ ├── ReshapeOptions.py
│ ├── ResizeBilinearOptions.py
│ ├── ResizeNearestNeighborOptions.py
│ ├── ReverseSequenceOptions.py
│ ├── ReverseV2Options.py
│ ├── Rfft2dOptions.py
│ ├── RightShiftOptions.py
│ ├── SVDFOptions.py
│ ├── ScatterNdOptions.py
│ ├── SegmentSumOptions.py
│ ├── SelectOptions.py
│ ├── SelectV2Options.py
│ ├── SequenceRNNOptions.py
│ ├── ShapeOptions.py
│ ├── SignOptions.py
│ ├── SignatureDef.py
│ ├── SkipGramOptions.py
│ ├── SliceOptions.py
│ ├── SoftmaxOptions.py
│ ├── SpaceToBatchNDOptions.py
│ ├── SpaceToDepthOptions.py
│ ├── SparseIndexVector.py
│ ├── SparseToDenseOptions.py
│ ├── SparsityParameters.py
│ ├── SplitOptions.py
│ ├── SplitVOptions.py
│ ├── SquareOptions.py
│ ├── SquaredDifferenceOptions.py
│ ├── SqueezeOptions.py
│ ├── StridedSliceOptions.py
│ ├── SubGraph.py
│ ├── SubOptions.py
│ ├── Tensor.py
│ ├── TensorMap.py
│ ├── TensorType.py
│ ├── TileOptions.py
│ ├── TopKV2Options.py
│ ├── TransposeConvOptions.py
│ ├── TransposeOptions.py
│ ├── Uint16Vector.py
│ ├── Uint8Vector.py
│ ├── UnidirectionalSequenceLSTMOptions.py
│ ├── UniqueOptions.py
│ ├── UnpackOptions.py
│ ├── UnsortedSegmentMaxOptions.py
│ ├── UnsortedSegmentMinOptions.py
│ ├── UnsortedSegmentProdOptions.py
│ ├── UnsortedSegmentSumOptions.py
│ ├── VarHandleOptions.py
│ ├── VariantSubType.py
│ ├── WhereOptions.py
│ ├── WhileOptions.py
│ ├── ZerosLikeOptions.py
│ └── __init__.py
├── tflite_handlers
│ ├── __init__.py
│ ├── tfl_controlflow.py
│ ├── tfl_direct.py
│ ├── tfl_math.py
│ ├── tfl_nn.py
│ ├── tfl_postprocess.py
│ └── tfl_tensor.py
├── tflite_rewriters
│ ├── __init__.py
│ ├── tfl_qdq_rewriter.py
│ ├── tfl_rfft_rewriter.py
│ ├── tfl_scan_output_rewriter.py
│ └── tfl_select_zero_rewriter.py
├── tflite_utils.py
├── tfonnx.py
├── utils.py
├── verbose_logging.py
└── version.py
├── tools
├── aggregate-patterns.py
├── dump-onnx.py
├── example.bat
├── example.sh
├── gen_doc.py
├── gen_tflite_flatbuffer.py
├── graphtool.py
├── make_regression_test_models.py
├── onnx-experiments.py
├── onnx-optimize.py
├── profile_conversion_time.py
├── pylintrc
├── save_pretrained_model.py
├── tf_graph_tool.py
└── tfgraph.py
└── tutorials
├── BertTutorial.ipynb
├── ConvertingSSDMobilenetToONNX.ipynb
├── README.md
├── efficientdet.ipynb
├── efficientnet-edge.ipynb
├── efficientnet-lite.ipynb
├── huggingface-bert.ipynb
├── keras-resnet50.ipynb
└── mobiledet-tflite.ipynb
/.github/ISSUE_TEMPLATE/bug-performance-issue.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug/Performance issue
3 | about: Use this template for reporting a bug or a performance issue.
4 | title: ''
5 | labels: 'bug'
6 | assignees: ''
7 | ---
8 |
9 |
16 |
17 | **Describe the bug**
18 |
19 |
20 | **Urgency**
21 |
22 |
23 | **System information**
24 | - OS Platform and Distribution (e.g., Linux Ubuntu 18.04*):
25 | - TensorFlow Version:
26 | - Python version:
27 | - ONNX version (if applicable, e.g. 1.11*):
28 | - ONNXRuntime version (if applicable, e.g. 1.11*):
29 |
30 |
31 | **To Reproduce**
32 |
33 |
34 | **Screenshots**
35 |
36 |
37 | **Additional context**
38 |
39 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Requests for new tf2onnx features
4 | title: ''
5 | labels: 'enhancement'
6 | assignees: ''
7 |
8 | ---
9 |
10 | Before submitting your request, please review past submissions to ensure that it is not a duplicate of a known feature request.
11 |
12 | ### Describe the feature request
13 |
14 |
15 | ### Describe scenario use case
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/operators.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Missing Operator
3 | about: Operator that does not currently support in tf2onnx.
4 | title: ''
5 | labels: 'unsupported ops'
6 | assignees: ''
7 |
8 |
9 |
10 | ---
11 | # New Operator
12 |
13 | ### Describe the operator
14 |
15 |
16 | ### Do you know this operator be constructed using existing ONNX operators?
17 |
18 |
19 | ### Is this operator used by any model currently? Which one?
20 |
21 | ### Are you willing to contribute it? (Y/N)
22 |
23 | ### Notes
24 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/question.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Question
3 | about: Ask a question about the tf2onnx.
4 | title: ''
5 | labels: 'question'
6 | assignees: ''
7 |
8 |
9 |
10 | ---
11 | # Ask a Question
12 |
13 | ### Question
14 |
15 |
16 | ### Further information
17 | - Is this issue related to a specific model?
18 | **Model name**:
19 |
20 | **Model opset**:
21 |
22 | ### Notes
23 |
--------------------------------------------------------------------------------
/.github/actions/keras_unit_test/action.yml:
--------------------------------------------------------------------------------
1 | name: Keras2onnx Unit Test Run
2 |
3 | inputs:
4 | tf_version:
5 | description: 'TensorFlow version'
6 | python_version:
7 | description: 'Python version'
8 | ort_version:
9 | description: 'ONNXRuntime version'
10 | onnx_version:
11 | description: 'ONNX version'
12 |
13 | runs:
14 | using: "composite"
15 | steps:
16 | - name: Set up Python (${{ inputs.python_version }})
17 | uses: actions/setup-python@v5
18 | with:
19 | python-version: ${{ inputs.python_version }}
20 |
21 | - name: Install dependencies (TF-v${{ inputs.tf_version }})
22 | shell: bash
23 | run: |
24 | python -m pip install --upgrade pip
25 | pip install onnxconverter-common
26 | pip install onnx==${{ inputs.onnx_version }}
27 | pip install h5py==3.7.0
28 | pip install parameterized
29 | pip install timeout-decorator
30 | pip install coloredlogs flatbuffers
31 | pip install tensorflow==${{ inputs.tf_version }}
32 | pip install pytest pytest-cov pytest-runner
33 | pip install onnxruntime==${{ inputs.ort_version }}
34 | pip uninstall -y protobuf
35 | pip install "protobuf~=3.20"
36 | if [[ ${{ inputs.tf_version }} == 1.* ]]; then
37 | pip install numpy==1.19.0
38 | else
39 | pip install "numpy<2"
40 | fi
41 |
42 | pip install -e .
43 |
44 | echo "----- List all of depdencies:"
45 | pip freeze --all
46 |
47 | - name: Run keras_unit_test (Linux)
48 | shell: bash
49 | if: runner.os == 'Linux'
50 | run: |
51 | python -c "import onnxruntime"
52 | python -c "import onnxconverter_common"
53 | pytest tests/keras2onnx_unit_tests --doctest-modules --junitxml=junit/test-results.xml
54 |
--------------------------------------------------------------------------------
/.github/actions/pretrained_model_test/action.yml:
--------------------------------------------------------------------------------
1 | name: Pretrained Model Test Run
2 |
3 | inputs:
4 | tf_version:
5 | description: 'TensorFlow version'
6 | python_version:
7 | description: 'Python version'
8 | ort_version:
9 | description: 'ONNXRuntime version'
10 | onnx_version:
11 | description: 'ONNX version'
12 | opset_version:
13 | description: 'OPSET version'
14 | skip_tflite:
15 | description: 'Skip TFLite tests'
16 |
17 | runs:
18 | using: "composite"
19 | steps:
20 | - name: Set up Python (${{ inputs.python_version }})
21 | uses: actions/setup-python@v5
22 | with:
23 | python-version: ${{ inputs.python_version }}
24 |
25 | - name: Install dependencies (TF-v${{ inputs.tf_version }})
26 | shell: bash
27 | run: |
28 | chmod +x ./tests/utils/setup_test_env.sh
29 | ./tests/utils/setup_test_env.sh ${{ inputs.tf_version }} ${{ inputs.ort_version }} ${{ inputs.onnx_version }}
30 |
31 | - name: Fix Paths (Windows only)
32 | shell: pwsh
33 | if: runner.os == 'Windows'
34 | run: |
35 | $site_dir = python -c "import site; print(site.getsitepackages()[1])"
36 | echo "##vso[task.prependpath]$site_dir\numpy\.libs"
37 | $base_dir = python -c "import site; print(site.getsitepackages()[0])"
38 | echo "##vso[task.prependpath]$base_dir/Library/bin"
39 |
40 | - name: Run pretrained_model_test
41 | shell: bash
42 | run: |
43 | # TODO: fix unity model path
44 | # python tests/run_pretrained_models.py --backend $CI_ONNX_BACKEND --opset $CI_ONNX_OPSET --config tests/unity.yaml || status=$?
45 | python tests/run_pretrained_models.py --backend onnxruntime --opset ${{ inputs.opset_version }} --skip_tf_tests False --skip_tflite_tests ${{ inputs.skip_tflite }} --skip_tfjs_tests True --config tests/run_pretrained_models.yaml || status=$?
46 | ls
47 |
--------------------------------------------------------------------------------
/.github/workflows/pylint.yml:
--------------------------------------------------------------------------------
1 | name: Pylint
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | push:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | enforce-style:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - name: Checkout code
16 | uses: actions/checkout@v4
17 |
18 | - name: Set up Python
19 | uses: actions/setup-python@v5
20 | with:
21 | python-version: 3.9 # Specify the desired Python version (e.g., 3.8, 3.9)
22 |
23 | - name: Install dependencies
24 | run: pip install pylint==2.4.4
25 |
26 | - name: Run pylint
27 | run: |
28 | pip freeze
29 | pylint --rcfile=tools/pylintrc --ignore=version.py,tflite --disable=cyclic-import tf2onnx tests/*.py tools -j 0
30 |
31 | # Add other jobs or steps as needed
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .idea
3 | build
4 | dist
5 | bin
6 | obj
7 | .ipynb_checkpoints
8 | __pycache__
9 | *.pyc
10 | *.swp
11 | .cache
12 | .eggs
13 | *.egg-info
14 | *.onnx
15 | run.sh
16 | node_modules/*
17 | tests/tfhub/*/*.onnx
18 | tests/tfhub/*/*.tar.gz
19 | tests/tfhub/*/*.tflite
20 | tests/tfhub/*/**
21 |
22 | # Unit test / coverage reports
23 | .coverage*
24 | .cache
25 | .hypothesis/
26 | .pytest_cache/
27 | test-output.xml
28 |
29 | # VSCode
30 | .vscode/
31 | !.vscode/extensions.json
32 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Contributing
4 |
5 | We're always looking for your help to fix bugs and improve the product. Create a pull request and we'll be happy to take a look.
6 |
7 | # Checkin procedure
8 | 1. Fork the repo
9 | 2. git clone your fork
10 | 3. Create feature branch
11 | 4. Make and checkin your changes along with unit tests
12 | 5. git commit your changes
13 | 6. git push origin HEAD
14 | 7. To request merge into main send a pull request from the web ui
15 | https://github.com/onnx/tensorflow-onnx.
16 |
17 |
18 | New code *must* be accompanied by unit tests.
19 |
20 | *Note*: After creating a pull request, you will see a build getting triggered right away. You may check if style check and unit tests are passing.
21 |
22 |
23 | # Coding guidelines
24 | Please see [Coding Conventions and Standards](http://google.github.io/styleguide/pyguide.html)
25 |
26 | # Licensing guidelines
27 | This project welcomes contributions and suggestions. The contributions require you to
28 | agree the Developer Certificate of Origin (DCO) declaring that you have the right to,
29 | and actually do, grant us the rights to use your contribution.
30 |
31 | When you submit a pull request, a DCO-bot will automatically determine whether you need
32 | to provide a DCO and decorate the PR appropriately.
33 |
34 | You are ready to sign your code by using the `-s` flag during your commits.
35 |
36 | ```sh
37 | git commit -s
38 | ```
39 |
40 |
41 | # Code of conduct
42 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
43 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
44 | or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
45 |
--------------------------------------------------------------------------------
/VERSION_NUMBER:
--------------------------------------------------------------------------------
1 | 1.16.1
2 |
--------------------------------------------------------------------------------
/build.bat:
--------------------------------------------------------------------------------
1 | rem SPDX-License-Identifier: Apache-2.0
2 |
3 | python -m pytest --cov=tf2onnx
4 | python setup.py bdist_wheel
5 |
--------------------------------------------------------------------------------
/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 |
6 | set -x
7 |
8 | apt-get install -y protobuf-compiler libprotoc-dev
9 | pip install setuptools
10 | pip install onnx pytest-cov
11 |
12 | python setup.py test
13 | python setup.py bdist_wheel
14 |
--------------------------------------------------------------------------------
/examples/call_converter_via_python.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | A simple example how to call tensorflow-onnx via python.
5 | """
6 |
7 | import tensorflow as tf
8 | import tf2onnx
9 |
10 | with tf.Session() as sess:
11 | x = tf.placeholder(tf.float32, [2, 3], name="input")
12 | x_ = tf.add(x, x)
13 | _ = tf.identity(x_, name="output")
14 | onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph, input_names=["input:0"], output_names=["output:0"])
15 | model_proto = onnx_graph.make_model("test")
16 | with open("/tmp/model.onnx", "wb") as f:
17 | f.write(model_proto.SerializeToString())
18 |
--------------------------------------------------------------------------------
/examples/custom_op_via_python.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | A simple example how to map a custom op in python.
5 | """
6 | import tensorflow as tf
7 | import tf2onnx
8 | from onnx import helper
9 |
10 | _TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
11 |
12 |
13 | def print_handler(ctx, node, name, args):
14 | # replace tf.Print() with Identity
15 | # T output = Print(T input, data, @list(type) U, @string message, @int first_n, @int summarize)
16 | # becomes:
17 | # T output = Identity(T Input)
18 | node.type = "Identity"
19 | node.domain = _TENSORFLOW_DOMAIN
20 | del node.input[1:]
21 | return node
22 |
23 |
24 | with tf.Session() as sess:
25 | x = tf.placeholder(tf.float32, [2, 3], name="input")
26 | x_ = tf.add(x, x)
27 | x_ = tf.Print(x_, [x_], "hello")
28 | _ = tf.identity(x_, name="output")
29 | onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph,
30 | custom_op_handlers={"Print": (print_handler, [])},
31 | extra_opset=[helper.make_opsetid(_TENSORFLOW_DOMAIN, 1)],
32 | input_names=["input:0"],
33 | output_names=["output:0"])
34 | model_proto = onnx_graph.make_model("test")
35 | with open("/tmp/model.onnx", "wb") as f:
36 | f.write(model_proto.SerializeToString())
37 |
--------------------------------------------------------------------------------
/examples/getting_started.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | This example shows how to convert tf functions and keras models using the Python API.
5 | It also demonstrates converting saved_models from the command line.
6 | """
7 |
8 | import tensorflow as tf
9 | import tf2onnx
10 | import numpy as np
11 | import onnxruntime as ort
12 | import os
13 |
14 | ##################### tf function #####################
15 |
16 | @tf.function
17 | def f(a, b):
18 | return a + b
19 |
20 | input_signature = [tf.TensorSpec([2, 3], tf.float32), tf.TensorSpec([2, 3], tf.float32)]
21 | onnx_model, _ = tf2onnx.convert.from_function(f, input_signature, opset=13)
22 |
23 | a_val = np.ones([2, 3], np.float32)
24 | b_val = np.zeros([2, 3], np.float32)
25 |
26 | print("Tensorflow result")
27 | print(f(a_val, b_val).numpy())
28 |
29 | print("ORT result")
30 | sess = ort.InferenceSession(onnx_model.SerializeToString())
31 | res = sess.run(None, {'a': a_val, 'b': b_val})
32 | print(res[0])
33 |
34 |
35 | ##################### Keras Model #####################
36 |
37 | model = tf.keras.Sequential()
38 | model.add(tf.keras.layers.Dense(4, activation="relu"))
39 |
40 | input_signature = [tf.TensorSpec([3, 3], tf.float32, name='x')]
41 | onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=13)
42 |
43 | x_val = np.ones((3, 3), np.float32)
44 |
45 | print("Keras result")
46 | print(model(x_val).numpy())
47 |
48 | print("ORT result")
49 | sess = ort.InferenceSession(onnx_model.SerializeToString())
50 | res = sess.run(None, {'x': x_val})
51 | print(res[0])
52 |
53 |
54 | ##################### Saved Model #####################
55 |
56 | model.save("savedmodel")
57 | os.system("python -m tf2onnx.convert --saved-model savedmodel --output model.onnx --opset 13")
58 |
59 | print("ORT result")
60 | sess = ort.InferenceSession("model.onnx")
61 | res = sess.run(None, {'dense_input': x_val})
62 | print(res[0])
63 |
64 | print("Conversion succeeded")
--------------------------------------------------------------------------------
/examples/tf_custom_op/double_and_add_one.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/examples/tf_custom_op/double_and_add_one.so
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | [aliases]
4 | test=pytest
5 |
6 | [tool:pytest]
7 | addopts=--cov=tf2onnx
8 | norecursedirs=tests/keras2onnx_applications tests/keras2onnx_unit_tests
9 |
--------------------------------------------------------------------------------
/tests/ade20k.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/ade20k.jpg
--------------------------------------------------------------------------------
/tests/beach.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/beach.jpg
--------------------------------------------------------------------------------
/tests/car.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/car.JPEG
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """ print pytest config."""
5 |
6 | from common import get_test_config
7 | from tf2onnx import logging
8 |
9 |
10 | def pytest_configure():
11 | config = get_test_config()
12 | logging.basicConfig(level=config.log_level)
13 | with logging.set_scope_level(logging.INFO) as logger:
14 | logger.info(config)
15 |
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/.gitignore:
--------------------------------------------------------------------------------
1 | *.h5
2 | *.onnx
3 |
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/data/street.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/keras2onnx_applications/data/street.jpg
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/lpcnet/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Introduction
4 | This tool converts the lpcnet model to onnx.
5 | To run this code, we need first install the original lpcnet model from .
6 | Note that lpcnet is not a package, so please add its directory to the path.
7 | Then run
8 | ```
9 | python convert_lpcnet_to_onnx.py [model_file]
10 | ```
11 | model_file is the model with trained weights, it is a *.h5 file.
12 |
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/lpcnet/convert_lpcnet_to_onnx.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import lpcnet
4 | import sys
5 |
6 | model, enc, dec = lpcnet.new_lpcnet_model(use_gpu=False)
7 | model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
8 | model_file = sys.argv[1]
9 | model.load_weights(model_file)
10 |
11 | import mock_keras2onnx
12 | oxml_enc = mock_keras2onnx.convert_keras(enc, 'lpcnet_enc')
13 | oxml_dec = mock_keras2onnx.convert_keras(dec, 'lpcnet_dec')
14 |
15 | import onnx
16 | onnx.save(oxml_enc, "lpcnet_enc.onnx")
17 | onnx.save(oxml_dec, "lpcnet_dec.onnx")
18 |
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/mask_rcnn/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Introduction
4 | The original Keras project of Masked RCNN is: . Follow the 'Installation' section in README.md to set up the model.
5 | There is also a good [tutorial](https://github.com/matterport/Mask_RCNN#step-by-step-detection) to learn about the object detection.
6 |
7 | The conversion supports since opset 11, And the converted model need to be working with ONNXRuntime latest version which supports ONNX opset 11 and contrib ops needed by this model.
8 |
9 | # Convert and Run the model.
10 | ```
11 | cd
12 | pip install -e .
13 | cd /applications/mask_rcnn
14 | # convert the model to onnx and test it with an image.
15 | python mask_rcnn.py
16 | ```
17 | The unit test is added in our nightly build, see [here](https://github.com/onnx/keras-onnx/blob/master/applications/nightly_build/test_mask_rcnn.py)
18 |
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/model_source/densenet_1/tensorflow_backend.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # From https://github.com/titu1994/DenseNet/blob/master/tensorflow_backend.py
4 | # Modifications Copyright (c) Microsoft.
5 |
6 | import tensorflow as tf
7 |
8 | from mock_keras2onnx.proto import keras
9 | from keras.backend import tensorflow_backend as KTF
10 | from keras.backend.common import image_data_format
11 |
12 | py_all = all
13 |
14 | def depth_to_space(input, scale, data_format=None):
15 | ''' Uses phase shift algorithm to convert channels/depth for spatial resolution '''
16 | if data_format is None:
17 | data_format = image_data_format()
18 |
19 | if data_format == 'channels_first':
20 | data_format = 'NCHW'
21 | else:
22 | data_format = 'NHWC'
23 |
24 | data_format = data_format.lower()
25 | out = tf.depth_to_space(input, scale, data_format=data_format)
26 | return out
27 |
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/nightly_build/run_all.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import os
4 | from os import listdir
5 | from os.path import isfile, join
6 | import argparse
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--exclude')
10 | args = parser.parse_args()
11 | exclude_set = set(args.exclude.split()) if args.exclude is not None else set()
12 |
13 | os.environ["PYTHONPATH"] = \
14 | os.environ.get("PYTHONPATH", "") + os.pathsep + "../../keras2onnx_unit_tests" + os.pathsep + "../../../"
15 | os.environ["TF2ONNX_CATCH_ERRORS"] = "FALSE"
16 |
17 | mypath = '.'
18 | files = [f for f in listdir(mypath) if isfile(join(mypath, f)) and f.find("test_") == 0]
19 | files.sort()
20 |
21 | res_final = True
22 | for f_ in files:
23 | if f_ not in exclude_set:
24 | res = os.system("pytest " + f_ + " --no-cov "
25 | "--doctest-modules --junitxml=junit/test-results-" + f_[5:-3] + ".xml")
26 | if res > 0:
27 | res_final = False
28 |
29 | if res_final:
30 | assert(True)
31 | else:
32 | assert(False)
33 |
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/nightly_build/run_all_v2.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import os
4 | os.environ["PYTHONPATH"] = \
5 | os.environ.get("PYTHONPATH", "") + os.pathsep + "../../keras2onnx_unit_tests" + os.pathsep + "../../../"
6 | os.environ["TF2ONNX_CATCH_ERRORS"] = "FALSE"
7 |
8 | files = ['test_keras_applications_v2.py', 'test_transformers.py', 'test_chatbot.py', 'test_efn.py', \
9 | 'test_resnext.py']
10 | files.sort()
11 |
12 | res_final = True
13 | for f_ in files:
14 | res = os.system("pytest " + f_ + " --no-cov "
15 | "--doctest-modules --junitxml=junit/test-results-" + f_[5:-3] + ".xml")
16 | if res > 0:
17 | res_final = False
18 |
19 | if res_final:
20 | assert(True)
21 | else:
22 | assert(False)
23 |
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/nightly_build/test_densenet_1.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import os
4 | import sys
5 | import unittest
6 | import keras_segmentation
7 | from os.path import dirname, abspath
8 |
9 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/'))
10 | from test_utils import run_image
11 |
12 | img_path = os.path.join(os.path.dirname(__file__), '../data', 'street.jpg')
13 |
14 | from mock_keras2onnx.proto import is_keras_older_than
15 |
16 | class TestDenseNet_1(unittest.TestCase):
17 |
18 | def setUp(self):
19 | self.model_files = []
20 |
21 | def tearDown(self):
22 | for fl in self.model_files:
23 | os.remove(fl)
24 |
25 | @unittest.skipIf(is_keras_older_than("2.2.3"),
26 | "Cannot import normalize_data_format from keras.backend")
27 | def test_densenet(self):
28 | # From https://github.com/titu1994/DenseNet/blob/master/densenet.py
29 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../model_source/densenet_1/'))
30 | import densenet_1
31 | image_dim = (224, 224, 3)
32 | model = densenet_1.DenseNetImageNet121(input_shape=image_dim)
33 | res = run_image(model, self.model_files, img_path, target_size=(224, 224))
34 | self.assertTrue(*res)
35 |
36 |
37 | if __name__ == "__main__":
38 | unittest.main()
39 |
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/nightly_build/test_densenet_2.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import os
4 | import sys
5 | import unittest
6 | import keras_segmentation
7 | from os.path import dirname, abspath
8 |
9 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/'))
10 | from test_utils import run_image
11 |
12 | img_path = os.path.join(os.path.dirname(__file__), '../data', 'street.jpg')
13 |
14 | from mock_keras2onnx.proto import is_keras_older_than
15 |
16 | class TestDenseNet_2(unittest.TestCase):
17 |
18 | def setUp(self):
19 | self.model_files = []
20 |
21 | def tearDown(self):
22 | for fl in self.model_files:
23 | os.remove(fl)
24 |
25 | def test_densenet(self):
26 | # From https://github.com/tdeboissiere/DeepLearningImplementations/blob/master/DenseNet/densenet.py
27 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../model_source/densenet_2/'))
28 | import densenet_2
29 | model = densenet_2.DenseNet(20,
30 | (224, 224, 3),
31 | 4,
32 | 1,
33 | 1,
34 | nb_filter=10)
35 | res = run_image(model, self.model_files, img_path, target_size=(224, 224))
36 | self.assertTrue(*res)
37 |
38 |
39 | if __name__ == "__main__":
40 | unittest.main()
41 |
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/nightly_build/test_efn.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import os
4 | import sys
5 | import unittest
6 | from os.path import dirname, abspath
7 | from mock_keras2onnx.proto import keras, is_tensorflow_older_than
8 |
9 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/'))
10 | from test_utils import run_image
11 |
12 | img_path = os.path.join(os.path.dirname(__file__), '../data', 'street.jpg')
13 |
14 |
15 | @unittest.skipIf(is_tensorflow_older_than('2.1.0'), "efficientnet needs tensorflow >= 2.1.0")
16 | class TestEfn(unittest.TestCase):
17 |
18 | def setUp(self):
19 | self.model_files = []
20 |
21 | def tearDown(self):
22 | for fl in self.model_files:
23 | os.remove(fl)
24 |
25 | @unittest.skip("TODO: model discrepancy")
26 | def test_custom(self):
27 | from efficientnet import tfkeras as efn
28 | keras.backend.set_learning_phase(0)
29 | base_model = efn.EfficientNetB0(input_shape=(600, 600, 3), weights=None)
30 | backbone = keras.Model(base_model.input, base_model.get_layer("top_activation").output)
31 | res = run_image(backbone, self.model_files, img_path, target_size=(600, 600),
32 | rtol=1e-2, atol=1e-1)
33 | self.assertTrue(*res)
34 |
35 | def test_efn(self):
36 | from efficientnet import tfkeras as efn
37 | keras.backend.set_learning_phase(0)
38 | model = efn.EfficientNetB0(weights=None)
39 | res = run_image(model, self.model_files, img_path, target_size=(224, 224), rtol=1e-2)
40 | self.assertTrue(*res)
41 |
42 |
43 | if __name__ == "__main__":
44 | unittest.main()
45 |
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/nightly_build/test_fcn.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import os
4 | import sys
5 | import unittest
6 | import keras_segmentation
7 | from os.path import dirname, abspath
8 |
9 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/'))
10 | from test_utils import run_image
11 | img_path = os.path.join(os.path.dirname(__file__), '../data', 'street.jpg')
12 |
13 |
14 | class TestFCN(unittest.TestCase):
15 |
16 | def setUp(self):
17 | self.model_files = []
18 |
19 | def tearDown(self):
20 | for fl in self.model_files:
21 | os.remove(fl)
22 |
23 | def test_fcn(self):
24 | # From https://github.com/divamgupta/image-segmentation-keras/models/fcn.py
25 | model = keras_segmentation.models.fcn.fcn_8(101)
26 | res = run_image(model, self.model_files, img_path, target_size=(416, 608))
27 | self.assertTrue(*res)
28 |
29 |
30 | if __name__ == "__main__":
31 | unittest.main()
32 |
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/nightly_build/test_segnet.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import os
4 | import sys
5 | import unittest
6 | import keras_segmentation
7 | from os.path import dirname, abspath
8 |
9 | sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/'))
10 | from test_utils import run_image
11 | img_path = os.path.join(os.path.dirname(__file__), '../data', 'street.jpg')
12 |
13 |
14 | class TestSegNet(unittest.TestCase):
15 |
16 | def setUp(self):
17 | self.model_files = []
18 |
19 | def tearDown(self):
20 | for fl in self.model_files:
21 | os.remove(fl)
22 |
23 | def test_segnet(self):
24 | # From https://github.com/divamgupta/image-segmentation-keras/blob/master/keras_segmentation/models/segnet.py
25 | model = keras_segmentation.models.segnet.segnet(101)
26 | res = run_image(model, self.model_files, img_path, target_size=(416, 608))
27 | self.assertTrue(*res)
28 |
29 | def test_vgg_segnet(self):
30 | # From https://github.com/divamgupta/image-segmentation-keras/blob/master/keras_segmentation/models/segnet.py
31 | model = keras_segmentation.models.segnet.vgg_segnet(101)
32 | res = run_image(model, self.model_files, img_path, rtol=3.e-3, target_size=(416, 608))
33 | self.assertTrue(*res)
34 |
35 | def test_mobilenet_segnet(self):
36 | # From https://github.com/divamgupta/image-segmentation-keras/blob/master/keras_segmentation/models/segnet.py
37 | model = keras_segmentation.models.segnet.mobilenet_segnet(101)
38 | res = run_image(model, self.model_files, img_path, target_size=(224, 224))
39 | self.assertTrue(*res)
40 |
41 | if __name__ == "__main__":
42 | unittest.main()
43 |
--------------------------------------------------------------------------------
/tests/keras2onnx_applications/yolov3/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Introduction
4 | The original keras model was coming from: , clone the project and follow the 'Quick Start' to get the pre-trained model.
5 |
6 | We have converted yolov3 model successfully and uploaded to the model zoo
7 |
8 | The model supports `batch_size = 1`.
9 |
10 | # Convert
11 | ```
12 | export PYTHONPATH=$(the keras-yolo3 path)
13 | # run object detection, convert the model to onnx first if the onnx model does not exist
14 | python yolov3.py
15 | ```
16 | The unit test is added in our nightly build, see [here](https://github.com/onnx/keras-onnx/blob/master/applications/nightly_build/test_yolov3.py)
17 |
--------------------------------------------------------------------------------
/tests/keras2onnx_unit_tests/conftest.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import os
4 | import pytest
5 |
6 | import numpy as np
7 | import tensorflow as tf
8 |
9 | from mock_keras2onnx.proto import keras, is_tf_keras
10 | from test_utils import run_onnx_runtime
11 | from mock_keras2onnx.proto.tfcompat import is_tf2
12 |
13 | K = keras.backend
14 |
15 |
16 | @pytest.fixture(scope='function')
17 | def runner():
18 | np.random.seed(42)
19 | if is_tf2:
20 | tf.random.set_seed(42)
21 | else:
22 | tf.random.set_random_seed(42)
23 | model_files = []
24 |
25 | def runner_func(*args, **kwargs):
26 | return run_onnx_runtime(*args, model_files, **kwargs)
27 |
28 | # Ensure Keras layer naming is reset for each function
29 | if hasattr(K, "reset_uids"):
30 | # see https://github.com/onnx/tensorflow-onnx/issues/2370
31 | K.reset_uids()
32 | # Reset the TensorFlow session to avoid resource leaking between tests
33 | K.clear_session()
34 |
35 | # Provide wrapped run_onnx_runtime function
36 | yield runner_func
37 | # Remove model files
38 | for fl in model_files:
39 | os.remove(fl)
40 |
--------------------------------------------------------------------------------
/tests/keras2onnx_unit_tests/mock_keras2onnx/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from tf2onnx.keras2onnx_api import convert_keras
3 |
4 | def set_converter(*args, **kwargs):
5 | pass
6 |
--------------------------------------------------------------------------------
/tests/keras2onnx_unit_tests/mock_keras2onnx/ke2onnx/batch_norm.py:
--------------------------------------------------------------------------------
1 | convert_keras_batch_normalization = None
--------------------------------------------------------------------------------
/tests/keras2onnx_unit_tests/mock_keras2onnx/proto/tfcompat.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | import os
4 | import tensorflow as _tf
5 |
6 | from packaging.version import Version
7 |
8 | is_tf2 = Version(_tf.__version__.split('-')[0]) >= Version("2.0.0")
9 |
10 |
11 | def normalize_tensor_shape(tensor_shape):
12 | if is_tf2:
13 | return [d for d in tensor_shape]
14 | else:
15 | return [d.value for d in tensor_shape]
16 |
17 |
18 | def dump_graph_into_tensorboard(tf_graph):
19 | # type: (_tf.Graph) -> None
20 | _tb_log_dir = os.environ.get('TB_LOG_DIR')
21 | if _tb_log_dir:
22 | if is_tf2:
23 | from tensorflow.python.ops.summary_ops_v2 import graph as write_graph
24 | pb_visual_writer = _tf.summary.create_file_writer(_tb_log_dir)
25 | with pb_visual_writer.as_default():
26 | write_graph(tf_graph)
27 | else:
28 | from tensorflow.python.summary import summary
29 | pb_visual_writer = summary.FileWriter(_tb_log_dir)
30 | pb_visual_writer.add_graph(tf_graph)
31 |
32 |
33 | if is_tf2:
34 | tensorflow = _tf.compat.v1
35 |
36 | def is_subclassed(layer):
37 | """Returns True if the object is a subclassed layer or subclassed model."""
38 | return (layer.__module__.find('keras.engine') == -1 and
39 | layer.__module__.find('keras.layers') == -1)
40 | else:
41 | tensorflow = _tf
42 |
43 | def is_subclassed(layer):
44 | return False
45 |
--------------------------------------------------------------------------------
/tests/models/ae0/frozen.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/ae0/frozen.pb
--------------------------------------------------------------------------------
/tests/models/conv-layers/frozen.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/conv-layers/frozen.pb
--------------------------------------------------------------------------------
/tests/models/fc-layers/frozen.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/fc-layers/frozen.pb
--------------------------------------------------------------------------------
/tests/models/gru/frozen.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/gru/frozen.pb
--------------------------------------------------------------------------------
/tests/models/lstm/frozen.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/lstm/frozen.pb
--------------------------------------------------------------------------------
/tests/models/regression/checkpoint/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "model"
2 | all_model_checkpoint_paths: "model"
3 |
--------------------------------------------------------------------------------
/tests/models/regression/checkpoint/model.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/checkpoint/model.data-00000-of-00001
--------------------------------------------------------------------------------
/tests/models/regression/checkpoint/model.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/checkpoint/model.index
--------------------------------------------------------------------------------
/tests/models/regression/checkpoint/model.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/checkpoint/model.meta
--------------------------------------------------------------------------------
/tests/models/regression/graphdef/frozen.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/graphdef/frozen.pb
--------------------------------------------------------------------------------
/tests/models/regression/saved_model/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/saved_model/saved_model.pb
--------------------------------------------------------------------------------
/tests/models/regression/saved_model/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/saved_model/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/tests/models/regression/saved_model/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/saved_model/variables/variables.index
--------------------------------------------------------------------------------
/tests/models/regression/tflite/model.tflite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/tflite/model.tflite
--------------------------------------------------------------------------------
/tests/models/regression/tflite/test_api_model.tflite:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/regression/tflite/test_api_model.tflite
--------------------------------------------------------------------------------
/tests/models/saved_model_with_redundant_inputs/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onnx/tensorflow-onnx/3dd7729e53253e04da079deb56fdce7bb9f4a338/tests/models/saved_model_with_redundant_inputs/saved_model.pb
--------------------------------------------------------------------------------
/tests/test_cudnn.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """Unit Tests for cudnn."""
5 |
6 | import numpy as np
7 | import tensorflow as tf
8 |
9 | from tensorflow.python.ops import init_ops
10 | from backend_test_base import Tf2OnnxBackendTestBase
11 | from common import check_tf_max_version, skip_tf_cpu, check_opset_min_version, unittest_main
12 |
13 |
14 | class CudnnTests(Tf2OnnxBackendTestBase):
15 | """ test cudnn cases """
16 | @check_tf_max_version("1.15.0", "not supported in tf-2.0")
17 | @skip_tf_cpu("only tf_gpu can run CudnnGPU")
18 | @check_opset_min_version(10, "CudnnGRU")
19 | def test_cudnngru(self):
20 | """ test contrib cudnn gru """
21 | seq_length = 3
22 | batch_size = 5
23 | input_size = 2
24 | num_layers = 2
25 | num_units = 2
26 | num_dirs = 2
27 | x_val = np.random.randint(0, 100, [seq_length, batch_size, input_size]).astype(np.float32)
28 | h_val = np.random.randint(0, 100, [num_layers * num_dirs, batch_size, num_units]).astype(np.float32).reshape(
29 | [num_layers * num_dirs, batch_size, num_units])
30 |
31 | def func(x, h):
32 | initializer = init_ops.constant_initializer(0.5)
33 | cudnngru = tf.contrib.cudnn_rnn.CudnnGRU(num_layers, num_units, 'linear_input', 'bidirectional',
34 | kernel_initializer=initializer, bias_initializer=initializer)
35 | cudnngru.build([seq_length, batch_size, input_size])
36 | outputs = cudnngru.call(x, tuple([h]))
37 | _ = tf.identity(outputs[0], name='output')
38 |
39 | feed_dict = {"input_1:0": x_val, "input_2:0": h_val}
40 | input_names_with_port = ["input_1:0", "input_2:0"]
41 | output_names_with_port = ["output:0"]
42 | self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-05, atol=1e-04)
43 |
44 |
45 | if __name__ == '__main__':
46 | unittest_main()
47 |
--------------------------------------------------------------------------------
/tests/test_profile.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """Unit Tests for Benchmarks."""
5 | import os
6 | import subprocess
7 | from backend_test_base import Tf2OnnxBackendTestBase
8 | from common import (
9 | check_opset_min_version, check_tf_min_version,
10 | unittest_main, check_onnxruntime_min_version
11 | )
12 |
13 | # pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
14 | # pylint: disable=invalid-name
15 | # pylint: enable=invalid-name
16 |
17 | class ProfileTests(Tf2OnnxBackendTestBase):
18 |
19 | folder = os.path.join(os.path.dirname(__file__), '..', 'tools')
20 |
21 | @check_tf_min_version("2.0")
22 | @check_opset_min_version(12)
23 | @check_onnxruntime_min_version('1.4.0')
24 | def test_profile_conversion_time(self):
25 | filename = os.path.join(ProfileTests.folder, 'profile_conversion_time.py')
26 | proc = subprocess.Popen(
27 | ["python", filename], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
28 | try:
29 | outs = proc.communicate(timeout=15)[0]
30 | except subprocess.TimeoutExpired:
31 | proc.kill()
32 | return
33 | assert b"Profile complete." in outs or outs == b''
34 |
35 |
36 | if __name__ == '__main__':
37 | unittest_main()
38 |
--------------------------------------------------------------------------------
/tests/test_tfjs_runner.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """Unit Tests for tfjs_runer.py"""
5 |
6 | import json
7 | import unittest
8 |
9 | import numpy as np
10 |
11 | from tfjs_runner import numpy_to_json, json_to_numpy
12 |
13 | # pylint: disable=missing-docstring
14 |
15 |
16 | class TestTfjsRunner(unittest.TestCase):
17 | def test_tfjs_runner(self):
18 | float_array = np.array([[1.1, 2.2], [3.3, 4.4]], np.float32)
19 | int_array = np.array([[1, 2], [3, 4]], np.int32)
20 | bool_array = np.array([[True, False], [True, True]], bool)
21 | string_array = np.array([['Hello world', ''], ['π', 'Tensor']], str)
22 | complex_array = np.array([[1 + 0.1j, 2 + 0.2j], [3 + 0.3j, 4 + 0.4j]], np.complex64)
23 |
24 | arrays = [float_array, int_array, bool_array, string_array, complex_array]
25 | for arr in arrays:
26 | array_enc = json.dumps(numpy_to_json(arr))
27 | print(array_enc)
28 | array_dec = json_to_numpy(json.loads(array_enc))
29 | np.testing.assert_equal(arr, array_dec)
30 |
31 |
32 | if __name__ == '__main__':
33 | unittest.main()
34 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_albert_en_xlarge.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | import numpy.random as rnd
5 | from collections import OrderedDict
6 | from _tools import generate_text_inputs, benchmark
7 |
8 |
9 | def main(opset=13):
10 | url = "https://tfhub.dev/tensorflow/albert_en_xlarge/3?tf-hub-format=compressed"
11 | dest = "tf-albert-en-xlarge"
12 | name = "albert-en-xlarge"
13 | onnx_name = os.path.join(dest, "%s-%d.zip" % (name, opset))
14 |
15 | inputs = generate_text_inputs()
16 | benchmark(url, dest, onnx_name, opset, inputs, output_name="pooled_output")
17 |
18 | inputs = [OrderedDict([
19 | ('input_word_ids', numpy.array([rnd.randint(0, 1000) for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))),
20 | ('input_mask', numpy.array([rnd.randint(0, 1) for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))),
21 | ('input_type_ids', numpy.array([i//5 for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1)))
22 | ]) for i in range(0, 10)]
23 |
24 | benchmark(url, dest, onnx_name, opset, inputs, output_name="pooled_output")
25 |
26 |
27 | if __name__ == "__main__":
28 | main()
29 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_bert_en_wwm_uncased.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | from collections import OrderedDict
4 | import numpy
5 | import numpy.random as rnd
6 | from _tools import generate_random_images, benchmark
7 |
8 |
9 | def main(opset=13):
10 | url = "https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/4?tf-hub-format=compressed"
11 | dest = "tf-bert-en-wwm-uncased-L-24-H-1024-A-16"
12 | name = "bert-en-wwm-uncased-L-24-H-1024-A-16"
13 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
14 |
15 | inputs = [OrderedDict([
16 | ('input_word_ids', numpy.array([rnd.randint(0, 1000) for i in range(0, 32)], dtype=numpy.int32).reshape((1, -1))),
17 | ('input_mask', numpy.array([rnd.randint(0, 1) for i in range(0, 32)], dtype=numpy.int32).reshape((1, -1))),
18 | ('input_type_ids', numpy.array([i//5 for i in range(0, 32)], dtype=numpy.int32).reshape((1, -1)))
19 | ]) for i in range(0, 10)]
20 |
21 | benchmark(url, dest, onnx_name, opset, inputs, output_name="pooled_output")
22 |
23 |
24 | if __name__ == "__main__":
25 | main()
26 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_blazeposedetector.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/mediapipe/tfjs-model/blazeposedetector/1/default/1?tfjs-format=compressed"
9 | dest = "tf-blazeposedetector"
10 | name = "blazeposedetector"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 |
13 | imgs = generate_random_images(shape=(1, 513, 513, 3), scale=1.)
14 |
15 | benchmark(url, dest, onnx_name, opset, imgs)
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
20 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_esrgan.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/captain-pool/esrgan-tf2/1?tf-hub-format=compressed"
9 | dest = "tf-esrgan-tf2"
10 | name = "esrgan-tf2"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 |
13 | imgs = generate_random_images()
14 |
15 | benchmark(url, dest, onnx_name, opset, imgs)
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
20 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_inception_v3.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/google/inaturalist/inception_v3/feature_vector/5?tf-hub-format=compressed"
9 | dest = "tf-inception_v3"
10 | name = "inception_v3"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 |
13 | imgs = generate_random_images(shape=(1, 299, 299, 3))
14 |
15 | benchmark(url, dest, onnx_name, opset, imgs)
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
20 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_lambert_en_uncased_L-24_H-1024_A-16.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | from collections import OrderedDict
4 | import numpy
5 | import numpy.random as rnd
6 | from _tools import generate_text_inputs, benchmark
7 |
8 |
9 | def main(opset=13):
10 |
11 | if False:
12 | import tensorflow as tf
13 | import tensorflow_text
14 | import tensorflow_hub as hub
15 | sentences = tf.constant(["Hi I'm some text"])
16 | text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)
17 | encoder = hub.KerasLayer(
18 | "https://tfhub.dev/tensorflow/lambert_en_uncased_L-24_H-1024_A-16/2", trainable=True)
19 | preprocessor = hub.KerasLayer(
20 | "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
21 | encoder_inputs = preprocessor(text_input)
22 | embedded_inputs = {k: v.numpy() for k, v in preprocessor(sentences).items()}
23 | for k, v in embedded_inputs.items():
24 | print(k, v.dtype, v.shape)
25 |
26 | url = "https://tfhub.dev/tensorflow/lambert_en_uncased_L-24_H-1024_A-16/2?tf-hub-format=compressed"
27 | dest = "tf-lambert_en_uncased_L-24_H-1024_A-16"
28 | name = "lambert_en_uncased_L-24_H-1024_A-16"
29 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
30 |
31 | inputs = [OrderedDict([
32 | ('input_word_ids', numpy.array([rnd.randint(0, 1000) for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))),
33 | ('input_mask', numpy.array([rnd.randint(0, 1) for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))),
34 | ('input_type_ids', numpy.array([i//5 for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1)))
35 | ]) for i in range(0, 10)]
36 |
37 | benchmark(url, dest, onnx_name, opset, inputs, output_name="pooled_output")
38 |
39 |
40 | if __name__ == "__main__":
41 | main()
42 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_mobile_food_segmenter_V1.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark
5 | import tf2onnx
6 | import onnxruntime as ort
7 |
8 |
9 | def main(opset=13):
10 | url = "https://tfhub.dev/google/seefood/segmenter/mobile_food_segmenter_V1/1?tf-hub-format=compressed"
11 | dest = "tf-mobile_food_segmenter_V1"
12 | name = "mobile_food_segmenter_V1"
13 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
14 |
15 | imgs = generate_random_images(shape=(1, 513, 513, 3), scale=1.)
16 |
17 | if True:
18 | benchmark(url, dest, onnx_name, opset, imgs, tag='')
19 | # The conversion works but tensorflow fails with
20 | # TypeError: 'AutoTrackable' object is not callable
21 |
22 | if True:
23 | import tensorflow.compat.v2 as tf
24 | import tensorflow_hub as hub
25 |
26 | m = hub.KerasLayer('https://tfhub.dev/google/seefood/segmenter/mobile_food_segmenter_V1/1')
27 | inputs = {
28 | "X": tf.keras.Input(shape=[1, 513, 513, 3], dtype="float32", name="X"),
29 | }
30 | outputs = m(inputs)["default"]
31 | # TypeError: pruned(images) missing required arguments: images
32 | print(outputs)
33 | model = tf.keras.Model(inputs, outputs)
34 |
35 | if not os.path.exists(dest):
36 | os.makedirs(dest)
37 |
38 | # This model is a large model.
39 | tf2onnx.convert.from_keras(model, opset=13, output_path=onnx_name)
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_mobilebert_en_uncased.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | from collections import OrderedDict
4 | import numpy
5 | import numpy.random as rnd
6 | from _tools import generate_text_inputs, benchmark
7 |
8 |
9 | def main(opset=13):
10 |
11 | if False:
12 | import tensorflow as tf
13 | import tensorflow_text
14 | import tensorflow_hub as hub
15 | sentences = tf.constant(["Hi I'm some text"])
16 | text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)
17 | preprocessor = hub.KerasLayer(
18 | "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
19 | encoder_inputs = preprocessor(text_input)
20 | embedded_inputs = {k: v.numpy() for k, v in preprocessor(sentences).items()}
21 | for k, v in embedded_inputs.items():
22 | print(k, v.dtype, v.shape)
23 |
24 | url = "https://tfhub.dev/tensorflow/mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT/1?tf-hub-format=compressed"
25 | dest = "tf-mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT"
26 | name = "mobilebert_en_uncased_L-24_H-128_B-512_A-4_F-4_OPT"
27 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
28 |
29 | inputs = generate_text_inputs()
30 | benchmark(url, dest, onnx_name, opset, inputs,
31 | output_name="attention_scores") #, ort_name="mobile_bert_encoder_50")
32 |
33 |
34 | if __name__ == "__main__":
35 | main()
36 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_mobilenet_v3_large_075_224.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/google/imagenet/mobilenet_v3_large_075_224/classification/5?tf-hub-format=compressed"
9 | dest = "tf-mobilenet-v3-large-075-224"
10 | name = "mobilenet-v3-large-075-224"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 |
13 | imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.)
14 |
15 | benchmark(url, dest, onnx_name, opset, imgs)
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
20 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_mobilenet_v3_small_075_224.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/google/imagenet/mobilenet_v3_small_075_224/feature_vector/5?tf-hub-format=compressed"
9 | dest = "tf-mobilenet-v3-small-075-224"
10 | name = "mobilenet-v3-small-075-224"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 |
13 | imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.)
14 |
15 | benchmark(url, dest, onnx_name, opset, imgs)
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
20 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_nasnet_large.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/google/imagenet/nasnet_large/feature_vector/5?tf-hub-format=compressed"
9 | dest = "tf-nasnet-large"
10 | name = "nasnet-large"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 |
13 | imgs = generate_random_images(shape=(1, 331, 331, 3))
14 |
15 | benchmark(url, dest, onnx_name, opset, imgs)
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
20 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_resnet_v1_101.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/google/imagenet/resnet_v1_101/feature_vector/5?tf-hub-format=compressed"
9 | dest = "tf-resnet_v1_101"
10 | name = "resnet_v1_101"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 |
13 | imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.)
14 |
15 | benchmark(url, dest, onnx_name, opset, imgs)
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
20 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_resnet_v1_101_keras.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | import onnxruntime as ort
5 | import tensorflow as tf
6 | import tensorflow_hub as hub
7 | import tf2onnx
8 | from _tools import generate_random_images, check_discrepencies
9 |
10 | imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.)
11 |
12 | model = tf.keras.Sequential([
13 | hub.KerasLayer("https://tfhub.dev/google/imagenet/resnet_v1_101/feature_vector/5",
14 | trainable=False)])
15 | model.build([None, 224, 224, 3])
16 |
17 | expected_output = model(imgs[0])
18 |
19 | dest = "tf-resnet_v1_101"
20 | if not os.path.exists(dest):
21 | os.makedirs(dest)
22 | dest_name = os.path.join(dest, "resnet_v1_101-13-keras.onnx")
23 | if not os.path.exists(dest_name):
24 | tf2onnx.convert.from_keras(model, opset=13, output_path=dest_name)
25 |
26 | sess = ort.InferenceSession(dest_name)
27 | print('inputs', [_.name for _ in sess.get_inputs()])
28 | ort_output = sess.run(None, {"keras_layer_input": imgs[0]})
29 |
30 | print("Actual")
31 | print(ort_output)
32 | print("Expected")
33 | print(expected_output)
34 |
35 | diff = expected_output.numpy() - ort_output[0]
36 | max_diff = numpy.abs(diff).max()
37 | rel_diff = (numpy.abs(diff) / (expected_output.numpy() + 1e-5)).max()
38 | print(max_diff, rel_diff, [ort_output[0].min(), ort_output[0].max()])
39 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_resnet_v2_101.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/google/imagenet/resnet_v2_101/feature_vector/5?tf-hub-format=compressed"
9 | dest = "tf-resnet_v2_101"
10 | name = "resnet_v2_101"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 |
13 | imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.)
14 |
15 | benchmark(url, dest, onnx_name, opset, imgs)
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
20 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_resnet_v2_101_classification.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/google/imagenet/resnet_v1_101/classification/5?tf-hub-format=compressed"
9 | dest = "tf-resnet_v2_101_classification"
10 | name = "resnet_v2_101_classification"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 |
13 | imgs = generate_random_images(shape=(1, 224, 224, 3), scale=1.)
14 |
15 | benchmark(url, dest, onnx_name, opset, imgs)
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
20 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_resnet_v2_152.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/google/imagenet/resnet_v2_152/classification/5?tf-hub-format=compressed"
9 | dest = "tf-resnet_v2_152"
10 | name = "resnet_v2_152"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 |
13 | imgs = generate_random_images(shape=(1, 224, 224, 3))
14 |
15 | benchmark(url, dest, onnx_name, opset, imgs)
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
20 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_spam_detection.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import random
4 | import numpy
5 | from _tools import generate_random_images, benchmark
6 |
7 |
8 | def main(opset=13):
9 | url = "https://tfhub.dev/tensorflow/tutorials/spam-detection/1?tf-hub-format=compressed"
10 | dest = "tf-spam-detection"
11 | name = "spam-detection"
12 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
13 |
14 | imgs = generate_random_images((1, 20), dtype=numpy.int32)
15 |
16 | benchmark(url, dest, onnx_name, opset, imgs)
17 |
18 |
19 | if __name__ == "__main__":
20 | main()
21 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_talkheads_ggelu_bert_en_large.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | from collections import OrderedDict
4 | import numpy
5 | import numpy.random as rnd
6 | from _tools import generate_random_images, benchmark
7 |
8 |
9 | def main(opset=13):
10 | url = "https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_large/2?tf-hub-format=compressed"
11 | dest = "tf-talkheads_ggelu_bert_en_large"
12 | name = "talkheads_ggelu_bert_en_large"
13 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
14 |
15 | inputs = [OrderedDict([
16 | ('input_word_ids', numpy.array([rnd.randint(0, 1000) for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))),
17 | ('input_mask', numpy.array([rnd.randint(0, 1) for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1))),
18 | ('input_type_ids', numpy.array([i//5 for i in range(0, 128)], dtype=numpy.int32).reshape((1, -1)))
19 | ]) for i in range(0, 10)]
20 |
21 | benchmark(url, dest, onnx_name, opset, inputs, output_name="pooled_output")
22 |
23 |
24 | if __name__ == "__main__":
25 | main()
26 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_thunder.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/google/movenet/singlepose/thunder/3?tf-hub-format=compressed"
9 | dest = "tf-thunder"
10 | name = "thunder"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 |
13 | imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32)
14 |
15 | benchmark(url, dest, onnx_name, opset, imgs,
16 | signature='serving_default')
17 |
18 |
19 | if __name__ == "__main__":
20 | main()
21 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_yamnet_coral.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark_tflite
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/google/coral-model/yamnet/classification/coral/1?coral-format=tflite"
9 | dest = "tf-yamnet-coral"
10 | name = "yamnet"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 |
13 | imgs = generate_random_images(shape=(1, 256, 256, 3), dtype=numpy.int32)
14 |
15 | benchmark_tflite(url, dest, onnx_name, opset, imgs)
16 | # WARNING - Error loading model into tflite interpreter: Encountered unresolved custom op: edgetpu-custom-op.Node number 14 (edgetpu-custom-op) failed to prepare.
17 | # WARNING - Could not parse attributes for custom op 'TFL_edgetpu-custom-op': 'utf-8' codec can't decode byte 0xc8 in position 0: invalid continuation byte
18 | # WARNING - For now, onnxruntime only support float32 type for Gemm rewriter
19 | # ERROR - Tensorflow op [tower0/network/layer32/final_output1_prequant: TFL_edgetpu-custom-op] is not supported
20 | # ERROR - Unsupported ops: Counter({'TFL_edgetpu-custom-op': 1})
21 |
22 | if __name__ == "__main__":
23 | main()
24 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_yamnet_tf.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark, benchmark_tflite
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/google/yamnet/1?tf-hub-format=compressed"
9 | dest = "tf-yamnet-tf"
10 | name = "yamnet"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 | tfl = os.path.join(dest, 'model.tflite')
13 |
14 | imgs = generate_random_images(shape=(16000, ), dtype=numpy.float32, scale=0.)
15 |
16 | # benchmark(url, dest, onnx_name, opset, imgs, convert_tflite=tfl)
17 |
18 | onnx_name = os.path.join(dest, "%s-tfl-%d.onnx" % (name, opset))
19 | benchmark_tflite(tfl, dest, onnx_name, opset, imgs)
20 |
21 |
22 | if __name__ == "__main__":
23 | main()
24 |
--------------------------------------------------------------------------------
/tests/tfhub/tfhub_yamnet_tflite.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | import os
3 | import numpy
4 | from _tools import generate_random_images, benchmark_tflite
5 |
6 |
7 | def main(opset=13):
8 | url = "https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1?lite-format=tflite"
9 | dest = "tf-yamnet-tflite"
10 | name = "yamnet"
11 | onnx_name = os.path.join(dest, "%s-%d.onnx" % (name, opset))
12 |
13 | imgs = generate_random_images(shape=(15600, ), dtype=numpy.float32, scale=0.)
14 |
15 | benchmark_tflite(url, dest, onnx_name, opset, imgs, names=[
16 | ('stft/rfft3', 'FFT_stft/rfft4_reshape__190:0'),
17 | ('magnitude_spectrogram', 'ComplexAbsmagnitude_spectrogram__206:0'),
18 | ('log_mel_spectrogram', 'log_mel_spectrogram'),
19 | ])
20 |
21 |
22 | if __name__ == "__main__":
23 | main()
24 |
--------------------------------------------------------------------------------
/tests/utils/setup_test_env.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # # Check if the argument is provided
4 | if [ "$#" -ne 3 ]; then
5 | echo "Usage: $0 "
6 | exit 1
7 | fi
8 |
9 | # Assign the argument to a variable
10 | TF_VERSION=$1
11 | ORT_VERSION=$2
12 | ONNX_VERSION=$3
13 |
14 | echo "==== TensorFlow version: $TF_VERSION"
15 | echo "==== ONNXRuntime version: $ORT_VERSION"
16 | echo "==== ONNX version: $ONNX_VERSION"
17 |
18 | pip install pytest pytest-cov pytest-runner coverage graphviz requests pyyaml pillow pandas parameterized sympy coloredlogs flatbuffers timeout-decorator
19 | pip install onnx==$ONNX_VERSION
20 | pip install onnxruntime==$ORT_VERSION
21 | pip install "numpy<2"
22 |
23 | pip install onnxruntime-extensions
24 | pip install "tensorflow-text<=$TF_VERSION"
25 |
26 | pip uninstall -y tensorflow
27 | pip install tensorflow==$TF_VERSION
28 | pip uninstall -y protobuf
29 | pip install "protobuf~=3.20"
30 |
31 | python setup.py install
32 |
33 | echo "----- List all of depdencies:"
34 | pip freeze --all
--------------------------------------------------------------------------------
/tf2onnx/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """tf2onnx package."""
4 |
5 | __all__ = ["utils", "graph_matcher", "graph", "graph_builder",
6 | "tfonnx", "shape_inference", "schemas", "tf_utils", "tf_loader", "convert"]
7 |
8 | import onnx
9 | from .version import git_version, version as __version__
10 | from . import verbose_logging as logging
11 | from tf2onnx import tfonnx, utils, graph, graph_builder, graph_matcher, shape_inference, schemas, convert # pylint: disable=wrong-import-order
12 |
--------------------------------------------------------------------------------
/tf2onnx/custom_opsets/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """ custom tf2onnx mapping functions. """
4 |
5 | from . import ms
6 | from . import onnx_ml
7 | from . import string_ops
8 |
--------------------------------------------------------------------------------
/tf2onnx/late_rewriters/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """tf2onnx.late_rewriters module."""
4 |
5 | from tf2onnx.late_rewriters.channel_order_rewriters import rewrite_channels_first, rewrite_channels_last
6 |
7 |
8 | __all__ = [
9 | "rewrite_channels_first",
10 | "rewrite_channels_last",
11 | ]
12 |
--------------------------------------------------------------------------------
/tf2onnx/onnx_opset/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """tf2onnx.onnx_opset module"""
4 |
5 | from . import (
6 | common,
7 | controlflow,
8 | generator,
9 | logical,
10 | math,
11 | misc,
12 | nn,
13 | quantize,
14 | reduction,
15 | rnn,
16 | signal,
17 | tensor,
18 | traditionalml
19 | )
20 |
--------------------------------------------------------------------------------
/tf2onnx/onnx_opset/misc.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | misc
6 | """
7 |
8 | import logging
9 |
10 | from tf2onnx.handler import tf_op
11 |
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | # pylint: disable=unused-argument,missing-docstring
16 |
17 | @tf_op(["CheckNumerics", "StopGradient"])
18 | class MoveToIdent:
19 | @classmethod
20 | def version_1(cls, ctx, node, **kwargs):
21 | node.type = "Identity"
22 | if node.inputs[0].is_const():
23 | # should not remove the identity node if it is output of the graph
24 | if node.output[0] in ctx.outputs:
25 | return
26 | # if identity has a const as input, remove it
27 | input_name = node.input[0]
28 | output_name = node.output[0]
29 | ctx.replace_all_inputs(output_name, input_name) # ops=ctx.get_nodes()
30 | ctx.remove_node(node.name)
31 |
32 |
33 | @tf_op(["Placeholder", "PlaceholderV2", "PlaceholderWithDefault"])
34 | class DirectOp:
35 | @classmethod
36 | def version_1(cls, ctx, node, **kwargs):
37 | pass
38 |
39 |
40 | @tf_op("NoOp")
41 | class NukeNode:
42 | @classmethod
43 | def version_1(cls, ctx, node, **kwargs):
44 | ctx.remove_node(node.name)
45 |
--------------------------------------------------------------------------------
/tf2onnx/onnx_opset/traditionalml.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | traditional ml
6 | """
7 |
8 | import logging
9 |
10 | logger = logging.getLogger(__name__)
11 |
--------------------------------------------------------------------------------
/tf2onnx/rewriter/fused_op_rewriter.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | tf2onnx.rewriter.fused_op_rewriter - rewrite tensorflow _Fused ops from grappler into other tf ops
6 | """
7 |
8 |
9 | # pylint: disable=missing-docstring
10 |
11 |
12 | def rewrite_fused_ops(g, ops):
13 | for node in ops:
14 | if node.type in ["_FusedConv2D", "_FusedMatMul", "_FusedDepthwiseConv2dNative"]:
15 | op_types = [op.decode() for op in node.get_attr_value("fused_ops")]
16 | extra_inputs = node.input[2:]
17 | g.replace_inputs(node, node.input[:2])
18 | last_output = node.output[0]
19 | node.type = node.type.replace("_Fused", "")
20 | dtype = g.get_dtype(node.output[0])
21 | shape = g.get_shape(node.output[0])
22 | first_node = None
23 | for op in op_types:
24 | num_inputs = {"BiasAdd": 2, "FusedBatchNorm": 5}.get(op, 1 + len(extra_inputs))
25 | my_inputs = [last_output] + extra_inputs[:num_inputs - 1]
26 | new_node = g.make_node(op, my_inputs, skip_conversion=False,
27 | op_name_scope=node.name, dtypes=[dtype], shapes=[shape])
28 | last_output = new_node.output[0]
29 | extra_inputs = extra_inputs[num_inputs - 1:]
30 | if first_node is None:
31 | first_node = new_node
32 |
33 | consumers = [n for n in g.find_output_consumers(node.output[0]) if n != first_node]
34 | g.replace_all_inputs(node.output[0], last_output, consumers)
35 |
36 | return g.get_nodes()
37 |
--------------------------------------------------------------------------------
/tf2onnx/rewriter/leakyrelu_rewriter.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | tf2onnx.rewriter - rewrite tensorflow subgraph to onnx leakyrelu op
6 | """
7 |
8 | from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9 |
10 |
11 | # pylint: disable=missing-docstring
12 |
13 |
14 | def rewrite_leakyrelu(g, ops):
15 | if g.opset < 6:
16 | return ops
17 |
18 | pattern = \
19 | OpTypePattern('Maximum', name='max', inputs=[
20 | OpTypePattern('Mul', name='mul', inputs=[
21 | OpTypePattern('Const', name='alpha'),
22 | OpTypePattern('*', name='mul_input'),
23 | ]),
24 | OpTypePattern('*', name='max_input'),
25 | ])
26 |
27 | matcher = GraphMatcher(pattern, allow_reorder=True)
28 | match_results = list(matcher.match_ops(ops))
29 | for match in match_results:
30 | max_node = match.get_op('max')
31 | mul_node = match.get_op("mul")
32 |
33 | max_input_edge_name = match.get_tensor('max_input')
34 | mul_input_edge_name = match.get_tensor('mul_input')
35 | if max_input_edge_name == mul_input_edge_name:
36 | alpha = match.get_op("alpha").get_tensor_value()
37 | if alpha >= 1:
38 | continue
39 | leakyrelu = g.make_node("LeakyRelu", inputs=[max_input_edge_name], attr={"alpha": alpha},
40 | shapes=[g.get_shape(max_node.output[0])], dtypes=[g.get_dtype(max_node.output[0])])
41 | ops.append(leakyrelu)
42 | g.replace_all_inputs(max_node.output[0], leakyrelu.output[0], ops=ops)
43 | to_delete = [max_node, mul_node]
44 | g.safe_remove_nodes(to_delete)
45 |
46 | return ops
47 |
--------------------------------------------------------------------------------
/tf2onnx/rewriter/ragged_variant_shape_rewriter.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | tf2onnx.rewriter - RaggedTensorToVariant -> Shape pattern
6 | """
7 |
8 | import numpy as np
9 | from tf2onnx import utils
10 | from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
11 |
12 |
13 | # pylint: disable=missing-docstring
14 |
15 |
16 | def rewrite_ragged_variant_shape(g, ops):
17 | pattern1 = \
18 | OpTypePattern('Shape', name='shape', inputs=[
19 | OpTypePattern('RaggedTensorToVariant', name='raggedtovariant')
20 | ])
21 |
22 | pattern_list = [pattern1]
23 | for pattern in pattern_list:
24 | matcher = GraphMatcher(pattern)
25 | match_results = list(matcher.match_ops(ops))
26 | for match in match_results:
27 | shape = match.get_op('shape')
28 | raggedtovariant = match.get_op('raggedtovariant')
29 | if raggedtovariant.get_attr_value("batched_input") != 1:
30 | continue
31 | if raggedtovariant.get_attr_value("RAGGED_RANK") != 1:
32 | continue
33 | # Shape of batched variant from ragged is same as number of splits minus 1
34 | g.replace_inputs(shape, [raggedtovariant.input[0]])
35 | np_dtype = utils.map_onnx_to_numpy_type(g.get_dtype(shape.output[0]))
36 | const_one = g.make_const(utils.make_name("const_one"), np.array(1, np_dtype)).output[0]
37 | g.insert_new_node_on_output("Sub", shape.output[0], inputs=[shape.output[0], const_one])
38 |
39 | return ops
40 |
--------------------------------------------------------------------------------
/tf2onnx/rewriter/rnn.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | tf2onnx.rewriter.rnn - lstm support
6 | """
7 |
8 | import logging
9 |
10 | from tf2onnx.rewriter.bilstm_rewriter import rewrite_bidirectional_lstms
11 | from tf2onnx.rewriter.bigru_rewriter import rewrite_bidirectional_grus
12 | from tf2onnx.rewriter.custom_rnn_rewriter import CustomRnnRewriter
13 | from tf2onnx.rewriter.loop_rewriter import LoopRewriter
14 | from tf2onnx.rewriter.lstm_rewriter import LSTMRewriter
15 | from tf2onnx.rewriter.gru_rewriter import GRUUnitRewriter
16 |
17 | # pylint: disable=invalid-name,unused-argument,missing-docstring
18 |
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | def rewrite_single_direction_lstm(g, ops):
24 | r = LSTMRewriter(g)
25 | return r.run()
26 |
27 |
28 | def rewrite_bi_direction_lstm(g, ops):
29 | return rewrite_bidirectional_lstms(g, ops)
30 |
31 |
32 | def rewrite_single_direction_gru(g, ops):
33 | r = GRUUnitRewriter(g)
34 | return r.run()
35 |
36 |
37 | def rewrite_bi_direction_gru(g, ops):
38 | return rewrite_bidirectional_grus(g, ops)
39 |
40 |
41 | def rewrite_custom_rnn_cell(g, ops):
42 | return CustomRnnRewriter(g).run()
43 |
44 |
45 | def rewrite_generic_loop(g, ops):
46 | return LoopRewriter(g).run()
47 |
--------------------------------------------------------------------------------
/tf2onnx/rewriter/thresholded_relu_rewriter.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | tf2onnx.rewriter - rewrite tensorflow subgraph to onnx ThresholdedRelu op
6 | """
7 |
8 | from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9 |
10 |
11 | # pylint: disable=missing-docstring
12 |
13 |
14 | def rewrite_thresholded_relu(g, ops):
15 | if g.opset < 10:
16 | return ops
17 |
18 | pattern = \
19 | OpTypePattern('Mul', name='mul', inputs=[
20 | OpTypePattern('Cast', name='cast', inputs=[
21 | OpTypePattern('Greater', name='greater', inputs=[
22 | OpTypePattern('*', name='greater_input'),
23 | OpTypePattern('Const', name='theta')
24 | ])
25 | ]),
26 | OpTypePattern('*', name='mul_input')
27 | ])
28 | matcher = GraphMatcher(pattern, allow_reorder=True)
29 | match_results = list(matcher.match_ops(ops))
30 |
31 | for match in match_results:
32 | mul_node = match.get_op("mul")
33 | cast_node = match.get_op('cast')
34 |
35 | greater_input_edge_name = match.get_tensor('greater_input')
36 | mul_input_edge_name = match.get_tensor('mul_input')
37 | if greater_input_edge_name == mul_input_edge_name:
38 | theta = match.get_op('theta').get_tensor_value()
39 | thresholded_relu = g.make_node("ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta},
40 | shapes=[g.get_shape(mul_node.output[0])],
41 | dtypes=[g.get_dtype(mul_node.output[0])])
42 | g.replace_all_inputs(mul_node.output[0], thresholded_relu.output[0], ops=ops)
43 | to_delete = [cast_node, mul_node]
44 | g.safe_remove_nodes(to_delete)
45 | return ops
46 |
--------------------------------------------------------------------------------
/tf2onnx/rewriter/transpose_rewriter.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | tf2onnx.rewriter - rewrite tensorflow transpose op
6 | """
7 |
8 | from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9 |
10 |
11 | # pylint: disable=missing-docstring
12 |
13 |
14 | def rewrite_transpose(g, ops):
15 | pattern = \
16 | OpTypePattern('Transpose', name='output', inputs=[
17 | OpTypePattern(None),
18 | OpTypePattern('Sub', inputs=[
19 | OpTypePattern('Sub', inputs=["*", "*"]),
20 | OpTypePattern('Range', inputs=["*", "*", "*"]),
21 | ]),
22 | ])
23 |
24 | matcher = GraphMatcher(pattern)
25 | match_results = list(matcher.match_ops(ops))
26 | for match in match_results:
27 | output = match.get_op('output')
28 | shape = g.get_shape(output.input[0])
29 | dims = range(len(shape) - 1, -1, -1)
30 | output.set_attr("perm", dims)
31 | g.remove_input(output, output.input[1], 1)
32 | to_delete = [n for n in match.get_nodes() if n != output]
33 | g.safe_remove_nodes(to_delete)
34 | return ops
35 |
--------------------------------------------------------------------------------
/tf2onnx/tflite/ATan2Options.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class ATan2Options(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = ATan2Options()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsATan2Options(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def ATan2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # ATan2Options
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def ATan2OptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def ATan2OptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/AbsOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class AbsOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = AbsOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsAbsOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def AbsOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # AbsOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def AbsOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def AbsOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/ActivationFunctionType.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | class ActivationFunctionType(object):
8 | NONE = 0
9 | RELU = 1
10 | RELU_N1_TO_1 = 2
11 | RELU6 = 3
12 | TANH = 4
13 | SIGN_BIT = 5
14 |
15 |
--------------------------------------------------------------------------------
/tf2onnx/tflite/AddNOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class AddNOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = AddNOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsAddNOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def AddNOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # AddNOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def AddNOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def AddNOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/ArgMaxOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class ArgMaxOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = ArgMaxOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsArgMaxOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def ArgMaxOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # ArgMaxOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | # ArgMaxOptions
34 | def OutputType(self):
35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
36 | if o != 0:
37 | return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
38 | return 0
39 |
40 | def Start(builder): builder.StartObject(1)
41 | def ArgMaxOptionsStart(builder):
42 | """This method is deprecated. Please switch to Start."""
43 | return Start(builder)
44 | def AddOutputType(builder, outputType): builder.PrependInt8Slot(0, outputType, 0)
45 | def ArgMaxOptionsAddOutputType(builder, outputType):
46 | """This method is deprecated. Please switch to AddOutputType."""
47 | return AddOutputType(builder, outputType)
48 | def End(builder): return builder.EndObject()
49 | def ArgMaxOptionsEnd(builder):
50 | """This method is deprecated. Please switch to End."""
51 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/ArgMinOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class ArgMinOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = ArgMinOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsArgMinOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def ArgMinOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # ArgMinOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | # ArgMinOptions
34 | def OutputType(self):
35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
36 | if o != 0:
37 | return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
38 | return 0
39 |
40 | def Start(builder): builder.StartObject(1)
41 | def ArgMinOptionsStart(builder):
42 | """This method is deprecated. Please switch to Start."""
43 | return Start(builder)
44 | def AddOutputType(builder, outputType): builder.PrependInt8Slot(0, outputType, 0)
45 | def ArgMinOptionsAddOutputType(builder, outputType):
46 | """This method is deprecated. Please switch to AddOutputType."""
47 | return AddOutputType(builder, outputType)
48 | def End(builder): return builder.EndObject()
49 | def ArgMinOptionsEnd(builder):
50 | """This method is deprecated. Please switch to End."""
51 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/AssignVariableOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class AssignVariableOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = AssignVariableOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsAssignVariableOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def AssignVariableOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # AssignVariableOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def AssignVariableOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def AssignVariableOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/BatchToSpaceNDOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class BatchToSpaceNDOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = BatchToSpaceNDOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsBatchToSpaceNDOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def BatchToSpaceNDOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # BatchToSpaceNDOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def BatchToSpaceNDOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def BatchToSpaceNDOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/BitcastOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class BitcastOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = BitcastOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsBitcastOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def BitcastOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # BitcastOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def BitcastOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def BitcastOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/BitwiseXorOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class BitwiseXorOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = BitwiseXorOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsBitwiseXorOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def BitwiseXorOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # BitwiseXorOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def BitwiseXorOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def BitwiseXorOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/BroadcastToOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class BroadcastToOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = BroadcastToOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsBroadcastToOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def BroadcastToOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # BroadcastToOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def BroadcastToOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def BroadcastToOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/CallOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class CallOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = CallOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsCallOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def CallOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # CallOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | # CallOptions
34 | def Subgraph(self):
35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
36 | if o != 0:
37 | return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
38 | return 0
39 |
40 | def Start(builder): builder.StartObject(1)
41 | def CallOptionsStart(builder):
42 | """This method is deprecated. Please switch to Start."""
43 | return Start(builder)
44 | def AddSubgraph(builder, subgraph): builder.PrependUint32Slot(0, subgraph, 0)
45 | def CallOptionsAddSubgraph(builder, subgraph):
46 | """This method is deprecated. Please switch to AddSubgraph."""
47 | return AddSubgraph(builder, subgraph)
48 | def End(builder): return builder.EndObject()
49 | def CallOptionsEnd(builder):
50 | """This method is deprecated. Please switch to End."""
51 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/CombinerType.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | class CombinerType(object):
8 | SUM = 0
9 | MEAN = 1
10 | SQRTN = 2
11 |
12 |
--------------------------------------------------------------------------------
/tf2onnx/tflite/CosOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class CosOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = CosOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsCosOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def CosOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # CosOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def CosOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def CosOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/CustomOptionsFormat.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | class CustomOptionsFormat(object):
8 | FLEXBUFFERS = 0
9 |
10 |
--------------------------------------------------------------------------------
/tf2onnx/tflite/DensifyOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class DensifyOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = DensifyOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsDensifyOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def DensifyOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # DensifyOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def DensifyOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def DensifyOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/DequantizeOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class DequantizeOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = DequantizeOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsDequantizeOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def DequantizeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # DequantizeOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def DequantizeOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def DequantizeOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/DimensionType.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | class DimensionType(object):
8 | DENSE = 0
9 | SPARSE_CSR = 1
10 |
11 |
--------------------------------------------------------------------------------
/tf2onnx/tflite/DynamicUpdateSliceOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class DynamicUpdateSliceOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = DynamicUpdateSliceOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsDynamicUpdateSliceOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def DynamicUpdateSliceOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # DynamicUpdateSliceOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def DynamicUpdateSliceOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def DynamicUpdateSliceOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/EqualOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class EqualOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = EqualOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsEqualOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def EqualOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # EqualOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def EqualOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def EqualOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/ExpOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class ExpOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = ExpOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsExpOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def ExpOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # ExpOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def ExpOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def ExpOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/ExpandDimsOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class ExpandDimsOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = ExpandDimsOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsExpandDimsOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def ExpandDimsOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # ExpandDimsOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def ExpandDimsOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def ExpandDimsOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/FillOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class FillOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = FillOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsFillOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def FillOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # FillOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def FillOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def FillOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/FloorDivOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class FloorDivOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = FloorDivOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsFloorDivOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def FloorDivOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # FloorDivOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def FloorDivOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def FloorDivOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/FloorModOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class FloorModOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = FloorModOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsFloorModOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def FloorModOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # FloorModOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def FloorModOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def FloorModOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/FullyConnectedOptionsWeightsFormat.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | class FullyConnectedOptionsWeightsFormat(object):
8 | DEFAULT = 0
9 | SHUFFLED4x16INT8 = 1
10 |
11 |
--------------------------------------------------------------------------------
/tf2onnx/tflite/GatherNdOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class GatherNdOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = GatherNdOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsGatherNdOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def GatherNdOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # GatherNdOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def GatherNdOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def GatherNdOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/GeluOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class GeluOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = GeluOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsGeluOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def GeluOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # GeluOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | # GeluOptions
34 | def Approximate(self):
35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
36 | if o != 0:
37 | return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
38 | return False
39 |
40 | def Start(builder): builder.StartObject(1)
41 | def GeluOptionsStart(builder):
42 | """This method is deprecated. Please switch to Start."""
43 | return Start(builder)
44 | def AddApproximate(builder, approximate): builder.PrependBoolSlot(0, approximate, 0)
45 | def GeluOptionsAddApproximate(builder, approximate):
46 | """This method is deprecated. Please switch to AddApproximate."""
47 | return AddApproximate(builder, approximate)
48 | def End(builder): return builder.EndObject()
49 | def GeluOptionsEnd(builder):
50 | """This method is deprecated. Please switch to End."""
51 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/GreaterEqualOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class GreaterEqualOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = GreaterEqualOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsGreaterEqualOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def GreaterEqualOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # GreaterEqualOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def GreaterEqualOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def GreaterEqualOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/GreaterOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class GreaterOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = GreaterOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsGreaterOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def GreaterOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # GreaterOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def GreaterOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def GreaterOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/HardSwishOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class HardSwishOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = HardSwishOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsHardSwishOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def HardSwishOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # HardSwishOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def HardSwishOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def HardSwishOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/HashtableFindOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class HashtableFindOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = HashtableFindOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsHashtableFindOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def HashtableFindOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # HashtableFindOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def HashtableFindOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def HashtableFindOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/HashtableImportOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class HashtableImportOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = HashtableImportOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsHashtableImportOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def HashtableImportOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # HashtableImportOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def HashtableImportOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def HashtableImportOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/HashtableSizeOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class HashtableSizeOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = HashtableSizeOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsHashtableSizeOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def HashtableSizeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # HashtableSizeOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def HashtableSizeOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def HashtableSizeOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/LSHProjectionType.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | class LSHProjectionType(object):
8 | UNKNOWN = 0
9 | SPARSE = 1
10 | DENSE = 2
11 |
12 |
--------------------------------------------------------------------------------
/tf2onnx/tflite/LSTMKernelType.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | class LSTMKernelType(object):
8 | FULL = 0
9 | BASIC = 1
10 |
11 |
--------------------------------------------------------------------------------
/tf2onnx/tflite/LeakyReluOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class LeakyReluOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = LeakyReluOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsLeakyReluOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def LeakyReluOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # LeakyReluOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | # LeakyReluOptions
34 | def Alpha(self):
35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
36 | if o != 0:
37 | return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
38 | return 0.0
39 |
40 | def Start(builder): builder.StartObject(1)
41 | def LeakyReluOptionsStart(builder):
42 | """This method is deprecated. Please switch to Start."""
43 | return Start(builder)
44 | def AddAlpha(builder, alpha): builder.PrependFloat32Slot(0, alpha, 0.0)
45 | def LeakyReluOptionsAddAlpha(builder, alpha):
46 | """This method is deprecated. Please switch to AddAlpha."""
47 | return AddAlpha(builder, alpha)
48 | def End(builder): return builder.EndObject()
49 | def LeakyReluOptionsEnd(builder):
50 | """This method is deprecated. Please switch to End."""
51 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/LessEqualOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class LessEqualOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = LessEqualOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsLessEqualOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def LessEqualOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # LessEqualOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def LessEqualOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def LessEqualOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/LessOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class LessOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = LessOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsLessOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def LessOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # LessOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def LessOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def LessOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/LogSoftmaxOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class LogSoftmaxOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = LogSoftmaxOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsLogSoftmaxOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def LogSoftmaxOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # LogSoftmaxOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def LogSoftmaxOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def LogSoftmaxOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/LogicalAndOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class LogicalAndOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = LogicalAndOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsLogicalAndOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def LogicalAndOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # LogicalAndOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def LogicalAndOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def LogicalAndOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/LogicalNotOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class LogicalNotOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = LogicalNotOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsLogicalNotOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def LogicalNotOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # LogicalNotOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def LogicalNotOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def LogicalNotOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/LogicalOrOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class LogicalOrOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = LogicalOrOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsLogicalOrOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def LogicalOrOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # LogicalOrOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def LogicalOrOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def LogicalOrOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/MatrixDiagOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class MatrixDiagOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = MatrixDiagOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsMatrixDiagOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def MatrixDiagOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # MatrixDiagOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def MatrixDiagOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def MatrixDiagOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/MatrixSetDiagOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class MatrixSetDiagOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = MatrixSetDiagOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsMatrixSetDiagOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def MatrixSetDiagOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # MatrixSetDiagOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def MatrixSetDiagOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def MatrixSetDiagOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/MaximumMinimumOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class MaximumMinimumOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = MaximumMinimumOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsMaximumMinimumOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def MaximumMinimumOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # MaximumMinimumOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def MaximumMinimumOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def MaximumMinimumOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/MirrorPadMode.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | class MirrorPadMode(object):
8 | REFLECT = 0
9 | SYMMETRIC = 1
10 |
11 |
--------------------------------------------------------------------------------
/tf2onnx/tflite/MirrorPadOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class MirrorPadOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = MirrorPadOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsMirrorPadOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def MirrorPadOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # MirrorPadOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | # MirrorPadOptions
34 | def Mode(self):
35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
36 | if o != 0:
37 | return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
38 | return 0
39 |
40 | def Start(builder): builder.StartObject(1)
41 | def MirrorPadOptionsStart(builder):
42 | """This method is deprecated. Please switch to Start."""
43 | return Start(builder)
44 | def AddMode(builder, mode): builder.PrependInt8Slot(0, mode, 0)
45 | def MirrorPadOptionsAddMode(builder, mode):
46 | """This method is deprecated. Please switch to AddMode."""
47 | return AddMode(builder, mode)
48 | def End(builder): return builder.EndObject()
49 | def MirrorPadOptionsEnd(builder):
50 | """This method is deprecated. Please switch to End."""
51 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/NegOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class NegOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = NegOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsNegOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def NegOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # NegOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def NegOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def NegOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/NonMaxSuppressionV4Options.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class NonMaxSuppressionV4Options(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = NonMaxSuppressionV4Options()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsNonMaxSuppressionV4Options(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def NonMaxSuppressionV4OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # NonMaxSuppressionV4Options
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def NonMaxSuppressionV4OptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def NonMaxSuppressionV4OptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/NonMaxSuppressionV5Options.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class NonMaxSuppressionV5Options(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = NonMaxSuppressionV5Options()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsNonMaxSuppressionV5Options(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def NonMaxSuppressionV5OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # NonMaxSuppressionV5Options
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def NonMaxSuppressionV5OptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def NonMaxSuppressionV5OptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/NotEqualOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class NotEqualOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = NotEqualOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsNotEqualOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def NotEqualOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # NotEqualOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def NotEqualOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def NotEqualOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/OneHotOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class OneHotOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = OneHotOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsOneHotOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def OneHotOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # OneHotOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | # OneHotOptions
34 | def Axis(self):
35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
36 | if o != 0:
37 | return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
38 | return 0
39 |
40 | def Start(builder): builder.StartObject(1)
41 | def OneHotOptionsStart(builder):
42 | """This method is deprecated. Please switch to Start."""
43 | return Start(builder)
44 | def AddAxis(builder, axis): builder.PrependInt32Slot(0, axis, 0)
45 | def OneHotOptionsAddAxis(builder, axis):
46 | """This method is deprecated. Please switch to AddAxis."""
47 | return AddAxis(builder, axis)
48 | def End(builder): return builder.EndObject()
49 | def OneHotOptionsEnd(builder):
50 | """This method is deprecated. Please switch to End."""
51 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/PadOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class PadOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = PadOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsPadOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def PadOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # PadOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def PadOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def PadOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/PadV2Options.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class PadV2Options(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = PadV2Options()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsPadV2Options(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def PadV2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # PadV2Options
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def PadV2OptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def PadV2OptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/Padding.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | class Padding(object):
8 | SAME = 0
9 | VALID = 1
10 |
11 |
--------------------------------------------------------------------------------
/tf2onnx/tflite/PowOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class PowOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = PowOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsPowOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def PowOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # PowOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def PowOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def PowOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/QuantizationDetails.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | class QuantizationDetails(object):
8 | NONE = 0
9 | CustomQuantization = 1
10 |
11 |
--------------------------------------------------------------------------------
/tf2onnx/tflite/QuantizeOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class QuantizeOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = QuantizeOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsQuantizeOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def QuantizeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # QuantizeOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def QuantizeOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def QuantizeOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/RangeOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class RangeOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = RangeOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsRangeOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def RangeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # RangeOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def RangeOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def RangeOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/RankOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class RankOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = RankOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsRankOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def RankOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # RankOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def RankOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def RankOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/ReadVariableOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class ReadVariableOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = ReadVariableOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsReadVariableOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def ReadVariableOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # ReadVariableOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def ReadVariableOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def ReadVariableOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/ReducerOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class ReducerOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = ReducerOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsReducerOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def ReducerOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # ReducerOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | # ReducerOptions
34 | def KeepDims(self):
35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
36 | if o != 0:
37 | return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
38 | return False
39 |
40 | def Start(builder): builder.StartObject(1)
41 | def ReducerOptionsStart(builder):
42 | """This method is deprecated. Please switch to Start."""
43 | return Start(builder)
44 | def AddKeepDims(builder, keepDims): builder.PrependBoolSlot(0, keepDims, 0)
45 | def ReducerOptionsAddKeepDims(builder, keepDims):
46 | """This method is deprecated. Please switch to AddKeepDims."""
47 | return AddKeepDims(builder, keepDims)
48 | def End(builder): return builder.EndObject()
49 | def ReducerOptionsEnd(builder):
50 | """This method is deprecated. Please switch to End."""
51 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/ReverseV2Options.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class ReverseV2Options(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = ReverseV2Options()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsReverseV2Options(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def ReverseV2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # ReverseV2Options
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def ReverseV2OptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def ReverseV2OptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/Rfft2dOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class Rfft2dOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = Rfft2dOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsRfft2dOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def Rfft2dOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # Rfft2dOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def Rfft2dOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def Rfft2dOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/RightShiftOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class RightShiftOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = RightShiftOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsRightShiftOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def RightShiftOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # RightShiftOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def RightShiftOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def RightShiftOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/ScatterNdOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class ScatterNdOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = ScatterNdOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsScatterNdOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def ScatterNdOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # ScatterNdOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def ScatterNdOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def ScatterNdOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/SegmentSumOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class SegmentSumOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = SegmentSumOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsSegmentSumOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def SegmentSumOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # SegmentSumOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def SegmentSumOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def SegmentSumOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/SelectOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class SelectOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = SelectOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsSelectOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def SelectOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # SelectOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def SelectOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def SelectOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/SelectV2Options.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class SelectV2Options(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = SelectV2Options()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsSelectV2Options(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def SelectV2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # SelectV2Options
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def SelectV2OptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def SelectV2OptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/ShapeOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class ShapeOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = ShapeOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsShapeOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def ShapeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # ShapeOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | # ShapeOptions
34 | def OutType(self):
35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
36 | if o != 0:
37 | return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
38 | return 0
39 |
40 | def Start(builder): builder.StartObject(1)
41 | def ShapeOptionsStart(builder):
42 | """This method is deprecated. Please switch to Start."""
43 | return Start(builder)
44 | def AddOutType(builder, outType): builder.PrependInt8Slot(0, outType, 0)
45 | def ShapeOptionsAddOutType(builder, outType):
46 | """This method is deprecated. Please switch to AddOutType."""
47 | return AddOutType(builder, outType)
48 | def End(builder): return builder.EndObject()
49 | def ShapeOptionsEnd(builder):
50 | """This method is deprecated. Please switch to End."""
51 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/SignOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class SignOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = SignOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsSignOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def SignOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # SignOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def SignOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def SignOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/SliceOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class SliceOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = SliceOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsSliceOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def SliceOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # SliceOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def SliceOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def SliceOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/SoftmaxOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class SoftmaxOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = SoftmaxOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsSoftmaxOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def SoftmaxOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # SoftmaxOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | # SoftmaxOptions
34 | def Beta(self):
35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
36 | if o != 0:
37 | return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
38 | return 0.0
39 |
40 | def Start(builder): builder.StartObject(1)
41 | def SoftmaxOptionsStart(builder):
42 | """This method is deprecated. Please switch to Start."""
43 | return Start(builder)
44 | def AddBeta(builder, beta): builder.PrependFloat32Slot(0, beta, 0.0)
45 | def SoftmaxOptionsAddBeta(builder, beta):
46 | """This method is deprecated. Please switch to AddBeta."""
47 | return AddBeta(builder, beta)
48 | def End(builder): return builder.EndObject()
49 | def SoftmaxOptionsEnd(builder):
50 | """This method is deprecated. Please switch to End."""
51 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/SpaceToBatchNDOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class SpaceToBatchNDOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = SpaceToBatchNDOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsSpaceToBatchNDOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def SpaceToBatchNDOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # SpaceToBatchNDOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def SpaceToBatchNDOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def SpaceToBatchNDOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/SparseIndexVector.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | class SparseIndexVector(object):
8 | NONE = 0
9 | Int32Vector = 1
10 | Uint16Vector = 2
11 | Uint8Vector = 3
12 |
13 |
--------------------------------------------------------------------------------
/tf2onnx/tflite/SplitOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class SplitOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = SplitOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsSplitOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def SplitOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # SplitOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | # SplitOptions
34 | def NumSplits(self):
35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
36 | if o != 0:
37 | return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
38 | return 0
39 |
40 | def Start(builder): builder.StartObject(1)
41 | def SplitOptionsStart(builder):
42 | """This method is deprecated. Please switch to Start."""
43 | return Start(builder)
44 | def AddNumSplits(builder, numSplits): builder.PrependInt32Slot(0, numSplits, 0)
45 | def SplitOptionsAddNumSplits(builder, numSplits):
46 | """This method is deprecated. Please switch to AddNumSplits."""
47 | return AddNumSplits(builder, numSplits)
48 | def End(builder): return builder.EndObject()
49 | def SplitOptionsEnd(builder):
50 | """This method is deprecated. Please switch to End."""
51 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/SplitVOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class SplitVOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = SplitVOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsSplitVOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def SplitVOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # SplitVOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | # SplitVOptions
34 | def NumSplits(self):
35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
36 | if o != 0:
37 | return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
38 | return 0
39 |
40 | def Start(builder): builder.StartObject(1)
41 | def SplitVOptionsStart(builder):
42 | """This method is deprecated. Please switch to Start."""
43 | return Start(builder)
44 | def AddNumSplits(builder, numSplits): builder.PrependInt32Slot(0, numSplits, 0)
45 | def SplitVOptionsAddNumSplits(builder, numSplits):
46 | """This method is deprecated. Please switch to AddNumSplits."""
47 | return AddNumSplits(builder, numSplits)
48 | def End(builder): return builder.EndObject()
49 | def SplitVOptionsEnd(builder):
50 | """This method is deprecated. Please switch to End."""
51 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/SquareOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class SquareOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = SquareOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsSquareOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def SquareOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # SquareOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def SquareOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def SquareOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/SquaredDifferenceOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class SquaredDifferenceOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = SquaredDifferenceOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsSquaredDifferenceOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def SquaredDifferenceOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # SquaredDifferenceOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def SquaredDifferenceOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def SquaredDifferenceOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/TensorType.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | class TensorType(object):
8 | FLOAT32 = 0
9 | FLOAT16 = 1
10 | INT32 = 2
11 | UINT8 = 3
12 | INT64 = 4
13 | STRING = 5
14 | BOOL = 6
15 | INT16 = 7
16 | COMPLEX64 = 8
17 | INT8 = 9
18 | FLOAT64 = 10
19 | COMPLEX128 = 11
20 | UINT64 = 12
21 | RESOURCE = 13
22 | VARIANT = 14
23 | UINT32 = 15
24 | UINT16 = 16
25 | INT4 = 17
26 |
27 |
--------------------------------------------------------------------------------
/tf2onnx/tflite/TileOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class TileOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = TileOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsTileOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def TileOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # TileOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def TileOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def TileOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/TopKV2Options.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class TopKV2Options(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = TopKV2Options()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsTopKV2Options(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def TopKV2OptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # TopKV2Options
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def TopKV2OptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def TopKV2OptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/TransposeOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class TransposeOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = TransposeOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsTransposeOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def TransposeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # TransposeOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def TransposeOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def TransposeOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/UniqueOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class UniqueOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = UniqueOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsUniqueOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def UniqueOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # UniqueOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | # UniqueOptions
34 | def IdxOutType(self):
35 | o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
36 | if o != 0:
37 | return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
38 | return 2
39 |
40 | def Start(builder): builder.StartObject(1)
41 | def UniqueOptionsStart(builder):
42 | """This method is deprecated. Please switch to Start."""
43 | return Start(builder)
44 | def AddIdxOutType(builder, idxOutType): builder.PrependInt8Slot(0, idxOutType, 2)
45 | def UniqueOptionsAddIdxOutType(builder, idxOutType):
46 | """This method is deprecated. Please switch to AddIdxOutType."""
47 | return AddIdxOutType(builder, idxOutType)
48 | def End(builder): return builder.EndObject()
49 | def UniqueOptionsEnd(builder):
50 | """This method is deprecated. Please switch to End."""
51 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/UnsortedSegmentMaxOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class UnsortedSegmentMaxOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = UnsortedSegmentMaxOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsUnsortedSegmentMaxOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def UnsortedSegmentMaxOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # UnsortedSegmentMaxOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def UnsortedSegmentMaxOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def UnsortedSegmentMaxOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/UnsortedSegmentMinOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class UnsortedSegmentMinOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = UnsortedSegmentMinOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsUnsortedSegmentMinOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def UnsortedSegmentMinOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # UnsortedSegmentMinOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def UnsortedSegmentMinOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def UnsortedSegmentMinOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/UnsortedSegmentProdOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class UnsortedSegmentProdOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = UnsortedSegmentProdOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsUnsortedSegmentProdOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def UnsortedSegmentProdOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # UnsortedSegmentProdOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def UnsortedSegmentProdOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def UnsortedSegmentProdOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/UnsortedSegmentSumOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class UnsortedSegmentSumOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = UnsortedSegmentSumOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsUnsortedSegmentSumOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def UnsortedSegmentSumOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # UnsortedSegmentSumOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def UnsortedSegmentSumOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def UnsortedSegmentSumOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/WhereOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class WhereOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = WhereOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsWhereOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def WhereOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # WhereOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def WhereOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def WhereOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/ZerosLikeOptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # automatically generated by the FlatBuffers compiler, do not modify
4 |
5 | # namespace: tflite
6 |
7 | import flatbuffers
8 | from flatbuffers.compat import import_numpy
9 | np = import_numpy()
10 |
11 | class ZerosLikeOptions(object):
12 | __slots__ = ['_tab']
13 |
14 | @classmethod
15 | def GetRootAs(cls, buf, offset=0):
16 | n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17 | x = ZerosLikeOptions()
18 | x.Init(buf, n + offset)
19 | return x
20 |
21 | @classmethod
22 | def GetRootAsZerosLikeOptions(cls, buf, offset=0):
23 | """This method is deprecated. Please switch to GetRootAs."""
24 | return cls.GetRootAs(buf, offset)
25 | @classmethod
26 | def ZerosLikeOptionsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
27 | return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x46\x4C\x33", size_prefixed=size_prefixed)
28 |
29 | # ZerosLikeOptions
30 | def Init(self, buf, pos):
31 | self._tab = flatbuffers.table.Table(buf, pos)
32 |
33 | def Start(builder): builder.StartObject(0)
34 | def ZerosLikeOptionsStart(builder):
35 | """This method is deprecated. Please switch to Start."""
36 | return Start(builder)
37 | def End(builder): return builder.EndObject()
38 | def ZerosLikeOptionsEnd(builder):
39 | """This method is deprecated. Please switch to End."""
40 | return End(builder)
--------------------------------------------------------------------------------
/tf2onnx/tflite/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
--------------------------------------------------------------------------------
/tf2onnx/tflite_handlers/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """tf2onnx.tflite_handlers module"""
4 |
5 | from . import (
6 | tfl_math,
7 | tfl_nn,
8 | tfl_controlflow,
9 | tfl_direct,
10 | tfl_tensor,
11 | tfl_postprocess,
12 | )
13 |
--------------------------------------------------------------------------------
/tf2onnx/tflite_rewriters/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """tf2onnx.tflite_rewriters module"""
4 |
5 | from tf2onnx.tflite_rewriters.tfl_scan_output_rewriter import rewrite_tfl_scan_outputs
6 | from tf2onnx.tflite_rewriters.tfl_qdq_rewriter import rewrite_tfl_qdq
7 | from tf2onnx.tflite_rewriters.tfl_select_zero_rewriter import rewrite_tfl_select_zero
8 | from tf2onnx.tflite_rewriters.tfl_rfft_rewriter import rewrite_tfl_rfft
9 |
10 | __all__ = [
11 | "rewrite_tfl_scan_outputs",
12 | "rewrite_tfl_qdq",
13 | "rewrite_tfl_select_zero",
14 | "rewrite_tfl_rfft",
15 | ]
16 |
--------------------------------------------------------------------------------
/tf2onnx/version.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | version = '1.16.1'
5 | git_version = '13bab8a91e17ccd87541b2f361ab60e8e38359d3'
6 |
--------------------------------------------------------------------------------
/tools/example.bat:
--------------------------------------------------------------------------------
1 | rem SPDX-License-Identifier: Apache-2.0
2 |
3 | set frozen=tests/models/fc-layers/frozen.pb
4 | set output=/tmp/model.onnx
5 | set input_names=X:0
6 | set output_names=output:0
7 | set output_names1=output
8 |
9 | python tf2onnx\convert.py --input %frozen% --inputs %input_names% --outputs %output_names% --output %output% %1 %2
10 |
--------------------------------------------------------------------------------
/tools/example.sh:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | frozen=tests/models/fc-layers/frozen.pb
4 | output=/tmp/model.onnx
5 | input_names=X:0
6 | output_names=output:0
7 | output_names1=output
8 |
9 | python -m tf2onnx.convert --input $frozen --inputs $input_names --outputs $output_names --output $output $@
10 |
--------------------------------------------------------------------------------
/tools/gen_tflite_flatbuffer.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | """
4 | Generates the files in tf2onnx/tflite used for parsing tflite flatbuffer
5 | WARNING: this script will overwrite all files in tf2onnx/tflite
6 | Before running, download the flatc executable from https://github.com/google/flatbuffers/releases and add it to PATH
7 | This script only tested on Windows
8 | """
9 |
10 | import os
11 | import subprocess
12 | import tempfile
13 | import wget
14 |
15 | SCHEMA_URL = "https://github.com/tensorflow/tensorflow/raw/master/tensorflow/lite/schema/schema.fbs"
16 |
17 | FILE_HEADER = "# SPDX-License-Identifier: Apache-2.0\n\n"
18 |
19 | def main():
20 | tmpdir = os.path.join(tempfile.gettempdir(), "tflite_flatbuffer")
21 | os.makedirs(tmpdir, exist_ok=True)
22 | repodir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
23 | dstpath = os.path.join(repodir, "tf2onnx", "tflite")
24 | os.makedirs(dstpath, exist_ok=True)
25 | # Remove existing flatbuffer bindings
26 | for file in os.listdir(dstpath):
27 | os.remove(os.path.join(dstpath, file))
28 | schema_path = os.path.join(tmpdir, "schema.fbs")
29 | # Download schema file
30 | if os.path.exists(schema_path):
31 | os.remove(schema_path)
32 | wget.download(SCHEMA_URL, schema_path)
33 | print()
34 | # Generate flatbuffer code
35 | subprocess.call(["flatc", "-p", "-o", tmpdir, schema_path])
36 | tmp_result_path = os.path.join(tmpdir, "tflite")
37 | for file in os.listdir(tmp_result_path):
38 | with open(os.path.join(tmp_result_path, file), "rt") as f:
39 | content = f.read()
40 | content = FILE_HEADER + content.replace("from tflite.", "from tf2onnx.tflite.")
41 | with open(os.path.join(dstpath, file), "wt") as f:
42 | f.write(content)
43 | print("Generated", file)
44 |
45 | main()
46 |
--------------------------------------------------------------------------------
/tools/graphtool.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 |
4 | """
5 | simple tool to convert .meta to .pb.
6 | """
7 |
8 | import argparse
9 | import os
10 |
11 | import tensorflow as tf
12 |
13 |
14 | def get_args():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("infile", nargs="*", help="event files")
17 | args = parser.parse_args()
18 | return args
19 |
20 |
21 | def to_pb(src):
22 | """Convert .meta to .pb."""
23 | _ = tf.train.import_meta_graph(src)
24 | graph = tf.get_default_graph()
25 |
26 | fname = os.path.basename(src)[:-5]
27 | tf.train.write_graph(graph, os.path.dirname(src), fname + '.pb', as_text=False)
28 | tf.train.write_graph(graph, os.path.dirname(src), fname + '.pbtxt', as_text=True)
29 |
30 | writer = tf.summary.FileWriter(os.path.dirname(src))
31 | writer.add_graph(graph)
32 | writer.close()
33 |
34 |
35 | def main():
36 | args = get_args()
37 | for src in args.infile:
38 | to_pb(src)
39 |
40 |
41 | if __name__ == "__main__":
42 | main()
43 |
--------------------------------------------------------------------------------
/tutorials/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # tf2onnx tutorials
4 |
5 | The following tutorials show how to convert various models to ONNX.
6 |
7 | ## Image Classifiers
8 | [efficientnet-edge](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/efficientnet-edge.ipynb)
9 |
10 | [efficientnet-lite](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/efficientnet-lite.ipynb)
11 |
12 | [keras-resnet50](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/keras-resnet50.ipynb) - shows how to convert a keras model via python api
13 |
14 | ## Object Detectors
15 | [ssd-mobilenet](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/ConvertingSSDMobilenetToONNX.ipynb)
16 |
17 | [efficientdet](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/efficientdet.ipynb)
18 |
19 | [mobiledet](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/mobiledet-tflite.ipynb) - shows how to convert a tflite model
20 |
21 | ## Nlp
22 | [Huggingface Bert Example](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/huggingface-bert.ipynb)
23 |
24 | [The original Tensorflow Bert model](https://github.com/onnx/tensorflow-onnx/blob/main/tutorials/BertTutorial.ipynb) - depreciated, use huggingface
25 |
--------------------------------------------------------------------------------