├── .gitignore
├── LICENSE
├── README.md
├── app
├── .gitignore
├── build.gradle.kts
├── proguard-rules.pro
└── src
│ └── main
│ ├── AndroidManifest.xml
│ ├── assets
│ └── .gitkeep
│ ├── java
│ └── io
│ │ └── shubham0204
│ │ └── sam_android
│ │ ├── MainActivity.kt
│ │ ├── MainActivityViewModel.kt
│ │ ├── sam
│ │ ├── SAMDecoder.kt
│ │ └── SAMEncoder.kt
│ │ └── ui
│ │ ├── components
│ │ ├── AppAlertDialog.kt
│ │ └── AppProgressDialog.kt
│ │ └── theme
│ │ ├── Color.kt
│ │ ├── Theme.kt
│ │ └── Type.kt
│ └── res
│ ├── drawable
│ ├── ic_launcher_background.xml
│ └── ic_launcher_foreground.xml
│ ├── mipmap-anydpi
│ ├── ic_launcher.xml
│ └── ic_launcher_round.xml
│ ├── mipmap-hdpi
│ ├── ic_launcher.webp
│ └── ic_launcher_round.webp
│ ├── mipmap-mdpi
│ ├── ic_launcher.webp
│ └── ic_launcher_round.webp
│ ├── mipmap-xhdpi
│ ├── ic_launcher.webp
│ └── ic_launcher_round.webp
│ ├── mipmap-xxhdpi
│ ├── ic_launcher.webp
│ └── ic_launcher_round.webp
│ ├── mipmap-xxxhdpi
│ ├── ic_launcher.webp
│ └── ic_launcher_round.webp
│ ├── values
│ ├── colors.xml
│ ├── strings.xml
│ └── themes.xml
│ └── xml
│ ├── backup_rules.xml
│ └── data_extraction_rules.xml
├── build.gradle.kts
├── gradle.properties
├── gradle
├── libs.versions.toml
└── wrapper
│ ├── gradle-wrapper.jar
│ └── gradle-wrapper.properties
├── gradlew
├── gradlew.bat
├── notebooks
└── SAM2_ONNX_Export.ipynb
└── settings.gradle.kts
/.gitignore:
--------------------------------------------------------------------------------
1 | *.iml
2 | .gradle
3 | .kotlin
4 | /local.properties
5 | /.idea
6 | .DS_Store
7 | /build
8 | /captures
9 | .externalNativeBuild
10 | .cxx
11 | local.properties
12 | app/release
13 | app/src/main/assets/encoder_base_plus.onnx
14 | app/src/main/assets/decoder_base_plus.onnx
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2024 Shubham Panchal
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Segment-Anything (SAM) and SAM v2 Inference In Android
2 |
3 | 
4 |
5 | - On-device inference of SAM/SAM2 with `onnxruntime`
6 | - Clean Kotlin-only implementation, with no additional code compilation
7 | - No support for text-prompt as an input to the model
8 | - The inference time is *quite high* even with `float16` quantization enabled
9 |
10 | Download the [APK](https://github.com/shubham0204/Segment-Anything-Android/releases/tag/release_apk) or [setup the project locally](#setup)
11 |
12 | ## About Segment-Anything
13 |
14 | 
15 |
16 | - Large-language models have demonstrated significant performance gains in numerous NLP tasks within zero or few-shot problem settings. The prompt or a text given at inference-time to the LLM guides the generation of the output.
17 | - Foundation models like CLIP and ALIGN have been popular due to wide adaptability and fine-tuning capabilities for downstream tasks.
18 | - The goal of the authors is to build a **foundation model for image segmentation**.
19 |
20 | #### Task
21 | - Authors define a **promptable image segmentation task**.
22 | - The **prompt** could be **spatial or textual information** which guides the model to generate the desired segmentation mask.
23 |
24 | #### Model
25 | - A powerful **image encoder** is used to produce image embeddings and a **prompt encoder** embeds prompts, both of which are combined with a **mask decoder**.
26 | - The authors focus on **point, box and mask prompts** with initial results on free-form text prompts.
27 | - **Image Encoder**: MAE (Masked Autoencoder) pre-trained Vision Transformer
28 | - **Prompt Encoder**: Points and boxes are represented by positional encodings, masks are embedded with convolutional layers, and free-form text with an encoder like CLIP
29 | - **Mask Decoder**: Transformer-based decoder model
30 |
31 | #### Data Engine
32 | - To achieve strong generalization on unknown datasets, authors propose a model-in-the-loop data annotation process with three phases.
33 | - In the ***assisted-manual phase***, SAM helps annotators in annotating masks.
34 | - In the ***semi-automatic phase***, SAM automatically generates masks for certain objects, by prompting their locations in the image.
35 | - In the ***fully-automatic phase***, SAM is prompted with a regular grid of foreground points, each of which yields a segmentation mask.
36 |
37 | ## Setup
38 |
39 | 1. Clone the project from GitHub and open the resulting directory in Android Studio.
40 |
41 | ```text
42 | git clone --depth=1 https://github.com/shubham0204/Segment-Anything-Android
43 | ```
44 |
45 | 2. Android Studio starts building the project automatically. If not, select **Build > Rebuild Project** to start a project build.
46 |
47 | 3. After a successful project build, [connect an Android device](https://developer.android.com/studio/run/device) to your system. Once connected, the name of the device must be visible in top menu-bar in Android Studio.
48 |
49 | 4. Download any `*_encoder.onnx` and corresponding `*_decoder.onnx` models from the [HuggingFace repository](https://huggingface.co/shubham0204/sam2-onnx-models) and place them in the root directory of the project. The models can be stored in one of the two possible methods
50 |
51 | #### Store the ONNX models in the `assets` folder
52 |
53 | By placing the `*_encoder.onnx` and `*_decoder.onnx` in the `app/src/main/assets` folder, the models are packaged with the APK, which increases the overall size of the APK but avoids any additional setup to bring the models to the device. Make sure you change the names of the encoder and decoder models in `MainActivity.kt`,
54 |
55 | ```kotlin
56 |
57 | class MainActivity : ComponentActivity() {
58 |
59 | private val encoder = SAMEncoder()
60 | private val decoder = SAMDecoder()
61 |
62 | // The app will look for models with these file-names
63 | // in the assets folder
64 | private val encoderFileName = "encoder_base_plus.onnx"
65 | private val decoderFileName = "decoder_base_plus.onnx"
66 |
67 | // ...
68 | }
69 | ```
70 |
71 | #### Store the ONNX models in the device's temporary storage
72 |
73 | Using the `adb` CLI tool, insert the ONNX models in the device's storage,
74 |
75 | ```text
76 | adb push sam2_hiera_small_encoder.onnx /data/local/tmp/sam/encoder.onnx
77 | adb push sam2_hiera_small_decoder.onnx /data/local/tmp/sam/decoder.onnx
78 | ```
79 |
80 | Replace `sam2_hiera_small_decoder.onnx` and `sam2_hiera_small_encoder.onnx` with the name of the model downloaded from the HF repository in step (4).
81 |
82 | Update the model paths and set other options in `MainActivity.kt`,
83 |
84 | ```kotlin
85 | class MainActivity : ComponentActivity() {
86 |
87 | // ...
88 |
89 | override fun onCreate(savedInstanceState: Bundle?) {
90 | super.onCreate(savedInstanceState)
91 | enableEdgeToEdge()
92 |
93 | setContent {
94 | SAMAndroidTheme {
95 | Scaffold(modifier = Modifier.fillMaxSize()) { innerPadding ->
96 | Column(
97 | // ...
98 | ) {
99 |
100 | // ...
101 |
102 | LaunchedEffect(0) {
103 | // ...
104 | // The paths below should match the ones
105 | // used in step (5)
106 | encoder.init(
107 | "/data/local/tmp/sam/encoder_fp16.onnx",
108 | useXNNPack = true, // XNNPack delegate for onnxruntime
109 | useFP16 = true
110 | )
111 | decoder.init(
112 | "/data/local/tmp/sam/decoder_fp16.onnx",
113 | useXNNPack = true,
114 | useFP16 = true
115 | )
116 | // ...
117 | }
118 |
119 | // ...
120 | }
121 | }
122 | }
123 | }
124 | }
125 | }
126 | ```
127 |
128 | ## Resources
129 |
130 | - [ONNX-SAM2-Segment-Anything](https://github.com/ibaiGorordo/ONNX-SAM2-Segment-Anything): ONNX models were derived from the Colab notebook linked in the `README.md` of this project.
131 | - [Segment Anything - arxiv](https://arxiv.org/abs/2304.02643)
132 | - [SAM 2: Segment Anything in Images and Videos - arxiv](https://arxiv.org/abs/2408.00714)
133 |
134 | ## Citations
135 |
136 | ```text
137 | @misc{ravi2024sam2segmentimages,
138 | title={SAM 2: Segment Anything in Images and Videos},
139 | author={Nikhila Ravi and Valentin Gabeur and Yuan-Ting Hu and Ronghang Hu and Chaitanya Ryali and Tengyu Ma and Haitham Khedr and Roman Rädle and Chloe Rolland and Laura Gustafson and Eric Mintun and Junting Pan and Kalyan Vasudev Alwala and Nicolas Carion and Chao-Yuan Wu and Ross Girshick and Piotr Dollár and Christoph Feichtenhofer},
140 | year={2024},
141 | eprint={2408.00714},
142 | archivePrefix={arXiv},
143 | primaryClass={cs.CV},
144 | url={https://arxiv.org/abs/2408.00714},
145 | }
146 | ```
147 |
148 | ```text
149 | @misc{kirillov2023segment,
150 | title={Segment Anything},
151 | author={Alexander Kirillov and Eric Mintun and Nikhila Ravi and Hanzi Mao and Chloe Rolland and Laura Gustafson and Tete Xiao and Spencer Whitehead and Alexander C. Berg and Wan-Yen Lo and Piotr Dollár and Ross Girshick},
152 | year={2023},
153 | eprint={2304.02643},
154 | archivePrefix={arXiv},
155 | primaryClass={cs.CV},
156 | url={https://arxiv.org/abs/2304.02643},
157 | }
158 | ```
159 |
--------------------------------------------------------------------------------
/app/.gitignore:
--------------------------------------------------------------------------------
1 | /build
--------------------------------------------------------------------------------
/app/build.gradle.kts:
--------------------------------------------------------------------------------
1 | plugins {
2 | alias(libs.plugins.android.application)
3 | alias(libs.plugins.jetbrains.kotlin.android)
4 | alias(libs.plugins.compose.compiler)
5 | }
6 |
7 | android {
8 | namespace = "io.shubham0204.sam_android"
9 | compileSdk = 34
10 |
11 | defaultConfig {
12 | applicationId = "io.shubham0204.sam_android"
13 | minSdk = 26
14 | targetSdk = 34
15 | versionCode = 1
16 | versionName = "1.0"
17 |
18 | testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
19 | vectorDrawables {
20 | useSupportLibrary = true
21 | }
22 | }
23 |
24 | buildTypes {
25 | release {
26 | isMinifyEnabled = false
27 | proguardFiles(
28 | getDefaultProguardFile("proguard-android-optimize.txt"),
29 | "proguard-rules.pro"
30 | )
31 | }
32 | }
33 | compileOptions {
34 | sourceCompatibility = JavaVersion.VERSION_1_8
35 | targetCompatibility = JavaVersion.VERSION_1_8
36 | }
37 | kotlinOptions {
38 | jvmTarget = "1.8"
39 | }
40 | buildFeatures {
41 | compose = true
42 | }
43 | composeOptions {
44 | kotlinCompilerExtensionVersion = "1.5.1"
45 | }
46 | packaging {
47 | resources {
48 | excludes += "/META-INF/{AL2.0,LGPL2.1}"
49 | }
50 | }
51 | }
52 |
53 | dependencies {
54 | implementation(libs.androidx.core.ktx)
55 | implementation(libs.androidx.lifecycle.runtime.ktx)
56 |
57 | implementation(platform(libs.androidx.compose.bom))
58 | implementation(libs.androidx.activity.compose)
59 | implementation(libs.androidx.ui)
60 | implementation(libs.androidx.ui.graphics)
61 | implementation(libs.androidx.compose.material3.icons.extended)
62 | implementation(libs.androidx.compose.runtime.livedata)
63 | implementation(libs.androidx.material3)
64 | implementation(libs.androidx.lifecycle.viewmodel.compose)
65 |
66 | implementation("com.microsoft.onnxruntime:onnxruntime-android:1.17.0")
67 |
68 | implementation(libs.androidx.ui.tooling.preview)
69 | implementation(libs.androidx.exifinterface)
70 | testImplementation(libs.junit)
71 | androidTestImplementation(libs.androidx.junit)
72 | androidTestImplementation(libs.androidx.espresso.core)
73 | androidTestImplementation(platform(libs.androidx.compose.bom))
74 | androidTestImplementation(libs.androidx.ui.test.junit4)
75 | debugImplementation(libs.androidx.ui.tooling)
76 | debugImplementation(libs.androidx.ui.test.manifest)
77 | }
--------------------------------------------------------------------------------
/app/proguard-rules.pro:
--------------------------------------------------------------------------------
1 | # Add project specific ProGuard rules here.
2 | # You can control the set of applied configuration files using the
3 | # proguardFiles setting in build.gradle.
4 | #
5 | # For more details, see
6 | # http://developer.android.com/guide/developing/tools/proguard.html
7 |
8 | # If your project uses WebView with JS, uncomment the following
9 | # and specify the fully qualified class name to the JavaScript interface
10 | # class:
11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview {
12 | # public *;
13 | #}
14 |
15 | # Uncomment this to preserve the line number information for
16 | # debugging stack traces.
17 | #-keepattributes SourceFile,LineNumberTable
18 |
19 | # If you keep the line number information, uncomment this to
20 | # hide the original source file name.
21 | #-renamesourcefileattribute SourceFile
--------------------------------------------------------------------------------
/app/src/main/AndroidManifest.xml:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
15 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/app/src/main/assets/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham0204/Segment-Anything-Android/5efe3b59b5478c5f3bcd62cda062e95554c53936/app/src/main/assets/.gitkeep
--------------------------------------------------------------------------------
/app/src/main/java/io/shubham0204/sam_android/MainActivity.kt:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2025 Shubham Panchal
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.shubham0204.sam_android
18 |
19 | import AppProgressDialog
20 | import android.graphics.Bitmap
21 | import android.graphics.BitmapFactory
22 | import android.graphics.Matrix
23 | import android.graphics.PointF
24 | import android.net.Uri
25 | import android.os.Bundle
26 | import android.util.Log
27 | import android.widget.Toast
28 | import androidx.activity.ComponentActivity
29 | import androidx.activity.compose.rememberLauncherForActivityResult
30 | import androidx.activity.compose.setContent
31 | import androidx.activity.enableEdgeToEdge
32 | import androidx.activity.result.PickVisualMediaRequest
33 | import androidx.activity.result.contract.ActivityResultContracts
34 | import androidx.compose.foundation.Image
35 | import androidx.compose.foundation.background
36 | import androidx.compose.foundation.clickable
37 | import androidx.compose.foundation.gestures.detectTapGestures
38 | import androidx.compose.foundation.layout.Box
39 | import androidx.compose.foundation.layout.Column
40 | import androidx.compose.foundation.layout.Row
41 | import androidx.compose.foundation.layout.Spacer
42 | import androidx.compose.foundation.layout.fillMaxSize
43 | import androidx.compose.foundation.layout.fillMaxWidth
44 | import androidx.compose.foundation.layout.height
45 | import androidx.compose.foundation.layout.padding
46 | import androidx.compose.foundation.lazy.LazyColumn
47 | import androidx.compose.foundation.lazy.itemsIndexed
48 | import androidx.compose.foundation.rememberScrollState
49 | import androidx.compose.foundation.verticalScroll
50 | import androidx.compose.material.icons.Icons
51 | import androidx.compose.material.icons.filled.Close
52 | import androidx.compose.material.icons.filled.Image
53 | import androidx.compose.material.icons.filled.Layers
54 | import androidx.compose.material.icons.filled.Tag
55 | import androidx.compose.material3.Button
56 | import androidx.compose.material3.ExperimentalMaterial3Api
57 | import androidx.compose.material3.Icon
58 | import androidx.compose.material3.IconButton
59 | import androidx.compose.material3.ModalBottomSheet
60 | import androidx.compose.material3.Scaffold
61 | import androidx.compose.material3.Text
62 | import androidx.compose.material3.rememberModalBottomSheetState
63 | import androidx.compose.runtime.Composable
64 | import androidx.compose.runtime.LaunchedEffect
65 | import androidx.compose.runtime.getValue
66 | import androidx.compose.runtime.mutableStateOf
67 | import androidx.compose.runtime.remember
68 | import androidx.compose.runtime.rememberCoroutineScope
69 | import androidx.compose.runtime.setValue
70 | import androidx.compose.ui.Modifier
71 | import androidx.compose.ui.draw.drawWithCache
72 | import androidx.compose.ui.geometry.Offset
73 | import androidx.compose.ui.geometry.Size
74 | import androidx.compose.ui.graphics.Color
75 | import androidx.compose.ui.graphics.asImageBitmap
76 | import androidx.compose.ui.input.pointer.pointerInput
77 | import androidx.compose.ui.layout.ContentScale
78 | import androidx.compose.ui.layout.onGloballyPositioned
79 | import androidx.compose.ui.text.style.TextAlign
80 | import androidx.compose.ui.unit.dp
81 | import androidx.compose.ui.unit.sp
82 | import androidx.compose.ui.unit.toSize
83 | import androidx.exifinterface.media.ExifInterface
84 | import androidx.lifecycle.viewmodel.compose.viewModel
85 | import hideProgressDialog
86 | import io.shubham0204.sam_android.sam.SAMDecoder
87 | import io.shubham0204.sam_android.sam.SAMEncoder
88 | import io.shubham0204.sam_android.ui.components.AppAlertDialog
89 | import io.shubham0204.sam_android.ui.components.createAlertDialog
90 | import io.shubham0204.sam_android.ui.theme.SAMAndroidTheme
91 | import kotlinx.coroutines.CoroutineScope
92 | import kotlinx.coroutines.Dispatchers
93 | import kotlinx.coroutines.launch
94 | import kotlinx.coroutines.withContext
95 | import setProgressDialogText
96 | import showProgressDialog
97 | import java.io.File
98 | import java.nio.FloatBuffer
99 | import java.nio.file.Paths
100 | import kotlin.time.DurationUnit
101 | import kotlin.time.measureTimedValue
102 |
103 | class MainActivity : ComponentActivity() {
104 | private val encoder = SAMEncoder()
105 | private val decoder = SAMDecoder()
106 | private val encoderFileName = "encoder_base_plus.onnx"
107 | private val decoderFileName = "decoder_base_plus.onnx"
108 |
109 | override fun onCreate(savedInstanceState: Bundle?) {
110 | super.onCreate(savedInstanceState)
111 | enableEdgeToEdge()
112 |
113 | setContent {
114 | SAMAndroidTheme {
115 | Scaffold(modifier = Modifier.fillMaxSize()) { innerPadding ->
116 | Column(
117 | modifier =
118 | Modifier
119 | .verticalScroll(rememberScrollState())
120 | .padding(innerPadding),
121 | ) {
122 | val viewModel = viewModel()
123 |
124 | var image by remember { mutableStateOf(null) }
125 | val outputImages = remember { viewModel.images }
126 | val points = remember { viewModel.points }
127 | var isReady by remember { mutableStateOf(false) }
128 | var viewPortDims by remember { mutableStateOf(null) }
129 |
130 | LaunchedEffect(0) {
131 | try {
132 | showProgressDialog()
133 | setProgressDialogText("Loading models...")
134 | if (isModelInAssets(encoderFileName) && isModelInAssets(decoderFileName)) {
135 | copyModelToStorage(encoderFileName)
136 | copyModelToStorage(decoderFileName)
137 | encoder.init(Paths.get(filesDir.absolutePath, encoderFileName).toString())
138 | decoder.init(Paths.get(filesDir.absolutePath, decoderFileName).toString())
139 | } else {
140 | encoder.init("/data/local/tmp/sam/encoder_base_plus.onnx")
141 | decoder.init("/data/local/tmp/sam/decoder_base_plus.onnx")
142 | }
143 | isReady = true
144 | hideProgressDialog()
145 | } catch (e: Exception) {
146 | hideProgressDialog()
147 | createAlertDialog(
148 | dialogTitle = "Error",
149 | dialogText = "An error occurred: ${e.message}",
150 | dialogPositiveButtonText = "Close",
151 | dialogNegativeButtonText = null,
152 | onPositiveButtonClick = { finish() },
153 | onNegativeButtonClick = null,
154 | )
155 | }
156 | }
157 |
158 | val pickMediaLauncher =
159 | rememberLauncherForActivityResult(
160 | contract = ActivityResultContracts.PickVisualMedia(),
161 | ) {
162 | if (it != null) {
163 | val bitmap = getFixedBitmap(it)
164 | image = bitmap
165 | viewModel.reset()
166 | }
167 | }
168 |
169 | Row(
170 | modifier =
171 | Modifier
172 | .padding(horizontal = 8.dp)
173 | .fillMaxWidth(),
174 | ) {
175 | Button(
176 | modifier =
177 | Modifier
178 | .fillMaxWidth()
179 | .padding(4.dp)
180 | .weight(1f),
181 | onClick = {
182 | viewModel.showBottomSheet.value = true
183 | },
184 | ) {
185 | Icon(
186 | imageVector = Icons.Default.Tag,
187 | contentDescription = "Choose Label For Points",
188 | )
189 | Text(text = "Choose Label For Points")
190 | }
191 | }
192 |
193 | Row(
194 | modifier =
195 | Modifier
196 | .padding(horizontal = 8.dp)
197 | .fillMaxWidth(),
198 | ) {
199 | Button(
200 | modifier =
201 | Modifier
202 | .fillMaxWidth()
203 | .padding(4.dp)
204 | .weight(1f),
205 | enabled = isReady && (image != null),
206 | onClick = {
207 | image?.let { bitmap ->
208 | processInputPoints(
209 | bitmap,
210 | points,
211 | viewPortDims,
212 | viewModel,
213 | )
214 | }
215 | },
216 | ) {
217 | Icon(
218 | imageVector = Icons.Default.Layers,
219 | contentDescription = "Segment",
220 | )
221 | Text(text = "Segment!")
222 | }
223 | Button(
224 | modifier =
225 | Modifier
226 | .fillMaxWidth()
227 | .padding(4.dp)
228 | .weight(1f),
229 | enabled = isReady,
230 | onClick = {
231 | pickMediaLauncher.launch(
232 | PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly),
233 | )
234 | },
235 | ) {
236 | Icon(
237 | imageVector = Icons.Default.Image,
238 | contentDescription = "Select Image",
239 | )
240 | Text(text = "Select Image")
241 | }
242 | }
243 |
244 | if (image != null) {
245 | Spacer(modifier = Modifier.height(4.dp))
246 | Text(
247 | text = "Currently selected label: Label ${viewModel.selectedLabelIndex.intValue}",
248 | fontSize = 12.sp,
249 | color = Color.DarkGray,
250 | textAlign = TextAlign.Center,
251 | modifier =
252 | Modifier
253 | .fillMaxWidth()
254 | .padding(4.dp),
255 | )
256 | Spacer(modifier = Modifier.height(4.dp))
257 | Box {
258 | Image(
259 | bitmap = image!!.asImageBitmap(),
260 | contentScale = ContentScale.Fit,
261 | contentDescription = "Selected Image",
262 | modifier =
263 | Modifier
264 | .pointerInput(Unit) {
265 | detectTapGestures(onLongPress = {
266 | val newPoints =
267 | points.filter { it.label != viewModel.selectedLabelIndex.intValue }
268 | points.clear()
269 | points.addAll(newPoints)
270 | Toast
271 | .makeText(
272 | this@MainActivity,
273 | "All guide-points removed",
274 | Toast.LENGTH_LONG,
275 | ).show()
276 | }, onTap = { offset ->
277 | points.add(
278 | LabelPoint(
279 | viewModel.selectedLabelIndex.intValue,
280 | PointF(offset.x, offset.y),
281 | ),
282 | )
283 | })
284 | }.onGloballyPositioned {
285 | viewPortDims = it.size.toSize()
286 | },
287 | )
288 | Spacer(
289 | modifier =
290 | Modifier
291 | .fillMaxSize()
292 | .drawWithCache {
293 | onDrawBehind {
294 | points
295 | .filter { labelPoint -> labelPoint.label == viewModel.selectedLabelIndex.intValue }
296 | .forEach { labelPoint ->
297 | drawCircle(
298 | color = Color.Black,
299 | radius = 15f,
300 | center =
301 | Offset(
302 | labelPoint.point.x,
303 | labelPoint.point.y,
304 | ),
305 | )
306 | drawCircle(
307 | color = Color.Yellow,
308 | radius = 12f,
309 | center =
310 | Offset(
311 | labelPoint.point.x,
312 | labelPoint.point.y,
313 | ),
314 | )
315 | }
316 | }
317 | },
318 | )
319 | }
320 | Spacer(modifier = Modifier.height(4.dp))
321 | Text(
322 | text = "Tap on the image to insert a guide-point\nLong-press to remove all guide-points for the current label",
323 | fontSize = 12.sp,
324 | color = Color.DarkGray,
325 | textAlign = TextAlign.Center,
326 | modifier =
327 | Modifier
328 | .fillMaxWidth()
329 | .padding(16.dp),
330 | )
331 | }
332 | if (outputImages.isNotEmpty()) {
333 | Text(
334 | modifier = Modifier.padding(4.dp),
335 | fontSize = 18.sp,
336 | text = "Segmented Images (${viewModel.inferenceTime.intValue} s)",
337 | )
338 | }
339 | outputImages.forEach {
340 | Image(
341 | modifier = Modifier.background(Color.Black.copy(green = 1.0f)),
342 | bitmap = it.asImageBitmap(),
343 | contentDescription = "Segmented image",
344 | )
345 | }
346 |
347 | AppAlertDialog()
348 | AppProgressDialog()
349 | ManageLabelsBottomSheet(viewModel)
350 | }
351 | }
352 | }
353 | }
354 | }
355 |
356 | @OptIn(ExperimentalMaterial3Api::class)
357 | @Composable
358 | private fun ManageLabelsBottomSheet(viewModel: MainActivityViewModel) {
359 | val sheetState = rememberModalBottomSheetState()
360 | val scope = rememberCoroutineScope()
361 | var showBottomSheet by remember { viewModel.showBottomSheet }
362 | val labels = remember { viewModel.labels }
363 | var lastAddedLabel by remember { viewModel.lastAddedLabel }
364 | var selectedLabelIndex by remember { viewModel.selectedLabelIndex }
365 |
366 | if (showBottomSheet) {
367 | ModalBottomSheet(
368 | containerColor = Color.White,
369 | onDismissRequest = { showBottomSheet = false },
370 | sheetState = sheetState,
371 | ) {
372 | Column(
373 | modifier = Modifier.padding(horizontal = 16.dp),
374 | ) {
375 | Row {
376 | Text(
377 | text = "Manage Labels",
378 | fontSize = 18.sp,
379 | modifier =
380 | Modifier
381 | .padding(8.dp)
382 | .weight(1f),
383 | )
384 | IconButton(onClick = {
385 | scope.launch { sheetState.hide() }.invokeOnCompletion {
386 | if (!sheetState.isVisible) {
387 | showBottomSheet = false
388 | }
389 | }
390 | }) {
391 | Icon(
392 | imageVector = Icons.Default.Close,
393 | contentDescription = "Close Panel",
394 | )
395 | }
396 | }
397 | LazyColumn {
398 | itemsIndexed(labels) { index, item ->
399 | Text(
400 | text = item,
401 | modifier =
402 | Modifier
403 | .clickable {
404 | selectedLabelIndex = index
405 | }.background(if (selectedLabelIndex == index) Color.Cyan else Color.White)
406 | .padding(8.dp)
407 | .fillMaxWidth(),
408 | )
409 | }
410 | }
411 | Spacer(modifier = Modifier.height(4.dp))
412 | Row {
413 | Button(
414 | modifier =
415 | Modifier
416 | .padding(4.dp)
417 | .fillMaxWidth()
418 | .weight(1f),
419 | onClick = {
420 | lastAddedLabel += 1
421 | labels.add("Label $lastAddedLabel")
422 | },
423 | ) {
424 | Text(text = "Add Label")
425 | }
426 | Button(
427 | modifier =
428 | Modifier
429 | .padding(4.dp)
430 | .fillMaxWidth()
431 | .weight(1f),
432 | onClick = {
433 | labels.removeAt(selectedLabelIndex)
434 | },
435 | ) {
436 | Text(text = "Remove Label")
437 | }
438 | }
439 | }
440 | }
441 | }
442 | }
443 |
444 | private fun processInputPoints(
445 | bitmap: Bitmap,
446 | points: List,
447 | viewPortDims: Size?,
448 | viewModel: MainActivityViewModel,
449 | ) {
450 | CoroutineScope(Dispatchers.Default).launch {
451 | try {
452 | showProgressDialog()
453 | setProgressDialogText("Performing image segmentation...")
454 | val pointsGroupByLabel = points.groupBy { it.label }
455 | val maxPoints = pointsGroupByLabel.maxOfOrNull { it.value.size } ?: return@launch
456 | val labelsCount = pointsGroupByLabel.keys.size
457 |
458 | val labelsBuffer = FloatBuffer.allocate(1 * labelsCount * maxPoints)
459 | val pointsBuffer = FloatBuffer.allocate(1 * labelsCount * maxPoints * 2)
460 |
461 | for ((label, labelPoints) in pointsGroupByLabel) {
462 | labelPoints.forEach {
463 | pointsBuffer.put((it.point.x / viewPortDims?.width!!) * 1024f)
464 | pointsBuffer.put((it.point.y / viewPortDims.height) * 1024f)
465 | }
466 | repeat(maxPoints - labelPoints.size) {
467 | pointsBuffer.put(0f)
468 | pointsBuffer.put(0f)
469 | }
470 | repeat(labelPoints.size) {
471 | labelsBuffer.put((label + 1).toFloat())
472 | }
473 | repeat(maxPoints - labelPoints.size) {
474 | labelsBuffer.put(-1f)
475 | }
476 | }
477 | pointsBuffer.rewind()
478 | labelsBuffer.rewind()
479 |
480 | val (imagesWithMask, time) =
481 | measureTimedValue {
482 | decoder.execute(
483 | encoder.execute(bitmap),
484 | pointsBuffer,
485 | labelsBuffer,
486 | labelsCount.toLong(),
487 | maxPoints.toLong(),
488 | bitmap,
489 | )
490 | }
491 | withContext(Dispatchers.Main) {
492 | viewModel.inferenceTime.intValue = time.toInt(DurationUnit.SECONDS)
493 | hideProgressDialog()
494 | viewModel.images.clear()
495 | viewModel.images.addAll(imagesWithMask)
496 | }
497 | } catch (e: Exception) {
498 | hideProgressDialog()
499 | createAlertDialog(
500 | dialogTitle = "Error",
501 | dialogText = "An error occurred: ${e.message}",
502 | dialogPositiveButtonText = "Close",
503 | dialogNegativeButtonText = null,
504 | onPositiveButtonClick = { finish() },
505 | onNegativeButtonClick = null,
506 | )
507 | }
508 | }
509 | }
510 |
511 | private fun isModelInAssets(modelFileName: String): Boolean = (assets.list("") ?: emptyArray()).contains(modelFileName)
512 |
513 | private fun copyModelToStorage(modelFileName: String) {
514 | val modelFile = File(filesDir, modelFileName)
515 | if (!modelFile.exists()) {
516 | assets.open(modelFileName).use { inputStream ->
517 | openFileOutput(modelFileName, MODE_PRIVATE).use { outputStream ->
518 | inputStream.copyTo(outputStream)
519 | }
520 | }
521 | Log.i(MainActivity::class.simpleName, "$modelFileName copied from assets to app storage")
522 | }
523 | }
524 |
525 | private fun getFixedBitmap(imageFileUri: Uri): Bitmap {
526 | var imageBitmap = BitmapFactory.decodeStream(contentResolver.openInputStream(imageFileUri))
527 | val exifInterface = ExifInterface(contentResolver.openInputStream(imageFileUri)!!)
528 | imageBitmap =
529 | when (
530 | exifInterface.getAttributeInt(
531 | ExifInterface.TAG_ORIENTATION,
532 | ExifInterface.ORIENTATION_UNDEFINED,
533 | )
534 | ) {
535 | ExifInterface.ORIENTATION_ROTATE_90 -> rotateBitmap(imageBitmap, 90f)
536 | ExifInterface.ORIENTATION_ROTATE_180 -> rotateBitmap(imageBitmap, 180f)
537 | ExifInterface.ORIENTATION_ROTATE_270 -> rotateBitmap(imageBitmap, 270f)
538 | else -> imageBitmap
539 | }
540 | return imageBitmap
541 | }
542 |
543 | private fun rotateBitmap(
544 | source: Bitmap,
545 | degrees: Float,
546 | ): Bitmap {
547 | val matrix = Matrix()
548 | matrix.postRotate(degrees)
549 | return Bitmap.createBitmap(source, 0, 0, source.width, source.height, matrix, false)
550 | }
551 |
552 | data class LabelPoint(
553 | val label: Int,
554 | val point: PointF,
555 | )
556 | }
557 |
--------------------------------------------------------------------------------
/app/src/main/java/io/shubham0204/sam_android/MainActivityViewModel.kt:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2025 Shubham Panchal
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.shubham0204.sam_android
18 |
19 | import android.graphics.Bitmap
20 | import androidx.compose.runtime.mutableIntStateOf
21 | import androidx.compose.runtime.mutableStateListOf
22 | import androidx.compose.runtime.mutableStateOf
23 | import androidx.lifecycle.ViewModel
24 |
25 | class MainActivityViewModel : ViewModel() {
26 | val showBottomSheet = mutableStateOf(false)
27 | val selectedLabelIndex = mutableIntStateOf(0)
28 | val lastAddedLabel = mutableIntStateOf(0)
29 | val labels = mutableStateListOf("Label 0")
30 | val points = mutableStateListOf()
31 | val images = mutableStateListOf()
32 | val inferenceTime = mutableIntStateOf(0)
33 |
34 | fun reset() {
35 | images.clear()
36 | points.clear()
37 | labels.clear()
38 | labels.add("Label 0")
39 | selectedLabelIndex.intValue = 0
40 | lastAddedLabel.intValue = 0
41 | showBottomSheet.value = false
42 | inferenceTime.intValue = 0
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/app/src/main/java/io/shubham0204/sam_android/sam/SAMDecoder.kt:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2025 Shubham Panchal
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.shubham0204.sam_android.sam
18 |
19 | import ai.onnxruntime.OnnxTensor
20 | import ai.onnxruntime.OrtEnvironment
21 | import ai.onnxruntime.OrtSession
22 | import ai.onnxruntime.providers.NNAPIFlags
23 | import android.content.Context
24 | import android.graphics.Bitmap
25 | import android.graphics.Color
26 | import android.util.Log
27 | import androidx.core.graphics.get
28 | import androidx.core.graphics.set
29 | import kotlinx.coroutines.Dispatchers
30 | import kotlinx.coroutines.joinAll
31 | import kotlinx.coroutines.launch
32 | import kotlinx.coroutines.withContext
33 | import java.io.File
34 | import java.io.FileOutputStream
35 | import java.nio.FloatBuffer
36 | import java.nio.IntBuffer
37 | import java.util.Collections
38 | import java.util.EnumSet
39 |
40 | class SAMDecoder {
41 | private lateinit var ortEnvironment: OrtEnvironment
42 | private lateinit var ortSession: OrtSession
43 |
44 | // input and output node names for the decoder
45 | // ONNX model
46 | private lateinit var maskOutputName: String
47 | private lateinit var scoresOutputName: String
48 |
49 | private lateinit var imageEmbeddingInputName: String
50 | private lateinit var highResFeature0InputName: String
51 | private lateinit var highResFeature1InputName: String
52 | private lateinit var pointCoordinatesInputName: String
53 | private lateinit var pointLabelsInputName: String
54 | private lateinit var maskInputName: String
55 | private lateinit var hasMaskInputName: String
56 |
57 | suspend fun init(
58 | modelPath: String,
59 | useFP16: Boolean = false,
60 | useXNNPack: Boolean = false,
61 | ) = withContext(Dispatchers.IO) {
62 | ortEnvironment = OrtEnvironment.getEnvironment()
63 | val options =
64 | OrtSession.SessionOptions().apply {
65 | if (useFP16) {
66 | addNnapi(EnumSet.of(NNAPIFlags.USE_FP16))
67 | }
68 | if (useXNNPack) {
69 | addXnnpack(
70 | mapOf(
71 | "intra_op_num_threads" to "2",
72 | ),
73 | )
74 | }
75 | }
76 | ortSession = ortEnvironment.createSession(modelPath, options)
77 | val decoderInputNames = ortSession.inputNames.toList()
78 | val decoderOutputNames = ortSession.outputNames.toList()
79 | Log.i(SAMDecoder::class.simpleName, "Decoder input names: $decoderInputNames")
80 | Log.i(SAMDecoder::class.simpleName, "Decoder output names: $decoderOutputNames")
81 | imageEmbeddingInputName = decoderInputNames[0]
82 | highResFeature0InputName = decoderInputNames[1]
83 | highResFeature1InputName = decoderInputNames[2]
84 | pointCoordinatesInputName = decoderInputNames[3]
85 | pointLabelsInputName = decoderInputNames[4]
86 | maskInputName = decoderInputNames[5]
87 | hasMaskInputName = decoderInputNames[6]
88 |
89 | maskOutputName = decoderOutputNames[0]
90 | scoresOutputName = decoderOutputNames[1]
91 | }
92 |
93 | suspend fun execute(
94 | encoderResults: SAMEncoder.SAMEncoderResults,
95 | pointCoordinates: FloatBuffer,
96 | pointLabels: FloatBuffer,
97 | numLabels: Long,
98 | numPoints: Long,
99 | inputImage: Bitmap,
100 | ): List =
101 | withContext(Dispatchers.Default) {
102 | val imgHeight = inputImage.height
103 | val imgWidth = inputImage.width
104 |
105 | val imageEmbeddingTensor =
106 | OnnxTensor.createTensor(
107 | ortEnvironment,
108 | encoderResults.imageEmbedding,
109 | longArrayOf(1, 256, 64, 64),
110 | )
111 | val highResFeature0Tensor =
112 | OnnxTensor.createTensor(
113 | ortEnvironment,
114 | encoderResults.highResFeature0,
115 | longArrayOf(1, 32, 256, 256),
116 | )
117 | val highResFeature1Tensor =
118 | OnnxTensor.createTensor(
119 | ortEnvironment,
120 | encoderResults.highResFeature1,
121 | longArrayOf(1, 64, 128, 128),
122 | )
123 |
124 | val pointCoordinatesTensor =
125 | OnnxTensor.createTensor(
126 | ortEnvironment,
127 | pointCoordinates,
128 | longArrayOf(numLabels, numPoints, 2),
129 | )
130 | val pointLabelsTensor =
131 | OnnxTensor.createTensor(
132 | ortEnvironment,
133 | pointLabels,
134 | longArrayOf(numLabels, numPoints),
135 | )
136 |
137 | val maskTensor =
138 | OnnxTensor.createTensor(
139 | ortEnvironment,
140 | FloatBuffer.wrap(FloatArray(numLabels.toInt() * 1 * 256 * 256) { 0f }),
141 | longArrayOf(numLabels, 1, 256, 256),
142 | )
143 | val hasMaskTensor =
144 | OnnxTensor.createTensor(
145 | ortEnvironment,
146 | FloatBuffer.wrap(floatArrayOf(0.0f)),
147 | longArrayOf(1),
148 | )
149 | val origImageSizeTensor =
150 | OnnxTensor.createTensor(
151 | ortEnvironment,
152 | IntBuffer.wrap(intArrayOf(imgHeight, imgWidth)),
153 | longArrayOf(2),
154 | )
155 |
156 | val outputs =
157 | ortSession.run(
158 | mapOf(
159 | imageEmbeddingInputName to imageEmbeddingTensor,
160 | highResFeature0InputName to highResFeature0Tensor,
161 | highResFeature1InputName to highResFeature1Tensor,
162 | pointCoordinatesInputName to pointCoordinatesTensor,
163 | pointLabelsInputName to pointLabelsTensor,
164 | maskInputName to maskTensor,
165 | hasMaskInputName to hasMaskTensor,
166 | "orig_im_size" to origImageSizeTensor,
167 | ),
168 | )
169 | val mask = (outputs[maskOutputName].get() as OnnxTensor).floatBuffer
170 | val scores = (outputs[scoresOutputName].get() as OnnxTensor).floatBuffer.array()
171 | Log.i(SAMDecoder::class.simpleName, "scores: ${scores.contentToString()}")
172 |
173 | // We apply masks to the input image in a parallel manner
174 | // by dispatching each (mask,image) pair to a new coroutine
175 | val bitmaps = Collections.synchronizedList(mutableListOf())
176 |
177 | val numPredictedMasks = scores.size / numLabels.toInt()
178 | Log.i(SAMDecoder::class.simpleName, "Num predicted masks: $numPredictedMasks")
179 | Log.i(SAMDecoder::class.simpleName, "Mask size: ${mask.capacity()}")
180 |
181 | (0..
183 | launch(Dispatchers.Default) {
184 | // Apply mask to the input image
185 | // The 'on' pixels (val > 0) in the mask, will deliver an pixel
186 | // with alpha = 0 in the final image
187 | val maskStartIndex = labelIndex * numPredictedMasks * imgHeight * imgWidth
188 | val colorBitmap = Bitmap.createBitmap(imgWidth, imgHeight, Bitmap.Config.ARGB_8888)
189 | for (i in 0.. 0) {
194 | 0
195 | } else {
196 | 255
197 | },
198 | Color.red(inputImage[j, i]),
199 | Color.green(inputImage[j, i]),
200 | Color.blue(inputImage[j, i]),
201 | )
202 | }
203 | }
204 | bitmaps.add(colorBitmap)
205 | }
206 | }.joinAll()
207 |
208 | return@withContext bitmaps
209 | }
210 |
211 | private fun saveBitmap(
212 | context: Context,
213 | image: Bitmap,
214 | name: String,
215 | ) {
216 | val fileOutputStream = FileOutputStream(File(context.filesDir.absolutePath + "/$name.png"))
217 | image.compress(Bitmap.CompressFormat.PNG, 100, fileOutputStream)
218 | }
219 | }
220 |
--------------------------------------------------------------------------------
/app/src/main/java/io/shubham0204/sam_android/sam/SAMEncoder.kt:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2025 Shubham Panchal
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.shubham0204.sam_android.sam
18 |
19 | import ai.onnxruntime.OnnxTensor
20 | import ai.onnxruntime.OrtEnvironment
21 | import ai.onnxruntime.OrtSession
22 | import ai.onnxruntime.providers.NNAPIFlags
23 | import android.graphics.Bitmap
24 | import android.graphics.Color
25 | import android.util.Log
26 | import androidx.core.graphics.get
27 | import kotlinx.coroutines.Dispatchers
28 | import kotlinx.coroutines.withContext
29 | import java.nio.FloatBuffer
30 | import java.util.EnumSet
31 |
32 | class SAMEncoder {
33 | data class SAMEncoderResults(
34 | val imageEmbedding: FloatBuffer,
35 | val highResFeature0: FloatBuffer,
36 | val highResFeature1: FloatBuffer,
37 | )
38 |
39 | private val inputDim = 1024
40 | private lateinit var ortEnvironment: OrtEnvironment
41 | private lateinit var ortSession: OrtSession
42 | private lateinit var inputName: String
43 | private lateinit var imageEmbeddingOutputName: String
44 | private lateinit var highResFeature0OutputName: String
45 | private lateinit var highResFeature1OutputName: String
46 |
47 | private val mean =
48 | floatArrayOf(
49 | 0.485f,
50 | 0.456f,
51 | 0.406f,
52 | )
53 | private val std =
54 | floatArrayOf(
55 | 0.229f,
56 | 0.224f,
57 | 0.225f,
58 | )
59 |
60 | suspend fun init(
61 | modelPath: String,
62 | useFP16: Boolean = false,
63 | useXNNPack: Boolean = false,
64 | ) = withContext(Dispatchers.IO) {
65 | ortEnvironment = OrtEnvironment.getEnvironment()
66 | val options =
67 | OrtSession.SessionOptions().apply {
68 | if (useFP16) {
69 | addNnapi(EnumSet.of(NNAPIFlags.USE_FP16))
70 | }
71 | if (useXNNPack) {
72 | addXnnpack(
73 | mapOf(
74 | "intra_op_num_threads" to "2",
75 | ),
76 | )
77 | }
78 | }
79 | ortSession = ortEnvironment.createSession(modelPath, options)
80 | inputName = ortSession.inputNames.first()
81 | val outputNames = ortSession.outputNames.toList()
82 | Log.i(SAMEncoder::class.simpleName, "Encoder input names: $inputName")
83 | Log.i(SAMEncoder::class.simpleName, "Encoder output names: $outputNames")
84 | highResFeature0OutputName = outputNames[0]
85 | highResFeature1OutputName = outputNames[1]
86 | imageEmbeddingOutputName = outputNames[2]
87 | }
88 |
89 | suspend fun execute(inputImage: Bitmap) =
90 | withContext(Dispatchers.IO) {
91 | // Resize the image to the model's required input size
92 | val resizedImage =
93 | Bitmap.createScaledBitmap(
94 | inputImage,
95 | inputDim,
96 | inputDim,
97 | true,
98 | )
99 |
100 | // Create a FloatBuffer to store the normalized image pixels
101 | // The model requires the image in the shape (1, C, H, W)
102 | val imagePixels = FloatBuffer.allocate(1 * resizedImage.width * resizedImage.height * 3)
103 | imagePixels.rewind()
104 | for (i in 0 until resizedImage.height) {
105 | for (j in 0 until resizedImage.width) {
106 | imagePixels.put(
107 | ((Color.red(resizedImage[j, i]).toFloat() / 255.0f) - mean[0]) / std[0],
108 | )
109 | }
110 | }
111 | for (i in 0 until resizedImage.height) {
112 | for (j in 0 until resizedImage.width) {
113 | imagePixels.put(
114 | ((Color.blue(resizedImage[j, i]).toFloat() / 255.0f) - mean[1]) / std[1],
115 | )
116 | }
117 | }
118 | for (i in 0 until resizedImage.height) {
119 | for (j in 0 until resizedImage.width) {
120 | imagePixels.put(
121 | ((Color.green(resizedImage[j, i]).toFloat() / 255.0f) - mean[2]) / std[2],
122 | )
123 | }
124 | }
125 | imagePixels.rewind()
126 |
127 | // Perform inference, and return the output tensors
128 | val imageTensor =
129 | OnnxTensor.createTensor(
130 | ortEnvironment,
131 | imagePixels,
132 | longArrayOf(1, 3, inputDim.toLong(), inputDim.toLong()),
133 | )
134 | val outputs = ortSession.run(mapOf(inputName to imageTensor))
135 | val highResFeature0 = outputs[highResFeature0OutputName].get() as OnnxTensor
136 | val highResFeature1 = outputs[highResFeature1OutputName].get() as OnnxTensor
137 | val imageEmbedding = outputs[imageEmbeddingOutputName].get() as OnnxTensor
138 | return@withContext SAMEncoderResults(
139 | imageEmbedding.floatBuffer,
140 | highResFeature0.floatBuffer,
141 | highResFeature1.floatBuffer,
142 | )
143 | }
144 | }
145 |
--------------------------------------------------------------------------------
/app/src/main/java/io/shubham0204/sam_android/ui/components/AppAlertDialog.kt:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2025 Shubham Panchal
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.shubham0204.sam_android.ui.components
18 |
19 | import androidx.compose.material3.AlertDialog
20 | import androidx.compose.material3.Text
21 | import androidx.compose.material3.TextButton
22 | import androidx.compose.runtime.Composable
23 | import androidx.compose.runtime.getValue
24 | import androidx.compose.runtime.mutableStateOf
25 | import androidx.compose.runtime.remember
26 |
27 | private var title = ""
28 | private var text = ""
29 | private var positiveButtonText = ""
30 | private var negativeButtonText = ""
31 | private lateinit var positiveButtonOnClick: (() -> Unit)
32 | private lateinit var negativeButtonOnClick: (() -> Unit)
33 | private val alertDialogShowStatus = mutableStateOf(false)
34 |
35 | @Composable
36 | fun AppAlertDialog() {
37 | val visible by remember { alertDialogShowStatus }
38 | if (visible) {
39 | AlertDialog(
40 | title = { Text(text = title) },
41 | text = { Text(text = text) },
42 | onDismissRequest = { /* All alert dialogs are non-cancellable */ },
43 | confirmButton = {
44 | if (positiveButtonText.isNotEmpty()) {
45 | TextButton(
46 | onClick = {
47 | alertDialogShowStatus.value = false
48 | positiveButtonOnClick()
49 | },
50 | ) {
51 | Text(text = positiveButtonText)
52 | }
53 | }
54 | },
55 | dismissButton = {
56 | if (negativeButtonText.isNotEmpty()) {
57 | TextButton(
58 | onClick = {
59 | alertDialogShowStatus.value = false
60 | negativeButtonOnClick()
61 | },
62 | ) {
63 | Text(text = negativeButtonText)
64 | }
65 | }
66 | },
67 | )
68 | }
69 | }
70 |
71 | fun createAlertDialog(
72 | dialogTitle: String,
73 | dialogText: String,
74 | dialogPositiveButtonText: String,
75 | dialogNegativeButtonText: String?,
76 | onPositiveButtonClick: (() -> Unit),
77 | onNegativeButtonClick: (() -> Unit)?,
78 | ) {
79 | title = dialogTitle
80 | text = dialogText
81 | positiveButtonOnClick = onPositiveButtonClick
82 | onNegativeButtonClick?.let { negativeButtonOnClick = it }
83 | positiveButtonText = dialogPositiveButtonText
84 | dialogNegativeButtonText?.let { negativeButtonText = it }
85 | alertDialogShowStatus.value = true
86 | }
87 |
--------------------------------------------------------------------------------
/app/src/main/java/io/shubham0204/sam_android/ui/components/AppProgressDialog.kt:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2025 Shubham Panchal
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | import androidx.compose.foundation.background
18 | import androidx.compose.foundation.layout.Box
19 | import androidx.compose.foundation.layout.Column
20 | import androidx.compose.foundation.layout.Spacer
21 | import androidx.compose.foundation.layout.fillMaxWidth
22 | import androidx.compose.foundation.layout.padding
23 | import androidx.compose.foundation.shape.RoundedCornerShape
24 | import androidx.compose.material3.LinearProgressIndicator
25 | import androidx.compose.material3.Text
26 | import androidx.compose.runtime.Composable
27 | import androidx.compose.runtime.getValue
28 | import androidx.compose.runtime.mutableStateOf
29 | import androidx.compose.runtime.remember
30 | import androidx.compose.ui.Alignment
31 | import androidx.compose.ui.Modifier
32 | import androidx.compose.ui.graphics.Color
33 | import androidx.compose.ui.text.style.TextAlign
34 | import androidx.compose.ui.unit.dp
35 | import androidx.compose.ui.window.Dialog
36 |
37 | private val progressDialogVisibleState = mutableStateOf(false)
38 | private val progressDialogText = mutableStateOf("")
39 |
40 | @Composable
41 | fun AppProgressDialog() {
42 | val isVisible by remember { progressDialogVisibleState }
43 | if (isVisible) {
44 | Dialog(onDismissRequest = { /* Progress dialogs are non-cancellable */ }) {
45 | Box(
46 | contentAlignment = Alignment.Center,
47 | modifier =
48 | Modifier
49 | .fillMaxWidth()
50 | .background(Color.White, shape = RoundedCornerShape(8.dp)),
51 | ) {
52 | Column(
53 | horizontalAlignment = Alignment.CenterHorizontally,
54 | modifier = Modifier.padding(vertical = 24.dp),
55 | ) {
56 | LinearProgressIndicator(modifier = Modifier.fillMaxWidth())
57 | Spacer(modifier = Modifier.padding(4.dp))
58 | Text(
59 | text = progressDialogText.value,
60 | textAlign = TextAlign.Center,
61 | modifier = Modifier.fillMaxWidth().padding(horizontal = 16.dp),
62 | )
63 | }
64 | }
65 | }
66 | }
67 | }
68 |
69 | fun setProgressDialogText(message: String) {
70 | progressDialogText.value = message
71 | }
72 |
73 | fun showProgressDialog() {
74 | progressDialogVisibleState.value = true
75 | progressDialogText.value = ""
76 | }
77 |
78 | fun hideProgressDialog() {
79 | progressDialogVisibleState.value = false
80 | }
81 |
--------------------------------------------------------------------------------
/app/src/main/java/io/shubham0204/sam_android/ui/theme/Color.kt:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2025 Shubham Panchal
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.shubham0204.sam_android.ui.theme
18 |
19 | import androidx.compose.ui.graphics.Color
20 |
21 | val Purple80 = Color(0xFFD0BCFF)
22 | val PurpleGrey80 = Color(0xFFCCC2DC)
23 | val Pink80 = Color(0xFFEFB8C8)
24 |
25 | val Purple40 = Color(0xFF6650a4)
26 | val PurpleGrey40 = Color(0xFF625b71)
27 | val Pink40 = Color(0xFF7D5260)
28 |
--------------------------------------------------------------------------------
/app/src/main/java/io/shubham0204/sam_android/ui/theme/Theme.kt:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2025 Shubham Panchal
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.shubham0204.sam_android.ui.theme
18 |
19 | import android.os.Build
20 | import androidx.compose.foundation.isSystemInDarkTheme
21 | import androidx.compose.material3.MaterialTheme
22 | import androidx.compose.material3.darkColorScheme
23 | import androidx.compose.material3.dynamicDarkColorScheme
24 | import androidx.compose.material3.dynamicLightColorScheme
25 | import androidx.compose.material3.lightColorScheme
26 | import androidx.compose.runtime.Composable
27 | import androidx.compose.ui.platform.LocalContext
28 |
29 | private val DarkColorScheme =
30 | darkColorScheme(
31 | primary = Purple80,
32 | secondary = PurpleGrey80,
33 | tertiary = Pink80,
34 | )
35 |
36 | private val LightColorScheme =
37 | lightColorScheme(
38 | primary = Purple40,
39 | secondary = PurpleGrey40,
40 | tertiary = Pink40,
41 | /* Other default colors to override
42 | background = Color(0xFFFFFBFE),
43 | surface = Color(0xFFFFFBFE),
44 | onPrimary = Color.White,
45 | onSecondary = Color.White,
46 | onTertiary = Color.White,
47 | onBackground = Color(0xFF1C1B1F),
48 | onSurface = Color(0xFF1C1B1F),
49 | */
50 | )
51 |
52 | @Composable
53 | fun SAMAndroidTheme(
54 | darkTheme: Boolean = isSystemInDarkTheme(),
55 | // Dynamic color is available on Android 12+
56 | dynamicColor: Boolean = true,
57 | content: @Composable () -> Unit,
58 | ) {
59 | val colorScheme =
60 | when {
61 | dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> {
62 | val context = LocalContext.current
63 | if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context)
64 | }
65 |
66 | darkTheme -> DarkColorScheme
67 | else -> LightColorScheme
68 | }
69 |
70 | MaterialTheme(
71 | colorScheme = colorScheme,
72 | typography = Typography,
73 | content = content,
74 | )
75 | }
76 |
--------------------------------------------------------------------------------
/app/src/main/java/io/shubham0204/sam_android/ui/theme/Type.kt:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2025 Shubham Panchal
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package io.shubham0204.sam_android.ui.theme
18 |
19 | import androidx.compose.material3.Typography
20 | import androidx.compose.ui.text.TextStyle
21 | import androidx.compose.ui.text.font.FontFamily
22 | import androidx.compose.ui.text.font.FontWeight
23 | import androidx.compose.ui.unit.sp
24 |
25 | // Set of Material typography styles to start with
26 | val Typography =
27 | Typography(
28 | bodyLarge =
29 | TextStyle(
30 | fontFamily = FontFamily.Default,
31 | fontWeight = FontWeight.Normal,
32 | fontSize = 16.sp,
33 | lineHeight = 24.sp,
34 | letterSpacing = 0.5.sp,
35 | ),
36 | /* Other default text styles to override
37 | titleLarge = TextStyle(
38 | fontFamily = FontFamily.Default,
39 | fontWeight = FontWeight.Normal,
40 | fontSize = 22.sp,
41 | lineHeight = 28.sp,
42 | letterSpacing = 0.sp
43 | ),
44 | labelSmall = TextStyle(
45 | fontFamily = FontFamily.Default,
46 | fontWeight = FontWeight.Medium,
47 | fontSize = 11.sp,
48 | lineHeight = 16.sp,
49 | letterSpacing = 0.5.sp
50 | )
51 | */
52 | )
53 |
--------------------------------------------------------------------------------
/app/src/main/res/drawable/ic_launcher_background.xml:
--------------------------------------------------------------------------------
1 |
2 |
17 |
18 |
23 |
26 |
31 |
36 |
41 |
46 |
51 |
56 |
61 |
66 |
71 |
76 |
81 |
86 |
91 |
96 |
101 |
106 |
111 |
116 |
121 |
126 |
131 |
136 |
141 |
146 |
151 |
156 |
161 |
166 |
171 |
176 |
181 |
186 |
187 |
--------------------------------------------------------------------------------
/app/src/main/res/drawable/ic_launcher_foreground.xml:
--------------------------------------------------------------------------------
1 |
16 |
17 |
23 |
24 |
25 |
31 |
34 |
37 |
38 |
39 |
40 |
46 |
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-anydpi/ic_launcher.xml:
--------------------------------------------------------------------------------
1 |
2 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml:
--------------------------------------------------------------------------------
1 |
2 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-hdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham0204/Segment-Anything-Android/5efe3b59b5478c5f3bcd62cda062e95554c53936/app/src/main/res/mipmap-hdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham0204/Segment-Anything-Android/5efe3b59b5478c5f3bcd62cda062e95554c53936/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-mdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham0204/Segment-Anything-Android/5efe3b59b5478c5f3bcd62cda062e95554c53936/app/src/main/res/mipmap-mdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham0204/Segment-Anything-Android/5efe3b59b5478c5f3bcd62cda062e95554c53936/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-xhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham0204/Segment-Anything-Android/5efe3b59b5478c5f3bcd62cda062e95554c53936/app/src/main/res/mipmap-xhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham0204/Segment-Anything-Android/5efe3b59b5478c5f3bcd62cda062e95554c53936/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham0204/Segment-Anything-Android/5efe3b59b5478c5f3bcd62cda062e95554c53936/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham0204/Segment-Anything-Android/5efe3b59b5478c5f3bcd62cda062e95554c53936/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham0204/Segment-Anything-Android/5efe3b59b5478c5f3bcd62cda062e95554c53936/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham0204/Segment-Anything-Android/5efe3b59b5478c5f3bcd62cda062e95554c53936/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/app/src/main/res/values/colors.xml:
--------------------------------------------------------------------------------
1 |
2 |
17 |
18 |
19 | #FFBB86FC
20 | #FF6200EE
21 | #FF3700B3
22 | #FF03DAC5
23 | #FF018786
24 | #FF000000
25 | #FFFFFFFF
26 |
--------------------------------------------------------------------------------
/app/src/main/res/values/strings.xml:
--------------------------------------------------------------------------------
1 |
16 |
17 |
18 | SAM-Android
19 |
--------------------------------------------------------------------------------
/app/src/main/res/values/themes.xml:
--------------------------------------------------------------------------------
1 |
2 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/app/src/main/res/xml/backup_rules.xml:
--------------------------------------------------------------------------------
1 |
16 |
17 |
24 |
25 |
29 |
--------------------------------------------------------------------------------
/app/src/main/res/xml/data_extraction_rules.xml:
--------------------------------------------------------------------------------
1 |
16 |
17 |
22 |
23 |
24 |
28 |
29 |
35 |
--------------------------------------------------------------------------------
/build.gradle.kts:
--------------------------------------------------------------------------------
1 | // Top-level build file where you can add configuration options common to all sub-projects/modules.
2 | plugins {
3 | alias(libs.plugins.android.application) apply false
4 | alias(libs.plugins.jetbrains.kotlin.android) apply false
5 | alias(libs.plugins.compose.compiler) apply false
6 | }
--------------------------------------------------------------------------------
/gradle.properties:
--------------------------------------------------------------------------------
1 | # Project-wide Gradle settings.
2 | # IDE (e.g. Android Studio) users:
3 | # Gradle settings configured through the IDE *will override*
4 | # any settings specified in this file.
5 | # For more details on how to configure your build environment visit
6 | # http://www.gradle.org/docs/current/userguide/build_environment.html
7 | # Specifies the JVM arguments used for the daemon process.
8 | # The setting is particularly useful for tweaking memory settings.
9 | org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
10 | # When configured, Gradle will run in incubating parallel mode.
11 | # This option should only be used with decoupled projects. For more details, visit
12 | # https://developer.android.com/r/tools/gradle-multi-project-decoupled-projects
13 | # org.gradle.parallel=true
14 | # AndroidX package structure to make it clearer which packages are bundled with the
15 | # Android operating system, and which are packaged with your app's APK
16 | # https://developer.android.com/topic/libraries/support-library/androidx-rn
17 | android.useAndroidX=true
18 | # Kotlin code style for this project: "official" or "obsolete":
19 | kotlin.code.style=official
20 | # Enables namespacing of each library's R class so that its R class includes only the
21 | # resources declared in the library itself and none from the library's dependencies,
22 | # thereby reducing the size of the R class for that library
23 | android.nonTransitiveRClass=true
24 | org.gradle.configuration-cache=true
--------------------------------------------------------------------------------
/gradle/libs.versions.toml:
--------------------------------------------------------------------------------
1 | [versions]
2 | agp = "8.5.2"
3 | kotlin = "2.0.0"
4 | coreKtx = "1.15.0"
5 | junit = "4.13.2"
6 | junitVersion = "1.2.1"
7 | espressoCore = "3.6.1"
8 | lifecycleRuntimeKtx = "2.8.7"
9 | activityCompose = "1.10.0"
10 | composeBom = "2025.01.01"
11 | exifinterface = "1.3.7"
12 |
13 | [libraries]
14 | androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
15 | androidx-lifecycle-viewmodel-compose = { module = "androidx.lifecycle:lifecycle-viewmodel-compose", version.ref = "lifecycleRuntimeKtx" }
16 | junit = { group = "junit", name = "junit", version.ref = "junit" }
17 | androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "junitVersion" }
18 | androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espressoCore" }
19 | androidx-lifecycle-runtime-ktx = { group = "androidx.lifecycle", name = "lifecycle-runtime-ktx", version.ref = "lifecycleRuntimeKtx" }
20 | androidx-activity-compose = { group = "androidx.activity", name = "activity-compose", version.ref = "activityCompose" }
21 | androidx-compose-bom = { group = "androidx.compose", name = "compose-bom", version.ref = "composeBom" }
22 | androidx-ui = { group = "androidx.compose.ui", name = "ui" }
23 | androidx-ui-graphics = { group = "androidx.compose.ui", name = "ui-graphics" }
24 | androidx-ui-tooling = { group = "androidx.compose.ui", name = "ui-tooling" }
25 | androidx-ui-tooling-preview = { group = "androidx.compose.ui", name = "ui-tooling-preview" }
26 | androidx-ui-test-manifest = { group = "androidx.compose.ui", name = "ui-test-manifest" }
27 | androidx-ui-test-junit4 = { group = "androidx.compose.ui", name = "ui-test-junit4" }
28 | androidx-material3 = { group = "androidx.compose.material3", name = "material3" }
29 | androidx-compose-material3-icons-extended = { module = "androidx.compose.material:material-icons-extended" }
30 | androidx-compose-runtime-livedata = { module = "androidx.compose.runtime:runtime-livedata" }
31 | androidx-exifinterface = { group = "androidx.exifinterface", name = "exifinterface", version.ref = "exifinterface" }
32 |
33 | [plugins]
34 | android-application = { id = "com.android.application", version.ref = "agp" }
35 | jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
36 | compose-compiler = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" }
37 |
--------------------------------------------------------------------------------
/gradle/wrapper/gradle-wrapper.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shubham0204/Segment-Anything-Android/5efe3b59b5478c5f3bcd62cda062e95554c53936/gradle/wrapper/gradle-wrapper.jar
--------------------------------------------------------------------------------
/gradle/wrapper/gradle-wrapper.properties:
--------------------------------------------------------------------------------
1 | #Fri Aug 09 09:58:09 IST 2024
2 | distributionBase=GRADLE_USER_HOME
3 | distributionPath=wrapper/dists
4 | distributionUrl=https\://services.gradle.org/distributions/gradle-8.7-bin.zip
5 | zipStoreBase=GRADLE_USER_HOME
6 | zipStorePath=wrapper/dists
7 |
--------------------------------------------------------------------------------
/gradlew:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | #
4 | # Copyright 2015 the original author or authors.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # https://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | #
18 |
19 | ##############################################################################
20 | ##
21 | ## Gradle start up script for UN*X
22 | ##
23 | ##############################################################################
24 |
25 | # Attempt to set APP_HOME
26 | # Resolve links: $0 may be a link
27 | PRG="$0"
28 | # Need this for relative symlinks.
29 | while [ -h "$PRG" ] ; do
30 | ls=`ls -ld "$PRG"`
31 | link=`expr "$ls" : '.*-> \(.*\)$'`
32 | if expr "$link" : '/.*' > /dev/null; then
33 | PRG="$link"
34 | else
35 | PRG=`dirname "$PRG"`"/$link"
36 | fi
37 | done
38 | SAVED="`pwd`"
39 | cd "`dirname \"$PRG\"`/" >/dev/null
40 | APP_HOME="`pwd -P`"
41 | cd "$SAVED" >/dev/null
42 |
43 | APP_NAME="Gradle"
44 | APP_BASE_NAME=`basename "$0"`
45 |
46 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
47 | DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
48 |
49 | # Use the maximum available, or set MAX_FD != -1 to use that value.
50 | MAX_FD="maximum"
51 |
52 | warn () {
53 | echo "$*"
54 | }
55 |
56 | die () {
57 | echo
58 | echo "$*"
59 | echo
60 | exit 1
61 | }
62 |
63 | # OS specific support (must be 'true' or 'false').
64 | cygwin=false
65 | msys=false
66 | darwin=false
67 | nonstop=false
68 | case "`uname`" in
69 | CYGWIN* )
70 | cygwin=true
71 | ;;
72 | Darwin* )
73 | darwin=true
74 | ;;
75 | MINGW* )
76 | msys=true
77 | ;;
78 | NONSTOP* )
79 | nonstop=true
80 | ;;
81 | esac
82 |
83 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
84 |
85 |
86 | # Determine the Java command to use to start the JVM.
87 | if [ -n "$JAVA_HOME" ] ; then
88 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
89 | # IBM's JDK on AIX uses strange locations for the executables
90 | JAVACMD="$JAVA_HOME/jre/sh/java"
91 | else
92 | JAVACMD="$JAVA_HOME/bin/java"
93 | fi
94 | if [ ! -x "$JAVACMD" ] ; then
95 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
96 |
97 | Please set the JAVA_HOME variable in your environment to match the
98 | location of your Java installation."
99 | fi
100 | else
101 | JAVACMD="java"
102 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
103 |
104 | Please set the JAVA_HOME variable in your environment to match the
105 | location of your Java installation."
106 | fi
107 |
108 | # Increase the maximum file descriptors if we can.
109 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
110 | MAX_FD_LIMIT=`ulimit -H -n`
111 | if [ $? -eq 0 ] ; then
112 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
113 | MAX_FD="$MAX_FD_LIMIT"
114 | fi
115 | ulimit -n $MAX_FD
116 | if [ $? -ne 0 ] ; then
117 | warn "Could not set maximum file descriptor limit: $MAX_FD"
118 | fi
119 | else
120 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
121 | fi
122 | fi
123 |
124 | # For Darwin, add options to specify how the application appears in the dock
125 | if $darwin; then
126 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
127 | fi
128 |
129 | # For Cygwin or MSYS, switch paths to Windows format before running java
130 | if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
131 | APP_HOME=`cygpath --path --mixed "$APP_HOME"`
132 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
133 |
134 | JAVACMD=`cygpath --unix "$JAVACMD"`
135 |
136 | # We build the pattern for arguments to be converted via cygpath
137 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
138 | SEP=""
139 | for dir in $ROOTDIRSRAW ; do
140 | ROOTDIRS="$ROOTDIRS$SEP$dir"
141 | SEP="|"
142 | done
143 | OURCYGPATTERN="(^($ROOTDIRS))"
144 | # Add a user-defined pattern to the cygpath arguments
145 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then
146 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
147 | fi
148 | # Now convert the arguments - kludge to limit ourselves to /bin/sh
149 | i=0
150 | for arg in "$@" ; do
151 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
152 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
153 |
154 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
155 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
156 | else
157 | eval `echo args$i`="\"$arg\""
158 | fi
159 | i=`expr $i + 1`
160 | done
161 | case $i in
162 | 0) set -- ;;
163 | 1) set -- "$args0" ;;
164 | 2) set -- "$args0" "$args1" ;;
165 | 3) set -- "$args0" "$args1" "$args2" ;;
166 | 4) set -- "$args0" "$args1" "$args2" "$args3" ;;
167 | 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
168 | 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
169 | 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
170 | 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
171 | 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
172 | esac
173 | fi
174 |
175 | # Escape application args
176 | save () {
177 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
178 | echo " "
179 | }
180 | APP_ARGS=`save "$@"`
181 |
182 | # Collect all arguments for the java command, following the shell quoting and substitution rules
183 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
184 |
185 | exec "$JAVACMD" "$@"
186 |
--------------------------------------------------------------------------------
/gradlew.bat:
--------------------------------------------------------------------------------
1 | @rem
2 | @rem Copyright 2015 the original author or authors.
3 | @rem
4 | @rem Licensed under the Apache License, Version 2.0 (the "License");
5 | @rem you may not use this file except in compliance with the License.
6 | @rem You may obtain a copy of the License at
7 | @rem
8 | @rem https://www.apache.org/licenses/LICENSE-2.0
9 | @rem
10 | @rem Unless required by applicable law or agreed to in writing, software
11 | @rem distributed under the License is distributed on an "AS IS" BASIS,
12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | @rem See the License for the specific language governing permissions and
14 | @rem limitations under the License.
15 | @rem
16 |
17 | @if "%DEBUG%" == "" @echo off
18 | @rem ##########################################################################
19 | @rem
20 | @rem Gradle startup script for Windows
21 | @rem
22 | @rem ##########################################################################
23 |
24 | @rem Set local scope for the variables with windows NT shell
25 | if "%OS%"=="Windows_NT" setlocal
26 |
27 | set DIRNAME=%~dp0
28 | if "%DIRNAME%" == "" set DIRNAME=.
29 | set APP_BASE_NAME=%~n0
30 | set APP_HOME=%DIRNAME%
31 |
32 | @rem Resolve any "." and ".." in APP_HOME to make it shorter.
33 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
34 |
35 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
36 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
37 |
38 | @rem Find java.exe
39 | if defined JAVA_HOME goto findJavaFromJavaHome
40 |
41 | set JAVA_EXE=java.exe
42 | %JAVA_EXE% -version >NUL 2>&1
43 | if "%ERRORLEVEL%" == "0" goto execute
44 |
45 | echo.
46 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
47 | echo.
48 | echo Please set the JAVA_HOME variable in your environment to match the
49 | echo location of your Java installation.
50 |
51 | goto fail
52 |
53 | :findJavaFromJavaHome
54 | set JAVA_HOME=%JAVA_HOME:"=%
55 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe
56 |
57 | if exist "%JAVA_EXE%" goto execute
58 |
59 | echo.
60 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
61 | echo.
62 | echo Please set the JAVA_HOME variable in your environment to match the
63 | echo location of your Java installation.
64 |
65 | goto fail
66 |
67 | :execute
68 | @rem Setup the command line
69 |
70 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
71 |
72 |
73 | @rem Execute Gradle
74 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
75 |
76 | :end
77 | @rem End local scope for the variables with windows NT shell
78 | if "%ERRORLEVEL%"=="0" goto mainEnd
79 |
80 | :fail
81 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
82 | rem the _cmd.exe /c_ return code!
83 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
84 | exit /b 1
85 |
86 | :mainEnd
87 | if "%OS%"=="Windows_NT" endlocal
88 |
89 | :omega
90 |
--------------------------------------------------------------------------------
/notebooks/SAM2_ONNX_Export.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "gpuType": "T4"
8 | },
9 | "kernelspec": {
10 | "name": "python3",
11 | "display_name": "Python 3"
12 | },
13 | "language_info": {
14 | "name": "python"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "source": [
22 | "# **Segment Anything Model 2 (SAM 2)**\n",
23 | ""
24 | ],
25 | "metadata": {
26 | "id": "qlsCAu5JBIcn"
27 | }
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "source": [
32 | "## Installation !!Requires GPU runtime!!"
33 | ],
34 | "metadata": {
35 | "id": "TVR_BlMZD4XJ"
36 | }
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {
42 | "colab": {
43 | "base_uri": "https://localhost:8080/",
44 | "height": 1000
45 | },
46 | "id": "sBSQnXVFAiwg",
47 | "outputId": "33994785-24fb-48f0-9fa6-7c554f19b21d"
48 | },
49 | "outputs": [
50 | {
51 | "output_type": "stream",
52 | "name": "stdout",
53 | "text": [
54 | "/content\n",
55 | "fatal: destination path 'segment-anything-2' already exists and is not an empty directory.\n",
56 | "/content/segment-anything-2\n",
57 | "Obtaining file:///content/segment-anything-2\n",
58 | " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
59 | " Checking if build backend supports build_editable ... \u001b[?25l\u001b[?25hdone\n",
60 | " Getting requirements to build editable ... \u001b[?25l\u001b[?25hdone\n",
61 | " Preparing editable metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
62 | "Requirement already satisfied: torch>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from SAM-2==1.0) (2.3.1+cu121)\n",
63 | "Requirement already satisfied: torchvision>=0.18.1 in /usr/local/lib/python3.10/dist-packages (from SAM-2==1.0) (0.18.1+cu121)\n",
64 | "Requirement already satisfied: numpy>=1.24.4 in /usr/local/lib/python3.10/dist-packages (from SAM-2==1.0) (1.26.4)\n",
65 | "Requirement already satisfied: tqdm>=4.66.1 in /usr/local/lib/python3.10/dist-packages (from SAM-2==1.0) (4.66.5)\n",
66 | "Requirement already satisfied: hydra-core>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from SAM-2==1.0) (1.3.2)\n",
67 | "Requirement already satisfied: iopath>=0.1.10 in /usr/local/lib/python3.10/dist-packages (from SAM-2==1.0) (0.1.10)\n",
68 | "Requirement already satisfied: pillow>=9.4.0 in /usr/local/lib/python3.10/dist-packages (from SAM-2==1.0) (9.4.0)\n",
69 | "Requirement already satisfied: omegaconf<2.4,>=2.2 in /usr/local/lib/python3.10/dist-packages (from hydra-core>=1.3.2->SAM-2==1.0) (2.3.0)\n",
70 | "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.10/dist-packages (from hydra-core>=1.3.2->SAM-2==1.0) (4.9.3)\n",
71 | "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from hydra-core>=1.3.2->SAM-2==1.0) (24.1)\n",
72 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from iopath>=0.1.10->SAM-2==1.0) (4.12.2)\n",
73 | "Requirement already satisfied: portalocker in /usr/local/lib/python3.10/dist-packages (from iopath>=0.1.10->SAM-2==1.0) (2.10.1)\n",
74 | "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (3.15.4)\n",
75 | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (1.13.1)\n",
76 | "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (3.3)\n",
77 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (3.1.4)\n",
78 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (2024.6.1)\n",
79 | "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (12.1.105)\n",
80 | "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (12.1.105)\n",
81 | "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (12.1.105)\n",
82 | "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (8.9.2.26)\n",
83 | "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (12.1.3.1)\n",
84 | "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (11.0.2.54)\n",
85 | "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (10.3.2.106)\n",
86 | "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (11.4.5.107)\n",
87 | "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (12.1.0.106)\n",
88 | "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (2.20.5)\n",
89 | "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (12.1.105)\n",
90 | "Requirement already satisfied: triton==2.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.1->SAM-2==1.0) (2.3.1)\n",
91 | "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=2.3.1->SAM-2==1.0) (12.6.20)\n",
92 | "Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.10/dist-packages (from omegaconf<2.4,>=2.2->hydra-core>=1.3.2->SAM-2==1.0) (6.0.2)\n",
93 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2.3.1->SAM-2==1.0) (2.1.5)\n",
94 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.3.1->SAM-2==1.0) (1.3.0)\n",
95 | "Building wheels for collected packages: SAM-2\n",
96 | " Building editable for SAM-2 (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
97 | " Created wheel for SAM-2: filename=SAM_2-1.0-0.editable-cp310-cp310-linux_x86_64.whl size=12322 sha256=5a4c2a3993372d4f3cb506a6cf1b922e3f20621e2ebf44804bb5868e05090689\n",
98 | " Stored in directory: /tmp/pip-ephem-wheel-cache-7zzz370q/wheels/7d/af/fe/c05425a1fdc391329545b53111d5cabdfc241ee07cab053945\n",
99 | "Successfully built SAM-2\n",
100 | "Installing collected packages: SAM-2\n",
101 | " Attempting uninstall: SAM-2\n",
102 | " Found existing installation: SAM-2 1.0\n",
103 | " Uninstalling SAM-2-1.0:\n",
104 | " Successfully uninstalled SAM-2-1.0\n",
105 | "Successfully installed SAM-2-1.0\n"
106 | ]
107 | },
108 | {
109 | "output_type": "display_data",
110 | "data": {
111 | "application/vnd.colab-display-data+json": {
112 | "pip_warning": {
113 | "packages": [
114 | "sam2",
115 | "sam2_configs"
116 | ]
117 | },
118 | "id": "aa560bb501504030ab74cbd178f01cd9"
119 | }
120 | },
121 | "metadata": {}
122 | },
123 | {
124 | "output_type": "stream",
125 | "name": "stdout",
126 | "text": [
127 | "Requirement already satisfied: onnx in /usr/local/lib/python3.10/dist-packages (1.16.2)\n",
128 | "Requirement already satisfied: onnxscript in /usr/local/lib/python3.10/dist-packages (0.1.0.dev20240818)\n",
129 | "Requirement already satisfied: onnxsim in /usr/local/lib/python3.10/dist-packages (0.4.36)\n",
130 | "Requirement already satisfied: onnxruntime in /usr/local/lib/python3.10/dist-packages (1.19.0)\n",
131 | "Requirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.10/dist-packages (from onnx) (1.26.4)\n",
132 | "Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx) (3.20.3)\n",
133 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from onnxscript) (4.12.2)\n",
134 | "Requirement already satisfied: ml-dtypes in /usr/local/lib/python3.10/dist-packages (from onnxscript) (0.4.0)\n",
135 | "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from onnxscript) (24.1)\n",
136 | "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from onnxsim) (13.7.1)\n",
137 | "Requirement already satisfied: coloredlogs in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (15.0.1)\n",
138 | "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (24.3.25)\n",
139 | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (1.13.1)\n",
140 | "Requirement already satisfied: humanfriendly>=9.1 in /usr/local/lib/python3.10/dist-packages (from coloredlogs->onnxruntime) (10.0)\n",
141 | "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->onnxsim) (3.0.0)\n",
142 | "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->onnxsim) (2.16.1)\n",
143 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->onnxruntime) (1.3.0)\n",
144 | "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->onnxsim) (0.1.2)\n"
145 | ]
146 | }
147 | ],
148 | "source": [
149 | "%cd /content\n",
150 | "!git clone https://github.com/facebookresearch/segment-anything-2.git\n",
151 | "%cd /content/segment-anything-2\n",
152 | "!pip3 install -e .\n",
153 | "!pip3 install onnx onnxscript onnxsim onnxruntime"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "source": [
159 | "%cd /content/segment-anything-2/checkpoints\n",
160 | "!./download_ckpts.sh"
161 | ],
162 | "metadata": {
163 | "id": "Vtld9UUcAxH_",
164 | "colab": {
165 | "base_uri": "https://localhost:8080/"
166 | },
167 | "outputId": "da05c017-291c-4b56-88c8-aec6f472f338"
168 | },
169 | "execution_count": null,
170 | "outputs": [
171 | {
172 | "output_type": "stream",
173 | "name": "stdout",
174 | "text": [
175 | "/content/segment-anything-2/checkpoints\n",
176 | "Downloading sam2_hiera_tiny.pt checkpoint...\n",
177 | "--2024-08-18 06:36:38-- https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt\n",
178 | "Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 18.165.83.35, 18.165.83.91, 18.165.83.44, ...\n",
179 | "Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|18.165.83.35|:443... connected.\n",
180 | "HTTP request sent, awaiting response... 200 OK\n",
181 | "Length: 155906050 (149M) [application/vnd.snesdev-page-table]\n",
182 | "Saving to: ‘sam2_hiera_tiny.pt’\n",
183 | "\n",
184 | "sam2_hiera_tiny.pt 100%[===================>] 148.68M 140MB/s in 1.1s \n",
185 | "\n",
186 | "2024-08-18 06:36:39 (140 MB/s) - ‘sam2_hiera_tiny.pt’ saved [155906050/155906050]\n",
187 | "\n",
188 | "Downloading sam2_hiera_small.pt checkpoint...\n",
189 | "--2024-08-18 06:36:39-- https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt\n",
190 | "Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 18.165.83.35, 18.165.83.91, 18.165.83.44, ...\n",
191 | "Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|18.165.83.35|:443... connected.\n",
192 | "HTTP request sent, awaiting response... 200 OK\n",
193 | "Length: 184309650 (176M) [application/vnd.snesdev-page-table]\n",
194 | "Saving to: ‘sam2_hiera_small.pt’\n",
195 | "\n",
196 | "sam2_hiera_small.pt 100%[===================>] 175.77M 144MB/s in 1.2s \n",
197 | "\n",
198 | "2024-08-18 06:36:40 (144 MB/s) - ‘sam2_hiera_small.pt’ saved [184309650/184309650]\n",
199 | "\n",
200 | "Downloading sam2_hiera_base_plus.pt checkpoint...\n",
201 | "--2024-08-18 06:36:40-- https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt\n",
202 | "Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 18.165.83.35, 18.165.83.91, 18.165.83.44, ...\n",
203 | "Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|18.165.83.35|:443... connected.\n",
204 | "HTTP request sent, awaiting response... 200 OK\n",
205 | "Length: 323493298 (309M) [application/vnd.snesdev-page-table]\n",
206 | "Saving to: ‘sam2_hiera_base_plus.pt’\n",
207 | "\n",
208 | "sam2_hiera_base_plu 100%[===================>] 308.51M 176MB/s in 1.8s \n",
209 | "\n",
210 | "2024-08-18 06:36:42 (176 MB/s) - ‘sam2_hiera_base_plus.pt’ saved [323493298/323493298]\n",
211 | "\n",
212 | "Downloading sam2_hiera_large.pt checkpoint...\n",
213 | "--2024-08-18 06:36:42-- https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt\n",
214 | "Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 18.165.83.35, 18.165.83.91, 18.165.83.44, ...\n",
215 | "Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|18.165.83.35|:443... connected.\n",
216 | "HTTP request sent, awaiting response... 200 OK\n",
217 | "Length: 897952466 (856M) [application/vnd.snesdev-page-table]\n",
218 | "Saving to: ‘sam2_hiera_large.pt’\n",
219 | "\n",
220 | "sam2_hiera_large.pt 100%[===================>] 856.35M 119MB/s in 8.7s \n",
221 | "\n",
222 | "2024-08-18 06:36:51 (97.9 MB/s) - ‘sam2_hiera_large.pt’ saved [897952466/897952466]\n",
223 | "\n",
224 | "All checkpoints are downloaded successfully.\n"
225 | ]
226 | }
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "source": [
232 | "%cd /content/segment-anything-2/\n",
233 | "from typing import Optional, Tuple, Any\n",
234 | "import torch\n",
235 | "from torch import nn\n",
236 | "import torch.nn.functional as F\n",
237 | "from torch.nn.init import trunc_normal_\n",
238 | "\n",
239 | "\n",
240 | "from sam2.modeling.sam2_base import SAM2Base\n",
241 | "\n",
242 | "class SAM2ImageEncoder(nn.Module):\n",
243 | " def __init__(self, sam_model: SAM2Base) -> None:\n",
244 | " super().__init__()\n",
245 | " self.model = sam_model\n",
246 | " self.image_encoder = sam_model.image_encoder\n",
247 | " self.no_mem_embed = sam_model.no_mem_embed\n",
248 | "\n",
249 | " def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:\n",
250 | " backbone_out = self.image_encoder(x)\n",
251 | " backbone_out[\"backbone_fpn\"][0] = self.model.sam_mask_decoder.conv_s0(\n",
252 | " backbone_out[\"backbone_fpn\"][0]\n",
253 | " )\n",
254 | " backbone_out[\"backbone_fpn\"][1] = self.model.sam_mask_decoder.conv_s1(\n",
255 | " backbone_out[\"backbone_fpn\"][1]\n",
256 | " )\n",
257 | "\n",
258 | " feature_maps = backbone_out[\"backbone_fpn\"][-self.model.num_feature_levels:]\n",
259 | " vision_pos_embeds = backbone_out[\"vision_pos_enc\"][-self.model.num_feature_levels:]\n",
260 | "\n",
261 | " feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]\n",
262 | "\n",
263 | " # flatten NxCxHxW to HWxNxC\n",
264 | " vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]\n",
265 | " vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]\n",
266 | "\n",
267 | " vision_feats[-1] = vision_feats[-1] + self.no_mem_embed\n",
268 | "\n",
269 | " feats = [feat.permute(1, 2, 0).reshape(1, -1, *feat_size)\n",
270 | " for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])][::-1]\n",
271 | "\n",
272 | " return feats[0], feats[1], feats[2]\n",
273 | "\n",
274 | "\n",
275 | "class SAM2ImageDecoder(nn.Module):\n",
276 | " def __init__(\n",
277 | " self,\n",
278 | " sam_model: SAM2Base,\n",
279 | " multimask_output: bool\n",
280 | " ) -> None:\n",
281 | " super().__init__()\n",
282 | " self.mask_decoder = sam_model.sam_mask_decoder\n",
283 | " self.prompt_encoder = sam_model.sam_prompt_encoder\n",
284 | " self.model = sam_model\n",
285 | " self.multimask_output = multimask_output\n",
286 | "\n",
287 | " @torch.no_grad()\n",
288 | " def forward(\n",
289 | " self,\n",
290 | " image_embed: torch.Tensor,\n",
291 | " high_res_feats_0: torch.Tensor,\n",
292 | " high_res_feats_1: torch.Tensor,\n",
293 | " point_coords: torch.Tensor,\n",
294 | " point_labels: torch.Tensor,\n",
295 | " mask_input: torch.Tensor,\n",
296 | " has_mask_input: torch.Tensor,\n",
297 | " img_size: torch.Tensor\n",
298 | " ):\n",
299 | " sparse_embedding = self._embed_points(point_coords, point_labels)\n",
300 | " self.sparse_embedding = sparse_embedding\n",
301 | " dense_embedding = self._embed_masks(mask_input, has_mask_input)\n",
302 | "\n",
303 | " high_res_feats = [high_res_feats_0, high_res_feats_1]\n",
304 | " image_embed = image_embed\n",
305 | "\n",
306 | " masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(\n",
307 | " image_embeddings=image_embed,\n",
308 | " image_pe=self.prompt_encoder.get_dense_pe(),\n",
309 | " sparse_prompt_embeddings=sparse_embedding,\n",
310 | " dense_prompt_embeddings=dense_embedding,\n",
311 | " repeat_image=False,\n",
312 | " high_res_features=high_res_feats,\n",
313 | " )\n",
314 | "\n",
315 | " if self.multimask_output:\n",
316 | " masks = masks[:, 1:, :, :]\n",
317 | " iou_predictions = iou_predictions[:, 1:]\n",
318 | " else:\n",
319 | " masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(masks, iou_predictions)\n",
320 | "\n",
321 | " masks = torch.clamp(masks, -32.0, 32.0)\n",
322 | " masks = masks > 0.0\n",
323 | " masks = masks.to(torch.float32)\n",
324 | " masks = masks * 255.0\n",
325 | "\n",
326 | " masks = F.interpolate(masks, (img_size[0], img_size[1]), mode=\"bilinear\", align_corners=False)\n",
327 | "\n",
328 | " return masks, iou_predictions\n",
329 | "\n",
330 | " def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:\n",
331 | "\n",
332 | " point_coords = point_coords + 0.5\n",
333 | "\n",
334 | " padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)\n",
335 | " padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)\n",
336 | " point_coords = torch.cat([point_coords, padding_point], dim=1)\n",
337 | " point_labels = torch.cat([point_labels, padding_label], dim=1)\n",
338 | "\n",
339 | " point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size\n",
340 | " point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size\n",
341 | "\n",
342 | " point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)\n",
343 | " point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)\n",
344 | "\n",
345 | " point_embedding = point_embedding * (point_labels != -1)\n",
346 | " point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (\n",
347 | " point_labels == -1\n",
348 | " )\n",
349 | "\n",
350 | " for i in range(self.prompt_encoder.num_point_embeddings):\n",
351 | " point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)\n",
352 | "\n",
353 | " return point_embedding\n",
354 | "\n",
355 | " def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:\n",
356 | " mask_embedding = has_mask_input * self.prompt_encoder.mask_downscaling(input_mask)\n",
357 | " mask_embedding = mask_embedding + (\n",
358 | " 1 - has_mask_input\n",
359 | " ) * self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)\n",
360 | " return mask_embedding"
361 | ],
362 | "metadata": {
363 | "id": "xQAznoeDIyak",
364 | "colab": {
365 | "base_uri": "https://localhost:8080/"
366 | },
367 | "outputId": "b897011f-28f9-48fc-d341-b0dff6c12d82"
368 | },
369 | "execution_count": null,
370 | "outputs": [
371 | {
372 | "output_type": "stream",
373 | "name": "stdout",
374 | "text": [
375 | "/content/segment-anything-2\n"
376 | ]
377 | }
378 | ]
379 | },
380 | {
381 | "cell_type": "markdown",
382 | "source": [
383 | "## Select model parameters"
384 | ],
385 | "metadata": {
386 | "id": "DGJ8EjKE7zd2"
387 | }
388 | },
389 | {
390 | "cell_type": "code",
391 | "source": [
392 | "model_type = 'sam2_hiera_base_plus' #@param [\"sam2_hiera_tiny\", \"sam2_hiera_small\", \"sam2_hiera_large\", \"sam2_hiera_base_plus\"]\n",
393 | "# input_size = 768 #@param {type:\"slider\", min:160, max:4102, step:8}\n",
394 | "input_size = 1024 # Bad output if anything else (for now)\n",
395 | "multimask_output = False\n",
396 | "\n",
397 | "if model_type == \"sam2_hiera_tiny\":\n",
398 | " model_cfg = \"sam2_hiera_t.yaml\"\n",
399 | "elif model_type == \"sam2_hiera_small\":\n",
400 | " model_cfg = \"sam2_hiera_s.yaml\"\n",
401 | "elif model_type == \"sam2_hiera_base_plus\":\n",
402 | " model_cfg = \"sam2_hiera_b+.yaml\"\n",
403 | "else:\n",
404 | " model_cfg = \"sam2_hiera_l.yaml\"\n"
405 | ],
406 | "metadata": {
407 | "id": "K-Ll5Iwh7428"
408 | },
409 | "execution_count": null,
410 | "outputs": []
411 | },
412 | {
413 | "cell_type": "markdown",
414 | "source": [
415 | "## Export Encoder"
416 | ],
417 | "metadata": {
418 | "id": "t46_lYeIy0ta"
419 | }
420 | },
421 | {
422 | "cell_type": "code",
423 | "source": [
424 | "\n",
425 | "%cd /content/segment-anything-2/\n",
426 | "import torch\n",
427 | "from sam2.build_sam import build_sam2\n",
428 | "\n",
429 | "sam2_checkpoint = f\"checkpoints/{model_type}.pt\"\n",
430 | "\n",
431 | "sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=\"cpu\")\n",
432 | "\n",
433 | "img=torch.randn(1, 3, input_size, input_size).cpu()\n",
434 | "\n",
435 | "sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()\n",
436 | "high_res_feats_0, high_res_feats_1, image_embed = sam2_encoder(img)\n",
437 | "print(high_res_feats_0.shape)\n",
438 | "print(high_res_feats_1.shape)\n",
439 | "print(image_embed.shape)\n",
440 | "\n",
441 | "torch.onnx.export(sam2_encoder,\n",
442 | " img,\n",
443 | " f\"{model_type}_encoder.onnx\",\n",
444 | " export_params=True,\n",
445 | " opset_version=17,\n",
446 | " do_constant_folding=True,\n",
447 | " input_names = ['image'],\n",
448 | " output_names = ['high_res_feats_0', 'high_res_feats_1', 'image_embed']\n",
449 | " )"
450 | ],
451 | "metadata": {
452 | "id": "IgHx4lbupej-",
453 | "colab": {
454 | "base_uri": "https://localhost:8080/"
455 | },
456 | "outputId": "9b44ead8-e488-4b4e-dd88-4ace505beda9"
457 | },
458 | "execution_count": null,
459 | "outputs": [
460 | {
461 | "output_type": "stream",
462 | "name": "stdout",
463 | "text": [
464 | "/content/segment-anything-2\n",
465 | "torch.Size([1, 32, 256, 256])\n",
466 | "torch.Size([1, 64, 128, 128])\n",
467 | "torch.Size([1, 256, 64, 64])\n"
468 | ]
469 | },
470 | {
471 | "output_type": "stream",
472 | "name": "stderr",
473 | "text": [
474 | "/content/segment-anything-2/sam2/modeling/backbones/utils.py:30: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
475 | " if pad_h > 0 or pad_w > 0:\n",
476 | "/content/segment-anything-2/sam2/modeling/backbones/utils.py:60: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
477 | " if Hp > H or Wp > W:\n"
478 | ]
479 | }
480 | ]
481 | },
482 | {
483 | "cell_type": "markdown",
484 | "source": [
485 | "## Export Decoder"
486 | ],
487 | "metadata": {
488 | "id": "JX1N64Y6y2-c"
489 | }
490 | },
491 | {
492 | "cell_type": "code",
493 | "source": [
494 | "%cd /content/segment-anything-2/\n",
495 | "\n",
496 | "\n",
497 | "sam2_decoder = SAM2ImageDecoder(sam2_model, multimask_output=multimask_output).cpu()\n",
498 | "\n",
499 | "embed_dim = sam2_model.sam_prompt_encoder.embed_dim\n",
500 | "embed_size = (sam2_model.image_size // sam2_model.backbone_stride, sam2_model.image_size // sam2_model.backbone_stride)\n",
501 | "mask_input_size = [4 * x for x in embed_size]\n",
502 | "print(embed_dim, embed_size, mask_input_size)\n",
503 | "\n",
504 | "point_coords = torch.randint(low=0, high=input_size, size=(1, 5, 2), dtype=torch.float)\n",
505 | "point_labels = torch.randint(low=0, high=1, size=(1, 5), dtype=torch.float)\n",
506 | "mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float)\n",
507 | "has_mask_input = torch.tensor([1], dtype=torch.float)\n",
508 | "orig_im_size = torch.tensor([input_size, input_size], dtype=torch.int32)\n",
509 | "\n",
510 | "masks, scores = sam2_decoder(image_embed, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input, orig_im_size)\n",
511 | "\n",
512 | "\n",
513 | "torch.onnx.export(sam2_decoder,\n",
514 | " (image_embed, high_res_feats_0, high_res_feats_1, point_coords, point_labels, mask_input, has_mask_input, orig_im_size),\n",
515 | " f\"{model_type}_decoder.onnx\",\n",
516 | " export_params=True,\n",
517 | " opset_version=16,\n",
518 | " do_constant_folding=True,\n",
519 | " input_names = ['image_embed', 'high_res_feats_0', 'high_res_feats_1', 'point_coords', 'point_labels', 'mask_input', 'has_mask_input', 'orig_im_size'],\n",
520 | " output_names = ['masks', 'iou_predictions'],\n",
521 | " dynamic_axes = {\"point_coords\": {0: \"num_labels\", 1: \"num_points\"},\n",
522 | " \"point_labels\": {0: \"num_labels\", 1: \"num_points\"},\n",
523 | " \"mask_input\": {0: \"num_labels\"},\n",
524 | " \"has_mask_input\": {0: \"num_labels\"}\n",
525 | " }\n",
526 | " )\n"
527 | ],
528 | "metadata": {
529 | "id": "KKqrn0sHaQYu",
530 | "colab": {
531 | "base_uri": "https://localhost:8080/"
532 | },
533 | "outputId": "ed6e613d-a9d8-473a-b551-cabc40441082"
534 | },
535 | "execution_count": null,
536 | "outputs": [
537 | {
538 | "output_type": "stream",
539 | "name": "stdout",
540 | "text": [
541 | "/content/segment-anything-2\n",
542 | "256 (64, 64) [256, 256]\n"
543 | ]
544 | },
545 | {
546 | "output_type": "stream",
547 | "name": "stderr",
548 | "text": [
549 | "/content/segment-anything-2/sam2/modeling/sam/mask_decoder.py:203: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
550 | " assert image_embeddings.shape[0] == tokens.shape[0]\n",
551 | "/content/segment-anything-2/sam2/modeling/sam/mask_decoder.py:207: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
552 | " image_pe.size(0) == 1\n",
553 | "/usr/local/lib/python3.10/dist-packages/torch/onnx/symbolic_opset9.py:5858: UserWarning: Exporting aten::index operator of advanced indexing in opset 16 is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.\n",
554 | " warnings.warn(\n"
555 | ]
556 | }
557 | ]
558 | },
559 | {
560 | "cell_type": "markdown",
561 | "source": [
562 | "## Simplify models"
563 | ],
564 | "metadata": {
565 | "id": "dMgwgcWO18hY"
566 | }
567 | },
568 | {
569 | "cell_type": "code",
570 | "source": [
571 | "%cd /content/segment-anything-2/\n",
572 | "!onnxsim {model_type}_encoder.onnx {model_type}_encoder.onnx\n",
573 | "!onnxsim {model_type}_decoder.onnx {model_type}_decoder.onnx"
574 | ],
575 | "metadata": {
576 | "id": "w4nMB2XD1-gx"
577 | },
578 | "execution_count": null,
579 | "outputs": []
580 | },
581 | {
582 | "cell_type": "markdown",
583 | "source": [
584 | "## Optional, mount GDrive for faster model download (Copy it to your Google Drive and then download)"
585 | ],
586 | "metadata": {
587 | "id": "JxyJ9H5xjoTT"
588 | }
589 | },
590 | {
591 | "cell_type": "code",
592 | "source": [
593 | "from google.colab import drive\n",
594 | "drive.mount('/content/gdrive',force_remount=True)"
595 | ],
596 | "metadata": {
597 | "id": "J6ameAOEjm9w",
598 | "colab": {
599 | "base_uri": "https://localhost:8080/"
600 | },
601 | "outputId": "0156245f-66eb-4da1-fc9b-91ebdade489c"
602 | },
603 | "execution_count": null,
604 | "outputs": [
605 | {
606 | "output_type": "stream",
607 | "name": "stdout",
608 | "text": [
609 | "Mounted at /content/gdrive\n"
610 | ]
611 | }
612 | ]
613 | },
614 | {
615 | "cell_type": "code",
616 | "source": [
617 | "%cd /content/segment-anything-2/\n",
618 | "!cp {model_type}_encoder.onnx '/content/gdrive/My Drive/'\n",
619 | "!cp {model_type}_decoder.onnx '/content/gdrive/My Drive/'"
620 | ],
621 | "metadata": {
622 | "id": "ZyBDSz5RjAb2",
623 | "colab": {
624 | "base_uri": "https://localhost:8080/"
625 | },
626 | "outputId": "ece91850-d66c-4423-e73f-7836d31f5891"
627 | },
628 | "execution_count": null,
629 | "outputs": [
630 | {
631 | "output_type": "stream",
632 | "name": "stdout",
633 | "text": [
634 | "/content/segment-anything-2\n"
635 | ]
636 | }
637 | ]
638 | },
639 | {
640 | "cell_type": "code",
641 | "source": [
642 | "!pip install onnxruntime"
643 | ],
644 | "metadata": {
645 | "colab": {
646 | "base_uri": "https://localhost:8080/"
647 | },
648 | "id": "t2ZM_igGlnEe",
649 | "outputId": "d9828273-dc90-4b66-c1d7-a39a0f71868b"
650 | },
651 | "execution_count": null,
652 | "outputs": [
653 | {
654 | "output_type": "stream",
655 | "name": "stdout",
656 | "text": [
657 | "Requirement already satisfied: onnxruntime in /usr/local/lib/python3.10/dist-packages (1.19.0)\n",
658 | "Requirement already satisfied: coloredlogs in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (15.0.1)\n",
659 | "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (24.3.25)\n",
660 | "Requirement already satisfied: numpy>=1.21.6 in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (1.26.4)\n",
661 | "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (24.1)\n",
662 | "Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (3.20.3)\n",
663 | "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (1.13.1)\n",
664 | "Requirement already satisfied: humanfriendly>=9.1 in /usr/local/lib/python3.10/dist-packages (from coloredlogs->onnxruntime) (10.0)\n",
665 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->onnxruntime) (1.3.0)\n"
666 | ]
667 | }
668 | ]
669 | },
670 | {
671 | "cell_type": "code",
672 | "source": [
673 | "import onnxruntime as ort\n",
674 | "\n",
675 | "session = ort.InferenceSession(\"/content/segment-anything-2/sam2_hiera_base_plus_decoder.onnx\" )\n",
676 | "print( [ t.shape for t in session.get_inputs() ] )\n",
677 | "print( [ t.type for t in session.get_inputs() ] )\n",
678 | "print( [ t.name for t in session.get_inputs() ] )\n",
679 | "print( [ t.shape for t in session.get_outputs() ] )\n",
680 | "print( [ t.type for t in session.get_outputs() ] )\n",
681 | "print( [ t.name for t in session.get_outputs() ] )\n",
682 | "print('\\n\\n\\n')"
683 | ],
684 | "metadata": {
685 | "colab": {
686 | "base_uri": "https://localhost:8080/"
687 | },
688 | "id": "nwpqCttDlsOh",
689 | "outputId": "4eb99fe4-dcae-43be-8fe5-03afb55f1b44"
690 | },
691 | "execution_count": null,
692 | "outputs": [
693 | {
694 | "output_type": "stream",
695 | "name": "stdout",
696 | "text": [
697 | "[[1, 256, 64, 64], [1, 32, 256, 256], [1, 64, 128, 128], ['num_labels', 'num_points', 2], ['num_labels', 'num_points'], ['num_labels', 1, 256, 256], ['num_labels'], [2]]\n",
698 | "['tensor(float)', 'tensor(float)', 'tensor(float)', 'tensor(float)', 'tensor(float)', 'tensor(float)', 'tensor(float)', 'tensor(int32)']\n",
699 | "['image_embed', 'high_res_feats_0', 'high_res_feats_1', 'point_coords', 'point_labels', 'mask_input', 'has_mask_input', 'orig_im_size']\n",
700 | "[['Resizemasks_dim_0', 'Resizemasks_dim_1', 'Resizemasks_dim_2', 'Resizemasks_dim_3'], ['Resizemasks_dim_0', 'Whereiou_predictions_dim_1']]\n",
701 | "['tensor(float)', 'tensor(float)']\n",
702 | "['masks', 'iou_predictions']\n",
703 | "\n",
704 | "\n",
705 | "\n",
706 | "\n"
707 | ]
708 | }
709 | ]
710 | },
711 | {
712 | "cell_type": "code",
713 | "source": [
714 | "from google.colab import files\n",
715 | "files.download('/content/segment-anything-2/sam2_hiera_base_plus_decoder.onnx')"
716 | ],
717 | "metadata": {
718 | "colab": {
719 | "base_uri": "https://localhost:8080/",
720 | "height": 34
721 | },
722 | "id": "D1VJtWbI6Svd",
723 | "outputId": "da601df8-5eee-4712-ad36-b52677838ab7"
724 | },
725 | "execution_count": null,
726 | "outputs": [
727 | {
728 | "output_type": "display_data",
729 | "data": {
730 | "text/plain": [
731 | ""
732 | ],
733 | "application/javascript": [
734 | "\n",
735 | " async function download(id, filename, size) {\n",
736 | " if (!google.colab.kernel.accessAllowed) {\n",
737 | " return;\n",
738 | " }\n",
739 | " const div = document.createElement('div');\n",
740 | " const label = document.createElement('label');\n",
741 | " label.textContent = `Downloading \"${filename}\": `;\n",
742 | " div.appendChild(label);\n",
743 | " const progress = document.createElement('progress');\n",
744 | " progress.max = size;\n",
745 | " div.appendChild(progress);\n",
746 | " document.body.appendChild(div);\n",
747 | "\n",
748 | " const buffers = [];\n",
749 | " let downloaded = 0;\n",
750 | "\n",
751 | " const channel = await google.colab.kernel.comms.open(id);\n",
752 | " // Send a message to notify the kernel that we're ready.\n",
753 | " channel.send({})\n",
754 | "\n",
755 | " for await (const message of channel.messages) {\n",
756 | " // Send a message to notify the kernel that we're ready.\n",
757 | " channel.send({})\n",
758 | " if (message.buffers) {\n",
759 | " for (const buffer of message.buffers) {\n",
760 | " buffers.push(buffer);\n",
761 | " downloaded += buffer.byteLength;\n",
762 | " progress.value = downloaded;\n",
763 | " }\n",
764 | " }\n",
765 | " }\n",
766 | " const blob = new Blob(buffers, {type: 'application/binary'});\n",
767 | " const a = document.createElement('a');\n",
768 | " a.href = window.URL.createObjectURL(blob);\n",
769 | " a.download = filename;\n",
770 | " div.appendChild(a);\n",
771 | " a.click();\n",
772 | " div.remove();\n",
773 | " }\n",
774 | " "
775 | ]
776 | },
777 | "metadata": {}
778 | },
779 | {
780 | "output_type": "display_data",
781 | "data": {
782 | "text/plain": [
783 | ""
784 | ],
785 | "application/javascript": [
786 | "download(\"download_269f8a05-a455-409f-822c-ea99f536a38f\", \"sam2_hiera_base_plus_decoder.onnx\", 16541001)"
787 | ]
788 | },
789 | "metadata": {}
790 | }
791 | ]
792 | }
793 | ]
794 | }
--------------------------------------------------------------------------------
/settings.gradle.kts:
--------------------------------------------------------------------------------
1 | pluginManagement {
2 | repositories {
3 | google {
4 | content {
5 | includeGroupByRegex("com\\.android.*")
6 | includeGroupByRegex("com\\.google.*")
7 | includeGroupByRegex("androidx.*")
8 | }
9 | }
10 | mavenCentral()
11 | gradlePluginPortal()
12 | }
13 | }
14 | dependencyResolutionManagement {
15 | repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
16 | repositories {
17 | google()
18 | mavenCentral()
19 | }
20 | }
21 |
22 | rootProject.name = "SAM-Android"
23 | include(":app")
24 |
--------------------------------------------------------------------------------