├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── SECURITY.md
├── environment.yaml
├── figures
├── Comparison.png
└── Improvement_vs_FPS.png
├── hubconf.py
├── input
└── .placeholder
├── midas
├── backbones
│ ├── beit.py
│ ├── levit.py
│ ├── next_vit.py
│ ├── swin.py
│ ├── swin2.py
│ ├── swin_common.py
│ ├── utils.py
│ └── vit.py
├── base_model.py
├── blocks.py
├── dpt_depth.py
├── midas_net.py
├── midas_net_custom.py
├── model_loader.py
└── transforms.py
├── mobile
├── README.md
├── android
│ ├── .gitignore
│ ├── EXPLORE_THE_CODE.md
│ ├── LICENSE
│ ├── README.md
│ ├── app
│ │ ├── .gitignore
│ │ ├── build.gradle
│ │ ├── proguard-rules.pro
│ │ └── src
│ │ │ ├── androidTest
│ │ │ ├── assets
│ │ │ │ ├── fox-mobilenet_v1_1.0_224_support.txt
│ │ │ │ └── fox-mobilenet_v1_1.0_224_task_api.txt
│ │ │ └── java
│ │ │ │ ├── AndroidManifest.xml
│ │ │ │ └── org
│ │ │ │ └── tensorflow
│ │ │ │ └── lite
│ │ │ │ └── examples
│ │ │ │ └── classification
│ │ │ │ └── ClassifierTest.java
│ │ │ └── main
│ │ │ ├── AndroidManifest.xml
│ │ │ ├── java
│ │ │ └── org
│ │ │ │ └── tensorflow
│ │ │ │ └── lite
│ │ │ │ └── examples
│ │ │ │ └── classification
│ │ │ │ ├── CameraActivity.java
│ │ │ │ ├── CameraConnectionFragment.java
│ │ │ │ ├── ClassifierActivity.java
│ │ │ │ ├── LegacyCameraConnectionFragment.java
│ │ │ │ ├── customview
│ │ │ │ ├── AutoFitTextureView.java
│ │ │ │ ├── OverlayView.java
│ │ │ │ ├── RecognitionScoreView.java
│ │ │ │ └── ResultsView.java
│ │ │ │ └── env
│ │ │ │ ├── BorderedText.java
│ │ │ │ ├── ImageUtils.java
│ │ │ │ └── Logger.java
│ │ │ └── res
│ │ │ ├── drawable-hdpi
│ │ │ └── ic_launcher.png
│ │ │ ├── drawable-mdpi
│ │ │ └── ic_launcher.png
│ │ │ ├── drawable-v24
│ │ │ └── ic_launcher_foreground.xml
│ │ │ ├── drawable-xxhdpi
│ │ │ ├── ic_launcher.png
│ │ │ ├── icn_chevron_down.png
│ │ │ ├── icn_chevron_up.png
│ │ │ ├── tfl2_logo.png
│ │ │ └── tfl2_logo_dark.png
│ │ │ ├── drawable
│ │ │ ├── bottom_sheet_bg.xml
│ │ │ ├── ic_baseline_add.xml
│ │ │ ├── ic_baseline_remove.xml
│ │ │ ├── ic_launcher_background.xml
│ │ │ └── rectangle.xml
│ │ │ ├── layout
│ │ │ ├── tfe_ic_activity_camera.xml
│ │ │ ├── tfe_ic_camera_connection_fragment.xml
│ │ │ └── tfe_ic_layout_bottom_sheet.xml
│ │ │ ├── mipmap-anydpi-v26
│ │ │ ├── ic_launcher.xml
│ │ │ └── ic_launcher_round.xml
│ │ │ ├── mipmap-hdpi
│ │ │ ├── ic_launcher.png
│ │ │ ├── ic_launcher_foreground.png
│ │ │ └── ic_launcher_round.png
│ │ │ ├── mipmap-mdpi
│ │ │ ├── ic_launcher.png
│ │ │ ├── ic_launcher_foreground.png
│ │ │ └── ic_launcher_round.png
│ │ │ ├── mipmap-xhdpi
│ │ │ ├── ic_launcher.png
│ │ │ ├── ic_launcher_foreground.png
│ │ │ └── ic_launcher_round.png
│ │ │ ├── mipmap-xxhdpi
│ │ │ ├── ic_launcher.png
│ │ │ ├── ic_launcher_foreground.png
│ │ │ └── ic_launcher_round.png
│ │ │ ├── mipmap-xxxhdpi
│ │ │ ├── ic_launcher.png
│ │ │ ├── ic_launcher_foreground.png
│ │ │ └── ic_launcher_round.png
│ │ │ └── values
│ │ │ ├── colors.xml
│ │ │ ├── dimens.xml
│ │ │ ├── strings.xml
│ │ │ └── styles.xml
│ ├── build.gradle
│ ├── gradle.properties
│ ├── gradle
│ │ └── wrapper
│ │ │ ├── gradle-wrapper.jar
│ │ │ └── gradle-wrapper.properties
│ ├── gradlew
│ ├── gradlew.bat
│ ├── lib_support
│ │ ├── build.gradle
│ │ ├── proguard-rules.pro
│ │ └── src
│ │ │ └── main
│ │ │ ├── AndroidManifest.xml
│ │ │ └── java
│ │ │ └── org
│ │ │ └── tensorflow
│ │ │ └── lite
│ │ │ └── examples
│ │ │ └── classification
│ │ │ └── tflite
│ │ │ ├── Classifier.java
│ │ │ ├── ClassifierFloatEfficientNet.java
│ │ │ ├── ClassifierFloatMobileNet.java
│ │ │ ├── ClassifierQuantizedEfficientNet.java
│ │ │ └── ClassifierQuantizedMobileNet.java
│ ├── lib_task_api
│ │ ├── build.gradle
│ │ ├── proguard-rules.pro
│ │ └── src
│ │ │ └── main
│ │ │ ├── AndroidManifest.xml
│ │ │ └── java
│ │ │ └── org
│ │ │ └── tensorflow
│ │ │ └── lite
│ │ │ └── examples
│ │ │ └── classification
│ │ │ └── tflite
│ │ │ ├── Classifier.java
│ │ │ ├── ClassifierFloatEfficientNet.java
│ │ │ ├── ClassifierFloatMobileNet.java
│ │ │ ├── ClassifierQuantizedEfficientNet.java
│ │ │ └── ClassifierQuantizedMobileNet.java
│ ├── models
│ │ ├── build.gradle
│ │ ├── download.gradle
│ │ ├── proguard-rules.pro
│ │ └── src
│ │ │ └── main
│ │ │ ├── AndroidManifest.xml
│ │ │ └── assets
│ │ │ ├── labels.txt
│ │ │ ├── labels_without_background.txt
│ │ │ └── run_tflite.py
│ └── settings.gradle
└── ios
│ ├── .gitignore
│ ├── LICENSE
│ ├── Midas.xcodeproj
│ ├── project.pbxproj
│ ├── project.xcworkspace
│ │ ├── contents.xcworkspacedata
│ │ ├── xcshareddata
│ │ │ └── IDEWorkspaceChecks.plist
│ │ └── xcuserdata
│ │ │ └── admin.xcuserdatad
│ │ │ └── UserInterfaceState.xcuserstate
│ └── xcuserdata
│ │ └── admin.xcuserdatad
│ │ └── xcschemes
│ │ └── xcschememanagement.plist
│ ├── Midas
│ ├── AppDelegate.swift
│ ├── Assets.xcassets
│ │ ├── AppIcon.appiconset
│ │ │ ├── 100.png
│ │ │ ├── 1024.png
│ │ │ ├── 114.png
│ │ │ ├── 120.png
│ │ │ ├── 144.png
│ │ │ ├── 152.png
│ │ │ ├── 167.png
│ │ │ ├── 180.png
│ │ │ ├── 20.png
│ │ │ ├── 29.png
│ │ │ ├── 40.png
│ │ │ ├── 50.png
│ │ │ ├── 57.png
│ │ │ ├── 58.png
│ │ │ ├── 60.png
│ │ │ ├── 72.png
│ │ │ ├── 76.png
│ │ │ ├── 80.png
│ │ │ ├── 87.png
│ │ │ └── Contents.json
│ │ ├── Contents.json
│ │ └── tfl_logo.png
│ ├── Camera Feed
│ │ ├── CameraFeedManager.swift
│ │ └── PreviewView.swift
│ ├── Cells
│ │ └── InfoCell.swift
│ ├── Constants.swift
│ ├── Extensions
│ │ ├── CGSizeExtension.swift
│ │ ├── CVPixelBufferExtension.swift
│ │ └── TFLiteExtension.swift
│ ├── Info.plist
│ ├── ModelDataHandler
│ │ └── ModelDataHandler.swift
│ ├── Storyboards
│ │ └── Base.lproj
│ │ │ ├── Launch Screen.storyboard
│ │ │ └── Main.storyboard
│ ├── ViewControllers
│ │ └── ViewController.swift
│ └── Views
│ │ └── OverlayView.swift
│ ├── Podfile
│ ├── README.md
│ └── RunScripts
│ └── download_models.sh
├── output
└── .placeholder
├── ros
├── LICENSE
├── README.md
├── additions
│ ├── do_catkin_make.sh
│ ├── downloads.sh
│ ├── install_ros_melodic_ubuntu_17_18.sh
│ ├── install_ros_noetic_ubuntu_20.sh
│ └── make_package_cpp.sh
├── launch_midas_cpp.sh
├── midas_cpp
│ ├── CMakeLists.txt
│ ├── launch
│ │ ├── midas_cpp.launch
│ │ └── midas_talker_listener.launch
│ ├── package.xml
│ ├── scripts
│ │ ├── listener.py
│ │ ├── listener_original.py
│ │ └── talker.py
│ └── src
│ │ └── main.cpp
└── run_talker_listener_test.sh
├── run.py
├── tf
├── README.md
├── input
│ └── .placeholder
├── make_onnx_model.py
├── output
│ └── .placeholder
├── run_onnx.py
├── run_pb.py
├── transforms.py
└── utils.py
├── utils.py
└── weights
└── .placeholder
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | *.png
107 | *.pfm
108 | *.jpg
109 | *.jpeg
110 | *.pt
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # enables cuda support in docker
2 | FROM nvidia/cuda:10.2-cudnn7-runtime-ubuntu18.04
3 |
4 | # install python 3.6, pip and requirements for opencv-python
5 | # (see https://github.com/NVIDIA/nvidia-docker/issues/864)
6 | RUN apt-get update && apt-get -y install \
7 | python3 \
8 | python3-pip \
9 | libsm6 \
10 | libxext6 \
11 | libxrender-dev \
12 | curl \
13 | && rm -rf /var/lib/apt/lists/*
14 |
15 | # install python dependencies
16 | RUN pip3 install --upgrade pip
17 | RUN pip3 install torch~=1.8 torchvision opencv-python-headless~=3.4 timm
18 |
19 | # copy inference code
20 | WORKDIR /opt/MiDaS
21 | COPY ./midas ./midas
22 | COPY ./*.py ./
23 |
24 | # download model weights so the docker image can be used offline
25 | RUN cd weights && {curl -OL https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt; cd -; }
26 | RUN python3 run.py --model_type dpt_hybrid; exit 0
27 |
28 | # entrypoint (dont forget to mount input and output directories)
29 | CMD python3 run.py --model_type dpt_hybrid
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Security Policy
2 | Intel is committed to rapidly addressing security vulnerabilities affecting our customers and providing clear guidance on the solution, impact, severity and mitigation.
3 |
4 | ## Reporting a Vulnerability
5 | Please report any security vulnerabilities in this project utilizing the guidelines [here](https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html).
6 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: midas-py310
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - nvidia::cudatoolkit=11.7
7 | - python=3.10.8
8 | - pytorch::pytorch=1.13.0
9 | - torchvision=0.14.0
10 | - pip=22.3.1
11 | - numpy=1.23.4
12 | - pip:
13 | - opencv-python==4.6.0.66
14 | - imutils==0.5.4
15 | - timm==0.6.12
16 | - einops==0.6.0
--------------------------------------------------------------------------------
/figures/Comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/figures/Comparison.png
--------------------------------------------------------------------------------
/figures/Improvement_vs_FPS.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/figures/Improvement_vs_FPS.png
--------------------------------------------------------------------------------
/input/.placeholder:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/input/.placeholder
--------------------------------------------------------------------------------
/midas/backbones/levit.py:
--------------------------------------------------------------------------------
1 | import timm
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 | from .utils import activations, get_activation, Transpose
7 |
8 |
9 | def forward_levit(pretrained, x):
10 | pretrained.model.forward_features(x)
11 |
12 | layer_1 = pretrained.activations["1"]
13 | layer_2 = pretrained.activations["2"]
14 | layer_3 = pretrained.activations["3"]
15 |
16 | layer_1 = pretrained.act_postprocess1(layer_1)
17 | layer_2 = pretrained.act_postprocess2(layer_2)
18 | layer_3 = pretrained.act_postprocess3(layer_3)
19 |
20 | return layer_1, layer_2, layer_3
21 |
22 |
23 | def _make_levit_backbone(
24 | model,
25 | hooks=[3, 11, 21],
26 | patch_grid=[14, 14]
27 | ):
28 | pretrained = nn.Module()
29 |
30 | pretrained.model = model
31 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
32 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
33 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
34 |
35 | pretrained.activations = activations
36 |
37 | patch_grid_size = np.array(patch_grid, dtype=int)
38 |
39 | pretrained.act_postprocess1 = nn.Sequential(
40 | Transpose(1, 2),
41 | nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
42 | )
43 | pretrained.act_postprocess2 = nn.Sequential(
44 | Transpose(1, 2),
45 | nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist()))
46 | )
47 | pretrained.act_postprocess3 = nn.Sequential(
48 | Transpose(1, 2),
49 | nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist()))
50 | )
51 |
52 | return pretrained
53 |
54 |
55 | class ConvTransposeNorm(nn.Sequential):
56 | """
57 | Modification of
58 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm
59 | such that ConvTranspose2d is used instead of Conv2d.
60 | """
61 |
62 | def __init__(
63 | self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1,
64 | groups=1, bn_weight_init=1):
65 | super().__init__()
66 | self.add_module('c',
67 | nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False))
68 | self.add_module('bn', nn.BatchNorm2d(out_chs))
69 |
70 | nn.init.constant_(self.bn.weight, bn_weight_init)
71 |
72 | @torch.no_grad()
73 | def fuse(self):
74 | c, bn = self._modules.values()
75 | w = bn.weight / (bn.running_var + bn.eps) ** 0.5
76 | w = c.weight * w[:, None, None, None]
77 | b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
78 | m = nn.ConvTranspose2d(
79 | w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
80 | padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
81 | m.weight.data.copy_(w)
82 | m.bias.data.copy_(b)
83 | return m
84 |
85 |
86 | def stem_b4_transpose(in_chs, out_chs, activation):
87 | """
88 | Modification of
89 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16
90 | such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half.
91 | """
92 | return nn.Sequential(
93 | ConvTransposeNorm(in_chs, out_chs, 3, 2, 1),
94 | activation(),
95 | ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1),
96 | activation())
97 |
98 |
99 | def _make_pretrained_levit_384(pretrained, hooks=None):
100 | model = timm.create_model("levit_384", pretrained=pretrained)
101 |
102 | hooks = [3, 11, 21] if hooks == None else hooks
103 | return _make_levit_backbone(
104 | model,
105 | hooks=hooks
106 | )
107 |
--------------------------------------------------------------------------------
/midas/backbones/next_vit.py:
--------------------------------------------------------------------------------
1 | import timm
2 |
3 | import torch.nn as nn
4 |
5 | from pathlib import Path
6 | from .utils import activations, forward_default, get_activation
7 |
8 | from ..external.next_vit.classification.nextvit import *
9 |
10 |
11 | def forward_next_vit(pretrained, x):
12 | return forward_default(pretrained, x, "forward")
13 |
14 |
15 | def _make_next_vit_backbone(
16 | model,
17 | hooks=[2, 6, 36, 39],
18 | ):
19 | pretrained = nn.Module()
20 |
21 | pretrained.model = model
22 | pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1"))
23 | pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2"))
24 | pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3"))
25 | pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4"))
26 |
27 | pretrained.activations = activations
28 |
29 | return pretrained
30 |
31 |
32 | def _make_pretrained_next_vit_large_6m(hooks=None):
33 | model = timm.create_model("nextvit_large")
34 |
35 | hooks = [2, 6, 36, 39] if hooks == None else hooks
36 | return _make_next_vit_backbone(
37 | model,
38 | hooks=hooks,
39 | )
40 |
--------------------------------------------------------------------------------
/midas/backbones/swin.py:
--------------------------------------------------------------------------------
1 | import timm
2 |
3 | from .swin_common import _make_swin_backbone
4 |
5 |
6 | def _make_pretrained_swinl12_384(pretrained, hooks=None):
7 | model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained)
8 |
9 | hooks = [1, 1, 17, 1] if hooks == None else hooks
10 | return _make_swin_backbone(
11 | model,
12 | hooks=hooks
13 | )
14 |
--------------------------------------------------------------------------------
/midas/backbones/swin2.py:
--------------------------------------------------------------------------------
1 | import timm
2 |
3 | from .swin_common import _make_swin_backbone
4 |
5 |
6 | def _make_pretrained_swin2l24_384(pretrained, hooks=None):
7 | model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained)
8 |
9 | hooks = [1, 1, 17, 1] if hooks == None else hooks
10 | return _make_swin_backbone(
11 | model,
12 | hooks=hooks
13 | )
14 |
15 |
16 | def _make_pretrained_swin2b24_384(pretrained, hooks=None):
17 | model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained)
18 |
19 | hooks = [1, 1, 17, 1] if hooks == None else hooks
20 | return _make_swin_backbone(
21 | model,
22 | hooks=hooks
23 | )
24 |
25 |
26 | def _make_pretrained_swin2t16_256(pretrained, hooks=None):
27 | model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained)
28 |
29 | hooks = [1, 1, 5, 1] if hooks == None else hooks
30 | return _make_swin_backbone(
31 | model,
32 | hooks=hooks,
33 | patch_grid=[64, 64]
34 | )
35 |
--------------------------------------------------------------------------------
/midas/backbones/swin_common.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 | from .utils import activations, forward_default, get_activation, Transpose
7 |
8 |
9 | def forward_swin(pretrained, x):
10 | return forward_default(pretrained, x)
11 |
12 |
13 | def _make_swin_backbone(
14 | model,
15 | hooks=[1, 1, 17, 1],
16 | patch_grid=[96, 96]
17 | ):
18 | pretrained = nn.Module()
19 |
20 | pretrained.model = model
21 | pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1"))
22 | pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2"))
23 | pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3"))
24 | pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4"))
25 |
26 | pretrained.activations = activations
27 |
28 | if hasattr(model, "patch_grid"):
29 | used_patch_grid = model.patch_grid
30 | else:
31 | used_patch_grid = patch_grid
32 |
33 | patch_grid_size = np.array(used_patch_grid, dtype=int)
34 |
35 | pretrained.act_postprocess1 = nn.Sequential(
36 | Transpose(1, 2),
37 | nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
38 | )
39 | pretrained.act_postprocess2 = nn.Sequential(
40 | Transpose(1, 2),
41 | nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist()))
42 | )
43 | pretrained.act_postprocess3 = nn.Sequential(
44 | Transpose(1, 2),
45 | nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist()))
46 | )
47 | pretrained.act_postprocess4 = nn.Sequential(
48 | Transpose(1, 2),
49 | nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist()))
50 | )
51 |
52 | return pretrained
53 |
--------------------------------------------------------------------------------
/midas/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device('cpu'))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/midas/dpt_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .base_model import BaseModel
5 | from .blocks import (
6 | FeatureFusionBlock_custom,
7 | Interpolate,
8 | _make_encoder,
9 | forward_beit,
10 | forward_swin,
11 | forward_levit,
12 | forward_vit,
13 | )
14 | from .backbones.levit import stem_b4_transpose
15 | from timm.models.layers import get_act_layer
16 |
17 |
18 | def _make_fusion_block(features, use_bn, size = None):
19 | return FeatureFusionBlock_custom(
20 | features,
21 | nn.ReLU(False),
22 | deconv=False,
23 | bn=use_bn,
24 | expand=False,
25 | align_corners=True,
26 | size=size,
27 | )
28 |
29 |
30 | class DPT(BaseModel):
31 | def __init__(
32 | self,
33 | head,
34 | features=256,
35 | backbone="vitb_rn50_384",
36 | readout="project",
37 | channels_last=False,
38 | use_bn=False,
39 | **kwargs
40 | ):
41 |
42 | super(DPT, self).__init__()
43 |
44 | self.channels_last = channels_last
45 |
46 | # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the
47 | # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments.
48 | hooks = {
49 | "beitl16_512": [5, 11, 17, 23],
50 | "beitl16_384": [5, 11, 17, 23],
51 | "beitb16_384": [2, 5, 8, 11],
52 | "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1]
53 | "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
54 | "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1]
55 | "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
56 | "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39]
57 | "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21]
58 | "vitb_rn50_384": [0, 1, 8, 11],
59 | "vitb16_384": [2, 5, 8, 11],
60 | "vitl16_384": [5, 11, 17, 23],
61 | }[backbone]
62 |
63 | if "next_vit" in backbone:
64 | in_features = {
65 | "next_vit_large_6m": [96, 256, 512, 1024],
66 | }[backbone]
67 | else:
68 | in_features = None
69 |
70 | # Instantiate backbone and reassemble blocks
71 | self.pretrained, self.scratch = _make_encoder(
72 | backbone,
73 | features,
74 | False, # Set to true of you want to train from scratch, uses ImageNet weights
75 | groups=1,
76 | expand=False,
77 | exportable=False,
78 | hooks=hooks,
79 | use_readout=readout,
80 | in_features=in_features,
81 | )
82 |
83 | self.number_layers = len(hooks) if hooks is not None else 4
84 | size_refinenet3 = None
85 | self.scratch.stem_transpose = None
86 |
87 | if "beit" in backbone:
88 | self.forward_transformer = forward_beit
89 | elif "swin" in backbone:
90 | self.forward_transformer = forward_swin
91 | elif "next_vit" in backbone:
92 | from .backbones.next_vit import forward_next_vit
93 | self.forward_transformer = forward_next_vit
94 | elif "levit" in backbone:
95 | self.forward_transformer = forward_levit
96 | size_refinenet3 = 7
97 | self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish"))
98 | else:
99 | self.forward_transformer = forward_vit
100 |
101 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
102 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
103 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3)
104 | if self.number_layers >= 4:
105 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
106 |
107 | self.scratch.output_conv = head
108 |
109 |
110 | def forward(self, x):
111 | if self.channels_last == True:
112 | x.contiguous(memory_format=torch.channels_last)
113 |
114 | layers = self.forward_transformer(self.pretrained, x)
115 | if self.number_layers == 3:
116 | layer_1, layer_2, layer_3 = layers
117 | else:
118 | layer_1, layer_2, layer_3, layer_4 = layers
119 |
120 | layer_1_rn = self.scratch.layer1_rn(layer_1)
121 | layer_2_rn = self.scratch.layer2_rn(layer_2)
122 | layer_3_rn = self.scratch.layer3_rn(layer_3)
123 | if self.number_layers >= 4:
124 | layer_4_rn = self.scratch.layer4_rn(layer_4)
125 |
126 | if self.number_layers == 3:
127 | path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:])
128 | else:
129 | path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
130 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
131 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
132 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
133 |
134 | if self.scratch.stem_transpose is not None:
135 | path_1 = self.scratch.stem_transpose(path_1)
136 |
137 | out = self.scratch.output_conv(path_1)
138 |
139 | return out
140 |
141 |
142 | class DPTDepthModel(DPT):
143 | def __init__(self, path=None, non_negative=True, **kwargs):
144 | features = kwargs["features"] if "features" in kwargs else 256
145 | head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features
146 | head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32
147 | kwargs.pop("head_features_1", None)
148 | kwargs.pop("head_features_2", None)
149 |
150 | head = nn.Sequential(
151 | nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1),
152 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
153 | nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
154 | nn.ReLU(True),
155 | nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
156 | nn.ReLU(True) if non_negative else nn.Identity(),
157 | nn.Identity(),
158 | )
159 |
160 | super().__init__(head, **kwargs)
161 |
162 | if path is not None:
163 | self.load(path)
164 |
165 | def forward(self, x):
166 | return super().forward(x).squeeze(dim=1)
167 |
--------------------------------------------------------------------------------
/midas/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=256, non_negative=True):
17 | """Init.
18 |
19 | Args:
20 | path (str, optional): Path to saved model. Defaults to None.
21 | features (int, optional): Number of features. Defaults to 256.
22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 | """
24 | print("Loading weights: ", path)
25 |
26 | super(MidasNet, self).__init__()
27 |
28 | use_pretrained = False if path is None else True
29 |
30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31 |
32 | self.scratch.refinenet4 = FeatureFusionBlock(features)
33 | self.scratch.refinenet3 = FeatureFusionBlock(features)
34 | self.scratch.refinenet2 = FeatureFusionBlock(features)
35 | self.scratch.refinenet1 = FeatureFusionBlock(features)
36 |
37 | self.scratch.output_conv = nn.Sequential(
38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 | Interpolate(scale_factor=2, mode="bilinear"),
40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 | nn.ReLU(True),
42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 | nn.ReLU(True) if non_negative else nn.Identity(),
44 | )
45 |
46 | if path:
47 | self.load(path)
48 |
49 | def forward(self, x):
50 | """Forward pass.
51 |
52 | Args:
53 | x (tensor): input data (image)
54 |
55 | Returns:
56 | tensor: depth
57 | """
58 |
59 | layer_1 = self.pretrained.layer1(x)
60 | layer_2 = self.pretrained.layer2(layer_1)
61 | layer_3 = self.pretrained.layer3(layer_2)
62 | layer_4 = self.pretrained.layer4(layer_3)
63 |
64 | layer_1_rn = self.scratch.layer1_rn(layer_1)
65 | layer_2_rn = self.scratch.layer2_rn(layer_2)
66 | layer_3_rn = self.scratch.layer3_rn(layer_3)
67 | layer_4_rn = self.scratch.layer4_rn(layer_4)
68 |
69 | path_4 = self.scratch.refinenet4(layer_4_rn)
70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73 |
74 | out = self.scratch.output_conv(path_1)
75 |
76 | return torch.squeeze(out, dim=1)
77 |
--------------------------------------------------------------------------------
/midas/midas_net_custom.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet_small(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17 | blocks={'expand': True}):
18 | """Init.
19 |
20 | Args:
21 | path (str, optional): Path to saved model. Defaults to None.
22 | features (int, optional): Number of features. Defaults to 256.
23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24 | """
25 | print("Loading weights: ", path)
26 |
27 | super(MidasNet_small, self).__init__()
28 |
29 | use_pretrained = False if path else True
30 |
31 | self.channels_last = channels_last
32 | self.blocks = blocks
33 | self.backbone = backbone
34 |
35 | self.groups = 1
36 |
37 | features1=features
38 | features2=features
39 | features3=features
40 | features4=features
41 | self.expand = False
42 | if "expand" in self.blocks and self.blocks['expand'] == True:
43 | self.expand = True
44 | features1=features
45 | features2=features*2
46 | features3=features*4
47 | features4=features*8
48 |
49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50 |
51 | self.scratch.activation = nn.ReLU(False)
52 |
53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57 |
58 |
59 | self.scratch.output_conv = nn.Sequential(
60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61 | Interpolate(scale_factor=2, mode="bilinear"),
62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63 | self.scratch.activation,
64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65 | nn.ReLU(True) if non_negative else nn.Identity(),
66 | nn.Identity(),
67 | )
68 |
69 | if path:
70 | self.load(path)
71 |
72 |
73 | def forward(self, x):
74 | """Forward pass.
75 |
76 | Args:
77 | x (tensor): input data (image)
78 |
79 | Returns:
80 | tensor: depth
81 | """
82 | if self.channels_last==True:
83 | print("self.channels_last = ", self.channels_last)
84 | x.contiguous(memory_format=torch.channels_last)
85 |
86 |
87 | layer_1 = self.pretrained.layer1(x)
88 | layer_2 = self.pretrained.layer2(layer_1)
89 | layer_3 = self.pretrained.layer3(layer_2)
90 | layer_4 = self.pretrained.layer4(layer_3)
91 |
92 | layer_1_rn = self.scratch.layer1_rn(layer_1)
93 | layer_2_rn = self.scratch.layer2_rn(layer_2)
94 | layer_3_rn = self.scratch.layer3_rn(layer_3)
95 | layer_4_rn = self.scratch.layer4_rn(layer_4)
96 |
97 |
98 | path_4 = self.scratch.refinenet4(layer_4_rn)
99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102 |
103 | out = self.scratch.output_conv(path_1)
104 |
105 | return torch.squeeze(out, dim=1)
106 |
107 |
108 |
109 | def fuse_model(m):
110 | prev_previous_type = nn.Identity()
111 | prev_previous_name = ''
112 | previous_type = nn.Identity()
113 | previous_name = ''
114 | for name, module in m.named_modules():
115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116 | # print("FUSED ", prev_previous_name, previous_name, name)
117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119 | # print("FUSED ", prev_previous_name, previous_name)
120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122 | # print("FUSED ", previous_name, name)
123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124 |
125 | prev_previous_type = previous_type
126 | prev_previous_name = previous_name
127 | previous_type = type(module)
128 | previous_name = name
--------------------------------------------------------------------------------
/mobile/README.md:
--------------------------------------------------------------------------------
1 | ## Mobile version of MiDaS for iOS / Android - Monocular Depth Estimation
2 |
3 | ### Accuracy
4 |
5 | * Old small model - ResNet50 default-decoder 384x384
6 | * New small model - EfficientNet-Lite3 small-decoder 256x256
7 |
8 | **Zero-shot error** (the lower - the better):
9 |
10 | | Model | DIW WHDR | Eth3d AbsRel | Sintel AbsRel | Kitti δ>1.25 | NyuDepthV2 δ>1.25 | TUM δ>1.25 |
11 | |---|---|---|---|---|---|---|
12 | | Old small model 384x384 | **0.1248** | 0.1550 | **0.3300** | **21.81** | 15.73 | 17.00 |
13 | | New small model 256x256 | 0.1344 | **0.1344** | 0.3370 | 29.27 | **13.43** | **14.53** |
14 | | Relative improvement, % | -8 % | **+13 %** | -2 % | -34 % | **+15 %** | **+15 %** |
15 |
16 | None of Train/Valid/Test subsets of datasets (DIW, Eth3d, Sintel, Kitti, NyuDepthV2, TUM) were not involved in Training or Fine Tuning.
17 |
18 | ### Inference speed (FPS) on iOS / Android
19 |
20 | **Frames Per Second** (the higher - the better):
21 |
22 | | Model | iPhone CPU | iPhone GPU | iPhone NPU | OnePlus8 CPU | OnePlus8 GPU | OnePlus8 NNAPI |
23 | |---|---|---|---|---|---|---|
24 | | Old small model 384x384 | 0.6 | N/A | N/A | 0.45 | 0.50 | 0.50 |
25 | | New small model 256x256 | 8 | 22 | **30** | 6 | **22** | 4 |
26 | | SpeedUp, X times | **12.8x** | - | - | **13.2x** | **44x** | **8x** |
27 |
28 | N/A - run-time error (no data available)
29 |
30 |
31 | #### Models:
32 |
33 | * Old small model - ResNet50 default-decoder 1x384x384x3, batch=1 FP32 (converters: Pytorch -> ONNX - [onnx_tf](https://github.com/onnx/onnx-tensorflow) -> (saved model) PB -> TFlite)
34 |
35 | (Trained on datasets: RedWeb, MegaDepth, WSVD, 3D Movies, DIML indoor)
36 |
37 | * New small model - EfficientNet-Lite3 small-decoder 1x256x256x3, batch=1 FP32 (custom converter: Pytorch -> TFlite)
38 |
39 | (Trained on datasets: RedWeb, MegaDepth, WSVD, 3D Movies, DIML indoor, HRWSI, IRS, TartanAir, BlendedMVS, ApolloScape)
40 |
41 | #### Frameworks for training and conversions:
42 | ```
43 | pip install torch==1.6.0 torchvision==0.7.0
44 | pip install tf-nightly-gpu==2.5.0.dev20201031 tensorflow-addons==0.11.2 numpy==1.18.0
45 | git clone --depth 1 --branch v1.6.0 https://github.com/onnx/onnx-tensorflow
46 | ```
47 |
48 | #### SoC - OS - Library:
49 |
50 | * iPhone 11 (A13 Bionic) - iOS 13.7 - TensorFlowLiteSwift 0.0.1-nightly
51 | * OnePlus 8 (Snapdragon 865) - Andoird 10 - org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly
52 |
53 |
54 | ### Citation
55 |
56 | This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3):
57 |
58 | >Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer
59 | René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun
60 |
61 | Please cite our paper if you use this code or any of the models:
62 | ```
63 | @article{Ranftl2020,
64 | author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun},
65 | title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer},
66 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
67 | year = {2020},
68 | }
69 | ```
70 |
71 |
--------------------------------------------------------------------------------
/mobile/android/.gitignore:
--------------------------------------------------------------------------------
1 | *.iml
2 | .gradle
3 | /local.properties
4 | /.idea/libraries
5 | /.idea/modules.xml
6 | /.idea/workspace.xml
7 | .DS_Store
8 | /build
9 | /captures
10 | .externalNativeBuild
11 |
12 | /.gradle/
13 | /.idea/
--------------------------------------------------------------------------------
/mobile/android/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Alexey
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/mobile/android/README.md:
--------------------------------------------------------------------------------
1 | # MiDaS on Android smartphone by using TensorFlow-lite (TFLite)
2 |
3 |
4 | * Either use Android Studio for compilation.
5 |
6 | * Or use ready to install apk-file:
7 | * Or use URL: https://i.diawi.com/CVb8a9
8 | * Or use QR-code:
9 |
10 | Scan QR-code or open URL -> Press `Install application` -> Press `Download` and wait for download -> Open -> Install -> Open -> Press: Allow MiDaS to take photo and video from the camera While using the APP
11 |
12 | 
13 |
14 | ----
15 |
16 | To use another model, you should convert it to `model_opt.tflite` and place it to the directory: `models\src\main\assets`
17 |
18 |
19 | ----
20 |
21 | Original repository: https://github.com/isl-org/MiDaS
22 |
--------------------------------------------------------------------------------
/mobile/android/app/.gitignore:
--------------------------------------------------------------------------------
1 | /build
2 |
3 | /build/
--------------------------------------------------------------------------------
/mobile/android/app/build.gradle:
--------------------------------------------------------------------------------
1 | apply plugin: 'com.android.application'
2 |
3 | android {
4 | compileSdkVersion 28
5 | defaultConfig {
6 | applicationId "org.tensorflow.lite.examples.classification"
7 | minSdkVersion 21
8 | targetSdkVersion 28
9 | versionCode 1
10 | versionName "1.0"
11 |
12 | testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
13 | }
14 | buildTypes {
15 | release {
16 | minifyEnabled false
17 | proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
18 | }
19 | }
20 | aaptOptions {
21 | noCompress "tflite"
22 | }
23 | compileOptions {
24 | sourceCompatibility = '1.8'
25 | targetCompatibility = '1.8'
26 | }
27 | lintOptions {
28 | abortOnError false
29 | }
30 | flavorDimensions "tfliteInference"
31 | productFlavors {
32 | // The TFLite inference is built using the TFLite Support library.
33 | support {
34 | dimension "tfliteInference"
35 | }
36 | // The TFLite inference is built using the TFLite Task library.
37 | taskApi {
38 | dimension "tfliteInference"
39 | }
40 | }
41 |
42 | }
43 |
44 | dependencies {
45 | implementation fileTree(dir: 'libs', include: ['*.jar'])
46 | supportImplementation project(":lib_support")
47 | taskApiImplementation project(":lib_task_api")
48 | implementation 'androidx.appcompat:appcompat:1.0.0'
49 | implementation 'androidx.coordinatorlayout:coordinatorlayout:1.0.0'
50 | implementation 'com.google.android.material:material:1.0.0'
51 |
52 | androidTestImplementation 'androidx.test.ext:junit:1.1.1'
53 | androidTestImplementation 'com.google.truth:truth:1.0.1'
54 | androidTestImplementation 'androidx.test:runner:1.2.0'
55 | androidTestImplementation 'androidx.test:rules:1.1.0'
56 | }
57 |
--------------------------------------------------------------------------------
/mobile/android/app/proguard-rules.pro:
--------------------------------------------------------------------------------
1 | # Add project specific ProGuard rules here.
2 | # You can control the set of applied configuration files using the
3 | # proguardFiles setting in build.gradle.
4 | #
5 | # For more details, see
6 | # http://developer.android.com/guide/developing/tools/proguard.html
7 |
8 | # If your project uses WebView with JS, uncomment the following
9 | # and specify the fully qualified class name to the JavaScript interface
10 | # class:
11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview {
12 | # public *;
13 | #}
14 |
15 | # Uncomment this to preserve the line number information for
16 | # debugging stack traces.
17 | #-keepattributes SourceFile,LineNumberTable
18 |
19 | # If you keep the line number information, uncomment this to
20 | # hide the original source file name.
21 | #-renamesourcefileattribute SourceFile
22 |
--------------------------------------------------------------------------------
/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_support.txt:
--------------------------------------------------------------------------------
1 | red_fox 0.79403335
2 | kit_fox 0.16753247
3 | grey_fox 0.03619214
4 |
--------------------------------------------------------------------------------
/mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_task_api.txt:
--------------------------------------------------------------------------------
1 | red_fox 0.85
2 | kit_fox 0.13
3 | grey_fox 0.02
4 |
--------------------------------------------------------------------------------
/mobile/android/app/src/androidTest/java/AndroidManifest.xml:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
--------------------------------------------------------------------------------
/mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.tensorflow.lite.examples.classification;
18 |
19 | import static com.google.common.truth.Truth.assertThat;
20 |
21 | import android.content.res.AssetManager;
22 | import android.graphics.Bitmap;
23 | import android.graphics.BitmapFactory;
24 | import android.util.Log;
25 | import androidx.test.ext.junit.runners.AndroidJUnit4;
26 | import androidx.test.platform.app.InstrumentationRegistry;
27 | import androidx.test.rule.ActivityTestRule;
28 | import java.io.IOException;
29 | import java.io.InputStream;
30 | import java.util.ArrayList;
31 | import java.util.Iterator;
32 | import java.util.List;
33 | import java.util.Scanner;
34 | import org.junit.Assert;
35 | import org.junit.Rule;
36 | import org.junit.Test;
37 | import org.junit.runner.RunWith;
38 | import org.tensorflow.lite.examples.classification.tflite.Classifier;
39 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Device;
40 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Model;
41 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition;
42 |
43 | /** Golden test for Image Classification Reference app. */
44 | @RunWith(AndroidJUnit4.class)
45 | public class ClassifierTest {
46 |
47 | @Rule
48 | public ActivityTestRule rule =
49 | new ActivityTestRule<>(ClassifierActivity.class);
50 |
51 | private static final String[] INPUTS = {"fox.jpg"};
52 | private static final String[] GOLDEN_OUTPUTS_SUPPORT = {"fox-mobilenet_v1_1.0_224_support.txt"};
53 | private static final String[] GOLDEN_OUTPUTS_TASK = {"fox-mobilenet_v1_1.0_224_task_api.txt"};
54 |
55 | @Test
56 | public void classificationResultsShouldNotChange() throws IOException {
57 | ClassifierActivity activity = rule.getActivity();
58 | Classifier classifier = Classifier.create(activity, Model.FLOAT_MOBILENET, Device.CPU, 1);
59 | for (int i = 0; i < INPUTS.length; i++) {
60 | String imageFileName = INPUTS[i];
61 | String goldenOutputFileName;
62 | // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy.
63 | // This is a temporary workaround to set different golden rest results as the preprocessing
64 | // of lib_support and lib_task_api are different. Will merge them once the above TODO is
65 | // resolved.
66 | if (Classifier.TAG.equals("ClassifierWithSupport")) {
67 | goldenOutputFileName = GOLDEN_OUTPUTS_SUPPORT[i];
68 | } else {
69 | goldenOutputFileName = GOLDEN_OUTPUTS_TASK[i];
70 | }
71 | Bitmap input = loadImage(imageFileName);
72 | List goldenOutput = loadRecognitions(goldenOutputFileName);
73 |
74 | List result = classifier.recognizeImage(input, 0);
75 | Iterator goldenOutputIterator = goldenOutput.iterator();
76 |
77 | for (Recognition actual : result) {
78 | Assert.assertTrue(goldenOutputIterator.hasNext());
79 | Recognition expected = goldenOutputIterator.next();
80 | assertThat(actual.getTitle()).isEqualTo(expected.getTitle());
81 | assertThat(actual.getConfidence()).isWithin(0.01f).of(expected.getConfidence());
82 | }
83 | }
84 | }
85 |
86 | private static Bitmap loadImage(String fileName) {
87 | AssetManager assetManager =
88 | InstrumentationRegistry.getInstrumentation().getContext().getAssets();
89 | InputStream inputStream = null;
90 | try {
91 | inputStream = assetManager.open(fileName);
92 | } catch (IOException e) {
93 | Log.e("Test", "Cannot load image from assets");
94 | }
95 | return BitmapFactory.decodeStream(inputStream);
96 | }
97 |
98 | private static List loadRecognitions(String fileName) {
99 | AssetManager assetManager =
100 | InstrumentationRegistry.getInstrumentation().getContext().getAssets();
101 | InputStream inputStream = null;
102 | try {
103 | inputStream = assetManager.open(fileName);
104 | } catch (IOException e) {
105 | Log.e("Test", "Cannot load probability results from assets");
106 | }
107 | Scanner scanner = new Scanner(inputStream);
108 | List result = new ArrayList<>();
109 | while (scanner.hasNext()) {
110 | String category = scanner.next();
111 | category = category.replace('_', ' ');
112 | if (!scanner.hasNextFloat()) {
113 | break;
114 | }
115 | float probability = scanner.nextFloat();
116 | Recognition recognition = new Recognition(null, category, probability, null);
117 | result.add(recognition);
118 | }
119 | return result;
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/mobile/android/app/src/main/AndroidManifest.xml:
--------------------------------------------------------------------------------
1 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
18 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.tensorflow.lite.examples.classification.customview;
18 |
19 | import android.content.Context;
20 | import android.util.AttributeSet;
21 | import android.view.TextureView;
22 |
23 | /** A {@link TextureView} that can be adjusted to a specified aspect ratio. */
24 | public class AutoFitTextureView extends TextureView {
25 | private int ratioWidth = 0;
26 | private int ratioHeight = 0;
27 |
28 | public AutoFitTextureView(final Context context) {
29 | this(context, null);
30 | }
31 |
32 | public AutoFitTextureView(final Context context, final AttributeSet attrs) {
33 | this(context, attrs, 0);
34 | }
35 |
36 | public AutoFitTextureView(final Context context, final AttributeSet attrs, final int defStyle) {
37 | super(context, attrs, defStyle);
38 | }
39 |
40 | /**
41 | * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio
42 | * calculated from the parameters. Note that the actual sizes of parameters don't matter, that is,
43 | * calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result.
44 | *
45 | * @param width Relative horizontal size
46 | * @param height Relative vertical size
47 | */
48 | public void setAspectRatio(final int width, final int height) {
49 | if (width < 0 || height < 0) {
50 | throw new IllegalArgumentException("Size cannot be negative.");
51 | }
52 | ratioWidth = width;
53 | ratioHeight = height;
54 | requestLayout();
55 | }
56 |
57 | @Override
58 | protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
59 | super.onMeasure(widthMeasureSpec, heightMeasureSpec);
60 | final int width = MeasureSpec.getSize(widthMeasureSpec);
61 | final int height = MeasureSpec.getSize(heightMeasureSpec);
62 | if (0 == ratioWidth || 0 == ratioHeight) {
63 | setMeasuredDimension(width, height);
64 | } else {
65 | if (width < height * ratioWidth / ratioHeight) {
66 | setMeasuredDimension(width, width * ratioHeight / ratioWidth);
67 | } else {
68 | setMeasuredDimension(height * ratioWidth / ratioHeight, height);
69 | }
70 | }
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java:
--------------------------------------------------------------------------------
1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | package org.tensorflow.lite.examples.classification.customview;
17 |
18 | import android.content.Context;
19 | import android.graphics.Canvas;
20 | import android.util.AttributeSet;
21 | import android.view.View;
22 | import java.util.LinkedList;
23 | import java.util.List;
24 |
25 | /** A simple View providing a render callback to other classes. */
26 | public class OverlayView extends View {
27 | private final List callbacks = new LinkedList();
28 |
29 | public OverlayView(final Context context, final AttributeSet attrs) {
30 | super(context, attrs);
31 | }
32 |
33 | public void addCallback(final DrawCallback callback) {
34 | callbacks.add(callback);
35 | }
36 |
37 | @Override
38 | public synchronized void draw(final Canvas canvas) {
39 | for (final DrawCallback callback : callbacks) {
40 | callback.drawCallback(canvas);
41 | }
42 | }
43 |
44 | /** Interface defining the callback for client classes. */
45 | public interface DrawCallback {
46 | public void drawCallback(final Canvas canvas);
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java:
--------------------------------------------------------------------------------
1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | package org.tensorflow.lite.examples.classification.customview;
17 |
18 | import android.content.Context;
19 | import android.graphics.Canvas;
20 | import android.graphics.Paint;
21 | import android.util.AttributeSet;
22 | import android.util.TypedValue;
23 | import android.view.View;
24 | import java.util.List;
25 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition;
26 |
27 | public class RecognitionScoreView extends View implements ResultsView {
28 | private static final float TEXT_SIZE_DIP = 16;
29 | private final float textSizePx;
30 | private final Paint fgPaint;
31 | private final Paint bgPaint;
32 | private List results;
33 |
34 | public RecognitionScoreView(final Context context, final AttributeSet set) {
35 | super(context, set);
36 |
37 | textSizePx =
38 | TypedValue.applyDimension(
39 | TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
40 | fgPaint = new Paint();
41 | fgPaint.setTextSize(textSizePx);
42 |
43 | bgPaint = new Paint();
44 | bgPaint.setColor(0xcc4285f4);
45 | }
46 |
47 | @Override
48 | public void setResults(final List results) {
49 | this.results = results;
50 | postInvalidate();
51 | }
52 |
53 | @Override
54 | public void onDraw(final Canvas canvas) {
55 | final int x = 10;
56 | int y = (int) (fgPaint.getTextSize() * 1.5f);
57 |
58 | canvas.drawPaint(bgPaint);
59 |
60 | if (results != null) {
61 | for (final Recognition recog : results) {
62 | canvas.drawText(recog.getTitle() + ": " + recog.getConfidence(), x, y, fgPaint);
63 | y += (int) (fgPaint.getTextSize() * 1.5f);
64 | }
65 | }
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java:
--------------------------------------------------------------------------------
1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | package org.tensorflow.lite.examples.classification.customview;
17 |
18 | import java.util.List;
19 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition;
20 |
21 | public interface ResultsView {
22 | public void setResults(final List results);
23 | }
24 |
--------------------------------------------------------------------------------
/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/env/BorderedText.java:
--------------------------------------------------------------------------------
1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | package org.tensorflow.lite.examples.classification.env;
17 |
18 | import android.graphics.Canvas;
19 | import android.graphics.Color;
20 | import android.graphics.Paint;
21 | import android.graphics.Paint.Align;
22 | import android.graphics.Paint.Style;
23 | import android.graphics.Rect;
24 | import android.graphics.Typeface;
25 | import java.util.Vector;
26 |
27 | /** A class that encapsulates the tedious bits of rendering legible, bordered text onto a canvas. */
28 | public class BorderedText {
29 | private final Paint interiorPaint;
30 | private final Paint exteriorPaint;
31 |
32 | private final float textSize;
33 |
34 | /**
35 | * Creates a left-aligned bordered text object with a white interior, and a black exterior with
36 | * the specified text size.
37 | *
38 | * @param textSize text size in pixels
39 | */
40 | public BorderedText(final float textSize) {
41 | this(Color.WHITE, Color.BLACK, textSize);
42 | }
43 |
44 | /**
45 | * Create a bordered text object with the specified interior and exterior colors, text size and
46 | * alignment.
47 | *
48 | * @param interiorColor the interior text color
49 | * @param exteriorColor the exterior text color
50 | * @param textSize text size in pixels
51 | */
52 | public BorderedText(final int interiorColor, final int exteriorColor, final float textSize) {
53 | interiorPaint = new Paint();
54 | interiorPaint.setTextSize(textSize);
55 | interiorPaint.setColor(interiorColor);
56 | interiorPaint.setStyle(Style.FILL);
57 | interiorPaint.setAntiAlias(false);
58 | interiorPaint.setAlpha(255);
59 |
60 | exteriorPaint = new Paint();
61 | exteriorPaint.setTextSize(textSize);
62 | exteriorPaint.setColor(exteriorColor);
63 | exteriorPaint.setStyle(Style.FILL_AND_STROKE);
64 | exteriorPaint.setStrokeWidth(textSize / 8);
65 | exteriorPaint.setAntiAlias(false);
66 | exteriorPaint.setAlpha(255);
67 |
68 | this.textSize = textSize;
69 | }
70 |
71 | public void setTypeface(Typeface typeface) {
72 | interiorPaint.setTypeface(typeface);
73 | exteriorPaint.setTypeface(typeface);
74 | }
75 |
76 | public void drawText(final Canvas canvas, final float posX, final float posY, final String text) {
77 | canvas.drawText(text, posX, posY, exteriorPaint);
78 | canvas.drawText(text, posX, posY, interiorPaint);
79 | }
80 |
81 | public void drawLines(Canvas canvas, final float posX, final float posY, Vector lines) {
82 | int lineNum = 0;
83 | for (final String line : lines) {
84 | drawText(canvas, posX, posY - getTextSize() * (lines.size() - lineNum - 1), line);
85 | ++lineNum;
86 | }
87 | }
88 |
89 | public void setInteriorColor(final int color) {
90 | interiorPaint.setColor(color);
91 | }
92 |
93 | public void setExteriorColor(final int color) {
94 | exteriorPaint.setColor(color);
95 | }
96 |
97 | public float getTextSize() {
98 | return textSize;
99 | }
100 |
101 | public void setAlpha(final int alpha) {
102 | interiorPaint.setAlpha(alpha);
103 | exteriorPaint.setAlpha(alpha);
104 | }
105 |
106 | public void getTextBounds(
107 | final String line, final int index, final int count, final Rect lineBounds) {
108 | interiorPaint.getTextBounds(line, index, count, lineBounds);
109 | }
110 |
111 | public void setTextAlign(final Align align) {
112 | interiorPaint.setTextAlign(align);
113 | exteriorPaint.setTextAlign(align);
114 | }
115 | }
116 |
--------------------------------------------------------------------------------
/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/env/ImageUtils.java:
--------------------------------------------------------------------------------
1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | package org.tensorflow.lite.examples.classification.env;
17 |
18 | import android.graphics.Bitmap;
19 | import android.os.Environment;
20 | import java.io.File;
21 | import java.io.FileOutputStream;
22 |
23 | /** Utility class for manipulating images. */
24 | public class ImageUtils {
25 | // This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges
26 | // are normalized to eight bits.
27 | static final int kMaxChannelValue = 262143;
28 |
29 | @SuppressWarnings("unused")
30 | private static final Logger LOGGER = new Logger();
31 |
32 | /**
33 | * Utility method to compute the allocated size in bytes of a YUV420SP image of the given
34 | * dimensions.
35 | */
36 | public static int getYUVByteSize(final int width, final int height) {
37 | // The luminance plane requires 1 byte per pixel.
38 | final int ySize = width * height;
39 |
40 | // The UV plane works on 2x2 blocks, so dimensions with odd size must be rounded up.
41 | // Each 2x2 block takes 2 bytes to encode, one each for U and V.
42 | final int uvSize = ((width + 1) / 2) * ((height + 1) / 2) * 2;
43 |
44 | return ySize + uvSize;
45 | }
46 |
47 | /**
48 | * Saves a Bitmap object to disk for analysis.
49 | *
50 | * @param bitmap The bitmap to save.
51 | */
52 | public static void saveBitmap(final Bitmap bitmap) {
53 | saveBitmap(bitmap, "preview.png");
54 | }
55 |
56 | /**
57 | * Saves a Bitmap object to disk for analysis.
58 | *
59 | * @param bitmap The bitmap to save.
60 | * @param filename The location to save the bitmap to.
61 | */
62 | public static void saveBitmap(final Bitmap bitmap, final String filename) {
63 | final String root =
64 | Environment.getExternalStorageDirectory().getAbsolutePath() + File.separator + "tensorflow";
65 | LOGGER.i("Saving %dx%d bitmap to %s.", bitmap.getWidth(), bitmap.getHeight(), root);
66 | final File myDir = new File(root);
67 |
68 | if (!myDir.mkdirs()) {
69 | LOGGER.i("Make dir failed");
70 | }
71 |
72 | final String fname = filename;
73 | final File file = new File(myDir, fname);
74 | if (file.exists()) {
75 | file.delete();
76 | }
77 | try {
78 | final FileOutputStream out = new FileOutputStream(file);
79 | bitmap.compress(Bitmap.CompressFormat.PNG, 99, out);
80 | out.flush();
81 | out.close();
82 | } catch (final Exception e) {
83 | LOGGER.e(e, "Exception!");
84 | }
85 | }
86 |
87 | public static void convertYUV420SPToARGB8888(byte[] input, int width, int height, int[] output) {
88 | final int frameSize = width * height;
89 | for (int j = 0, yp = 0; j < height; j++) {
90 | int uvp = frameSize + (j >> 1) * width;
91 | int u = 0;
92 | int v = 0;
93 |
94 | for (int i = 0; i < width; i++, yp++) {
95 | int y = 0xff & input[yp];
96 | if ((i & 1) == 0) {
97 | v = 0xff & input[uvp++];
98 | u = 0xff & input[uvp++];
99 | }
100 |
101 | output[yp] = YUV2RGB(y, u, v);
102 | }
103 | }
104 | }
105 |
106 | private static int YUV2RGB(int y, int u, int v) {
107 | // Adjust and check YUV values
108 | y = (y - 16) < 0 ? 0 : (y - 16);
109 | u -= 128;
110 | v -= 128;
111 |
112 | // This is the floating point equivalent. We do the conversion in integer
113 | // because some Android devices do not have floating point in hardware.
114 | // nR = (int)(1.164 * nY + 2.018 * nU);
115 | // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU);
116 | // nB = (int)(1.164 * nY + 1.596 * nV);
117 | int y1192 = 1192 * y;
118 | int r = (y1192 + 1634 * v);
119 | int g = (y1192 - 833 * v - 400 * u);
120 | int b = (y1192 + 2066 * u);
121 |
122 | // Clipping RGB values to be inside boundaries [ 0 , kMaxChannelValue ]
123 | r = r > kMaxChannelValue ? kMaxChannelValue : (r < 0 ? 0 : r);
124 | g = g > kMaxChannelValue ? kMaxChannelValue : (g < 0 ? 0 : g);
125 | b = b > kMaxChannelValue ? kMaxChannelValue : (b < 0 ? 0 : b);
126 |
127 | return 0xff000000 | ((r << 6) & 0xff0000) | ((g >> 2) & 0xff00) | ((b >> 10) & 0xff);
128 | }
129 |
130 | public static void convertYUV420ToARGB8888(
131 | byte[] yData,
132 | byte[] uData,
133 | byte[] vData,
134 | int width,
135 | int height,
136 | int yRowStride,
137 | int uvRowStride,
138 | int uvPixelStride,
139 | int[] out) {
140 | int yp = 0;
141 | for (int j = 0; j < height; j++) {
142 | int pY = yRowStride * j;
143 | int pUV = uvRowStride * (j >> 1);
144 |
145 | for (int i = 0; i < width; i++) {
146 | int uv_offset = pUV + (i >> 1) * uvPixelStride;
147 |
148 | out[yp++] = YUV2RGB(0xff & yData[pY + i], 0xff & uData[uv_offset], 0xff & vData[uv_offset]);
149 | }
150 | }
151 | }
152 | }
153 |
--------------------------------------------------------------------------------
/mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/env/Logger.java:
--------------------------------------------------------------------------------
1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | package org.tensorflow.lite.examples.classification.env;
17 |
18 | import android.util.Log;
19 | import java.util.HashSet;
20 | import java.util.Set;
21 |
22 | /** Wrapper for the platform log function, allows convenient message prefixing and log disabling. */
23 | public final class Logger {
24 | private static final String DEFAULT_TAG = "tensorflow";
25 | private static final int DEFAULT_MIN_LOG_LEVEL = Log.DEBUG;
26 |
27 | // Classes to be ignored when examining the stack trace
28 | private static final Set IGNORED_CLASS_NAMES;
29 |
30 | static {
31 | IGNORED_CLASS_NAMES = new HashSet(3);
32 | IGNORED_CLASS_NAMES.add("dalvik.system.VMStack");
33 | IGNORED_CLASS_NAMES.add("java.lang.Thread");
34 | IGNORED_CLASS_NAMES.add(Logger.class.getCanonicalName());
35 | }
36 |
37 | private final String tag;
38 | private final String messagePrefix;
39 | private int minLogLevel = DEFAULT_MIN_LOG_LEVEL;
40 |
41 | /**
42 | * Creates a Logger using the class name as the message prefix.
43 | *
44 | * @param clazz the simple name of this class is used as the message prefix.
45 | */
46 | public Logger(final Class> clazz) {
47 | this(clazz.getSimpleName());
48 | }
49 |
50 | /**
51 | * Creates a Logger using the specified message prefix.
52 | *
53 | * @param messagePrefix is prepended to the text of every message.
54 | */
55 | public Logger(final String messagePrefix) {
56 | this(DEFAULT_TAG, messagePrefix);
57 | }
58 |
59 | /**
60 | * Creates a Logger with a custom tag and a custom message prefix. If the message prefix is set to
61 | *
62 | *
null
63 | *
64 | * , the caller's class name is used as the prefix.
65 | *
66 | * @param tag identifies the source of a log message.
67 | * @param messagePrefix prepended to every message if non-null. If null, the name of the caller is
68 | * being used
69 | */
70 | public Logger(final String tag, final String messagePrefix) {
71 | this.tag = tag;
72 | final String prefix = messagePrefix == null ? getCallerSimpleName() : messagePrefix;
73 | this.messagePrefix = (prefix.length() > 0) ? prefix + ": " : prefix;
74 | }
75 |
76 | /** Creates a Logger using the caller's class name as the message prefix. */
77 | public Logger() {
78 | this(DEFAULT_TAG, null);
79 | }
80 |
81 | /** Creates a Logger using the caller's class name as the message prefix. */
82 | public Logger(final int minLogLevel) {
83 | this(DEFAULT_TAG, null);
84 | this.minLogLevel = minLogLevel;
85 | }
86 |
87 | /**
88 | * Return caller's simple name.
89 | *
90 | *