├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── convert_jax
└── JAX_MNIST_to_LiteRT.ipynb
└── convert_pytorch
├── DIS_segmentation_and_quantization.ipynb
├── Image_classification_with_convnext_v2.ipynb
├── Image_classification_with_mobile_vit.ipynb
└── SegNext_segmentation_and_quantization.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | **/*.pyc
2 | **/.DS_Store
3 | **/.idea
4 | **/.ipynb_checkpoints
5 |
6 | # TensorFlow Lite model
7 | *.tflite
8 |
9 | # MediaPipe Task files
10 | *.tasks
11 |
12 | # Swift
13 |
14 | ## Build generated
15 | build/
16 | DerivedData/
17 | .idea
18 | ## Various settings
19 | *.pbxuser
20 | !default.pbxuser
21 | *.mode1v3
22 | !default.mode1v3
23 | *.mode2v3
24 | !default.mode2v3
25 | *.perspectivev3
26 | !default.perspectivev3
27 | xcuserdata/
28 |
29 | ## Other
30 | *.moved-aside
31 | *.xccheckout
32 | *.xcscmblueprint
33 |
34 | ## Obj-C/Swift specific
35 | *.hmap
36 | *.ipa
37 | *.dSYM.zip
38 | *.dSYM
39 |
40 | # CocoaPods
41 | Pods/
42 | *.xcworkspace
43 |
44 | # mkdocs ignore generated site
45 | site/
46 |
47 | # Android
48 |
49 | # Built application files
50 | *.apk
51 | *.aar
52 | *.ap_
53 | *.aab
54 |
55 | # Files for the ART/Dalvik VM
56 | *.dex
57 |
58 | # Java class files
59 | *.class
60 |
61 | # Generated files
62 | bin/
63 | gen/
64 | out/
65 | # Uncomment the following line in case you need and you don't have the release build type files in your app
66 | # release/
67 |
68 | # Gradle files
69 | .gradle/
70 | build/
71 |
72 | # Local configuration file (sdk path, etc)
73 | local.properties
74 |
75 | # Proguard folder generated by Eclipse
76 | proguard/
77 |
78 | # Log Files
79 | *.log
80 |
81 | # Android Studio Navigation editor temp files
82 | .navigation/
83 |
84 | # Android Studio captures folder
85 | captures/
86 |
87 | # IntelliJ
88 | *.iml
89 | .idea/workspace.xml
90 | .idea/tasks.xml
91 | .idea/gradle.xml
92 | .idea/assetWizardSettings.xml
93 | .idea/dictionaries
94 | .idea/libraries
95 | # Android Studio 3 in .gitignore file.
96 | .idea/caches
97 | .idea/modules.xml
98 | .idea/navEditor.xml
99 |
100 | # Keystore files
101 | # Uncomment the following lines if you do not want to check your keystore files in.
102 | #*.jks
103 | #*.keystore
104 |
105 | # External native build folder generated in Android Studio 2.2 and later
106 | .externalNativeBuild
107 | .cxx/
108 |
109 | # Google Services (e.g. APIs or Firebase)
110 | # google-services.json
111 |
112 | # Version control
113 | vcs.xml
114 |
115 | # lint
116 | lint/intermediates/
117 | lint/generated/
118 | lint/outputs/
119 | lint/tmp/
120 | # lint/reports/
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Edge AI Code of Conduct
2 |
3 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation.
4 |
5 |
6 | ## Our Standards
7 |
8 | Examples of behavior that contributes to creating a positive environment include:
9 |
10 | * Using welcoming and inclusive language
11 | * Being respectful of differing viewpoints and experiences
12 | * Gracefully accepting constructive criticism
13 | * Focusing on what is best for the community
14 | * Showing empathy towards other community members
15 |
16 | Examples of unacceptable behavior by participants include:
17 |
18 | * The use of sexualized language or imagery and unwelcome sexual attention or advances
19 | * Trolling, insulting/derogatory comments, and personal or political attacks
20 | * Public or private harassment
21 | * Publishing others' private information, such as a physical or electronic address, without explicit permission
22 | * Conduct which could reasonably be considered inappropriate for the forum in which it occurs.
23 |
24 | All AI Edge forums and spaces are meant for professional interactions, and any behavior which could reasonably be considered inappropriate in a professional setting is unacceptable.
25 |
26 |
27 | ## Our Responsibilities
28 |
29 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
30 |
31 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
32 |
33 |
34 | ## Scope
35 |
36 | This Code of Conduct applies to all content on AI Edge’s GitHub organization, or any other official AI Edge web presence allowing for community interactions, as well as at all official AI Edge events, whether offline or online.
37 |
38 | The Code of Conduct also applies within project spaces and in public spaces whenever an individual is representing AI Edge or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed or de facto representative at an online or offline event.
39 |
40 |
41 | ## Conflict Resolution
42 |
43 | Conflicts in an open source project can take many forms, from someone having a bad day and using harsh and hurtful language in the issue queue, to more serious instances such as sexist/racist statements or threats of violence, and everything in between.
44 |
45 | If the behavior is threatening or harassing, or for other reasons requires immediate escalation, please see below.
46 |
47 | However, for the vast majority of issues, we aim to empower individuals to first resolve conflicts themselves, asking for help when needed, and only after that fails to escalate further. This approach gives people more control over the outcome of their dispute.
48 |
49 | If you are experiencing or witnessing conflict, we ask you to use the following escalation strategy to address the conflict:
50 |
51 | 1. Address the perceived conflict directly with those involved, preferably in a real-time medium.
52 | 2. If this fails, get a third party (e.g. a mutual friend, and/or someone with background on the issue, but not involved in conflict) to intercede.
53 | 3. If you are still unable to resolve the conflict, and you believe it rises to harassment or another code of conduct violation, report it.
54 |
55 |
56 | ## Reporting Violations
57 |
58 | Violations of the Code of Conduct can be reported to Edge AI’s Project Stewards, Paul Trebilcox-Ruiz (ptruiz@google.com) and Joana Carrasqueira (joanafilipa@google.com). The Project Steward will determine whether the Code of Conduct was violated, and will issue an appropriate sanction, possibly including a written warning or expulsion from the project, project sponsored spaces, or project forums. We ask that you make a good-faith effort to resolve your conflict via the conflict resolution policy before submitting a report.
59 |
60 | Violations of the Code of Conduct can occur in any setting, even those unrelated to the project. We will only consider complaints about conduct that has occurred within one year of the report.
61 |
62 |
63 | ## Enforcement
64 |
65 | If the Project Stewards receive a report alleging a violation of the Code of Conduct, the Project Stewards will notify the accused of the report, and provide them an opportunity to discuss the report before a sanction is issued. The Project Stewards will do their utmost to keep the reporter anonymous. If the act is ongoing (such as someone engaging in harassment), or involves a threat to anyone's safety (e.g. threats of violence), the Project Stewards may issue sanctions without notice.
66 |
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at https://contributor-covenant.org/version/1/4, and includes some aspects of the Geek Feminism Code of Conduct and the Drupal Code of Conduct.
71 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing guidelines
2 |
3 | ## How to become a contributor and submit your own code
4 |
5 | ### Contributor License Agreements
6 |
7 | We'd love to accept your patches! Before we can take them, we have to jump a couple of legal hurdles.
8 |
9 | Please fill out either the individual or corporate Contributor License Agreement (CLA).
10 |
11 | * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](https://code.google.com/legal/individual-cla-v1.0.html).
12 | * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](https://code.google.com/legal/corporate-cla-v1.0.html).
13 |
14 | Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests.
15 |
16 | ***NOTE***: Only original source code from you and other people that have signed the CLA can be accepted into the main repository.
17 |
18 | ### Contributing code
19 |
20 | If you have improvements to Edge AI, send us your pull requests! For those
21 | just getting started, Github has a [howto](https://help.github.com/articles/using-pull-requests/).
22 |
23 | Edge AI team members will be assigned to review your pull requests. Once the pull requests are approved and pass continuous integration checks, we will merge the pull requests.
24 | For some pull requests, we will apply the patch for each pull request to our internal version control system first, and export the change out as a new commit later, at which point the original pull request will be closed. The commits in the pull request will be squashed into a single commit with the pull request creator as the author. These pull requests will be labeled as pending merge internally.
25 |
26 | #### License
27 |
28 | Include a license at the top of new files.
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AI Edge Model Conversion/Creating Samples
2 |
3 | This repository is home to sames related to Google AI Edge and the creation/conversion of models. The *convert_pytorch* directory contains colab files that you can run via your personal Google account to download and convert PyTorch models to the TensorFlow Lite format.
4 |
--------------------------------------------------------------------------------
/convert_jax/JAX_MNIST_to_LiteRT.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "8vD3L4qeREvg"
7 | },
8 | "source": [
9 | "##### Copyright 2024 The AI Edge Authors."
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "qLCxmWRyRMZE",
17 | "cellView": "form"
18 | },
19 | "outputs": [],
20 | "source": [
21 | "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
22 | "# you may not use this file except in compliance with the License.\n",
23 | "# You may obtain a copy of the License at\n",
24 | "#\n",
25 | "# https://www.apache.org/licenses/LICENSE-2.0\n",
26 | "#\n",
27 | "# Unless required by applicable law or agreed to in writing, software\n",
28 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
29 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
30 | "# See the License for the specific language governing permissions and\n",
31 | "# limitations under the License."
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {
37 | "id": "4k5PoHrgJQOU"
38 | },
39 | "source": [
40 | "# Jax Model Conversion For LiteRT\n",
41 | "## Overview\n",
42 | "Note: This API is new and we recommend using via pip install tf-nightly. Also, the API is still experimental and subject to changes."
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "metadata": {
48 | "id": "lq-T8XZMJ-zv"
49 | },
50 | "source": [
51 | "## Prerequisites\n",
52 | "It's recommended to try this feature with the newest TensorFlow nightly pip build."
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {
59 | "id": "EV04hKdrnE4f"
60 | },
61 | "outputs": [],
62 | "source": [
63 | "!pip install jax --upgrade\n",
64 | "!pip install ai-edge-litert\n",
65 | "!pip install orbax-export --upgrade\n",
66 | "!pip install tf-nightly --upgrade"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {
73 | "id": "vsilblGuGQa2"
74 | },
75 | "outputs": [],
76 | "source": [
77 | "# Make sure your JAX version is at least 0.4.20 or above.\n",
78 | "import jax\n",
79 | "jax.__version__"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {
86 | "id": "j9_CVA0THQNc"
87 | },
88 | "outputs": [],
89 | "source": [
90 | "from orbax.export import ExportManager\n",
91 | "from orbax.export import JaxModule\n",
92 | "from orbax.export import ServingConfig\n",
93 | "from orbax.export import constants\n",
94 | "\n",
95 | "import tensorflow as tf\n",
96 | "from PIL import Image\n",
97 | "\n",
98 | "import time\n",
99 | "import functools\n",
100 | "import itertools\n",
101 | "\n",
102 | "import numpy as np\n",
103 | "import numpy.random as npr\n",
104 | "\n",
105 | "import jax.numpy as jnp\n",
106 | "from jax import jit, grad, random\n",
107 | "from jax.example_libraries import optimizers\n",
108 | "from jax.example_libraries import stax"
109 | ]
110 | },
111 | {
112 | "cell_type": "markdown",
113 | "metadata": {
114 | "id": "QAeY43k9KM55"
115 | },
116 | "source": [
117 | "## Data Preparation\n",
118 | "Download the MNIST data with Keras dataset and run that data through a pre-processing step. This dataset consists of multiple images that are 28x28 pixels and grayscaled (only having one color channel from black to white) representing hand drawn digits from 0 to 9.\n",
119 | "\n",
120 | "During the pre-processing step, the images will be normalized so that their gray color channel will change from 0->255 to 0.0->1.0. This decreases training time.\n",
121 | "\n",
122 | "The model will also use One Hot Encoding. This filters predictions to the most likely prediction."
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": null,
128 | "metadata": {
129 | "id": "hdJIt3Da2Qn1"
130 | },
131 | "outputs": [],
132 | "source": [
133 | "# Create a one-hot encoding of x of size k.\n",
134 | "def _one_hot(x, k, dtype=np.float32):\n",
135 | " return np.array(x[:, None] == np.arange(k), dtype)\n",
136 | "\n",
137 | "# JAX doesn't have its own data loader, so you can use Keras here.\n",
138 | "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n",
139 | "\n",
140 | "# Normalize the image pixels to a range of 0.0 to 1.0\n",
141 | "train_images, test_images = train_images / 255.0, test_images / 255.0\n",
142 | "train_images = train_images.astype(np.float32)\n",
143 | "test_images = test_images.astype(np.float32)\n",
144 | "\n",
145 | "train_labels = _one_hot(train_labels, 10)\n",
146 | "test_labels = _one_hot(test_labels, 10)"
147 | ]
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "source": [
152 | "The following code block is a simple utility to display a set of the MNIST dataset images."
153 | ],
154 | "metadata": {
155 | "id": "J_bJvaWrzXr7"
156 | }
157 | },
158 | {
159 | "cell_type": "code",
160 | "source": [
161 | "# Draws out some of the data in the training dataset.\n",
162 | "import matplotlib.pyplot as plt\n",
163 | "\n",
164 | "rows = 3\n",
165 | "cols = 7\n",
166 | "\n",
167 | "for i in range(rows):\n",
168 | " for j in range(cols):\n",
169 | " index = i * cols + j\n",
170 | " if index < len(train_images):\n",
171 | " plt.subplot(rows, cols, index + 1)\n",
172 | " plt.imshow(train_images[index], cmap='gray')\n",
173 | " plt.title(f\"Label: {np.argmax(train_labels[index])}\")\n",
174 | " plt.axis('off')\n",
175 | "\n",
176 | "plt.tight_layout()\n",
177 | "plt.show()"
178 | ],
179 | "metadata": {
180 | "id": "uyyFkg4Zpe3u"
181 | },
182 | "execution_count": null,
183 | "outputs": []
184 | },
185 | {
186 | "cell_type": "markdown",
187 | "metadata": {
188 | "id": "0eFhx85YKlEY"
189 | },
190 | "source": [
191 | "## Build the MNIST model with Jax"
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "source": [
197 | "This block outlines the loss and accuracy functions for training a new classification model, as well as defines the shape of the model layers."
198 | ],
199 | "metadata": {
200 | "id": "v0lIlmVgz_0g"
201 | }
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": null,
206 | "metadata": {
207 | "id": "mi3TKB9nnQdK"
208 | },
209 | "outputs": [],
210 | "source": [
211 | "# Loss function: Measures how well the model's predictions match expected outputs.\n",
212 | "def loss(params, batch):\n",
213 | " inputs, targets = batch\n",
214 | " preds = predict(params, inputs)\n",
215 | " return -jnp.mean(jnp.sum(preds * targets, axis=1))\n",
216 | "\n",
217 | "# Accuracy function: Average number of times the predictec class matches the true class\n",
218 | "def accuracy(params, batch):\n",
219 | " inputs, targets = batch\n",
220 | " # Finds the highest value in the output targets, which is the true value\n",
221 | " target_class = jnp.argmax(targets, axis=1)\n",
222 | " # Gets the primary predicted value from classification\n",
223 | " predicted_class = jnp.argmax(predict(params, inputs), axis=1)\n",
224 | " return jnp.mean(predicted_class == target_class)\n",
225 | "\n",
226 | "\n",
227 | "init_random_params, predict = stax.serial(\n",
228 | " stax.Flatten, # turns input data into a vector (1D array)\n",
229 | " stax.Dense(1024), stax.Relu, # Create two dense layers with ReLU activation\n",
230 | " stax.Dense(1024), stax.Relu,\n",
231 | " stax.Dense(10), stax.LogSoftmax) # Final layer condenses predictions into one of ten potential output classifications (0->9)\n",
232 | "\n",
233 | "# Pseudo random number generator used for initializing values\n",
234 | "rng = random.PRNGKey(0)"
235 | ]
236 | },
237 | {
238 | "cell_type": "markdown",
239 | "metadata": {
240 | "id": "bRtnOBdJLd63"
241 | },
242 | "source": [
243 | "## Train & Evaluate the model"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": null,
249 | "metadata": {
250 | "id": "SWbYRyj7LYZt"
251 | },
252 | "outputs": [],
253 | "source": [
254 | "step_size = 0.001 # Learning rate - smaller means slower but more stable learning\n",
255 | "num_epochs = 10\n",
256 | "batch_size = 128\n",
257 | "momentum_mass = 0.9 # Momentum optimization algorithm - helps converge faster\n",
258 | "\n",
259 | "# Data setup\n",
260 | "num_train = train_images.shape[0]\n",
261 | "num_complete_batches, leftover = divmod(num_train, batch_size)\n",
262 | "num_batches = num_complete_batches + bool(leftover)\n",
263 | "\n",
264 | "def data_stream():\n",
265 | " rng = npr.RandomState(0)\n",
266 | " while True:\n",
267 | " perm = rng.permutation(num_train)\n",
268 | " for i in range(num_batches):\n",
269 | " batch_idx = perm[i * batch_size:(i + 1) * batch_size]\n",
270 | " yield train_images[batch_idx], train_labels[batch_idx]\n",
271 | "batches = data_stream()\n",
272 | "\n",
273 | "# Optimizer setup\n",
274 | "opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)\n",
275 | "\n",
276 | "# Performs a single training step. Gets the current parameters, calculates\n",
277 | "# gradient of the loss function, then updates the optimizer state and model parameters\n",
278 | "@jit\n",
279 | "def update(i, opt_state, batch):\n",
280 | " params = get_params(opt_state)\n",
281 | " return opt_update(i, grad(loss)(params, batch), opt_state)\n",
282 | "\n",
283 | "# Run the training loop!\n",
284 | "_, init_params = init_random_params(rng, (-1, 28 * 28))\n",
285 | "opt_state = opt_init(init_params)\n",
286 | "itercount = itertools.count()\n",
287 | "\n",
288 | "print(\"\\nStarting training...\")\n",
289 | "for epoch in range(num_epochs):\n",
290 | " start_time = time.time()\n",
291 | " for _ in range(num_batches):\n",
292 | " opt_state = update(next(itercount), opt_state, next(batches))\n",
293 | " epoch_time = time.time() - start_time\n",
294 | "\n",
295 | " params = get_params(opt_state)\n",
296 | " train_acc = accuracy(params, (train_images, train_labels))\n",
297 | " test_acc = accuracy(params, (test_images, test_labels))\n",
298 | " print(\"Epoch {} in {:0.2f} sec\".format(epoch, epoch_time))\n",
299 | " print(\"Training set accuracy {}\".format(train_acc))\n",
300 | " print(\"Test set accuracy {}\".format(test_acc))"
301 | ]
302 | },
303 | {
304 | "cell_type": "markdown",
305 | "metadata": {
306 | "id": "7Y1OZBhfQhOj"
307 | },
308 | "source": [
309 | "## Convert to a tflite model.\n",
310 | "\n",
311 | "Using the `orbax` library, you can export the newly trained `JAX` model to a TensorFlow `SavedModel` file. Once you have a `SavedModel`, you can convert it to a `.tflite` file that can work with the LiteRT interpreter.\n",
312 | "\n",
313 | "\n",
314 | "\n",
315 | "\n"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": null,
321 | "metadata": {
322 | "id": "6pcqKZqdNTmn"
323 | },
324 | "outputs": [],
325 | "source": [
326 | "# This line bridges JAX to TensorFlow\n",
327 | "# Key point: `params` is everything that was learned during training. This is the\n",
328 | "# core part of what you just accomplished.\n",
329 | "# `predict` is the JAX function that does inference.\n",
330 | "jax_module = JaxModule(params, predict, input_polymorphic_shape='b, ...')\n",
331 | "\n",
332 | "converter = tf.lite.TFLiteConverter.from_concrete_functions(\n",
333 | " [\n",
334 | " jax_module.methods[constants.DEFAULT_METHOD_KEY].get_concrete_function(\n",
335 | " tf.TensorSpec(shape=(1, 28, 28), dtype=tf.float32, name=\"input\")\n",
336 | " )\n",
337 | " ],\n",
338 | " trackable_obj=tf.function() # Added empty trackable_obj argument\n",
339 | ")\n",
340 | "\n",
341 | "tflite_model = converter.convert()\n",
342 | "with open('jax_mnist.tflite', 'wb') as f:\n",
343 | " f.write(tflite_model)"
344 | ]
345 | },
346 | {
347 | "cell_type": "markdown",
348 | "metadata": {
349 | "id": "sqEhzaJPSPS1"
350 | },
351 | "source": [
352 | "## Check the Converted TFLite Model\n",
353 | "Next you can compare the converted model's results with the Jax model. This first block defines a utility to perform the prediction inference."
354 | ]
355 | },
356 | {
357 | "cell_type": "code",
358 | "source": [
359 | "def predict_image_class(image_path, model_path):\n",
360 | "\n",
361 | " try:\n",
362 | " # Load the TFLite model and allocate tensors.\n",
363 | " interpreter = Interpreter(model_path=model_path)\n",
364 | " interpreter.allocate_tensors()\n",
365 | "\n",
366 | " # Get input and output tensors.\n",
367 | " input_details = interpreter.get_input_details()\n",
368 | " output_details = interpreter.get_output_details()\n",
369 | "\n",
370 | " # Load the test image.\n",
371 | " img = Image.open(image_path).convert('L').resize((28, 28))\n",
372 | " img_array = np.array(img)\n",
373 | " img_array = img_array / 255.0\n",
374 | " img_array = np.expand_dims(img_array, axis=0)\n",
375 | " img_array = img_array.astype(np.float32)\n",
376 | "\n",
377 | " # Set the tensor to the input tensor and run inference.\n",
378 | " interpreter.set_tensor(input_details[0]['index'], img_array)\n",
379 | " interpreter.invoke()\n",
380 | "\n",
381 | " # Get the output tensor.\n",
382 | " output_data = interpreter.get_tensor(output_details[0]['index'])\n",
383 | "\n",
384 | " # Get the predicted class\n",
385 | " predicted_class = np.argmax(output_data)\n",
386 | " print(\"Predicted class:\", predicted_class)\n",
387 | "\n",
388 | " except Exception as e:\n",
389 | " print(f\"An error occurred: {e}\")"
390 | ],
391 | "metadata": {
392 | "id": "iKqKZZN7td0b"
393 | },
394 | "execution_count": null,
395 | "outputs": []
396 | },
397 | {
398 | "cell_type": "markdown",
399 | "source": [
400 | "You can download a pre-drawn image for testing that Google has provided, or load your own hand drawn monochronmatic image into the `/content/` directory."
401 | ],
402 | "metadata": {
403 | "id": "xZfUExBQ1Wha"
404 | }
405 | },
406 | {
407 | "cell_type": "code",
408 | "source": [
409 | "!wget https://storage.googleapis.com/ai-edge/models-samples/jax_converter/jax_to_litert_conversion_test/7.png -O /content/7.png"
410 | ],
411 | "metadata": {
412 | "id": "TXa1OkQZ3onF"
413 | },
414 | "execution_count": null,
415 | "outputs": []
416 | },
417 | {
418 | "cell_type": "code",
419 | "source": [
420 | "from ai_edge_litert.interpreter import Interpreter\n",
421 | "\n",
422 | "# Example usage\n",
423 | "# Replace with your image and model paths\n",
424 | "image_path = \"/content/7.png\"\n",
425 | "model_path = \"/content/jax_mnist.tflite\"\n",
426 | "\n",
427 | "predict_image_class(image_path, model_path)"
428 | ],
429 | "metadata": {
430 | "id": "UoohrrHusvS9"
431 | },
432 | "execution_count": null,
433 | "outputs": []
434 | },
435 | {
436 | "cell_type": "markdown",
437 | "metadata": {
438 | "id": "Qy9Gp4H2SjBL"
439 | },
440 | "source": [
441 | "## Optimize the Model\n",
442 | "We will provide a `representative_dataset` to do post-training quantiztion to optimize the model. This will reduce the model size to roughly a quarter.\n",
443 | "\n",
444 | "\n"
445 | ]
446 | },
447 | {
448 | "cell_type": "code",
449 | "source": [
450 | "def representative_dataset():\n",
451 | " for i in range(1000):\n",
452 | " x = train_images[i:i+1]\n",
453 | " yield [x]\n",
454 | "\n",
455 | "\n",
456 | "# Create a orbax.export.JaxModule that wraps the given JAX function and params into a TF.Module\n",
457 | "jax_module = JaxModule(params, predict)\n",
458 | "\n",
459 | "# Instanciate tf.lite.TFLiteConverter object from the default_signature in the above module\n",
460 | "converter = tf.lite.TFLiteConverter.from_concrete_functions(\n",
461 | " [\n",
462 | " jax_module.methods[constants.DEFAULT_METHOD_KEY].get_concrete_function(\n",
463 | " tf.TensorSpec(shape=(1, 28, 28), dtype=tf.float32, name=\"input\")\n",
464 | " )\n",
465 | " ],\n",
466 | " trackable_obj=tf.function() # Added empty trackable_obj argument\n",
467 | ")\n",
468 | "\n",
469 | "# Apply optimization settings and convert the model\n",
470 | "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
471 | "converter.representative_dataset = representative_dataset\n",
472 | "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n",
473 | "tflite_quant_model = converter.convert()\n",
474 | "\n",
475 | "# Save the serialized model contents to a .tflite flatbuffer\n",
476 | "with open('jax_mnist_quant.tflite', 'wb') as f:\n",
477 | " f.write(tflite_quant_model)"
478 | ],
479 | "metadata": {
480 | "id": "-sSPO7xsmYjK"
481 | },
482 | "execution_count": null,
483 | "outputs": []
484 | },
485 | {
486 | "cell_type": "markdown",
487 | "metadata": {
488 | "id": "15xQR3JZS8TV"
489 | },
490 | "source": [
491 | "## Evaluate the Optimized Model"
492 | ]
493 | },
494 | {
495 | "cell_type": "code",
496 | "execution_count": null,
497 | "metadata": {
498 | "id": "X3oOm0OaevD6"
499 | },
500 | "outputs": [],
501 | "source": [
502 | "image_path = \"/content/7.png\"\n",
503 | "model_path = \"/content/jax_mnist_quant.tflite\"\n",
504 | "\n",
505 | "predict_image_class(image_path, model_path)"
506 | ]
507 | },
508 | {
509 | "cell_type": "markdown",
510 | "metadata": {
511 | "id": "QqHXCNa3myor"
512 | },
513 | "source": [
514 | "## Compare the Quantized Model size\n",
515 | "We should be able to see the quantized model is four times smaller than the original model."
516 | ]
517 | },
518 | {
519 | "cell_type": "code",
520 | "execution_count": null,
521 | "metadata": {
522 | "id": "imFPw007juVG"
523 | },
524 | "outputs": [],
525 | "source": [
526 | "!du -h jax_mnist.tflite\n",
527 | "!du -h jax_mnist_quant.tflite"
528 | ]
529 | },
530 | {
531 | "cell_type": "code",
532 | "source": [
533 | "from google.colab import files\n",
534 | "\n",
535 | "files.download('jax_mnist.tflite')\n",
536 | "files.download('jax_mnist_quant.tflite')"
537 | ],
538 | "metadata": {
539 | "id": "LVxWUI78erIF"
540 | },
541 | "execution_count": null,
542 | "outputs": []
543 | }
544 | ],
545 | "metadata": {
546 | "colab": {
547 | "private_outputs": true,
548 | "provenance": [],
549 | "gpuType": "T4"
550 | },
551 | "kernelspec": {
552 | "display_name": "Python 3",
553 | "name": "python3"
554 | },
555 | "language_info": {
556 | "name": "python"
557 | },
558 | "accelerator": "GPU"
559 | },
560 | "nbformat": 4,
561 | "nbformat_minor": 0
562 | }
--------------------------------------------------------------------------------
/convert_pytorch/DIS_segmentation_and_quantization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "id": "lWoqui4egB0q"
8 | },
9 | "outputs": [],
10 | "source": [
11 | "# Copyright 2024 The AI Edge Torch Authors.\n",
12 | "#\n",
13 | "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
14 | "# you may not use this file except in compliance with the License.\n",
15 | "# You may obtain a copy of the License at\n",
16 | "#\n",
17 | "# http://www.apache.org/licenses/LICENSE-2.0\n",
18 | "#\n",
19 | "# Unless required by applicable law or agreed to in writing, software\n",
20 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
21 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
22 | "# See the License for the specific language governing permissions and\n",
23 | "# limitations under the License.\n",
24 | "# =============================================================================="
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {
30 | "id": "Xvt-8e8eE1da"
31 | },
32 | "source": [
33 | "This demo will teach you how to convert a PyTorch [IS-Net](https://github.com/xuebinqin/DIS) model to a LiteRT model using Google's AI Edge Torch library. You will then run the newly converted `tflite` model locally using the LiteRT API, as well as learn where to find other tools for running your newly converted model on other edge hardware, including mobile devices and web browsers."
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {
39 | "id": "Mzf2MdHoG-9c"
40 | },
41 | "source": [
42 | "# Prerequisites"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "metadata": {
48 | "id": "hux_Gsc_G4nl"
49 | },
50 | "source": [
51 | "You can start by importing the necessary dependencies for converting the model, as well as some additional utilities for displaying various information as you progress through this sample."
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {
58 | "id": "l-9--DWON236"
59 | },
60 | "outputs": [],
61 | "source": [
62 | "!pip install ai-edge-torch"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "metadata": {
68 | "id": "IUMh9GRk17fV"
69 | },
70 | "source": [
71 | "You will also need to download an image to verify model functionality."
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "metadata": {
78 | "id": "6TDCmXEplIyB"
79 | },
80 | "outputs": [],
81 | "source": [
82 | "import urllib\n",
83 | "\n",
84 | "IMAGE_FILENAMES = ['astrid_happy_hike.jpg']\n",
85 | "\n",
86 | "for name in IMAGE_FILENAMES:\n",
87 | " url = f'https://storage.googleapis.com/ai-edge/models-samples/torch_converter/image_segmentation_dis/{name}'\n",
88 | " urllib.request.urlretrieve(url, name)"
89 | ]
90 | },
91 | {
92 | "cell_type": "markdown",
93 | "source": [
94 | "Optionally, you can upload your own image. If you want to do so, uncomment and run the cell below. Additionally, this will allow you to select multiple images to upload and test at each step in this colab."
95 | ],
96 | "metadata": {
97 | "id": "SDb92B05EZi1"
98 | }
99 | },
100 | {
101 | "cell_type": "code",
102 | "source": [
103 | "from google.colab import files\n",
104 | "uploaded = files.upload()\n",
105 | "\n",
106 | "for filename in uploaded:\n",
107 | " content = uploaded[filename]\n",
108 | " with open(filename, 'wb') as f:\n",
109 | " f.write(content)\n",
110 | "IMAGE_FILENAMES = list(uploaded.keys())\n",
111 | "\n",
112 | "print('Uploaded files:', IMAGE_FILENAMES)"
113 | ],
114 | "metadata": {
115 | "id": "rPDSVb3oEb19"
116 | },
117 | "execution_count": null,
118 | "outputs": []
119 | },
120 | {
121 | "cell_type": "markdown",
122 | "source": [
123 | "Now go ahead and verify that the image was loaded successfully"
124 | ],
125 | "metadata": {
126 | "id": "mgc88K-CEeb-"
127 | }
128 | },
129 | {
130 | "cell_type": "code",
131 | "source": [
132 | "import cv2\n",
133 | "from google.colab.patches import cv2_imshow\n",
134 | "import math\n",
135 | "\n",
136 | "DESIRED_HEIGHT = 480\n",
137 | "DESIRED_WIDTH = 480\n",
138 | "\n",
139 | "def resize_and_show(image):\n",
140 | " h, w = image.shape[:2]\n",
141 | " if h < w:\n",
142 | " img = cv2.resize(image, (DESIRED_WIDTH, math.floor(h/(w/DESIRED_WIDTH))))\n",
143 | " else:\n",
144 | " img = cv2.resize(image, (math.floor(w/(h/DESIRED_HEIGHT)), DESIRED_HEIGHT))\n",
145 | " cv2_imshow(img)\n",
146 | "\n",
147 | "\n",
148 | "# Preview the images.\n",
149 | "images = {name: cv2.imread(name) for name in IMAGE_FILENAMES}\n",
150 | "\n",
151 | "for name, image in images.items():\n",
152 | " print(name)\n",
153 | " resize_and_show(image)"
154 | ],
155 | "metadata": {
156 | "id": "D754sE6WEhOw"
157 | },
158 | "execution_count": null,
159 | "outputs": []
160 | },
161 | {
162 | "cell_type": "markdown",
163 | "source": [
164 | "Finally, we've written a few utility functions to help with visualizing each step in this process, as well as one function that performs inference using the various models that can be passed into it. Go ahead and run this cell now so that they're available."
165 | ],
166 | "metadata": {
167 | "id": "bqmpr-BI3tAD"
168 | }
169 | },
170 | {
171 | "cell_type": "code",
172 | "source": [
173 | " def display_two_column_images(title_1, title_2, image_1, image_2):\n",
174 | " f, ax = plt.subplots(1, 2, figsize = (7,7))\n",
175 | " ax[0].imshow(image_1)\n",
176 | " ax[1].imshow(image_2, cmap = 'gray')\n",
177 | " ax[0].set_title(title_1)\n",
178 | " ax[1].set_title(title_2)\n",
179 | " ax[0].axis('off')\n",
180 | " ax[1].axis('off')\n",
181 | " plt.tight_layout()\n",
182 | " plt.show()\n",
183 | "\n",
184 | " def display_three_column_images(title_1, title_2, title_3, image_1, image_2, image_3):\n",
185 | " f, ax = plt.subplots(1, 3, figsize = (10,10))\n",
186 | " ax[0].imshow(image_1) # Original image.\n",
187 | " ax[1].imshow(image_2, cmap = 'gray') # PT segmentation mask.\n",
188 | " ax[2].imshow(image_3, cmap = 'gray') # TFL segmentation mask.\n",
189 | " ax[0].set_title(title_1)\n",
190 | " ax[1].set_title(title_2)\n",
191 | " ax[2].set_title(title_3)\n",
192 | " ax[0].axis('off')\n",
193 | " ax[1].axis('off')\n",
194 | " ax[2].axis('off')\n",
195 | " plt.tight_layout()\n",
196 | " plt.show()\n",
197 | "\n",
198 | " def get_processed_isnet_result(model_output, original_image_hw):\n",
199 | " # Min-max normalization.\n",
200 | " output_min = model_output.min()\n",
201 | " output_max = model_output.max()\n",
202 | " result = (model_output - output_min) / (output_max - output_min)\n",
203 | "\n",
204 | " # Scale [0, 1] -> [0, 255].\n",
205 | " result = (result * 255).astype(np.uint8)\n",
206 | "\n",
207 | " # Restore original image size.\n",
208 | " result = Image.fromarray(result.squeeze(), \"L\")\n",
209 | " return result.resize(original_image_hw, Image.Resampling.BILINEAR)"
210 | ],
211 | "metadata": {
212 | "id": "fzqBAJtU3sb_"
213 | },
214 | "execution_count": null,
215 | "outputs": []
216 | },
217 | {
218 | "cell_type": "markdown",
219 | "metadata": {
220 | "id": "IBFYQIm-yFz1"
221 | },
222 | "source": [
223 | "# PyTorch model validation"
224 | ]
225 | },
226 | {
227 | "cell_type": "markdown",
228 | "metadata": {
229 | "id": "BKfMAS7ggB0w"
230 | },
231 | "source": [
232 | "Now that you have your test images and utility functions, it's time to test the original PyTorch model that will be converted to the `tflite` format. You can start by retrieving the PyTorch model from Kaggle, along with the original project from GitHub that will be used for building the model."
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "execution_count": null,
238 | "metadata": {
239 | "id": "ywS-73O6gB0x"
240 | },
241 | "outputs": [],
242 | "source": [
243 | "%cd /content\n",
244 | "!rm -rf DIS sample_data\n",
245 | "\n",
246 | "!git clone https://github.com/xuebinqin/DIS.git\n",
247 | "%cd DIS/IS-Net/\n",
248 | "\n",
249 | "!curl -o ./model.tar.gz -L https://www.kaggle.com/api/v1/models/paulruiz/dis/pyTorch/8-17-22/1/download\n",
250 | "!tar -xvf 'model.tar.gz'"
251 | ]
252 | },
253 | {
254 | "cell_type": "markdown",
255 | "metadata": {
256 | "id": "MW3TdIhyr-ds"
257 | },
258 | "source": [
259 | "Next you will load in that new model build it to run locally."
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": null,
265 | "metadata": {
266 | "id": "bvyEsyNQp7FT"
267 | },
268 | "outputs": [],
269 | "source": [
270 | "import torch\n",
271 | "from models import ISNetDIS\n",
272 | "\n",
273 | "pytorch_model_filename = 'isnet-general-use.pth'\n",
274 | "pt_model = ISNetDIS()\n",
275 | "pt_model.load_state_dict(\n",
276 | " torch.load(pytorch_model_filename, map_location=torch.device('cpu'))\n",
277 | ")\n",
278 | "pt_model.eval();"
279 | ]
280 | },
281 | {
282 | "cell_type": "markdown",
283 | "metadata": {
284 | "id": "B5d4s8SSr8wn"
285 | },
286 | "source": [
287 | "And to finish validating the original model, you can use it to run inference on the test image(s) that you loaded earlier. In this step you will save the generated PyTorch segmentation mask images so they can be compared to your LiteRT segmentation mask images later in this colab."
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": null,
293 | "metadata": {
294 | "id": "XefR4a2nGqmz"
295 | },
296 | "outputs": [],
297 | "source": [
298 | "from io import BytesIO\n",
299 | "import numpy as np\n",
300 | "from skimage import io\n",
301 | "\n",
302 | "import torch.nn as nn\n",
303 | "import torch.nn.functional as F\n",
304 | "from torchvision.transforms.functional import normalize\n",
305 | "\n",
306 | "from matplotlib import pyplot as plt\n",
307 | "\n",
308 | "MODEL_INPUT_HW = (1024, 1024)\n",
309 | "pt_result = []\n",
310 | "images = []\n",
311 | "for index in range(len(IMAGE_FILENAMES)) :\n",
312 | " images.append(io.imread('../../'+IMAGE_FILENAMES[index]))\n",
313 | "\n",
314 | " # BHWC -> BCHW.\n",
315 | " image_tensor = torch.tensor(images[index], dtype=torch.float32).permute(2, 0, 1)\n",
316 | "\n",
317 | " # Resize to meet model input size requirements.\n",
318 | " image_tensor = F.upsample(torch.unsqueeze(image_tensor, 0),\n",
319 | " MODEL_INPUT_HW, mode='bilinear').type(torch.uint8)\n",
320 | "\n",
321 | " # Scale [0, 255] -> [0, 1].\n",
322 | " pt_image = torch.divide(image_tensor, 255.0)\n",
323 | "\n",
324 | " # Normalize.\n",
325 | " pt_image = normalize(pt_image, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])\n",
326 | "\n",
327 | " # Get output with the most accurate prediction.\n",
328 | " pt_result.append(pt_model(pt_image)[0][0])\n",
329 | "\n",
330 | " # Recover the prediction spatial size to the orignal image size.\n",
331 | " pt_result[index] = F.upsample(pt_result[index], images[index].shape[:2], mode='bilinear')\n",
332 | " pt_result[index] = torch.squeeze(pt_result[index], 0)\n",
333 | "\n",
334 | " # Min-max normalization.\n",
335 | " ma = torch.max(pt_result[index])\n",
336 | " mi = torch.min(pt_result[index])\n",
337 | " pt_result[index] = (pt_result[index] - mi) / (ma - mi)\n",
338 | "\n",
339 | " # Scale [0, 1] -> [0, 255].\n",
340 | " pt_result[index] = pt_result[index] * 255\n",
341 | "\n",
342 | " # BCHW -> BHWC.\n",
343 | " pt_result[index] = pt_result[index].permute(1, 2, 0)\n",
344 | "\n",
345 | " # Get numpy array.\n",
346 | " pt_result[index] = pt_result[index].cpu().data.numpy().astype(np.uint8)\n",
347 | "\n",
348 | " display_two_column_images('Original Image', 'Mask', images[index], pt_result[index])"
349 | ]
350 | },
351 | {
352 | "cell_type": "markdown",
353 | "source": [
354 | "# Convert to LiteRT"
355 | ],
356 | "metadata": {
357 | "id": "2v1rYRI68Jdl"
358 | }
359 | },
360 | {
361 | "cell_type": "markdown",
362 | "metadata": {
363 | "id": "qk7zWa2S7eLU"
364 | },
365 | "source": [
366 | "## Add model wrapper"
367 | ]
368 | },
369 | {
370 | "cell_type": "markdown",
371 | "metadata": {
372 | "id": "AHO3K-kXXHWp"
373 | },
374 | "source": [
375 | "The original IS-Net model generates 12 outputs, each corresponding to different stages in the segmentation process. While the official PyTorch model demo provides guidance on selecting the final (best) output, obtaining the desired output from the converted LiteRT model requires additional effort.\n",
376 | "\n",
377 | "One of the methods you can use to get to this final output is to download the `tflite` file after the conversion step in this colab, open it with [Model Explorer](https://ai.google.dev/edge/model-explorer) and confirm which output in the graph has the expected output shape.\n",
378 | "\n",
379 | "That's kind of a lot for this example, so to simplify the process and eliminate this effort, you can use a wrapper for the PyTorch model that narrows the scope to only the final output. This approach ensures that your new LiteRT model has only a single output after the conversion stage.\n",
380 | "\n",
381 | "Additionally, this colab include some extra pre and post-processing steps, such as excluding min-max normalization because `torch.min` and `torch.max` are not currently supported in the conversion process.\n",
382 | "\n",
383 | "You can create the wrapper by running the following cell:"
384 | ]
385 | },
386 | {
387 | "cell_type": "code",
388 | "execution_count": null,
389 | "metadata": {
390 | "id": "mr2XESVJGucI"
391 | },
392 | "outputs": [],
393 | "source": [
394 | "class ImageSegmentationModelWrapper(nn.Module):\n",
395 | "\n",
396 | " RESCALING_FACTOR = 255.0\n",
397 | " MEAN = 0.5\n",
398 | " STD = 1.0\n",
399 | "\n",
400 | " def __init__(self, pt_model):\n",
401 | " super().__init__()\n",
402 | " self.model = pt_model\n",
403 | "\n",
404 | " def forward(self, image: torch.Tensor):\n",
405 | " # BHWC -> BCHW.\n",
406 | " image = image.permute(0, 3, 1, 2)\n",
407 | "\n",
408 | " # Rescale [0, 255] -> [0, 1].\n",
409 | " image = image / self.RESCALING_FACTOR\n",
410 | "\n",
411 | " # Normalize.\n",
412 | " image = (image - self.MEAN) / self.STD\n",
413 | "\n",
414 | " # Get result.\n",
415 | " result = self.model(image)[0][0]\n",
416 | "\n",
417 | " # BHWC -> BCHW.\n",
418 | " result = result.permute(0, 2, 3, 1)\n",
419 | "\n",
420 | " return result\n",
421 | "\n",
422 | "\n",
423 | "wrapped_pt_model = ImageSegmentationModelWrapper(pt_model).eval()"
424 | ]
425 | },
426 | {
427 | "cell_type": "markdown",
428 | "metadata": {
429 | "id": "GMBNfgcV7k0f"
430 | },
431 | "source": [
432 | "## Convert to LiteRT"
433 | ]
434 | },
435 | {
436 | "cell_type": "markdown",
437 | "metadata": {
438 | "id": "T2MnULes70W0"
439 | },
440 | "source": [
441 | "Provide sample arguments -- result LiteRT model will expect input of this size -- and convert the model.\n",
442 | "\n",
443 | "Now it's time to perform the conversion! You will need to provide a couple arguments, such as the expected input shape (for example: 1, model input height, model input width, and 3 for the RGB layers of an image) and the wrapper that you created in the last step."
444 | ]
445 | },
446 | {
447 | "cell_type": "code",
448 | "execution_count": null,
449 | "metadata": {
450 | "id": "XOfNPYpnLGrp"
451 | },
452 | "outputs": [],
453 | "source": [
454 | "import ai_edge_torch\n",
455 | "\n",
456 | "sample_args = (torch.rand((1, *MODEL_INPUT_HW, 3)),)\n",
457 | "edge_model = ai_edge_torch.convert(wrapped_pt_model, sample_args)"
458 | ]
459 | },
460 | {
461 | "cell_type": "markdown",
462 | "metadata": {
463 | "id": "e7II2a_389DH"
464 | },
465 | "source": [
466 | "# Validate converted model with LiteRT Interpreter"
467 | ]
468 | },
469 | {
470 | "cell_type": "markdown",
471 | "metadata": {
472 | "id": "F65_ULYRLkTY"
473 | },
474 | "source": [
475 | "Now that you have a converted model stored in colab, it's time to test it. You can start by preparing the test image(s) that you loaded earlier. Since all of the preprocessing steps were into the model earlier, you will only need to resize and type cast the input image(s) in this step. At the end of this stage you should see the original image, the PyTorch mask graphic, and the LiteRT mask graphic for your test input."
476 | ]
477 | },
478 | {
479 | "cell_type": "code",
480 | "execution_count": null,
481 | "metadata": {
482 | "id": "yQBmo3uqMC8p"
483 | },
484 | "outputs": [],
485 | "source": [
486 | "from PIL import Image\n",
487 | "\n",
488 | "np_images = []\n",
489 | "image_sizes = []\n",
490 | "for index in range(len(IMAGE_FILENAMES)) :\n",
491 | " # Retrieve each image from the file system\n",
492 | " image = Image.open('../../' + IMAGE_FILENAMES[index])\n",
493 | " # Track each image's size here to simplify displaying later\n",
494 | " image_sizes.append(image.size)\n",
495 | " # Convert each image into a NumPy array and save for later\n",
496 | " np_images.append(np.array(image.resize(MODEL_INPUT_HW, Image.Resampling.BILINEAR)))\n",
497 | " np_images[index] = np.expand_dims(np_images[index], axis=0).astype(np.float32)\n",
498 | "\n",
499 | " # Retrieve an output from the converted model\n",
500 | " edge_model_output = edge_model(np_images[index])\n",
501 | "\n",
502 | " # Use the visualization utility created earlier to get a displayable image\n",
503 | " lrt_result = get_processed_isnet_result(edge_model_output, image_sizes[index])\n",
504 | "\n",
505 | " display_three_column_images('Original Image', 'PT Mask', 'TFL Mask', images[index], pt_result[index], lrt_result)"
506 | ]
507 | },
508 | {
509 | "cell_type": "markdown",
510 | "metadata": {
511 | "id": "jVcrUu9aaP9W"
512 | },
513 | "source": [
514 | "# Post Training and Dynamic-Range Quantization with LiteRT"
515 | ]
516 | },
517 | {
518 | "cell_type": "markdown",
519 | "source": [
520 | "At this point you should have a working `tflite` model that you have converted from the original PyTorch format. Congratulations! But if you're working with edge devices, then you likely know that model size is an **important** consideration for things like mobile devices. Using a technique called *quantization*, you can reduce a model's size to roughly a quarter of the original size while maintaining a similar level of output quality. To do this with the Google AI Edge PyTorch Converter, you can pass in an optimization flag to the `convert` function to include a step for dynamic-range quantization.\n",
521 | "\n",
522 | "If you'd like to know more about quantization and other optimizations, you can find our official documentation [here](https://www.tensorflow.org/lite/performance/post_training_quantization)."
523 | ],
524 | "metadata": {
525 | "id": "dKQilmWqnzpV"
526 | }
527 | },
528 | {
529 | "cell_type": "code",
530 | "execution_count": null,
531 | "metadata": {
532 | "id": "UDmkx7zLaXn8"
533 | },
534 | "outputs": [],
535 | "source": [
536 | "import tensorflow as tf\n",
537 | "\n",
538 | "\n",
539 | "tfl_converter_flags={\n",
540 | " \"optimizations\": [tf.lite.Optimize.DEFAULT]\n",
541 | "}\n",
542 | "tfl_drq_model = ai_edge_torch.convert(\n",
543 | " wrapped_pt_model,\n",
544 | " sample_args,\n",
545 | " _ai_edge_converter_flags=tfl_converter_flags\n",
546 | ")"
547 | ]
548 | },
549 | {
550 | "cell_type": "markdown",
551 | "source": [
552 | "After the conversion has finished, you can compare the newly converted and quantized model with the original image and PyTorch mask image from earlier."
553 | ],
554 | "metadata": {
555 | "id": "y-Ty1YsvozQm"
556 | }
557 | },
558 | {
559 | "cell_type": "code",
560 | "execution_count": null,
561 | "metadata": {
562 | "id": "6fJtHyxbaejb"
563 | },
564 | "outputs": [],
565 | "source": [
566 | "for index in range(len(IMAGE_FILENAMES)) :\n",
567 | "\n",
568 | " tfl_drq_model_output = tfl_drq_model(np_images[index])\n",
569 | "\n",
570 | " tfl_drq_result = get_processed_isnet_result(tfl_drq_model_output, image_sizes[index])\n",
571 | "\n",
572 | " display_three_column_images('Original Image', 'PT Mask', 'TFLQ Mask', images[index], pt_result[index], tfl_drq_result)"
573 | ]
574 | },
575 | {
576 | "cell_type": "markdown",
577 | "source": [
578 | "# Post Training and Dynamic-Range Quantization with PT2E"
579 | ],
580 | "metadata": {
581 | "id": "me3y_PzayhyM"
582 | }
583 | },
584 | {
585 | "cell_type": "markdown",
586 | "source": [
587 | "Another available option for dynamic-range quantization is called PT2E, which is a framework-level quantization feature available in PyTorch 2.0. For more details see [PyTorch tutorial](https://pytorch.org/tutorials/prototype/quantization_in_pytorch_2_0_export_tutorial.html).\n",
588 | "\n",
589 | "PT2EQuantizer is developed specifically for the AI Edge Torch framework and is configured to quantize models leveraging various operators and kernals offered by the LiteRT Runtime.\n",
590 | "\n",
591 | "You can see how to configure the PT2EQuantizer and use it as an additional parameter in the `convert` function below."
592 | ],
593 | "metadata": {
594 | "id": "Xv6hZkvqmdHj"
595 | }
596 | },
597 | {
598 | "cell_type": "code",
599 | "source": [
600 | "from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config\n",
601 | "from ai_edge_torch.quantize.pt2e_quantizer import PT2EQuantizer\n",
602 | "from ai_edge_torch.quantize.quant_config import QuantConfig\n",
603 | "\n",
604 | "from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e\n",
605 | "from torch._export import capture_pre_autograd_graph\n",
606 | "\n",
607 | "\n",
608 | "pt2e_quantizer = PT2EQuantizer().set_global(\n",
609 | " get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True)\n",
610 | ")\n",
611 | "\n",
612 | "# Following are the required steps recommended in the PT2E quantization\n",
613 | "# workflow.\n",
614 | "autograd_torch_model = capture_pre_autograd_graph(wrapped_pt_model, sample_args)\n",
615 | "# 1. Prepare for quantization.\n",
616 | "pt2e_torch_model = prepare_pt2e(autograd_torch_model, pt2e_quantizer)\n",
617 | "# 2. Run the prepared model with sample input data to ensure that internal\n",
618 | "# observers are populated with correct values.\n",
619 | "pt2e_torch_model(*sample_args)\n",
620 | "# 3. Finally, convert (quantize) the prepared model.\n",
621 | "pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)\n",
622 | "\n",
623 | "pt2e_drq_model = ai_edge_torch.convert(\n",
624 | " pt2e_torch_model,\n",
625 | " sample_args,\n",
626 | " quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer)\n",
627 | ")"
628 | ],
629 | "metadata": {
630 | "id": "-AjWfYAoy3D8"
631 | },
632 | "execution_count": null,
633 | "outputs": []
634 | },
635 | {
636 | "cell_type": "markdown",
637 | "source": [
638 | "Once the model has been converted again using the PT2E Quantizer, it's time to review the results so you can compare them to both the original image and the PyTorch inferred mask."
639 | ],
640 | "metadata": {
641 | "id": "ZxQDTvQZnlp0"
642 | }
643 | },
644 | {
645 | "cell_type": "code",
646 | "source": [
647 | "for index in range(len(IMAGE_FILENAMES)) :\n",
648 | "\n",
649 | " pt2e_drq_output = pt2e_drq_model(np_images[index])\n",
650 | "\n",
651 | " pt2e_drq_result = get_processed_isnet_result(pt2e_drq_output, image_sizes[index])\n",
652 | "\n",
653 | " display_three_column_images('Original Image', 'PT Mask', 'PT2E DRQ Mask', images[index], pt_result[index], pt2e_drq_result)"
654 | ],
655 | "metadata": {
656 | "id": "PipB5Og-0dx1"
657 | },
658 | "execution_count": null,
659 | "outputs": []
660 | },
661 | {
662 | "cell_type": "markdown",
663 | "metadata": {
664 | "id": "3AOmkXUaBVUb"
665 | },
666 | "source": [
667 | "# Download converted models"
668 | ]
669 | },
670 | {
671 | "cell_type": "markdown",
672 | "source": [
673 | "Now that you've converted and optimized the DIS model for LiteRT, it's time to save those models. The following cells are set up to download three models: the newly converted `tflite` model without optimizations, the converted model using dynamic range quantization, and the model that uses PT2E quantization. When you've finished downloading these files, check out their finished sizes! You'll notice that the original converted model is about 175MB in size, whereas the quantized models are about 45MB - much more manageable for edge devices!"
674 | ],
675 | "metadata": {
676 | "id": "TetqIwvEnu__"
677 | }
678 | },
679 | {
680 | "cell_type": "code",
681 | "execution_count": null,
682 | "metadata": {
683 | "id": "mY00XJQ1BZP3"
684 | },
685 | "outputs": [],
686 | "source": [
687 | "from google.colab import files\n",
688 | "\n",
689 | "tfl_filename = \"isnet.tflite\"\n",
690 | "edge_model.export(tfl_filename)\n",
691 | "\n",
692 | "files.download(tfl_filename)"
693 | ]
694 | },
695 | {
696 | "cell_type": "code",
697 | "execution_count": null,
698 | "metadata": {
699 | "id": "XgFa0lDSd7Z5"
700 | },
701 | "outputs": [],
702 | "source": [
703 | "tfl_drq_filename = 'isnet_tfl_drq.tflite'\n",
704 | "tfl_drq_model.export(tfl_drq_filename)\n",
705 | "\n",
706 | "files.download(tfl_drq_filename)"
707 | ]
708 | },
709 | {
710 | "cell_type": "code",
711 | "source": [
712 | "pt2e_drq_filename = 'isnet_pt2e_drq.tflite'\n",
713 | "pt2e_drq_model.export(pt2e_drq_filename)\n",
714 | "\n",
715 | "files.download(pt2e_drq_filename)"
716 | ],
717 | "metadata": {
718 | "id": "-NGABbj-0hiZ"
719 | },
720 | "execution_count": null,
721 | "outputs": []
722 | },
723 | {
724 | "cell_type": "markdown",
725 | "source": [
726 | "# Next steps\n",
727 | "\n",
728 | "Now that you have learned how to convert this segmentation model from the PyTorch to tflite format, it's time to do more with it! You can go over additional LiteRT API samples for multiple platforms, including Android, iOS, and Python, as well as learn more about on-device machine learning inference from the [Google AI Edge official documentation](https://ai.google.dev/edge/). You can find samples to run the output from this colab on Android [here](https://github.com/google-ai-edge/litert-samples/tree/main/examples/image_segmentation_DIS/android) and on iOS [here](https://github.com/google-ai-edge/litert-samples/tree/main/examples/image_segmentation_DIS/ios)."
729 | ],
730 | "metadata": {
731 | "id": "t40oY48RogjW"
732 | }
733 | }
734 | ],
735 | "metadata": {
736 | "colab": {
737 | "provenance": []
738 | },
739 | "kernelspec": {
740 | "display_name": "Python 3 (ipykernel)",
741 | "language": "python",
742 | "name": "python3"
743 | },
744 | "language_info": {
745 | "codemirror_mode": {
746 | "name": "ipython",
747 | "version": 3
748 | },
749 | "file_extension": ".py",
750 | "mimetype": "text/x-python",
751 | "name": "python",
752 | "nbconvert_exporter": "python",
753 | "pygments_lexer": "ipython3",
754 | "version": "3.11.8"
755 | }
756 | },
757 | "nbformat": 4,
758 | "nbformat_minor": 0
759 | }
760 |
--------------------------------------------------------------------------------
/convert_pytorch/Image_classification_with_convnext_v2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "gpuType": "T4",
8 | "name": "Image_classification_with_convnext_v2.ipynb"
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | },
17 | "accelerator": "GPU"
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "code",
22 | "source": [
23 | "# Copyright 2025 The AI Edge Torch Authors.\n",
24 | "#\n",
25 | "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
26 | "# you may not use this file except in compliance with the License.\n",
27 | "# You may obtain a copy of the License at\n",
28 | "#\n",
29 | "# http://www.apache.org/licenses/LICENSE-2.0\n",
30 | "#\n",
31 | "# Unless required by applicable law or agreed to in writing, software\n",
32 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
33 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
34 | "# See the License for the specific language governing permissions and\n",
35 | "# limitations under the License.\n",
36 | "# =============================================================================="
37 | ],
38 | "metadata": {
39 | "id": "r4lisalb-A5R"
40 | },
41 | "execution_count": null,
42 | "outputs": []
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "source": [
47 | "This demo will teach you how to convert a PyTorch [ConvNext V2](https://huggingface.co/docs/transformers/en/model_doc/convnextv2#overview) model to a LiteRT (formally TensorFlow Lite) model using Google's AI Edge Torch library."
48 | ],
49 | "metadata": {
50 | "id": "LwrH6f2sGJ6U"
51 | }
52 | },
53 | {
54 | "cell_type": "markdown",
55 | "source": [
56 | "# Prerequisites"
57 | ],
58 | "metadata": {
59 | "id": "Mzf2MdHoG-9c"
60 | }
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "source": [
65 | "Before starting the conversion process, ensure that you have all the necessary dependencies installed and required resources (like test images) available. You can start by importing the necessary dependencies for converting the model, as well as some additional utilities for displaying various information as you progress through this sample."
66 | ],
67 | "metadata": {
68 | "id": "hux_Gsc_G4nl"
69 | }
70 | },
71 | {
72 | "cell_type": "code",
73 | "source": [
74 | "!pip install ai-edge-torch-nightly\n",
75 | "!pip install transformers pillow requests matplotlib"
76 | ],
77 | "metadata": {
78 | "id": "l-9--DWON236"
79 | },
80 | "execution_count": null,
81 | "outputs": []
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "source": [
86 | "You will also need to download an image to verify model functionality."
87 | ],
88 | "metadata": {
89 | "id": "IUMh9GRk17fV"
90 | }
91 | },
92 | {
93 | "cell_type": "code",
94 | "source": [
95 | "import urllib\n",
96 | "\n",
97 | "IMAGE_FILENAMES = ['cat.jpg']\n",
98 | "\n",
99 | "for name in IMAGE_FILENAMES:\n",
100 | " # TODO: Update path to the appropriate task subfolder in the GCS bucket\n",
101 | " url = f'https://storage.googleapis.com/ai-edge/models-samples/torch_converter/image_classification_mobile_vit/{name}'\n",
102 | " urllib.request.urlretrieve(url, name)"
103 | ],
104 | "metadata": {
105 | "id": "lfdgp-4Id51J"
106 | },
107 | "execution_count": null,
108 | "outputs": []
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "source": [
113 | "Optionally, you can upload your own image. If you want to do so, uncomment and run the cell below."
114 | ],
115 | "metadata": {
116 | "id": "XYQeTVp-qqZ0"
117 | }
118 | },
119 | {
120 | "cell_type": "code",
121 | "source": [
122 | "# from google.colab import files\n",
123 | "# uploaded = files.upload()\n",
124 | "\n",
125 | "# for filename in uploaded:\n",
126 | "# content = uploaded[filename]\n",
127 | "# with open(filename, 'wb') as f:\n",
128 | "# f.write(content)\n",
129 | "# IMAGE_FILENAMES = list(uploaded.keys())\n",
130 | "\n",
131 | "# print('Uploaded files:', IMAGE_FILENAMES)"
132 | ],
133 | "metadata": {
134 | "id": "8X6tiqVGqq0l"
135 | },
136 | "execution_count": null,
137 | "outputs": []
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "source": [
142 | "Now go ahead and verify that the image was loaded successfully"
143 | ],
144 | "metadata": {
145 | "id": "RIGfyYzcVkIB"
146 | }
147 | },
148 | {
149 | "cell_type": "code",
150 | "source": [
151 | "import cv2\n",
152 | "from google.colab.patches import cv2_imshow\n",
153 | "import math\n",
154 | "\n",
155 | "DESIRED_HEIGHT = 480\n",
156 | "DESIRED_WIDTH = 480\n",
157 | "\n",
158 | "def resize_and_show(image):\n",
159 | " h, w = image.shape[:2]\n",
160 | " if h < w:\n",
161 | " img = cv2.resize(image, (DESIRED_WIDTH, math.floor(h/(w/DESIRED_WIDTH))))\n",
162 | " else:\n",
163 | " img = cv2.resize(image, (math.floor(w/(h/DESIRED_HEIGHT)), DESIRED_HEIGHT))\n",
164 | " cv2_imshow(img)\n",
165 | "\n",
166 | "\n",
167 | "# Preview the images.\n",
168 | "images = {name: cv2.imread(name) for name in IMAGE_FILENAMES}\n",
169 | "\n",
170 | "for name, image in images.items():\n",
171 | " print(name)\n",
172 | " resize_and_show(image)"
173 | ],
174 | "metadata": {
175 | "id": "-GMYmZ5jVq6t"
176 | },
177 | "execution_count": null,
178 | "outputs": []
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "source": [
183 | "# PyTorch model validation"
184 | ],
185 | "metadata": {
186 | "id": "IBFYQIm-yFz1"
187 | }
188 | },
189 | {
190 | "cell_type": "markdown",
191 | "source": [
192 | "Now that you have your test images, it's time to validate the PyTorch model (in this case ConvNext V2) that will be converted to the LiteRT format.\n",
193 | "\n",
194 | "Start by retrieving the PyTorch model and the appropriate corresponding processor."
195 | ],
196 | "metadata": {
197 | "id": "g7qbJRCcvQJt"
198 | }
199 | },
200 | {
201 | "cell_type": "code",
202 | "source": [
203 | "from transformers import ConvNextImageProcessor, ConvNextV2ForImageClassification\n",
204 | "\n",
205 | "# Define the Hugging Face model path\n",
206 | "hf_model_path = 'facebook/convnextv2-tiny-1k-224'\n",
207 | "\n",
208 | "# Initialize the image processor\n",
209 | "processor = ConvNextImageProcessor.from_pretrained(\n",
210 | " hf_model_path\n",
211 | ")"
212 | ],
213 | "metadata": {
214 | "id": "flLiQaaL6tU5"
215 | },
216 | "execution_count": null,
217 | "outputs": []
218 | },
219 | {
220 | "cell_type": "code",
221 | "source": [
222 | "# Display the image normalization parameters\n",
223 | "print(\"Image Mean:\", processor.image_mean)\n",
224 | "print(\"Image Std:\", processor.image_std)"
225 | ],
226 | "metadata": {
227 | "id": "RG1y5A2ifk_f"
228 | },
229 | "execution_count": null,
230 | "outputs": []
231 | },
232 | {
233 | "cell_type": "code",
234 | "source": [
235 | "pt_model = ConvNextV2ForImageClassification.from_pretrained(hf_model_path)"
236 | ],
237 | "metadata": {
238 | "id": "rh9oZFK2iydm"
239 | },
240 | "execution_count": null,
241 | "outputs": []
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "source": [
246 | "The `ConvNextImageProcessor` defined below will perform multiple steps on the input image to match the requirements of the `ConvNextV2` model:\n",
247 | "\n",
248 | "1. Rescale the image from the [0, 255] range to the range specified by the pretrained model.\n",
249 | "2. Resize input image to 224x224 pixels. Differes from default behaviour of processor (includes padding and center cropping) to make it easier to validate the converted model with LiteRT (more details in the corresponding section)."
250 | ],
251 | "metadata": {
252 | "id": "4Mik12qkNEU_"
253 | }
254 | },
255 | {
256 | "cell_type": "code",
257 | "source": [
258 | "from PIL import Image\n",
259 | "\n",
260 | "images = []\n",
261 | "for filename in IMAGE_FILENAMES:\n",
262 | " images.append(Image.open(filename))\n",
263 | "\n",
264 | "inputs = processor(\n",
265 | " images=images,\n",
266 | " return_tensors='pt',\n",
267 | " # Adjusts the image to have the shortest edge of 224 pixels\n",
268 | " size={'shortest_edge': 224},\n",
269 | " do_center_crop=False\n",
270 | ")"
271 | ],
272 | "metadata": {
273 | "id": "_-WmB2MYWc-P"
274 | },
275 | "execution_count": null,
276 | "outputs": []
277 | },
278 | {
279 | "cell_type": "markdown",
280 | "source": [
281 | "Now that you have your test data ready and the inputs processed, it's time to validate the classifications. In this step you will loop through your test image(s) and display the top 5 predicted classification categories. This model was trained with ImageNet-1000, so there are 1000 different potential classifications that may be applied to your test data."
282 | ],
283 | "metadata": {
284 | "id": "ZAQG5SSVzVi2"
285 | }
286 | },
287 | {
288 | "cell_type": "code",
289 | "source": [
290 | "import torch\n",
291 | "from torch import nn\n",
292 | "\n",
293 | "for image_index in range(len(IMAGE_FILENAMES)) :\n",
294 | " outputs = pt_model(**inputs)\n",
295 | " logits = outputs.logits\n",
296 | " probs, indices = nn.functional.softmax(logits[image_index], dim=-1).flatten().topk(k=5)\n",
297 | "\n",
298 | " print(IMAGE_FILENAMES[image_index], 'predictions: ')\n",
299 | " for prediction_index in range(len(indices)):\n",
300 | " class_label = pt_model.config.id2label[indices[prediction_index].item()]\n",
301 | " prob = probs[prediction_index].item()\n",
302 | " print(f'{(prob * 100):4.1f}% {class_label}')\n",
303 | " print('\\n')"
304 | ],
305 | "metadata": {
306 | "id": "ofbZW6nVzSrS"
307 | },
308 | "execution_count": null,
309 | "outputs": []
310 | },
311 | {
312 | "cell_type": "markdown",
313 | "source": [
314 | "# Convert to the `tflite` Format"
315 | ],
316 | "metadata": {
317 | "id": "pfJkS3bH7Jpw"
318 | }
319 | },
320 | {
321 | "cell_type": "markdown",
322 | "source": [
323 | "Before converting the PyTorch model to work with the tflite format, you will need to take an extra step to match it to the format expected by LiteRT. Here are the necessary adjustments:\n",
324 | "\n",
325 | "1. **Channel Ordering**: Convert images from channel-first (BCHW) to channel-last (BHWC) format.\n",
326 | "2. **Softmax Layer**: Add a softmax layer to the classification logits as required by LiteRT as this is an image classification task.\n",
327 | "3. **Preprocessing Wrapper**: Incorporate preprocessing steps (e.g., RGB to BGR conversion, scaling, normalization) into a wrapper class, similar to what you did when validating the PyTorch model in the previous section.\n"
328 | ],
329 | "metadata": {
330 | "id": "ci-8lp_55TLi"
331 | }
332 | },
333 | {
334 | "cell_type": "code",
335 | "source": [
336 | "class HF2LiteRT_ImageClassificationModelWrapper(nn.Module):\n",
337 | "\n",
338 | " def __init__(self, hf_image_classification_model, hf_processor):\n",
339 | " super().__init__()\n",
340 | " self.model = hf_image_classification_model\n",
341 | " if hf_processor.do_rescale:\n",
342 | " self.rescale_factor = hf_processor.rescale_factor\n",
343 | " else:\n",
344 | " self.rescale_factor = 1.0\n",
345 | "\n",
346 | " # Initialize image_mean and image_std as instance variables\n",
347 | " self.image_mean = torch.tensor(hf_processor.image_mean).view(1, -1, 1, 1) # Shape: [1, C, 1, 1]\n",
348 | " self.image_std = torch.tensor(hf_processor.image_std).view(1, -1, 1, 1) # Shape: [1, C, 1, 1]\n",
349 | "\n",
350 | " def forward(self, image: torch.Tensor):\n",
351 | " # BHWC -> BCHW.\n",
352 | " image = image.permute(0, 3, 1, 2)\n",
353 | " # Scale [0, 255] -> [0, 1].\n",
354 | " image = image * self.rescale_factor\n",
355 | " # Normalize\n",
356 | " image = (image - self.image_mean) / self.image_std\n",
357 | " logits = self.model(pixel_values=image).logits # [B, 1000] float32.\n",
358 | " # Softmax is required for MediaPipe classification model.\n",
359 | " logits = torch.nn.functional.softmax(logits, dim=-1)\n",
360 | "\n",
361 | " return logits\n",
362 | "\n",
363 | "\n",
364 | "hf_convnext_v2_processor = ConvNextImageProcessor.from_pretrained(hf_model_path)\n",
365 | "hf_convnext_v2_model = ConvNextV2ForImageClassification.from_pretrained(hf_model_path)\n",
366 | "wrapped_pt_model = HF2LiteRT_ImageClassificationModelWrapper(\n",
367 | " hf_convnext_v2_model, hf_convnext_v2_processor\n",
368 | ").eval()"
369 | ],
370 | "metadata": {
371 | "id": "NlBmvShe4Mt0"
372 | },
373 | "execution_count": null,
374 | "outputs": []
375 | },
376 | {
377 | "cell_type": "markdown",
378 | "source": [
379 | "## Convert to `TFLite`"
380 | ],
381 | "metadata": {
382 | "id": "GMBNfgcV7k0f"
383 | }
384 | },
385 | {
386 | "cell_type": "markdown",
387 | "source": [
388 | "Now it's time to perform the conversion! You will need to provide simple arguments, such as the expected input shape (in this case three layers for images that are 224 height by 224 width)."
389 | ],
390 | "metadata": {
391 | "id": "T2MnULes70W0"
392 | }
393 | },
394 | {
395 | "cell_type": "code",
396 | "source": [
397 | "import ai_edge_torch\n",
398 | "\n",
399 | "sample_args = (torch.rand((1, 224, 224, 3)),)\n",
400 | "edge_model = ai_edge_torch.convert(wrapped_pt_model, sample_args)"
401 | ],
402 | "metadata": {
403 | "id": "XOfNPYpnLGrp"
404 | },
405 | "execution_count": null,
406 | "outputs": []
407 | },
408 | {
409 | "cell_type": "markdown",
410 | "source": [
411 | "## Export the Converted Model\n",
412 | "\n",
413 | "Running the following saves the converted model as a **FlatBuffer** file, which is compatible with **LiteRT**.\n"
414 | ],
415 | "metadata": {
416 | "id": "yr6Lhls93tNO"
417 | }
418 | },
419 | {
420 | "cell_type": "code",
421 | "source": [
422 | "from pathlib import Path\n",
423 | "\n",
424 | "TFLITE_MODEL_PATH = 'hf_convnext_v2_mp_image_classification_raw.tflite'\n",
425 | "flatbuffer_file = Path(TFLITE_MODEL_PATH)\n",
426 | "edge_model.export(flatbuffer_file)"
427 | ],
428 | "metadata": {
429 | "id": "cj-DW_2S4aUf"
430 | },
431 | "execution_count": null,
432 | "outputs": []
433 | },
434 | {
435 | "cell_type": "markdown",
436 | "source": [
437 | "# Validate converted model with LiteRT"
438 | ],
439 | "metadata": {
440 | "id": "e7II2a_389DH"
441 | }
442 | },
443 | {
444 | "cell_type": "markdown",
445 | "source": [
446 | "Now it's time to test your newly converted model directly with the LiteRT Interpreter API. Before getting into that code, you can add the following utility functions to improve the output displayed."
447 | ],
448 | "metadata": {
449 | "id": "-3kFtIGK_1qi"
450 | }
451 | },
452 | {
453 | "cell_type": "code",
454 | "source": [
455 | "#@markdown Functions to visualize the image classification results.
Run this cell to activate the functions.\n",
456 | "\n",
457 | "from matplotlib import pyplot as plt\n",
458 | "plt.rcParams.update({\n",
459 | " 'axes.spines.top': False,\n",
460 | " 'axes.spines.right': False,\n",
461 | " 'axes.spines.left': False,\n",
462 | " 'axes.spines.bottom': False,\n",
463 | " 'xtick.labelbottom': False,\n",
464 | " 'xtick.bottom': False,\n",
465 | " 'ytick.labelleft': False,\n",
466 | " 'ytick.left': False,\n",
467 | " 'xtick.labeltop': False,\n",
468 | " 'xtick.top': False,\n",
469 | " 'ytick.labelright': False,\n",
470 | " 'ytick.right': False\n",
471 | "})\n",
472 | "\n",
473 | "\n",
474 | "def display_one_image(image, title, subplot, titlesize=16):\n",
475 | " \"\"\"Displays one image along with the predicted category name and score.\"\"\"\n",
476 | " plt.subplot(*subplot)\n",
477 | " plt.imshow(image)\n",
478 | " if len(title) > 0:\n",
479 | " plt.title(title, fontsize=int(titlesize), color='black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))\n",
480 | " return (subplot[0], subplot[1], subplot[2]+1)\n",
481 | "\n",
482 | "def display_batch_of_images(images, predictions):\n",
483 | " \"\"\"Displays a batch of images with the classifications.\"\"\"\n",
484 | " # Auto-squaring: this will drop data that does not fit into square or square-ish rectangle.\n",
485 | " rows = int(math.sqrt(len(images)))\n",
486 | " cols = len(images) // rows\n",
487 | "\n",
488 | " # Size and spacing.\n",
489 | " FIGSIZE = 13.0\n",
490 | " SPACING = 0.1\n",
491 | " subplot=(rows,cols, 1)\n",
492 | " if rows < cols:\n",
493 | " plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))\n",
494 | " else:\n",
495 | " plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))\n",
496 | "\n",
497 | " # Display.\n",
498 | " for i, (image, prediction) in enumerate(zip(images[:rows*cols], predictions[:rows*cols])):\n",
499 | " dynamic_titlesize = FIGSIZE * SPACING / max(rows,cols) * 40 + 3\n",
500 | " subplot = display_one_image(image, prediction, subplot, titlesize=dynamic_titlesize)\n",
501 | "\n",
502 | " # Layout.\n",
503 | " plt.tight_layout()\n",
504 | " plt.subplots_adjust(wspace=SPACING, hspace=SPACING)\n",
505 | " plt.show()"
506 | ],
507 | "metadata": {
508 | "id": "mV2Uo2yg2Lw4",
509 | "cellView": "form"
510 | },
511 | "execution_count": null,
512 | "outputs": []
513 | },
514 | {
515 | "cell_type": "markdown",
516 | "source": [
517 | "## Inference with LiteRT Interpreter\n",
518 | "\n",
519 | "Now it's time to move on to the actual inference code and display the highest confidence classification result. Let's now run inference using the converted LiteRT model and compare the results with the original PyTorch model."
520 | ],
521 | "metadata": {
522 | "id": "5_HCsDSI3cOe"
523 | }
524 | },
525 | {
526 | "cell_type": "code",
527 | "source": [
528 | "import numpy as np\n",
529 | "\n",
530 | "# Load the LiteRT model and allocate tensors.\n",
531 | "from ai_edge_litert.interpreter import Interpreter\n",
532 | "\n",
533 | "# Path to the converted LiteRT model\n",
534 | "TFLITE_MODEL_PATH = 'hf_convnext_v2_mp_image_classification_raw.tflite'\n",
535 | "\n",
536 | "# Initialize the LiteRT interpreter\n",
537 | "interpreter = Interpreter(TFLITE_MODEL_PATH)\n",
538 | "interpreter.allocate_tensors()\n",
539 | "\n",
540 | "# Get input and output tensor details\n",
541 | "input_details = interpreter.get_input_details()\n",
542 | "output_details = interpreter.get_output_details()\n",
543 | "\n",
544 | "print(\"LiteRT Model Input Details:\", input_details)\n",
545 | "print(\"LiteRT Model Output Details:\", output_details)"
546 | ],
547 | "metadata": {
548 | "id": "IuHqd7414MbX"
549 | },
550 | "execution_count": null,
551 | "outputs": []
552 | },
553 | {
554 | "cell_type": "markdown",
555 | "source": [
556 | "## Define Preprocessing and Postprocessing Functions\n",
557 | "Prepare functions to preprocess images for LiteRT and to extract top predictions."
558 | ],
559 | "metadata": {
560 | "id": "u34Rw0VX3NAa"
561 | }
562 | },
563 | {
564 | "cell_type": "code",
565 | "source": [
566 | "def preprocess_image_lite(image_path, size=(224, 224)):\n",
567 | " \"\"\"\n",
568 | " Loads an image, resizes it to the specified size, and converts it to a NumPy array.\n",
569 | " \"\"\"\n",
570 | " image = Image.open(image_path).convert('RGB')\n",
571 | " image_resized = image.resize(size, Image.Resampling.BILINEAR)\n",
572 | " image_array = np.array(image_resized).astype(np.float32)\n",
573 | " # Expand dimensions to match model's expected input shape (1, H, W, C)\n",
574 | " image_array = np.expand_dims(image_array, axis=0)\n",
575 | " return image_array\n",
576 | "\n",
577 | "def get_top_k_predictions_lite(output, k=5):\n",
578 | " \"\"\"\n",
579 | " Returns the top K predictions from the already softmaxed output.\n",
580 | " \"\"\"\n",
581 | " # Convert the numpy array to a PyTorch tensor\n",
582 | " probs_tensor = torch.from_numpy(output)\n",
583 | "\n",
584 | " # Retrieve the top K probabilities and their corresponding indices\n",
585 | " top_probs, top_indices = torch.topk(probs_tensor, k)\n",
586 | "\n",
587 | " # Convert the results back to numpy arrays and flatten them\n",
588 | " return top_probs.numpy().flatten(), top_indices.numpy().flatten()"
589 | ],
590 | "metadata": {
591 | "id": "vvhbSOHM5NNk"
592 | },
593 | "execution_count": null,
594 | "outputs": []
595 | },
596 | {
597 | "cell_type": "markdown",
598 | "source": [
599 | "## Run Inference and Visualize\n",
600 | "Execute the inference process and visualize the predictions.\n",
601 | "\n"
602 | ],
603 | "metadata": {
604 | "id": "Ko1LHnoW3Efv"
605 | }
606 | },
607 | {
608 | "cell_type": "code",
609 | "source": [
610 | "images = []\n",
611 | "predictions = []\n",
612 | "\n",
613 | "for image_name in IMAGE_FILENAMES:\n",
614 | " # STEP 1: Load the input image(s).\n",
615 | " image = np.array(Image.open(image_name).convert('RGB'))\n",
616 | "\n",
617 | " # STEP 2: Load and preprocess the input image\n",
618 | " lite_input = preprocess_image_lite(image_name, size=(224, 224))\n",
619 | "\n",
620 | " # STEP 3: Classify the input image using LiteRT model\n",
621 | " interpreter.set_tensor(input_details[0]['index'], lite_input)\n",
622 | " interpreter.invoke()\n",
623 | " lite_output = interpreter.get_tensor(output_details[0]['index'])\n",
624 | "\n",
625 | " # STEP 4: Process the classification result (get top 5 predictions)\n",
626 | " lite_probs, lite_indices = get_top_k_predictions_lite(lite_output, k=5)\n",
627 | "\n",
628 | " # STEP 5: Get the top category (highest probability) and visualize\n",
629 | " top_prob = lite_probs[0]\n",
630 | " top_idx = lite_indices[0]\n",
631 | " top_category_name = pt_model.config.id2label[top_idx]\n",
632 | " prediction_text = f\"{top_category_name} ({top_prob * 100:.2f}%)\"\n",
633 | "\n",
634 | " images.append(image)\n",
635 | " predictions.append(prediction_text)\n",
636 | "\n",
637 | "# Display the image with prediction\n",
638 | "display_batch_of_images(images, predictions)"
639 | ],
640 | "metadata": {
641 | "id": "4GdTsCQP10To"
642 | },
643 | "execution_count": null,
644 | "outputs": []
645 | },
646 | {
647 | "cell_type": "markdown",
648 | "source": [
649 | "You should now see your loaded test images and their confidence scores/classifications that match the original PyTorch model results! If everything looks good, the final step should be downloading your newly converted `tflite` model file to your computer so you can use it elsewhere."
650 | ],
651 | "metadata": {
652 | "id": "1bxosFdH_99n"
653 | }
654 | },
655 | {
656 | "cell_type": "code",
657 | "source": [
658 | "from google.colab import files\n",
659 | "\n",
660 | "files.download(TFLITE_MODEL_PATH)"
661 | ],
662 | "metadata": {
663 | "id": "mY00XJQ1BZP3"
664 | },
665 | "execution_count": null,
666 | "outputs": []
667 | },
668 | {
669 | "cell_type": "markdown",
670 | "source": [
671 | "# Next steps\n",
672 | "\n",
673 | "Now that you have learned how to convert a PyTorch model to the LiteRT format, it's time to check out the [LiteRT Interpreter API](https://ai.google.dev/edge/litert) for running other custom solutions, and read more about the PyTorch to LiteRT framework with our [official documentation](https://ai.google.dev/edge/lite/models/convert_pytorch)."
674 | ],
675 | "metadata": {
676 | "id": "_6BEXn642zBu"
677 | }
678 | }
679 | ]
680 | }
--------------------------------------------------------------------------------
/convert_pytorch/Image_classification_with_mobile_vit.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": []
7 | },
8 | "kernelspec": {
9 | "name": "python3",
10 | "display_name": "Python 3"
11 | },
12 | "language_info": {
13 | "name": "python"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "code",
19 | "source": [
20 | "# Copyright 2024 The AI Edge Torch Authors.\n",
21 | "#\n",
22 | "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
23 | "# you may not use this file except in compliance with the License.\n",
24 | "# You may obtain a copy of the License at\n",
25 | "#\n",
26 | "# http://www.apache.org/licenses/LICENSE-2.0\n",
27 | "#\n",
28 | "# Unless required by applicable law or agreed to in writing, software\n",
29 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
30 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
31 | "# See the License for the specific language governing permissions and\n",
32 | "# limitations under the License.\n",
33 | "# =============================================================================="
34 | ],
35 | "metadata": {
36 | "id": "r4lisalb-A5R"
37 | },
38 | "execution_count": null,
39 | "outputs": []
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "source": [
44 | "This demo will teach you how to convert a PyTorch [MobileViT](https://huggingface.co/docs/transformers/en/model_doc/mobilevit#overview) model to a LiteRT (formally TensorFlow Lite) model intended to run with [MediaPipe](https://developers.google.com/mediapipe/solutions) Tasks using Google's AI Edge Torch library. You will then run the newly converted `tflite` model locally using the MediaPipe Tasks on-device inference tool, as well as learn where to find other tools for running your newly converted model on other edge hardware, including mobile devices and web browsers."
45 | ],
46 | "metadata": {
47 | "id": "LwrH6f2sGJ6U"
48 | }
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "source": [
53 | "# Prerequisites"
54 | ],
55 | "metadata": {
56 | "id": "Mzf2MdHoG-9c"
57 | }
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "source": [
62 | "You can start by importing the necessary dependencies for converting the model, as well as some additional utilities for displaying various information as you progress through this sample."
63 | ],
64 | "metadata": {
65 | "id": "hux_Gsc_G4nl"
66 | }
67 | },
68 | {
69 | "cell_type": "code",
70 | "source": [
71 | "!pip install mediapipe\n",
72 | "!pip install ai-edge-torch\n",
73 | "!pip install transformers pillow requests matplotlib mediapipe"
74 | ],
75 | "metadata": {
76 | "id": "l-9--DWON236"
77 | },
78 | "execution_count": null,
79 | "outputs": []
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "source": [
84 | "You will also need to download an image to verify model functionality."
85 | ],
86 | "metadata": {
87 | "id": "IUMh9GRk17fV"
88 | }
89 | },
90 | {
91 | "cell_type": "code",
92 | "source": [
93 | "import urllib\n",
94 | "\n",
95 | "IMAGE_FILENAMES = ['cat.jpg']\n",
96 | "\n",
97 | "for name in IMAGE_FILENAMES:\n",
98 | " url = f'https://storage.googleapis.com/ai-edge/models-samples/torch_converter/image_classification_mobile_vit/{name}'\n",
99 | " urllib.request.urlretrieve(url, name)"
100 | ],
101 | "metadata": {
102 | "id": "lfdgp-4Id51J"
103 | },
104 | "execution_count": null,
105 | "outputs": []
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "source": [
110 | "Optionally, you can upload your own image. If you want to do so, uncomment and run the cell below."
111 | ],
112 | "metadata": {
113 | "id": "XYQeTVp-qqZ0"
114 | }
115 | },
116 | {
117 | "cell_type": "code",
118 | "source": [
119 | "# from google.colab import files\n",
120 | "# uploaded = files.upload()\n",
121 | "\n",
122 | "# for filename in uploaded:\n",
123 | "# content = uploaded[filename]\n",
124 | "# with open(filename, 'wb') as f:\n",
125 | "# f.write(content)\n",
126 | "# IMAGE_FILENAMES = list(uploaded.keys())\n",
127 | "\n",
128 | "# print('Uploaded files:', IMAGE_FILENAMES)"
129 | ],
130 | "metadata": {
131 | "id": "8X6tiqVGqq0l"
132 | },
133 | "execution_count": null,
134 | "outputs": []
135 | },
136 | {
137 | "cell_type": "markdown",
138 | "source": [
139 | "Now go ahead and verify that the image was loaded successfully"
140 | ],
141 | "metadata": {
142 | "id": "RIGfyYzcVkIB"
143 | }
144 | },
145 | {
146 | "cell_type": "code",
147 | "source": [
148 | "import cv2\n",
149 | "from google.colab.patches import cv2_imshow\n",
150 | "import math\n",
151 | "\n",
152 | "DESIRED_HEIGHT = 480\n",
153 | "DESIRED_WIDTH = 480\n",
154 | "\n",
155 | "def resize_and_show(image):\n",
156 | " h, w = image.shape[:2]\n",
157 | " if h < w:\n",
158 | " img = cv2.resize(image, (DESIRED_WIDTH, math.floor(h/(w/DESIRED_WIDTH))))\n",
159 | " else:\n",
160 | " img = cv2.resize(image, (math.floor(w/(h/DESIRED_HEIGHT)), DESIRED_HEIGHT))\n",
161 | " cv2_imshow(img)\n",
162 | "\n",
163 | "\n",
164 | "# Preview the images.\n",
165 | "images = {name: cv2.imread(name) for name in IMAGE_FILENAMES}\n",
166 | "\n",
167 | "for name, image in images.items():\n",
168 | " print(name)\n",
169 | " resize_and_show(image)"
170 | ],
171 | "metadata": {
172 | "id": "-GMYmZ5jVq6t"
173 | },
174 | "execution_count": null,
175 | "outputs": []
176 | },
177 | {
178 | "cell_type": "markdown",
179 | "source": [
180 | "# PyTorch model validation"
181 | ],
182 | "metadata": {
183 | "id": "IBFYQIm-yFz1"
184 | }
185 | },
186 | {
187 | "cell_type": "markdown",
188 | "source": [
189 | "Now that you have your test images, it's time to validate the PyTorch model (in this case MobileViT) that will be converted to the LiteRT format.\n",
190 | "\n",
191 | "Start by retrieving the PyTorch model and the appropriate corresponding processor."
192 | ],
193 | "metadata": {
194 | "id": "g7qbJRCcvQJt"
195 | }
196 | },
197 | {
198 | "cell_type": "code",
199 | "source": [
200 | "from transformers import MobileViTImageProcessor, MobileViTForImageClassification\n",
201 | "\n",
202 | "hf_model_path = 'apple/mobilevit-small'\n",
203 | "processor = MobileViTImageProcessor.from_pretrained(hf_model_path)\n",
204 | "pt_model = MobileViTForImageClassification.from_pretrained(hf_model_path)"
205 | ],
206 | "metadata": {
207 | "id": "flLiQaaL6tU5"
208 | },
209 | "execution_count": null,
210 | "outputs": []
211 | },
212 | {
213 | "cell_type": "markdown",
214 | "source": [
215 | "The MobileViTImageProcessor defined below will perform multiple steps on the input image to match the requirements of the MobileViT model:\n",
216 | "\n",
217 | "1. Convert the image from RGB to BGR.\n",
218 | "2. Rescale the image from the [0, 255] range to the [0, 1] range.\n",
219 | "3. Resize input image to 256x256 pixels. Differes from default behaviour of processor (includes padding and center cropping) to make it easier to validate the converted model with MediaPipe Tasks (more details in the corresponding section)."
220 | ],
221 | "metadata": {
222 | "id": "4Mik12qkNEU_"
223 | }
224 | },
225 | {
226 | "cell_type": "code",
227 | "source": [
228 | "from PIL import Image\n",
229 | "\n",
230 | "images = []\n",
231 | "for filename in IMAGE_FILENAMES:\n",
232 | " images.append(Image.open(filename))\n",
233 | "\n",
234 | "inputs = processor(\n",
235 | " images=images,\n",
236 | " return_tensors='pt',\n",
237 | " size={'height': 256, 'width': 256},\n",
238 | " do_center_crop=False\n",
239 | ")"
240 | ],
241 | "metadata": {
242 | "id": "_-WmB2MYWc-P"
243 | },
244 | "execution_count": null,
245 | "outputs": []
246 | },
247 | {
248 | "cell_type": "markdown",
249 | "source": [
250 | "Now that you have your test data ready and the inputs processed, it's time to validate the classifications. In this step you will loop through your test image(s) and display the top 5 predicted classification categories. This model was trained with ImageNet-1000, so there are 1000 different potential classifications that may be applied to your test data."
251 | ],
252 | "metadata": {
253 | "id": "ZAQG5SSVzVi2"
254 | }
255 | },
256 | {
257 | "cell_type": "code",
258 | "source": [
259 | "import torch\n",
260 | "from torch import nn\n",
261 | "\n",
262 | "for image_index in range(len(IMAGE_FILENAMES)) :\n",
263 | " outputs = pt_model(**inputs)\n",
264 | " logits = outputs.logits\n",
265 | " probs, indices = nn.functional.softmax(logits[image_index], dim=-1).flatten().topk(k=5)\n",
266 | "\n",
267 | " print(IMAGE_FILENAMES[image_index], 'predictions: ')\n",
268 | " for prediction_index in range(len(indices)):\n",
269 | " class_label = pt_model.config.id2label[indices[prediction_index].item()]\n",
270 | " prob = probs[prediction_index].item()\n",
271 | " print(f'{(prob * 100):4.1f}% {class_label}')\n",
272 | " print('\\n')"
273 | ],
274 | "metadata": {
275 | "id": "ofbZW6nVzSrS"
276 | },
277 | "execution_count": null,
278 | "outputs": []
279 | },
280 | {
281 | "cell_type": "markdown",
282 | "source": [
283 | "# Convert to the `tflite` Format"
284 | ],
285 | "metadata": {
286 | "id": "pfJkS3bH7Jpw"
287 | }
288 | },
289 | {
290 | "cell_type": "markdown",
291 | "source": [
292 | "Before converting the PyTorch model to work with the tflite format, you will need to take an extra step to match it to the format expected by MediaPipe (MP) Tasks. Here are the necessary adjustments:\n",
293 | "\n",
294 | "1. MediaPipe Tasks require channel-last images (BHWC) while PyTorch uses channel-first (BCHW).\n",
295 | "\n",
296 | "2. For the Image Classification Task, MediaPipe requires an additional sigmoid layer on classification logits.\n",
297 | "\n",
298 | "You can also include preprocessing steps into a wrapper, such as converting from RGB to BGR and scaling, similar to what you did when validating the PyTorch model in the previous section."
299 | ],
300 | "metadata": {
301 | "id": "ci-8lp_55TLi"
302 | }
303 | },
304 | {
305 | "cell_type": "code",
306 | "source": [
307 | "class HF2MP_ImageClassificationModelWrapper(nn.Module):\n",
308 | "\n",
309 | " def __init__(self, hf_image_classification_model, hf_processor):\n",
310 | " super().__init__()\n",
311 | " self.model = hf_image_classification_model\n",
312 | " if hf_processor.do_rescale:\n",
313 | " self.rescale_factor = hf_processor.rescale_factor\n",
314 | " else:\n",
315 | " self.rescale_factor = 1.0\n",
316 | "\n",
317 | " def forward(self, image: torch.Tensor):\n",
318 | " # BHWC -> BCHW.\n",
319 | " image = image.permute(0, 3, 1, 2)\n",
320 | " # RGB -> BGR.\n",
321 | " image = image.flip(dims=(1,))\n",
322 | " # Scale [0, 255] -> [0, 1].\n",
323 | " image = image * self.rescale_factor\n",
324 | " logits = self.model(pixel_values=image).logits # [B, 1000] float32.\n",
325 | " # Softmax is required for MediaPipe classification model.\n",
326 | " logits = torch.nn.functional.softmax(logits, dim=-1)\n",
327 | "\n",
328 | " return logits\n",
329 | "\n",
330 | "\n",
331 | "hf_model_path = 'apple/mobilevit-small'\n",
332 | "hf_mobile_vit_processor = MobileViTImageProcessor.from_pretrained(hf_model_path)\n",
333 | "hf_mobile_vit_model = MobileViTForImageClassification.from_pretrained(hf_model_path)\n",
334 | "wrapped_pt_model = HF2MP_ImageClassificationModelWrapper(\n",
335 | "hf_mobile_vit_model, hf_mobile_vit_processor).eval()"
336 | ],
337 | "metadata": {
338 | "id": "NlBmvShe4Mt0"
339 | },
340 | "execution_count": null,
341 | "outputs": []
342 | },
343 | {
344 | "cell_type": "markdown",
345 | "source": [
346 | "## Convert to `tflite`"
347 | ],
348 | "metadata": {
349 | "id": "GMBNfgcV7k0f"
350 | }
351 | },
352 | {
353 | "cell_type": "markdown",
354 | "source": [
355 | "Now it's time to perform the conversion! You will need to provide simple arguments, such as the expected input shape (in this case three layers for images that are 256 height by 256 width)."
356 | ],
357 | "metadata": {
358 | "id": "T2MnULes70W0"
359 | }
360 | },
361 | {
362 | "cell_type": "code",
363 | "source": [
364 | "import ai_edge_torch\n",
365 | "\n",
366 | "sample_args = (torch.rand((1, 256, 256, 3)),)\n",
367 | "edge_model = ai_edge_torch.convert(wrapped_pt_model, sample_args)"
368 | ],
369 | "metadata": {
370 | "id": "XOfNPYpnLGrp"
371 | },
372 | "execution_count": null,
373 | "outputs": []
374 | },
375 | {
376 | "cell_type": "markdown",
377 | "source": [
378 | "Once the conversion is finished and you have a new `tflite` model file, you will convert the raw `tflite` file into a *model buffer* so that you can do a little more additional processing to prepare the file for working with MediaPipe Tasks. This includes attaching the labels for the model to the new `tflite` model so that it can be used with MediaPipe Tasks Image Classification."
379 | ],
380 | "metadata": {
381 | "id": "HPbeMLwbLZb7"
382 | }
383 | },
384 | {
385 | "cell_type": "code",
386 | "source": [
387 | "from mediapipe.tasks.python.metadata.metadata_writers import image_classifier\n",
388 | "from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer\n",
389 | "from mediapipe.tasks.python.vision.image_classifier import ImageClassifier\n",
390 | "from pathlib import Path\n",
391 | "\n",
392 | "flatbuffer_file = Path('hf_mobile_vit_mp_image_classification_raw.tflite')\n",
393 | "edge_model.export(flatbuffer_file)\n",
394 | "tflite_model_buffer = flatbuffer_file.read_bytes()\n",
395 | "\n",
396 | "labels = list(hf_mobile_vit_model.config.id2label.values())\n",
397 | "\n",
398 | "writer = image_classifier.MetadataWriter.create(\n",
399 | " tflite_model_buffer,\n",
400 | " input_norm_mean=[0.0], # Normalization is not needed for this model.\n",
401 | " input_norm_std=[1.0],\n",
402 | " labels=metadata_writer.Labels().add(labels),\n",
403 | ")\n",
404 | "tflite_model_buffer, _ = writer.populate()"
405 | ],
406 | "metadata": {
407 | "id": "1mDOCFdG7H16"
408 | },
409 | "execution_count": null,
410 | "outputs": []
411 | },
412 | {
413 | "cell_type": "markdown",
414 | "source": [
415 | "After attaching the metadata to the intermediate model buffer object, you can convert the buffer back into a `tflite` file."
416 | ],
417 | "metadata": {
418 | "id": "AMk6rSsfDykf"
419 | }
420 | },
421 | {
422 | "cell_type": "code",
423 | "source": [
424 | "tflite_filename = 'hf_mobile_vit_mp_image_classification.tflite'\n",
425 | "# Save converted model to Colab's local file system.\n",
426 | "with open(tflite_filename, 'wb') as f:\n",
427 | " f.write(tflite_model_buffer)"
428 | ],
429 | "metadata": {
430 | "id": "jpQ8R2pxQrIW"
431 | },
432 | "execution_count": null,
433 | "outputs": []
434 | },
435 | {
436 | "cell_type": "markdown",
437 | "source": [
438 | "Before moving on to *using* the converted model, it's always a good idea to make sure the model was successefully saved."
439 | ],
440 | "metadata": {
441 | "id": "1KVR8e4V8pou"
442 | }
443 | },
444 | {
445 | "cell_type": "code",
446 | "source": [
447 | "!ls -l /content/hf_mobile_vit_mp_image_classification.tflite"
448 | ],
449 | "metadata": {
450 | "id": "wuwP7uMzCAS5"
451 | },
452 | "execution_count": null,
453 | "outputs": []
454 | },
455 | {
456 | "cell_type": "markdown",
457 | "source": [
458 | "# Validate converted model with MediaPipe Tasks"
459 | ],
460 | "metadata": {
461 | "id": "e7II2a_389DH"
462 | }
463 | },
464 | {
465 | "cell_type": "markdown",
466 | "source": [
467 | "Now it's time to test your newly converted model directly with the MediaPipe Image Classification Task. Before getting into that code, you can add the following utility functions to improve the output displayed."
468 | ],
469 | "metadata": {
470 | "id": "-3kFtIGK_1qi"
471 | }
472 | },
473 | {
474 | "cell_type": "code",
475 | "source": [
476 | "from matplotlib import pyplot as plt\n",
477 | "plt.rcParams.update({\n",
478 | " 'axes.spines.top': False,\n",
479 | " 'axes.spines.right': False,\n",
480 | " 'axes.spines.left': False,\n",
481 | " 'axes.spines.bottom': False,\n",
482 | " 'xtick.labelbottom': False,\n",
483 | " 'xtick.bottom': False,\n",
484 | " 'ytick.labelleft': False,\n",
485 | " 'ytick.left': False,\n",
486 | " 'xtick.labeltop': False,\n",
487 | " 'xtick.top': False,\n",
488 | " 'ytick.labelright': False,\n",
489 | " 'ytick.right': False\n",
490 | "})\n",
491 | "\n",
492 | "\n",
493 | "def display_one_image(image, title, subplot, titlesize=16):\n",
494 | " \"\"\"Displays one image along with the predicted category name and score.\"\"\"\n",
495 | " plt.subplot(*subplot)\n",
496 | " plt.imshow(image)\n",
497 | " if len(title) > 0:\n",
498 | " plt.title(title, fontsize=int(titlesize), color='black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))\n",
499 | " return (subplot[0], subplot[1], subplot[2]+1)\n",
500 | "\n",
501 | "def display_batch_of_images(images, predictions):\n",
502 | " \"\"\"Displays a batch of images with the classifications.\"\"\"\n",
503 | " # Images and predictions.\n",
504 | " images = [image.numpy_view() for image in images]\n",
505 | "\n",
506 | " # Auto-squaring: this will drop data that does not fit into square or square-ish rectangle.\n",
507 | " rows = int(math.sqrt(len(images)))\n",
508 | " cols = len(images) // rows\n",
509 | "\n",
510 | " # Size and spacing.\n",
511 | " FIGSIZE = 13.0\n",
512 | " SPACING = 0.1\n",
513 | " subplot=(rows,cols, 1)\n",
514 | " if rows < cols:\n",
515 | " plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))\n",
516 | " else:\n",
517 | " plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))\n",
518 | "\n",
519 | " # Display.\n",
520 | " for i, (image, prediction) in enumerate(zip(images[:rows*cols], predictions[:rows*cols])):\n",
521 | " dynamic_titlesize = FIGSIZE * SPACING / max(rows,cols) * 40 + 3\n",
522 | " subplot = display_one_image(image, prediction, subplot, titlesize=dynamic_titlesize)\n",
523 | "\n",
524 | " # Layout.\n",
525 | " plt.tight_layout()\n",
526 | " plt.subplots_adjust(wspace=SPACING, hspace=SPACING)\n",
527 | " plt.show()"
528 | ],
529 | "metadata": {
530 | "id": "mV2Uo2yg2Lw4"
531 | },
532 | "execution_count": null,
533 | "outputs": []
534 | },
535 | {
536 | "cell_type": "markdown",
537 | "source": [
538 | "Now it's time to move on to the actual inference code and display the highest confidence classification result.\n",
539 | "\n",
540 | "While the converted model expects a square input image with a height of 256 pixels and a width of 256 pixels, the MediaPipe Image Classification Task automatically resizes and adds padding to the input image to meet the model's input requirements.\n",
541 | "\n",
542 | "During this validation step, you will ensure that the converted model produces roughly the same output as the original PyTorch model for the same input. One thing worth noting is since the resizing and padding in MediaPipe differs from that performed in MobileViTImageProcessor, there will likely be some minor differences in prediction confidences. To account for this, we will bypass the padding and automatic resizing step by resizing the input image manually before feeding it to the image classifier."
543 | ],
544 | "metadata": {
545 | "id": "GY-UcEls4TH6"
546 | }
547 | },
548 | {
549 | "cell_type": "code",
550 | "source": [
551 | "import mediapipe as mp\n",
552 | "from mediapipe.tasks import python\n",
553 | "from mediapipe.tasks.python.components import processors\n",
554 | "from mediapipe.tasks.python import vision\n",
555 | "\n",
556 | "# STEP 1: Create an ImageClassifier object.\n",
557 | "\n",
558 | "base_options= python.BaseOptions(\n",
559 | " model_asset_path=f'/content/{tflite_filename}')\n",
560 | "\n",
561 | "options = vision.ImageClassifierOptions(\n",
562 | " base_options=base_options,\n",
563 | " max_results=5)\n",
564 | "\n",
565 | "classifier = vision.ImageClassifier.create_from_options(options)\n",
566 | "\n",
567 | "images = []\n",
568 | "predictions = []\n",
569 | "for image_name in IMAGE_FILENAMES:\n",
570 | " # STEP 2: Load the input image(s).\n",
571 | " image = mp.Image.create_from_file(image_name)\n",
572 | "\n",
573 | " # STEP 3: Classify the input image(s).\n",
574 | " classification_result = classifier.classify(image)\n",
575 | "\n",
576 | " # STEP 4: Process the classification result. In this case, visualize it.\n",
577 | " images.append(image)\n",
578 | " top_category = classification_result.classifications[0].categories[0]\n",
579 | " predictions.append(f\"{top_category.category_name} ({top_category.score:.2f})\")\n",
580 | "\n",
581 | "display_batch_of_images(images, predictions)"
582 | ],
583 | "metadata": {
584 | "id": "4GdTsCQP10To"
585 | },
586 | "execution_count": null,
587 | "outputs": []
588 | },
589 | {
590 | "cell_type": "markdown",
591 | "source": [
592 | "You should now see your loaded test images and their confidence scores/classifications that match the original PyTorch model results! If everything looks good, the final step should be downloading your newly converted `tflite` model file to your computer so you can use it elsewhere."
593 | ],
594 | "metadata": {
595 | "id": "1bxosFdH_99n"
596 | }
597 | },
598 | {
599 | "cell_type": "code",
600 | "source": [
601 | "from google.colab import files\n",
602 | "\n",
603 | "files.download(tflite_filename)"
604 | ],
605 | "metadata": {
606 | "id": "mY00XJQ1BZP3"
607 | },
608 | "execution_count": null,
609 | "outputs": []
610 | },
611 | {
612 | "cell_type": "markdown",
613 | "source": [
614 | "# Next steps\n",
615 | "\n",
616 | "Now that you have learned how to convert a PyTorch model to the LiteRT format, it's time to do more with it! You can go over additional [MediaPipe](https://github.com/google-ai-edge/mediapipe-samples) samples for Android, iOS, web, and Python (including the Raspberry Pi!) to try your new model on multiple platforms, check out the [LiteRT Interpreter API](https://ai.google.dev/edge/litert) for running custom solutions, and read more about the PyTorch to LiteRT framework with our [official documentation](https://ai.google.dev/edge/lite/models/convert_pytorch)."
617 | ],
618 | "metadata": {
619 | "id": "3AOmkXUaBVUb"
620 | }
621 | }
622 | ]
623 | }
--------------------------------------------------------------------------------
/convert_pytorch/SegNext_segmentation_and_quantization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 5,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "gpuType": "T4"
8 | },
9 | "kernelspec": {
10 | "display_name": "Python 3",
11 | "name": "python3"
12 | },
13 | "language_info": {
14 | "name": "python",
15 | "version": ""
16 | },
17 | "accelerator": "GPU"
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {
23 | "id": "ln-0JGdPeeyz"
24 | },
25 | "source": [
26 | "# Convert a SegNeXt PyTorch Model to LiteRT\n",
27 | "\n",
28 | "This notebook demonstrates how to convert a **SegNeXt** model (originally trained and published in PyTorch) into a LiteRT model using [AI Edge Torch](https://ai.google.dev/edge). The sample also shows how to optimize the resulting model with dynamic-range quantization using [AI Edge Quantizer](https://github.com/google-ai-edge/ai-edge-quantizer).\n",
29 | "\n",
30 | "## What you'll learn\n",
31 | "\n",
32 | "- **Setup:** Installing necessary libraries and tools to download and load the SegNeXt model.\n",
33 | "- **Inference Validation:** Running the PyTorch model for segmentation.\n",
34 | "- **Model Conversion:** Converting a SegNext model to LiteRT using AI Edge Torch.\n",
35 | "- **Verifying Results:** Comparing outputs between PyTorch and LiteRT models.\n",
36 | "- **Quantization:** Applying post-training quantization techniques to reduce model size.\n",
37 | "- **Export and Download**: Download your newly created or optimized LiteRT model."
38 | ],
39 | "id": "ln-0JGdPeeyz"
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "04Y9CioSeey1"
45 | },
46 | "source": [
47 | "## Install and Import Dependencies\n",
48 | "\n",
49 | "You can start by importing the necessary dependencies for converting the model, as well as some additional tweaks to get the `mmsegmentation` library working as expected with the AI Torch Edge Converter.\n",
50 | "\n",
51 | "Make sure to run the following cells to set up the environment with the required libraries:"
52 | ],
53 | "id": "04Y9CioSeey1"
54 | },
55 | {
56 | "cell_type": "code",
57 | "metadata": {
58 | "id": "installation_code",
59 | "language_info": {
60 | "name": "python"
61 | }
62 | },
63 | "execution_count": null,
64 | "outputs": [],
65 | "source": [
66 | "# Install MMCV and its dependencies.\n",
67 | "!pip install openmim -q\n",
68 | "!mim install mmengine -q\n",
69 | "!mim install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html\n",
70 | "!pip install ftfy\n",
71 | "\n",
72 | "# Install AI Torch Edge and Quantizer.\n",
73 | "!pip install ai-edge-torch-nightly -q\n",
74 | "!pip install ai-edge-quantizer-nightly -q"
75 | ],
76 | "id": "installation_code"
77 | },
78 | {
79 | "cell_type": "code",
80 | "source": [
81 | "# Clone the MMSegmentation GitHub repository.\n",
82 | "!git clone -b v1.2.2 https://github.com/open-mmlab/mmsegmentation.git\n",
83 | "\n",
84 | "# Patch the version constraints in mmseg/__init__.py\n",
85 | "!sed -i \"s/MMCV_MAX = '2.2.0'/MMCV_MAX = '6.5.0'/g\" mmsegmentation/mmseg/__init__.py\n",
86 | "\n",
87 | "# Install MMSegmentation\n",
88 | "%cd mmsegmentation\n",
89 | "!pip install -e ."
90 | ],
91 | "metadata": {
92 | "id": "u1AlddWxioAh"
93 | },
94 | "id": "u1AlddWxioAh",
95 | "execution_count": null,
96 | "outputs": []
97 | },
98 | {
99 | "cell_type": "code",
100 | "source": [
101 | "import urllib\n",
102 | "import cv2\n",
103 | "import math\n",
104 | "import sys\n",
105 | "\n",
106 | "# PyTorch, Vision, and AI Edge Torch.\n",
107 | "import torch\n",
108 | "import torchvision.transforms as T\n",
109 | "import ai_edge_torch\n",
110 | "\n",
111 | "# PIL, NumPy, IPython display.\n",
112 | "import numpy as np\n",
113 | "from PIL import Image\n",
114 | "from IPython import display\n",
115 | "\n",
116 | "# Google Colab utilities.\n",
117 | "from google.colab import files\n",
118 | "from google.colab.patches import cv2_imshow\n",
119 | "\n",
120 | "# Matplotlib for visualization.\n",
121 | "from matplotlib import gridspec\n",
122 | "from matplotlib import pyplot as plt\n",
123 | "\n",
124 | "# AI Edge Torch Quantization utilities.\n",
125 | "import ai_edge_litert\n",
126 | "from ai_edge_quantizer import quantizer, recipe"
127 | ],
128 | "metadata": {
129 | "id": "vxicg9FaiWrC"
130 | },
131 | "id": "vxicg9FaiWrC",
132 | "execution_count": null,
133 | "outputs": []
134 | },
135 | {
136 | "cell_type": "markdown",
137 | "source": [
138 | "### Patch the `MMEngine` registry\n",
139 | "We'll also patch the `Registry` to address potential naming collisions in the mmseg registry, then import our classes and create an inference object."
140 | ],
141 | "metadata": {
142 | "id": "NWIrsDX2ipCq"
143 | },
144 | "id": "NWIrsDX2ipCq"
145 | },
146 | {
147 | "cell_type": "code",
148 | "source": [
149 | "# @markdown We implemented some functions to patch the mmengine registry.
Run the following cell to activate the functions.\n",
150 | "%%writefile patch_registry.py\n",
151 | "import logging\n",
152 | "\n",
153 | "from mmengine.registry import Registry\n",
154 | "from mmengine.logging import print_log\n",
155 | "from typing import Type, Optional, Union, List\n",
156 | "\n",
157 | "def _register_module(self,\n",
158 | " module: Type,\n",
159 | " module_name: Optional[Union[str, List[str]]] = None,\n",
160 | " force: bool = False) -> None:\n",
161 | " \"\"\"Register a module.\n",
162 | "\n",
163 | " Args:\n",
164 | " module (type): Module to be registered. Typically a class or a\n",
165 | " function, but generally all ``Callable`` are acceptable.\n",
166 | " module_name (str or list of str, optional): The module name to be\n",
167 | " registered. If not specified, the class name will be used.\n",
168 | " Defaults to None.\n",
169 | " force (bool): Whether to override an existing class with the same\n",
170 | " name. Defaults to False.\n",
171 | " \"\"\"\n",
172 | " if not callable(module):\n",
173 | " raise TypeError(f'module must be Callable, but got {type(module)}')\n",
174 | "\n",
175 | " if module_name is None:\n",
176 | " module_name = module.__name__\n",
177 | " if isinstance(module_name, str):\n",
178 | " module_name = [module_name]\n",
179 | " for name in module_name:\n",
180 | " if not force and name in self._module_dict:\n",
181 | " existed_module = self.module_dict[name]\n",
182 | " print_log(\n",
183 | " f'{name} is already registered in {self.name} '\n",
184 | " f'at {existed_module.__module__}. Registration ignored.',\n",
185 | " logger='current',\n",
186 | " level=logging.INFO\n",
187 | " )\n",
188 | " self._module_dict[name] = module\n",
189 | "\n",
190 | "Registry._register_module = _register_module\n"
191 | ],
192 | "metadata": {
193 | "cellView": "form",
194 | "id": "vQ6zJOsniihd"
195 | },
196 | "id": "vQ6zJOsniihd",
197 | "execution_count": null,
198 | "outputs": []
199 | },
200 | {
201 | "cell_type": "code",
202 | "source": [
203 | "# Patch the MMEngine registry.\n",
204 | "import patch_registry\n",
205 | "\n",
206 | "# Check MMSegmentation installation.\n",
207 | "import mmseg\n",
208 | "print(mmseg.__version__)\n",
209 | "\n",
210 | "# Import the `apis` and `datasets` modules.\n",
211 | "from mmseg import apis, datasets"
212 | ],
213 | "metadata": {
214 | "id": "XZ73z0fYi1l7"
215 | },
216 | "id": "XZ73z0fYi1l7",
217 | "execution_count": null,
218 | "outputs": []
219 | },
220 | {
221 | "cell_type": "markdown",
222 | "metadata": {
223 | "id": "hp0ZAmzYeey2"
224 | },
225 | "source": [
226 | "### Download a Sample Image\n",
227 | "We'll retrieve an image that we'll use for our segmentation demo. Feel free to upload your own image(s) if desired."
228 | ],
229 | "id": "hp0ZAmzYeey2"
230 | },
231 | {
232 | "cell_type": "code",
233 | "metadata": {
234 | "id": "download_image"
235 | },
236 | "execution_count": null,
237 | "outputs": [],
238 | "source": [
239 | "IMAGE_FILENAMES = ['Bruce_car.JPG']\n",
240 | "\n",
241 | "for name in IMAGE_FILENAMES:\n",
242 | " # TODO: Update path to use the appropriate task subfolder in the AI Edge GCS bucket\n",
243 | " url = f'https://upload.wikimedia.org/wikipedia/commons/9/9c/{name}'\n",
244 | " urllib.request.urlretrieve(url, name)"
245 | ],
246 | "id": "download_image"
247 | },
248 | {
249 | "cell_type": "markdown",
250 | "metadata": {
251 | "id": "JnZkAy-Peey2"
252 | },
253 | "source": [
254 | "If you want to upload additional images, uncomment and run the cell below. Then update `IMAGE_FILENAMES` to match your uploaded file(s)."
255 | ],
256 | "id": "JnZkAy-Peey2"
257 | },
258 | {
259 | "cell_type": "code",
260 | "metadata": {
261 | "id": "upload_images_example"
262 | },
263 | "execution_count": null,
264 | "outputs": [],
265 | "source": [
266 | "# from google.colab import files\n",
267 | "# uploaded = files.upload()\n",
268 | "#\n",
269 | "# for filename in uploaded:\n",
270 | "# content = uploaded[filename]\n",
271 | "# with open(filename, 'wb') as f:\n",
272 | "# f.write(content)\n",
273 | "#\n",
274 | "# IMAGE_FILENAMES = list(uploaded.keys())\n",
275 | "# print('Uploaded files:', IMAGE_FILENAMES)"
276 | ],
277 | "id": "upload_images_example"
278 | },
279 | {
280 | "cell_type": "markdown",
281 | "metadata": {
282 | "id": "nYPROgSgeey3"
283 | },
284 | "source": [
285 | "Quickly display the loaded image(s) to confirm."
286 | ],
287 | "id": "nYPROgSgeey3"
288 | },
289 | {
290 | "cell_type": "code",
291 | "metadata": {
292 | "id": "verify_images"
293 | },
294 | "execution_count": null,
295 | "outputs": [],
296 | "source": [
297 | "DESIRED_HEIGHT = 480\n",
298 | "DESIRED_WIDTH = 480\n",
299 | "\n",
300 | "def resize_and_show(image):\n",
301 | " h, w = image.shape[:2]\n",
302 | " if h < w:\n",
303 | " img = cv2.resize(image, (DESIRED_WIDTH, math.floor(h/(w/DESIRED_WIDTH))))\n",
304 | " else:\n",
305 | " img = cv2.resize(image, (math.floor(w/(h/DESIRED_HEIGHT)), DESIRED_HEIGHT))\n",
306 | " cv2_imshow(img)\n",
307 | "\n",
308 | "# Preview the images.\n",
309 | "images = {name: cv2.imread(name) for name in IMAGE_FILENAMES}\n",
310 | "\n",
311 | "for name, image in images.items():\n",
312 | " print(name)\n",
313 | " resize_and_show(image)"
314 | ],
315 | "id": "verify_images"
316 | },
317 | {
318 | "cell_type": "markdown",
319 | "metadata": {
320 | "id": "fMJ7KVrUeey3"
321 | },
322 | "source": [
323 | "## Load SegNext\n",
324 | "We'll clone the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) repo, install it, and then load a SegNeXt model trained on [ADE20K](https://ade20k.csail.mit.edu/). In this example, we're using the [SegNeXt mscan-b ADE20K model](https://github.com/open-mmlab/mmsegmentation/tree/main/configs/segnext)."
325 | ],
326 | "id": "fMJ7KVrUeey3"
327 | },
328 | {
329 | "cell_type": "code",
330 | "metadata": {
331 | "id": "imports_mmseg"
332 | },
333 | "execution_count": null,
334 | "outputs": [],
335 | "source": [
336 | "# Load the SegNext PyTorch model while setting the device to CPU.\n",
337 | "inferencer = apis.MMSegInferencer(model='segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512', device='cpu')\n",
338 | "\n",
339 | "# Retrieve the actual PyTorch model.\n",
340 | "pt_model = inferencer.model\n",
341 | "pt_model.eval()"
342 | ],
343 | "id": "imports_mmseg"
344 | },
345 | {
346 | "cell_type": "markdown",
347 | "source": [
348 | "## The MIT ADE20K scene parsing dataset \n",
349 | "`ADE20K` is composed of more than 27K images from the [SUN](https://groups.csail.mit.edu/vision/SUN/hierarchy.html) and [Places](https://www.csail.mit.edu/research/places-database-scene-recognition) databases. Images are fully annotated with objects, spanning over 3K object categories."
350 | ],
351 | "metadata": {
352 | "id": "Kh2jlhm33EcH"
353 | },
354 | "id": "Kh2jlhm33EcH"
355 | },
356 | {
357 | "cell_type": "code",
358 | "source": [
359 | "classes = datasets.ADE20KDataset.METAINFO['classes']\n",
360 | "palette = datasets.ADE20KDataset.METAINFO['palette']"
361 | ],
362 | "metadata": {
363 | "id": "iUgM8mNh3Wa5"
364 | },
365 | "id": "iUgM8mNh3Wa5",
366 | "execution_count": null,
367 | "outputs": []
368 | },
369 | {
370 | "cell_type": "markdown",
371 | "source": [
372 | "For the dataset, we extract the class labels and the color palette from its metadata. We also retrieve the mean and standard deviation values from the data preprocessor configuration (via `inferencer.cfg`), which will be essential when using the converter later."
373 | ],
374 | "metadata": {
375 | "id": "9T8phj8m16Lj"
376 | },
377 | "id": "9T8phj8m16Lj"
378 | },
379 | {
380 | "cell_type": "code",
381 | "source": [
382 | "data_preprocessor_dict = inferencer.cfg.to_dict()['data_preprocessor']\n",
383 | "data_preprocessor_dict['mean'], data_preprocessor_dict['std']"
384 | ],
385 | "metadata": {
386 | "id": "q4ChjHSm1-NM"
387 | },
388 | "id": "q4ChjHSm1-NM",
389 | "execution_count": null,
390 | "outputs": []
391 | },
392 | {
393 | "cell_type": "markdown",
394 | "metadata": {
395 | "id": "_fQSgQFveey4"
396 | },
397 | "source": [
398 | "## Inference using the PyTorch Model\n",
399 | "Let's verify the PyTorch model by doing a quick inference and visualizing the segmentation output.\n"
400 | ],
401 | "id": "_fQSgQFveey4"
402 | },
403 | {
404 | "cell_type": "code",
405 | "metadata": {
406 | "id": "test_pt_model"
407 | },
408 | "execution_count": null,
409 | "outputs": [],
410 | "source": [
411 | "# The output mask is saved under 'outputs/vis/.jpg'\n",
412 | "\n",
413 | "for index in range(len(IMAGE_FILENAMES)):\n",
414 | " inferencer(IMAGE_FILENAMES[index], out_dir='outputs', img_out_dir='vis', return_vis=True)"
415 | ],
416 | "id": "test_pt_model"
417 | },
418 | {
419 | "cell_type": "markdown",
420 | "source": [
421 | "Let's visualize one of the results."
422 | ],
423 | "metadata": {
424 | "id": "h6V7rRuFQzRb"
425 | },
426 | "id": "h6V7rRuFQzRb"
427 | },
428 | {
429 | "cell_type": "code",
430 | "source": [
431 | "display.Image(f'outputs/vis/{IMAGE_FILENAMES[0]}')"
432 | ],
433 | "metadata": {
434 | "id": "OmK2xxilQuLT"
435 | },
436 | "id": "OmK2xxilQuLT",
437 | "execution_count": null,
438 | "outputs": []
439 | },
440 | {
441 | "cell_type": "markdown",
442 | "source": [
443 | "We have now confirmed that the original PyTorch model can generate valid segmentation predictions and that it runs properly in Python."
444 | ],
445 | "metadata": {
446 | "id": "rdcp23wEe-1w"
447 | },
448 | "id": "rdcp23wEe-1w"
449 | },
450 | {
451 | "cell_type": "markdown",
452 | "metadata": {
453 | "id": "aD_DyvBheey4"
454 | },
455 | "source": [
456 | "\n",
457 | "## Create a Model Wrapper\n",
458 | "To simplify the model output and ensure a single output node during conversion, we'll create a wrapper module. We'll also handle the typical mean/std normalization manually within this wrapper (since some methods, like `torch.min` or `torch.max`, might not be fully supported in the LiteRT conversion)."
459 | ],
460 | "id": "aD_DyvBheey4"
461 | },
462 | {
463 | "cell_type": "code",
464 | "metadata": {
465 | "id": "create_wrapper"
466 | },
467 | "execution_count": null,
468 | "outputs": [],
469 | "source": [
470 | "class ImageSegmentationModelWrapper(torch.nn.Module):\n",
471 | " def __init__(self, pt_model, mmseg_cfg):\n",
472 | " super().__init__()\n",
473 | " self.model = pt_model\n",
474 | " data_preprocessor_dict = mmseg_cfg.to_dict()['data_preprocessor']\n",
475 | " # Convert the mean and std from shape (3,) to (1, 3, 1, 1)\n",
476 | " self.image_mean = torch.tensor(data_preprocessor_dict['mean']).view(1, -1, 1, 1)\n",
477 | " self.image_std = torch.tensor(data_preprocessor_dict['std']).view(1, -1, 1, 1)\n",
478 | "\n",
479 | " def forward(self, image: torch.Tensor):\n",
480 | " # Input shape: (N, H, W, C)\n",
481 | " # Convert BHWC to BCHW.\n",
482 | " image = image.permute(0, 3, 1, 2)\n",
483 | "\n",
484 | " # Normalize.\n",
485 | " image = (image - self.image_mean) / self.image_std\n",
486 | "\n",
487 | " # Model output is typically (N, C, H, W).\n",
488 | " result = self.model(image)\n",
489 | "\n",
490 | " # Convert from NCHW to NHWC.\n",
491 | " result = result.permute(0, 2, 3, 1)\n",
492 | "\n",
493 | " return result\n",
494 | "\n",
495 | "# Create the wrapped model.\n",
496 | "wrapped_pt_model = ImageSegmentationModelWrapper(pt_model, inferencer.cfg).eval()"
497 | ],
498 | "id": "create_wrapper"
499 | },
500 | {
501 | "cell_type": "markdown",
502 | "metadata": {
503 | "id": "LmcvaKlheey5"
504 | },
505 | "source": [
506 | "## Convert to LiteRT\n",
507 | "\n",
508 | "One of the methods you can use to get to this final output is to download the `tflite` file after the conversion step in this colab, open it with [Model Explorer](https://ai.google.dev/edge/model-explorer) and confirm which output in the graph has the expected output shape.\n",
509 | "\n",
510 | "That's kind of a lot for this example, so to simplify the process and eliminate this effort, you can use a wrapper for the PyTorch model that narrows the scope to only the final output. This approach ensures that your new LiteRT model has only a single output after the conversion stage.\n",
511 | "\n",
512 | "We'll use AI Edge Torch to convert our PyTorch model. We pass in a sample input of appropriate shape to guide the conversion. (This shape also becomes your expected inference shape in the resulting LiteRT model.)"
513 | ],
514 | "id": "LmcvaKlheey5"
515 | },
516 | {
517 | "cell_type": "code",
518 | "metadata": {
519 | "id": "convert_to_litert"
520 | },
521 | "execution_count": null,
522 | "outputs": [],
523 | "source": [
524 | "MODEL_INPUT_HW = (512, 512)\n",
525 | "sample_args = (torch.rand((1, *MODEL_INPUT_HW, 3)),)\n",
526 | "\n",
527 | "edge_model = ai_edge_torch.convert(wrapped_pt_model, sample_args)"
528 | ],
529 | "id": "convert_to_litert"
530 | },
531 | {
532 | "cell_type": "markdown",
533 | "metadata": {
534 | "id": "jC-HGLCoeey5"
535 | },
536 | "source": [
537 | "## Validate Converted Model with LiteRT Interpreter\n",
538 | "We can test the converted LiteRT model's output. Since our pre-processing is embedded within the wrapper, we'll only resize and cast the input image.\n"
539 | ],
540 | "id": "jC-HGLCoeey5"
541 | },
542 | {
543 | "cell_type": "code",
544 | "metadata": {
545 | "id": "vis_utils",
546 | "cellView": "form"
547 | },
548 | "execution_count": null,
549 | "outputs": [],
550 | "source": [
551 | "# @markdown We implemented some functions to visualize the segmentation results.
Run the following cell to activate the functions.\n",
552 | "\n",
553 | "# Visualization utilities\n",
554 | "def label_to_color_image(label, palette):\n",
555 | " if label.ndim != 2:\n",
556 | " raise ValueError('Expect 2-D input label')\n",
557 | " colormap = np.asarray(palette)\n",
558 | " if np.max(label) >= len(colormap):\n",
559 | " raise ValueError('label value too large.')\n",
560 | " return colormap[label]\n",
561 | "\n",
562 | "def vis_segmentation(image, seg_map, palette, label_names):\n",
563 | " plt.figure(figsize=(15, 5))\n",
564 | " grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])\n",
565 | "\n",
566 | " plt.subplot(grid_spec[0])\n",
567 | " plt.imshow(image)\n",
568 | " plt.axis('off')\n",
569 | " plt.title('input image')\n",
570 | "\n",
571 | " plt.subplot(grid_spec[1])\n",
572 | " seg_image = label_to_color_image(seg_map, palette).astype(np.uint8)\n",
573 | " plt.imshow(seg_image)\n",
574 | " plt.axis('off')\n",
575 | " plt.title('segmentation map')\n",
576 | "\n",
577 | " H, W = image.shape[:2]\n",
578 | " plt.subplot(grid_spec[2])\n",
579 | " plt.imshow(image, extent=(0, W, H, 0))\n",
580 | " plt.imshow(seg_image, alpha=0.7, extent=(0, W, H, 0))\n",
581 | " plt.axis('off')\n",
582 | " plt.title('segmentation overlay')\n",
583 | "\n",
584 | " unique_labels = np.unique(seg_map)\n",
585 | " ax = plt.subplot(grid_spec[3])\n",
586 | " full_color_map = label_to_color_image(\n",
587 | " np.arange(len(label_names)).reshape(len(label_names), 1),\n",
588 | " palette\n",
589 | " )\n",
590 | " plt.imshow(full_color_map[unique_labels].astype(np.uint8), interpolation='nearest')\n",
591 | " ax.yaxis.tick_right()\n",
592 | " plt.yticks(range(len(unique_labels)), label_names[unique_labels])\n",
593 | " plt.xticks([], [])\n",
594 | " ax.tick_params(width=0.0)\n",
595 | " plt.grid('off')\n",
596 | " plt.show()\n",
597 | "\n",
598 | "LABEL_NAMES = np.asarray(classes)\n",
599 | "PALETTE = palette"
600 | ],
601 | "id": "vis_utils"
602 | },
603 | {
604 | "cell_type": "code",
605 | "metadata": {
606 | "id": "compare_litert_pt"
607 | },
608 | "execution_count": null,
609 | "outputs": [],
610 | "source": [
611 | "np_images = []\n",
612 | "image_sizes = []\n",
613 | "\n",
614 | "for index in range(len(IMAGE_FILENAMES)):\n",
615 | " # Retrieve each image from the file system.\n",
616 | " image = Image.open(IMAGE_FILENAMES[index])\n",
617 | " # Save the size for reference.\n",
618 | " image_sizes.append(image.size)\n",
619 | " # Convert each image into a NumPy array with shape (1, H, W, 3)\n",
620 | " np_image = np.array(image.resize(MODEL_INPUT_HW, Image.Resampling.BILINEAR))\n",
621 | " np_image = np.expand_dims(np_image, axis=0).astype(np.float32)\n",
622 | " np_images.append(np_image)\n",
623 | "\n",
624 | " # Retrieve an output from the converted model.\n",
625 | " edge_model_output = edge_model(np_image)\n",
626 | " segmentation_map = edge_model_output.squeeze()\n",
627 | "\n",
628 | " # Visualize.\n",
629 | " vis_segmentation(\n",
630 | " np_images[index][0].astype(np.uint8),\n",
631 | " np.argmax(segmentation_map, axis=-1),\n",
632 | " PALETTE,\n",
633 | " LABEL_NAMES\n",
634 | " )"
635 | ],
636 | "id": "compare_litert_pt"
637 | },
638 | {
639 | "cell_type": "code",
640 | "source": [
641 | "# Serialize the LiteRT model.\n",
642 | "edge_model.export('segnext.tflite')"
643 | ],
644 | "metadata": {
645 | "id": "5rgUpVFInolB"
646 | },
647 | "id": "5rgUpVFInolB",
648 | "execution_count": null,
649 | "outputs": []
650 | },
651 | {
652 | "cell_type": "markdown",
653 | "metadata": {
654 | "id": "kk3Dd3ZZeey5"
655 | },
656 | "source": [
657 | "## Apply Quantization\n",
658 | "Model size matters on edge devices. Post-training quantization can significantly reduce the size of your `tflite` model. This section demonstrates how to use **dynamic-range quantization** through AI Edge Quantizer."
659 | ],
660 | "id": "kk3Dd3ZZeey5"
661 | },
662 | {
663 | "cell_type": "markdown",
664 | "metadata": {
665 | "id": "pGbRHy2jeey5"
666 | },
667 | "source": [
668 | "### Quantizing the model with dynamic quantization (AI Edge Quantizer)\n",
669 | "\n",
670 | "To use the `Quantizer`, we need to\n",
671 | "* Instantiate a Quantizer class. This is the entry point to the quantizer's functionalities.\n",
672 | "* Load a desired quantization recipe.\n",
673 | "* Quantize (and save) the model. This is where most of the quantizer's internal logic works."
674 | ],
675 | "id": "pGbRHy2jeey5"
676 | },
677 | {
678 | "cell_type": "code",
679 | "source": [
680 | "# Quantization (API will quantize and save a flatbuffer as *.tflite).\n",
681 | "quantizer = quantizer.Quantizer(float_model='segnext.tflite')\n",
682 | "quantizer.load_quantization_recipe(recipe=recipe.dynamic_wi8_afp32())\n",
683 | "\n",
684 | "quantization_result = quantizer.quantize()\n",
685 | "quantization_result.export_model('segnext_dynamic_wi8_afp32.tflite')"
686 | ],
687 | "metadata": {
688 | "id": "KgEnwA_-mvRS"
689 | },
690 | "id": "KgEnwA_-mvRS",
691 | "execution_count": null,
692 | "outputs": []
693 | },
694 | {
695 | "cell_type": "markdown",
696 | "source": [
697 | "`quantization_result` has two components\n",
698 | "\n",
699 | "\n",
700 | "* quantized LiteRT model (in bytearray) and\n",
701 | "* the corresponding quantization recipe"
702 | ],
703 | "metadata": {
704 | "id": "9EdSnuPsq8yn"
705 | },
706 | "id": "9EdSnuPsq8yn"
707 | },
708 | {
709 | "cell_type": "markdown",
710 | "source": [
711 | "Let's compare the size of flatbuffers\n"
712 | ],
713 | "metadata": {
714 | "id": "AeXKNrJDq-wK"
715 | },
716 | "id": "AeXKNrJDq-wK"
717 | },
718 | {
719 | "cell_type": "code",
720 | "source": [
721 | "!ls -lh *.tflite"
722 | ],
723 | "metadata": {
724 | "id": "5W5rk99EnTZA"
725 | },
726 | "id": "5W5rk99EnTZA",
727 | "execution_count": null,
728 | "outputs": []
729 | },
730 | {
731 | "cell_type": "markdown",
732 | "source": [
733 | "Let's take a look at what in this recipe\n",
734 | "\n"
735 | ],
736 | "metadata": {
737 | "id": "jljfKwytoQQa"
738 | },
739 | "id": "jljfKwytoQQa"
740 | },
741 | {
742 | "cell_type": "code",
743 | "source": [
744 | "quantization_result.recipe"
745 | ],
746 | "metadata": {
747 | "id": "Pq9JjrMxoRWP"
748 | },
749 | "id": "Pq9JjrMxoRWP",
750 | "execution_count": null,
751 | "outputs": []
752 | },
753 | {
754 | "cell_type": "markdown",
755 | "source": [
756 | "Here the recipe means: apply the naive min/max uniform algorithm (`min_max_uniform_quantize`) for all ops supported by the AI Edge Quantizer (indicated by `*`) under layers satisfying regex `.*` (i.e., all layers). We want the weights of these ops to be quantized as int8, symmetric, channel_wise, and we want to execute the ops in `Integer` mode."
757 | ],
758 | "metadata": {
759 | "id": "Y0Pm7OZ9odeZ"
760 | },
761 | "id": "Y0Pm7OZ9odeZ"
762 | },
763 | {
764 | "cell_type": "markdown",
765 | "source": [
766 | "\n",
767 | "Next, you'll create a function using LiteRT to run the newly generated quantized model.\n"
768 | ],
769 | "metadata": {
770 | "id": "ItdDEAZ7sm-v"
771 | },
772 | "id": "ItdDEAZ7sm-v"
773 | },
774 | {
775 | "cell_type": "code",
776 | "metadata": {
777 | "id": "drq_convert",
778 | "cellView": "form"
779 | },
780 | "execution_count": null,
781 | "outputs": [],
782 | "source": [
783 | "# @markdown We implemented some functions to run segmentation on the quantized model.
Run the following cell to activate the functions.\n",
784 | "def run_segmentation(image, model_path):\n",
785 | " \"\"\"Get segmentation mask of the image.\"\"\"\n",
786 | " image = np.expand_dims(image, axis=0)\n",
787 | " interpreter = ai_edge_litert.interpreter.Interpreter(model_path=model_path)\n",
788 | " interpreter.allocate_tensors()\n",
789 | "\n",
790 | " input_details = interpreter.get_input_details()[0]\n",
791 | " interpreter.set_tensor(input_details['index'], image)\n",
792 | " interpreter.invoke()\n",
793 | "\n",
794 | " output_details = interpreter.get_output_details()\n",
795 | " output_index = 0\n",
796 | " outputs = []\n",
797 | " for detail in output_details:\n",
798 | " outputs.append(interpreter.get_tensor(detail['index']))\n",
799 | " mask = np.squeeze(outputs[output_index])\n",
800 | " return mask"
801 | ],
802 | "id": "drq_convert"
803 | },
804 | {
805 | "cell_type": "markdown",
806 | "source": [
807 | "Now let try running the newly quantized model and see how they compare."
808 | ],
809 | "metadata": {
810 | "id": "i8MDev7IrH1E"
811 | },
812 | "id": "i8MDev7IrH1E"
813 | },
814 | {
815 | "cell_type": "code",
816 | "metadata": {
817 | "id": "drq_compare"
818 | },
819 | "execution_count": null,
820 | "outputs": [],
821 | "source": [
822 | "# Validate the model.\n",
823 | "for index in range(len(IMAGE_FILENAMES)):\n",
824 | " quantized_model_output = run_segmentation(np_images[index][0],\n",
825 | " 'segnext_dynamic_wi8_afp32.tflite')\n",
826 | " vis_segmentation(\n",
827 | " np_images[index][0].astype(np.uint8),\n",
828 | " np.argmax(quantized_model_output, axis=-1),\n",
829 | " PALETTE,\n",
830 | " LABEL_NAMES\n",
831 | " )"
832 | ],
833 | "id": "drq_compare"
834 | },
835 | {
836 | "cell_type": "markdown",
837 | "metadata": {
838 | "id": "yrw5gdWLeey6"
839 | },
840 | "source": [
841 | "## Export and Download Models\n",
842 | "Let's save and download the converted `tflite` model, along with the dynamic-range quantized version."
843 | ],
844 | "id": "yrw5gdWLeey6"
845 | },
846 | {
847 | "cell_type": "code",
848 | "metadata": {
849 | "id": "download_models"
850 | },
851 | "execution_count": null,
852 | "outputs": [],
853 | "source": [
854 | "files.download('segnext.tflite')"
855 | ],
856 | "id": "download_models"
857 | },
858 | {
859 | "cell_type": "code",
860 | "source": [
861 | "files.download('segnext_dynamic_wi8_afp32.tflite')"
862 | ],
863 | "metadata": {
864 | "id": "DeowMFHIfke1"
865 | },
866 | "id": "DeowMFHIfke1",
867 | "execution_count": null,
868 | "outputs": []
869 | },
870 | {
871 | "cell_type": "markdown",
872 | "metadata": {
873 | "id": "VtUKRJQAeey6"
874 | },
875 | "source": [
876 | "## Next Steps\n",
877 | "Now you've got a fully converted (and optionally quantized!) `tflite` model. Here are some ideas on what to do next:\n",
878 | "\n",
879 | "- Explore [AI Edge Torch documentation](https://ai.google.dev/edge) for additional use cases or advanced topics.\n",
880 | "- Try out your new model on mobile or web using the [LiteRT API samples](https://ai.google.dev/edge/docs/litert).\n",
881 | "- Further tune or calibrate your quantization techniques to achieve the desired balance between model size and accuracy.\n",
882 | "\n",
883 | "Have fun deploying your model to the edge!"
884 | ],
885 | "id": "VtUKRJQAeey6"
886 | }
887 | ]
888 | }
--------------------------------------------------------------------------------