├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── SECURITY.md ├── environment.yaml ├── figures ├── Comparison.png └── Improvement_vs_FPS.png ├── hubconf.py ├── input └── .placeholder ├── midas ├── backbones │ ├── beit.py │ ├── levit.py │ ├── next_vit.py │ ├── swin.py │ ├── swin2.py │ ├── swin_common.py │ ├── utils.py │ └── vit.py ├── base_model.py ├── blocks.py ├── dpt_depth.py ├── midas_net.py ├── midas_net_custom.py ├── model_loader.py └── transforms.py ├── mobile ├── README.md ├── android │ ├── .gitignore │ ├── EXPLORE_THE_CODE.md │ ├── LICENSE │ ├── README.md │ ├── app │ │ ├── .gitignore │ │ ├── build.gradle │ │ ├── proguard-rules.pro │ │ └── src │ │ │ ├── androidTest │ │ │ ├── assets │ │ │ │ ├── fox-mobilenet_v1_1.0_224_support.txt │ │ │ │ └── fox-mobilenet_v1_1.0_224_task_api.txt │ │ │ └── java │ │ │ │ ├── AndroidManifest.xml │ │ │ │ └── org │ │ │ │ └── tensorflow │ │ │ │ └── lite │ │ │ │ └── examples │ │ │ │ └── classification │ │ │ │ └── ClassifierTest.java │ │ │ └── main │ │ │ ├── AndroidManifest.xml │ │ │ ├── java │ │ │ └── org │ │ │ │ └── tensorflow │ │ │ │ └── lite │ │ │ │ └── examples │ │ │ │ └── classification │ │ │ │ ├── CameraActivity.java │ │ │ │ ├── CameraConnectionFragment.java │ │ │ │ ├── ClassifierActivity.java │ │ │ │ ├── LegacyCameraConnectionFragment.java │ │ │ │ ├── customview │ │ │ │ ├── AutoFitTextureView.java │ │ │ │ ├── OverlayView.java │ │ │ │ ├── RecognitionScoreView.java │ │ │ │ └── ResultsView.java │ │ │ │ └── env │ │ │ │ ├── BorderedText.java │ │ │ │ ├── ImageUtils.java │ │ │ │ └── Logger.java │ │ │ └── res │ │ │ ├── drawable-hdpi │ │ │ └── ic_launcher.png │ │ │ ├── drawable-mdpi │ │ │ └── ic_launcher.png │ │ │ ├── drawable-v24 │ │ │ └── ic_launcher_foreground.xml │ │ │ ├── drawable-xxhdpi │ │ │ ├── ic_launcher.png │ │ │ ├── icn_chevron_down.png │ │ │ ├── icn_chevron_up.png │ │ │ ├── tfl2_logo.png │ │ │ └── tfl2_logo_dark.png │ │ │ ├── drawable │ │ │ ├── bottom_sheet_bg.xml │ │ │ ├── ic_baseline_add.xml │ │ │ ├── ic_baseline_remove.xml │ │ │ ├── ic_launcher_background.xml │ │ │ └── rectangle.xml │ │ │ ├── layout │ │ │ ├── tfe_ic_activity_camera.xml │ │ │ ├── tfe_ic_camera_connection_fragment.xml │ │ │ └── tfe_ic_layout_bottom_sheet.xml │ │ │ ├── mipmap-anydpi-v26 │ │ │ ├── ic_launcher.xml │ │ │ └── ic_launcher_round.xml │ │ │ ├── mipmap-hdpi │ │ │ ├── ic_launcher.png │ │ │ ├── ic_launcher_foreground.png │ │ │ └── ic_launcher_round.png │ │ │ ├── mipmap-mdpi │ │ │ ├── ic_launcher.png │ │ │ ├── ic_launcher_foreground.png │ │ │ └── ic_launcher_round.png │ │ │ ├── mipmap-xhdpi │ │ │ ├── ic_launcher.png │ │ │ ├── ic_launcher_foreground.png │ │ │ └── ic_launcher_round.png │ │ │ ├── mipmap-xxhdpi │ │ │ ├── ic_launcher.png │ │ │ ├── ic_launcher_foreground.png │ │ │ └── ic_launcher_round.png │ │ │ ├── mipmap-xxxhdpi │ │ │ ├── ic_launcher.png │ │ │ ├── ic_launcher_foreground.png │ │ │ └── ic_launcher_round.png │ │ │ └── values │ │ │ ├── colors.xml │ │ │ ├── dimens.xml │ │ │ ├── strings.xml │ │ │ └── styles.xml │ ├── build.gradle │ ├── gradle.properties │ ├── gradle │ │ └── wrapper │ │ │ ├── gradle-wrapper.jar │ │ │ └── gradle-wrapper.properties │ ├── gradlew │ ├── gradlew.bat │ ├── lib_support │ │ ├── build.gradle │ │ ├── proguard-rules.pro │ │ └── src │ │ │ └── main │ │ │ ├── AndroidManifest.xml │ │ │ └── java │ │ │ └── org │ │ │ └── tensorflow │ │ │ └── lite │ │ │ └── examples │ │ │ └── classification │ │ │ └── tflite │ │ │ ├── Classifier.java │ │ │ ├── ClassifierFloatEfficientNet.java │ │ │ ├── ClassifierFloatMobileNet.java │ │ │ ├── ClassifierQuantizedEfficientNet.java │ │ │ └── ClassifierQuantizedMobileNet.java │ ├── lib_task_api │ │ ├── build.gradle │ │ ├── proguard-rules.pro │ │ └── src │ │ │ └── main │ │ │ ├── AndroidManifest.xml │ │ │ └── java │ │ │ └── org │ │ │ └── tensorflow │ │ │ └── lite │ │ │ └── examples │ │ │ └── classification │ │ │ └── tflite │ │ │ ├── Classifier.java │ │ │ ├── ClassifierFloatEfficientNet.java │ │ │ ├── ClassifierFloatMobileNet.java │ │ │ ├── ClassifierQuantizedEfficientNet.java │ │ │ └── ClassifierQuantizedMobileNet.java │ ├── models │ │ ├── build.gradle │ │ ├── download.gradle │ │ ├── proguard-rules.pro │ │ └── src │ │ │ └── main │ │ │ ├── AndroidManifest.xml │ │ │ └── assets │ │ │ ├── labels.txt │ │ │ ├── labels_without_background.txt │ │ │ └── run_tflite.py │ └── settings.gradle └── ios │ ├── .gitignore │ ├── LICENSE │ ├── Midas.xcodeproj │ ├── project.pbxproj │ ├── project.xcworkspace │ │ ├── contents.xcworkspacedata │ │ ├── xcshareddata │ │ │ └── IDEWorkspaceChecks.plist │ │ └── xcuserdata │ │ │ └── admin.xcuserdatad │ │ │ └── UserInterfaceState.xcuserstate │ └── xcuserdata │ │ └── admin.xcuserdatad │ │ └── xcschemes │ │ └── xcschememanagement.plist │ ├── Midas │ ├── AppDelegate.swift │ ├── Assets.xcassets │ │ ├── AppIcon.appiconset │ │ │ ├── 100.png │ │ │ ├── 1024.png │ │ │ ├── 114.png │ │ │ ├── 120.png │ │ │ ├── 144.png │ │ │ ├── 152.png │ │ │ ├── 167.png │ │ │ ├── 180.png │ │ │ ├── 20.png │ │ │ ├── 29.png │ │ │ ├── 40.png │ │ │ ├── 50.png │ │ │ ├── 57.png │ │ │ ├── 58.png │ │ │ ├── 60.png │ │ │ ├── 72.png │ │ │ ├── 76.png │ │ │ ├── 80.png │ │ │ ├── 87.png │ │ │ └── Contents.json │ │ ├── Contents.json │ │ └── tfl_logo.png │ ├── Camera Feed │ │ ├── CameraFeedManager.swift │ │ └── PreviewView.swift │ ├── Cells │ │ └── InfoCell.swift │ ├── Constants.swift │ ├── Extensions │ │ ├── CGSizeExtension.swift │ │ ├── CVPixelBufferExtension.swift │ │ └── TFLiteExtension.swift │ ├── Info.plist │ ├── ModelDataHandler │ │ └── ModelDataHandler.swift │ ├── Storyboards │ │ └── Base.lproj │ │ │ ├── Launch Screen.storyboard │ │ │ └── Main.storyboard │ ├── ViewControllers │ │ └── ViewController.swift │ └── Views │ │ └── OverlayView.swift │ ├── Podfile │ ├── README.md │ └── RunScripts │ └── download_models.sh ├── output └── .placeholder ├── ros ├── LICENSE ├── README.md ├── additions │ ├── do_catkin_make.sh │ ├── downloads.sh │ ├── install_ros_melodic_ubuntu_17_18.sh │ ├── install_ros_noetic_ubuntu_20.sh │ └── make_package_cpp.sh ├── launch_midas_cpp.sh ├── midas_cpp │ ├── CMakeLists.txt │ ├── launch │ │ ├── midas_cpp.launch │ │ └── midas_talker_listener.launch │ ├── package.xml │ ├── scripts │ │ ├── listener.py │ │ ├── listener_original.py │ │ └── talker.py │ └── src │ │ └── main.cpp └── run_talker_listener_test.sh ├── run.py ├── tf ├── README.md ├── input │ └── .placeholder ├── make_onnx_model.py ├── output │ └── .placeholder ├── run_onnx.py ├── run_pb.py ├── transforms.py └── utils.py ├── utils.py └── weights └── .placeholder /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | *.png 107 | *.pfm 108 | *.jpg 109 | *.jpeg 110 | *.pt -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # enables cuda support in docker 2 | FROM nvidia/cuda:10.2-cudnn7-runtime-ubuntu18.04 3 | 4 | # install python 3.6, pip and requirements for opencv-python 5 | # (see https://github.com/NVIDIA/nvidia-docker/issues/864) 6 | RUN apt-get update && apt-get -y install \ 7 | python3 \ 8 | python3-pip \ 9 | libsm6 \ 10 | libxext6 \ 11 | libxrender-dev \ 12 | curl \ 13 | && rm -rf /var/lib/apt/lists/* 14 | 15 | # install python dependencies 16 | RUN pip3 install --upgrade pip 17 | RUN pip3 install torch~=1.8 torchvision opencv-python-headless~=3.4 timm 18 | 19 | # copy inference code 20 | WORKDIR /opt/MiDaS 21 | COPY ./midas ./midas 22 | COPY ./*.py ./ 23 | 24 | # download model weights so the docker image can be used offline 25 | RUN cd weights && {curl -OL https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt; cd -; } 26 | RUN python3 run.py --model_type dpt_hybrid; exit 0 27 | 28 | # entrypoint (dont forget to mount input and output directories) 29 | CMD python3 run.py --model_type dpt_hybrid 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | Intel is committed to rapidly addressing security vulnerabilities affecting our customers and providing clear guidance on the solution, impact, severity and mitigation. 3 | 4 | ## Reporting a Vulnerability 5 | Please report any security vulnerabilities in this project utilizing the guidelines [here](https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html). 6 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: midas-py310 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - nvidia::cudatoolkit=11.7 7 | - python=3.10.8 8 | - pytorch::pytorch=1.13.0 9 | - torchvision=0.14.0 10 | - pip=22.3.1 11 | - numpy=1.23.4 12 | - pip: 13 | - opencv-python==4.6.0.66 14 | - imutils==0.5.4 15 | - timm==0.6.12 16 | - einops==0.6.0 -------------------------------------------------------------------------------- /figures/Comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/figures/Comparison.png -------------------------------------------------------------------------------- /figures/Improvement_vs_FPS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/figures/Improvement_vs_FPS.png -------------------------------------------------------------------------------- /input/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/input/.placeholder -------------------------------------------------------------------------------- /midas/backbones/levit.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | from .utils import activations, get_activation, Transpose 7 | 8 | 9 | def forward_levit(pretrained, x): 10 | pretrained.model.forward_features(x) 11 | 12 | layer_1 = pretrained.activations["1"] 13 | layer_2 = pretrained.activations["2"] 14 | layer_3 = pretrained.activations["3"] 15 | 16 | layer_1 = pretrained.act_postprocess1(layer_1) 17 | layer_2 = pretrained.act_postprocess2(layer_2) 18 | layer_3 = pretrained.act_postprocess3(layer_3) 19 | 20 | return layer_1, layer_2, layer_3 21 | 22 | 23 | def _make_levit_backbone( 24 | model, 25 | hooks=[3, 11, 21], 26 | patch_grid=[14, 14] 27 | ): 28 | pretrained = nn.Module() 29 | 30 | pretrained.model = model 31 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 32 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 33 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 34 | 35 | pretrained.activations = activations 36 | 37 | patch_grid_size = np.array(patch_grid, dtype=int) 38 | 39 | pretrained.act_postprocess1 = nn.Sequential( 40 | Transpose(1, 2), 41 | nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) 42 | ) 43 | pretrained.act_postprocess2 = nn.Sequential( 44 | Transpose(1, 2), 45 | nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) 46 | ) 47 | pretrained.act_postprocess3 = nn.Sequential( 48 | Transpose(1, 2), 49 | nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) 50 | ) 51 | 52 | return pretrained 53 | 54 | 55 | class ConvTransposeNorm(nn.Sequential): 56 | """ 57 | Modification of 58 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm 59 | such that ConvTranspose2d is used instead of Conv2d. 60 | """ 61 | 62 | def __init__( 63 | self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, 64 | groups=1, bn_weight_init=1): 65 | super().__init__() 66 | self.add_module('c', 67 | nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) 68 | self.add_module('bn', nn.BatchNorm2d(out_chs)) 69 | 70 | nn.init.constant_(self.bn.weight, bn_weight_init) 71 | 72 | @torch.no_grad() 73 | def fuse(self): 74 | c, bn = self._modules.values() 75 | w = bn.weight / (bn.running_var + bn.eps) ** 0.5 76 | w = c.weight * w[:, None, None, None] 77 | b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 78 | m = nn.ConvTranspose2d( 79 | w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, 80 | padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) 81 | m.weight.data.copy_(w) 82 | m.bias.data.copy_(b) 83 | return m 84 | 85 | 86 | def stem_b4_transpose(in_chs, out_chs, activation): 87 | """ 88 | Modification of 89 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 90 | such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. 91 | """ 92 | return nn.Sequential( 93 | ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), 94 | activation(), 95 | ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), 96 | activation()) 97 | 98 | 99 | def _make_pretrained_levit_384(pretrained, hooks=None): 100 | model = timm.create_model("levit_384", pretrained=pretrained) 101 | 102 | hooks = [3, 11, 21] if hooks == None else hooks 103 | return _make_levit_backbone( 104 | model, 105 | hooks=hooks 106 | ) 107 | -------------------------------------------------------------------------------- /midas/backbones/next_vit.py: -------------------------------------------------------------------------------- 1 | import timm 2 | 3 | import torch.nn as nn 4 | 5 | from pathlib import Path 6 | from .utils import activations, forward_default, get_activation 7 | 8 | from ..external.next_vit.classification.nextvit import * 9 | 10 | 11 | def forward_next_vit(pretrained, x): 12 | return forward_default(pretrained, x, "forward") 13 | 14 | 15 | def _make_next_vit_backbone( 16 | model, 17 | hooks=[2, 6, 36, 39], 18 | ): 19 | pretrained = nn.Module() 20 | 21 | pretrained.model = model 22 | pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) 23 | pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) 24 | pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) 25 | pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) 26 | 27 | pretrained.activations = activations 28 | 29 | return pretrained 30 | 31 | 32 | def _make_pretrained_next_vit_large_6m(hooks=None): 33 | model = timm.create_model("nextvit_large") 34 | 35 | hooks = [2, 6, 36, 39] if hooks == None else hooks 36 | return _make_next_vit_backbone( 37 | model, 38 | hooks=hooks, 39 | ) 40 | -------------------------------------------------------------------------------- /midas/backbones/swin.py: -------------------------------------------------------------------------------- 1 | import timm 2 | 3 | from .swin_common import _make_swin_backbone 4 | 5 | 6 | def _make_pretrained_swinl12_384(pretrained, hooks=None): 7 | model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained) 8 | 9 | hooks = [1, 1, 17, 1] if hooks == None else hooks 10 | return _make_swin_backbone( 11 | model, 12 | hooks=hooks 13 | ) 14 | -------------------------------------------------------------------------------- /midas/backbones/swin2.py: -------------------------------------------------------------------------------- 1 | import timm 2 | 3 | from .swin_common import _make_swin_backbone 4 | 5 | 6 | def _make_pretrained_swin2l24_384(pretrained, hooks=None): 7 | model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained) 8 | 9 | hooks = [1, 1, 17, 1] if hooks == None else hooks 10 | return _make_swin_backbone( 11 | model, 12 | hooks=hooks 13 | ) 14 | 15 | 16 | def _make_pretrained_swin2b24_384(pretrained, hooks=None): 17 | model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained) 18 | 19 | hooks = [1, 1, 17, 1] if hooks == None else hooks 20 | return _make_swin_backbone( 21 | model, 22 | hooks=hooks 23 | ) 24 | 25 | 26 | def _make_pretrained_swin2t16_256(pretrained, hooks=None): 27 | model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained) 28 | 29 | hooks = [1, 1, 5, 1] if hooks == None else hooks 30 | return _make_swin_backbone( 31 | model, 32 | hooks=hooks, 33 | patch_grid=[64, 64] 34 | ) 35 | -------------------------------------------------------------------------------- /midas/backbones/swin_common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | from .utils import activations, forward_default, get_activation, Transpose 7 | 8 | 9 | def forward_swin(pretrained, x): 10 | return forward_default(pretrained, x) 11 | 12 | 13 | def _make_swin_backbone( 14 | model, 15 | hooks=[1, 1, 17, 1], 16 | patch_grid=[96, 96] 17 | ): 18 | pretrained = nn.Module() 19 | 20 | pretrained.model = model 21 | pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1")) 22 | pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2")) 23 | pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3")) 24 | pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4")) 25 | 26 | pretrained.activations = activations 27 | 28 | if hasattr(model, "patch_grid"): 29 | used_patch_grid = model.patch_grid 30 | else: 31 | used_patch_grid = patch_grid 32 | 33 | patch_grid_size = np.array(used_patch_grid, dtype=int) 34 | 35 | pretrained.act_postprocess1 = nn.Sequential( 36 | Transpose(1, 2), 37 | nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) 38 | ) 39 | pretrained.act_postprocess2 = nn.Sequential( 40 | Transpose(1, 2), 41 | nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist())) 42 | ) 43 | pretrained.act_postprocess3 = nn.Sequential( 44 | Transpose(1, 2), 45 | nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist())) 46 | ) 47 | pretrained.act_postprocess4 = nn.Sequential( 48 | Transpose(1, 2), 49 | nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist())) 50 | ) 51 | 52 | return pretrained 53 | -------------------------------------------------------------------------------- /midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base_model import BaseModel 5 | from .blocks import ( 6 | FeatureFusionBlock_custom, 7 | Interpolate, 8 | _make_encoder, 9 | forward_beit, 10 | forward_swin, 11 | forward_levit, 12 | forward_vit, 13 | ) 14 | from .backbones.levit import stem_b4_transpose 15 | from timm.models.layers import get_act_layer 16 | 17 | 18 | def _make_fusion_block(features, use_bn, size = None): 19 | return FeatureFusionBlock_custom( 20 | features, 21 | nn.ReLU(False), 22 | deconv=False, 23 | bn=use_bn, 24 | expand=False, 25 | align_corners=True, 26 | size=size, 27 | ) 28 | 29 | 30 | class DPT(BaseModel): 31 | def __init__( 32 | self, 33 | head, 34 | features=256, 35 | backbone="vitb_rn50_384", 36 | readout="project", 37 | channels_last=False, 38 | use_bn=False, 39 | **kwargs 40 | ): 41 | 42 | super(DPT, self).__init__() 43 | 44 | self.channels_last = channels_last 45 | 46 | # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the 47 | # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments. 48 | hooks = { 49 | "beitl16_512": [5, 11, 17, 23], 50 | "beitl16_384": [5, 11, 17, 23], 51 | "beitb16_384": [2, 5, 8, 11], 52 | "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1] 53 | "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] 54 | "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1] 55 | "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1] 56 | "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39] 57 | "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21] 58 | "vitb_rn50_384": [0, 1, 8, 11], 59 | "vitb16_384": [2, 5, 8, 11], 60 | "vitl16_384": [5, 11, 17, 23], 61 | }[backbone] 62 | 63 | if "next_vit" in backbone: 64 | in_features = { 65 | "next_vit_large_6m": [96, 256, 512, 1024], 66 | }[backbone] 67 | else: 68 | in_features = None 69 | 70 | # Instantiate backbone and reassemble blocks 71 | self.pretrained, self.scratch = _make_encoder( 72 | backbone, 73 | features, 74 | False, # Set to true of you want to train from scratch, uses ImageNet weights 75 | groups=1, 76 | expand=False, 77 | exportable=False, 78 | hooks=hooks, 79 | use_readout=readout, 80 | in_features=in_features, 81 | ) 82 | 83 | self.number_layers = len(hooks) if hooks is not None else 4 84 | size_refinenet3 = None 85 | self.scratch.stem_transpose = None 86 | 87 | if "beit" in backbone: 88 | self.forward_transformer = forward_beit 89 | elif "swin" in backbone: 90 | self.forward_transformer = forward_swin 91 | elif "next_vit" in backbone: 92 | from .backbones.next_vit import forward_next_vit 93 | self.forward_transformer = forward_next_vit 94 | elif "levit" in backbone: 95 | self.forward_transformer = forward_levit 96 | size_refinenet3 = 7 97 | self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish")) 98 | else: 99 | self.forward_transformer = forward_vit 100 | 101 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 102 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 103 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3) 104 | if self.number_layers >= 4: 105 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 106 | 107 | self.scratch.output_conv = head 108 | 109 | 110 | def forward(self, x): 111 | if self.channels_last == True: 112 | x.contiguous(memory_format=torch.channels_last) 113 | 114 | layers = self.forward_transformer(self.pretrained, x) 115 | if self.number_layers == 3: 116 | layer_1, layer_2, layer_3 = layers 117 | else: 118 | layer_1, layer_2, layer_3, layer_4 = layers 119 | 120 | layer_1_rn = self.scratch.layer1_rn(layer_1) 121 | layer_2_rn = self.scratch.layer2_rn(layer_2) 122 | layer_3_rn = self.scratch.layer3_rn(layer_3) 123 | if self.number_layers >= 4: 124 | layer_4_rn = self.scratch.layer4_rn(layer_4) 125 | 126 | if self.number_layers == 3: 127 | path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:]) 128 | else: 129 | path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) 130 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) 131 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) 132 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 133 | 134 | if self.scratch.stem_transpose is not None: 135 | path_1 = self.scratch.stem_transpose(path_1) 136 | 137 | out = self.scratch.output_conv(path_1) 138 | 139 | return out 140 | 141 | 142 | class DPTDepthModel(DPT): 143 | def __init__(self, path=None, non_negative=True, **kwargs): 144 | features = kwargs["features"] if "features" in kwargs else 256 145 | head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features 146 | head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32 147 | kwargs.pop("head_features_1", None) 148 | kwargs.pop("head_features_2", None) 149 | 150 | head = nn.Sequential( 151 | nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1), 152 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 153 | nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), 154 | nn.ReLU(True), 155 | nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), 156 | nn.ReLU(True) if non_negative else nn.Identity(), 157 | nn.Identity(), 158 | ) 159 | 160 | super().__init__(head, **kwargs) 161 | 162 | if path is not None: 163 | self.load(path) 164 | 165 | def forward(self, x): 166 | return super().forward(x).squeeze(dim=1) 167 | -------------------------------------------------------------------------------- /midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /mobile/README.md: -------------------------------------------------------------------------------- 1 | ## Mobile version of MiDaS for iOS / Android - Monocular Depth Estimation 2 | 3 | ### Accuracy 4 | 5 | * Old small model - ResNet50 default-decoder 384x384 6 | * New small model - EfficientNet-Lite3 small-decoder 256x256 7 | 8 | **Zero-shot error** (the lower - the better): 9 | 10 | | Model | DIW WHDR | Eth3d AbsRel | Sintel AbsRel | Kitti δ>1.25 | NyuDepthV2 δ>1.25 | TUM δ>1.25 | 11 | |---|---|---|---|---|---|---| 12 | | Old small model 384x384 | **0.1248** | 0.1550 | **0.3300** | **21.81** | 15.73 | 17.00 | 13 | | New small model 256x256 | 0.1344 | **0.1344** | 0.3370 | 29.27 | **13.43** | **14.53** | 14 | | Relative improvement, % | -8 % | **+13 %** | -2 % | -34 % | **+15 %** | **+15 %** | 15 | 16 | None of Train/Valid/Test subsets of datasets (DIW, Eth3d, Sintel, Kitti, NyuDepthV2, TUM) were not involved in Training or Fine Tuning. 17 | 18 | ### Inference speed (FPS) on iOS / Android 19 | 20 | **Frames Per Second** (the higher - the better): 21 | 22 | | Model | iPhone CPU | iPhone GPU | iPhone NPU | OnePlus8 CPU | OnePlus8 GPU | OnePlus8 NNAPI | 23 | |---|---|---|---|---|---|---| 24 | | Old small model 384x384 | 0.6 | N/A | N/A | 0.45 | 0.50 | 0.50 | 25 | | New small model 256x256 | 8 | 22 | **30** | 6 | **22** | 4 | 26 | | SpeedUp, X times | **12.8x** | - | - | **13.2x** | **44x** | **8x** | 27 | 28 | N/A - run-time error (no data available) 29 | 30 | 31 | #### Models: 32 | 33 | * Old small model - ResNet50 default-decoder 1x384x384x3, batch=1 FP32 (converters: Pytorch -> ONNX - [onnx_tf](https://github.com/onnx/onnx-tensorflow) -> (saved model) PB -> TFlite) 34 | 35 | (Trained on datasets: RedWeb, MegaDepth, WSVD, 3D Movies, DIML indoor) 36 | 37 | * New small model - EfficientNet-Lite3 small-decoder 1x256x256x3, batch=1 FP32 (custom converter: Pytorch -> TFlite) 38 | 39 | (Trained on datasets: RedWeb, MegaDepth, WSVD, 3D Movies, DIML indoor, HRWSI, IRS, TartanAir, BlendedMVS, ApolloScape) 40 | 41 | #### Frameworks for training and conversions: 42 | ``` 43 | pip install torch==1.6.0 torchvision==0.7.0 44 | pip install tf-nightly-gpu==2.5.0.dev20201031 tensorflow-addons==0.11.2 numpy==1.18.0 45 | git clone --depth 1 --branch v1.6.0 https://github.com/onnx/onnx-tensorflow 46 | ``` 47 | 48 | #### SoC - OS - Library: 49 | 50 | * iPhone 11 (A13 Bionic) - iOS 13.7 - TensorFlowLiteSwift 0.0.1-nightly 51 | * OnePlus 8 (Snapdragon 865) - Andoird 10 - org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly 52 | 53 | 54 | ### Citation 55 | 56 | This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): 57 | 58 | >Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer 59 | René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun 60 | 61 | Please cite our paper if you use this code or any of the models: 62 | ``` 63 | @article{Ranftl2020, 64 | author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, 65 | title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, 66 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, 67 | year = {2020}, 68 | } 69 | ``` 70 | 71 | -------------------------------------------------------------------------------- /mobile/android/.gitignore: -------------------------------------------------------------------------------- 1 | *.iml 2 | .gradle 3 | /local.properties 4 | /.idea/libraries 5 | /.idea/modules.xml 6 | /.idea/workspace.xml 7 | .DS_Store 8 | /build 9 | /captures 10 | .externalNativeBuild 11 | 12 | /.gradle/ 13 | /.idea/ -------------------------------------------------------------------------------- /mobile/android/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Alexey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mobile/android/README.md: -------------------------------------------------------------------------------- 1 | # MiDaS on Android smartphone by using TensorFlow-lite (TFLite) 2 | 3 | 4 | * Either use Android Studio for compilation. 5 | 6 | * Or use ready to install apk-file: 7 | * Or use URL: https://i.diawi.com/CVb8a9 8 | * Or use QR-code: 9 | 10 | Scan QR-code or open URL -> Press `Install application` -> Press `Download` and wait for download -> Open -> Install -> Open -> Press: Allow MiDaS to take photo and video from the camera While using the APP 11 | 12 | ![CVb8a9](https://user-images.githubusercontent.com/4096485/97727213-38552500-1ae1-11eb-8b76-4ea11216f76d.png) 13 | 14 | ---- 15 | 16 | To use another model, you should convert it to `model_opt.tflite` and place it to the directory: `models\src\main\assets` 17 | 18 | 19 | ---- 20 | 21 | Original repository: https://github.com/isl-org/MiDaS 22 | -------------------------------------------------------------------------------- /mobile/android/app/.gitignore: -------------------------------------------------------------------------------- 1 | /build 2 | 3 | /build/ -------------------------------------------------------------------------------- /mobile/android/app/build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: 'com.android.application' 2 | 3 | android { 4 | compileSdkVersion 28 5 | defaultConfig { 6 | applicationId "org.tensorflow.lite.examples.classification" 7 | minSdkVersion 21 8 | targetSdkVersion 28 9 | versionCode 1 10 | versionName "1.0" 11 | 12 | testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" 13 | } 14 | buildTypes { 15 | release { 16 | minifyEnabled false 17 | proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' 18 | } 19 | } 20 | aaptOptions { 21 | noCompress "tflite" 22 | } 23 | compileOptions { 24 | sourceCompatibility = '1.8' 25 | targetCompatibility = '1.8' 26 | } 27 | lintOptions { 28 | abortOnError false 29 | } 30 | flavorDimensions "tfliteInference" 31 | productFlavors { 32 | // The TFLite inference is built using the TFLite Support library. 33 | support { 34 | dimension "tfliteInference" 35 | } 36 | // The TFLite inference is built using the TFLite Task library. 37 | taskApi { 38 | dimension "tfliteInference" 39 | } 40 | } 41 | 42 | } 43 | 44 | dependencies { 45 | implementation fileTree(dir: 'libs', include: ['*.jar']) 46 | supportImplementation project(":lib_support") 47 | taskApiImplementation project(":lib_task_api") 48 | implementation 'androidx.appcompat:appcompat:1.0.0' 49 | implementation 'androidx.coordinatorlayout:coordinatorlayout:1.0.0' 50 | implementation 'com.google.android.material:material:1.0.0' 51 | 52 | androidTestImplementation 'androidx.test.ext:junit:1.1.1' 53 | androidTestImplementation 'com.google.truth:truth:1.0.1' 54 | androidTestImplementation 'androidx.test:runner:1.2.0' 55 | androidTestImplementation 'androidx.test:rules:1.1.0' 56 | } 57 | -------------------------------------------------------------------------------- /mobile/android/app/proguard-rules.pro: -------------------------------------------------------------------------------- 1 | # Add project specific ProGuard rules here. 2 | # You can control the set of applied configuration files using the 3 | # proguardFiles setting in build.gradle. 4 | # 5 | # For more details, see 6 | # http://developer.android.com/guide/developing/tools/proguard.html 7 | 8 | # If your project uses WebView with JS, uncomment the following 9 | # and specify the fully qualified class name to the JavaScript interface 10 | # class: 11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview { 12 | # public *; 13 | #} 14 | 15 | # Uncomment this to preserve the line number information for 16 | # debugging stack traces. 17 | #-keepattributes SourceFile,LineNumberTable 18 | 19 | # If you keep the line number information, uncomment this to 20 | # hide the original source file name. 21 | #-renamesourcefileattribute SourceFile 22 | -------------------------------------------------------------------------------- /mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_support.txt: -------------------------------------------------------------------------------- 1 | red_fox 0.79403335 2 | kit_fox 0.16753247 3 | grey_fox 0.03619214 4 | -------------------------------------------------------------------------------- /mobile/android/app/src/androidTest/assets/fox-mobilenet_v1_1.0_224_task_api.txt: -------------------------------------------------------------------------------- 1 | red_fox 0.85 2 | kit_fox 0.13 3 | grey_fox 0.02 4 | -------------------------------------------------------------------------------- /mobile/android/app/src/androidTest/java/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | -------------------------------------------------------------------------------- /mobile/android/app/src/androidTest/java/org/tensorflow/lite/examples/classification/ClassifierTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.tensorflow.lite.examples.classification; 18 | 19 | import static com.google.common.truth.Truth.assertThat; 20 | 21 | import android.content.res.AssetManager; 22 | import android.graphics.Bitmap; 23 | import android.graphics.BitmapFactory; 24 | import android.util.Log; 25 | import androidx.test.ext.junit.runners.AndroidJUnit4; 26 | import androidx.test.platform.app.InstrumentationRegistry; 27 | import androidx.test.rule.ActivityTestRule; 28 | import java.io.IOException; 29 | import java.io.InputStream; 30 | import java.util.ArrayList; 31 | import java.util.Iterator; 32 | import java.util.List; 33 | import java.util.Scanner; 34 | import org.junit.Assert; 35 | import org.junit.Rule; 36 | import org.junit.Test; 37 | import org.junit.runner.RunWith; 38 | import org.tensorflow.lite.examples.classification.tflite.Classifier; 39 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; 40 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Model; 41 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; 42 | 43 | /** Golden test for Image Classification Reference app. */ 44 | @RunWith(AndroidJUnit4.class) 45 | public class ClassifierTest { 46 | 47 | @Rule 48 | public ActivityTestRule rule = 49 | new ActivityTestRule<>(ClassifierActivity.class); 50 | 51 | private static final String[] INPUTS = {"fox.jpg"}; 52 | private static final String[] GOLDEN_OUTPUTS_SUPPORT = {"fox-mobilenet_v1_1.0_224_support.txt"}; 53 | private static final String[] GOLDEN_OUTPUTS_TASK = {"fox-mobilenet_v1_1.0_224_task_api.txt"}; 54 | 55 | @Test 56 | public void classificationResultsShouldNotChange() throws IOException { 57 | ClassifierActivity activity = rule.getActivity(); 58 | Classifier classifier = Classifier.create(activity, Model.FLOAT_MOBILENET, Device.CPU, 1); 59 | for (int i = 0; i < INPUTS.length; i++) { 60 | String imageFileName = INPUTS[i]; 61 | String goldenOutputFileName; 62 | // TODO(b/169379396): investigate the impact of the resize algorithm on accuracy. 63 | // This is a temporary workaround to set different golden rest results as the preprocessing 64 | // of lib_support and lib_task_api are different. Will merge them once the above TODO is 65 | // resolved. 66 | if (Classifier.TAG.equals("ClassifierWithSupport")) { 67 | goldenOutputFileName = GOLDEN_OUTPUTS_SUPPORT[i]; 68 | } else { 69 | goldenOutputFileName = GOLDEN_OUTPUTS_TASK[i]; 70 | } 71 | Bitmap input = loadImage(imageFileName); 72 | List goldenOutput = loadRecognitions(goldenOutputFileName); 73 | 74 | List result = classifier.recognizeImage(input, 0); 75 | Iterator goldenOutputIterator = goldenOutput.iterator(); 76 | 77 | for (Recognition actual : result) { 78 | Assert.assertTrue(goldenOutputIterator.hasNext()); 79 | Recognition expected = goldenOutputIterator.next(); 80 | assertThat(actual.getTitle()).isEqualTo(expected.getTitle()); 81 | assertThat(actual.getConfidence()).isWithin(0.01f).of(expected.getConfidence()); 82 | } 83 | } 84 | } 85 | 86 | private static Bitmap loadImage(String fileName) { 87 | AssetManager assetManager = 88 | InstrumentationRegistry.getInstrumentation().getContext().getAssets(); 89 | InputStream inputStream = null; 90 | try { 91 | inputStream = assetManager.open(fileName); 92 | } catch (IOException e) { 93 | Log.e("Test", "Cannot load image from assets"); 94 | } 95 | return BitmapFactory.decodeStream(inputStream); 96 | } 97 | 98 | private static List loadRecognitions(String fileName) { 99 | AssetManager assetManager = 100 | InstrumentationRegistry.getInstrumentation().getContext().getAssets(); 101 | InputStream inputStream = null; 102 | try { 103 | inputStream = assetManager.open(fileName); 104 | } catch (IOException e) { 105 | Log.e("Test", "Cannot load probability results from assets"); 106 | } 107 | Scanner scanner = new Scanner(inputStream); 108 | List result = new ArrayList<>(); 109 | while (scanner.hasNext()) { 110 | String category = scanner.next(); 111 | category = category.replace('_', ' '); 112 | if (!scanner.hasNextFloat()) { 113 | break; 114 | } 115 | float probability = scanner.nextFloat(); 116 | Recognition recognition = new Recognition(null, category, probability, null); 117 | result.add(recognition); 118 | } 119 | return result; 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 18 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/AutoFitTextureView.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | package org.tensorflow.lite.examples.classification.customview; 18 | 19 | import android.content.Context; 20 | import android.util.AttributeSet; 21 | import android.view.TextureView; 22 | 23 | /** A {@link TextureView} that can be adjusted to a specified aspect ratio. */ 24 | public class AutoFitTextureView extends TextureView { 25 | private int ratioWidth = 0; 26 | private int ratioHeight = 0; 27 | 28 | public AutoFitTextureView(final Context context) { 29 | this(context, null); 30 | } 31 | 32 | public AutoFitTextureView(final Context context, final AttributeSet attrs) { 33 | this(context, attrs, 0); 34 | } 35 | 36 | public AutoFitTextureView(final Context context, final AttributeSet attrs, final int defStyle) { 37 | super(context, attrs, defStyle); 38 | } 39 | 40 | /** 41 | * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio 42 | * calculated from the parameters. Note that the actual sizes of parameters don't matter, that is, 43 | * calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result. 44 | * 45 | * @param width Relative horizontal size 46 | * @param height Relative vertical size 47 | */ 48 | public void setAspectRatio(final int width, final int height) { 49 | if (width < 0 || height < 0) { 50 | throw new IllegalArgumentException("Size cannot be negative."); 51 | } 52 | ratioWidth = width; 53 | ratioHeight = height; 54 | requestLayout(); 55 | } 56 | 57 | @Override 58 | protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { 59 | super.onMeasure(widthMeasureSpec, heightMeasureSpec); 60 | final int width = MeasureSpec.getSize(widthMeasureSpec); 61 | final int height = MeasureSpec.getSize(heightMeasureSpec); 62 | if (0 == ratioWidth || 0 == ratioHeight) { 63 | setMeasuredDimension(width, height); 64 | } else { 65 | if (width < height * ratioWidth / ratioHeight) { 66 | setMeasuredDimension(width, width * ratioHeight / ratioWidth); 67 | } else { 68 | setMeasuredDimension(height * ratioWidth / ratioHeight, height); 69 | } 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/OverlayView.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.customview; 17 | 18 | import android.content.Context; 19 | import android.graphics.Canvas; 20 | import android.util.AttributeSet; 21 | import android.view.View; 22 | import java.util.LinkedList; 23 | import java.util.List; 24 | 25 | /** A simple View providing a render callback to other classes. */ 26 | public class OverlayView extends View { 27 | private final List callbacks = new LinkedList(); 28 | 29 | public OverlayView(final Context context, final AttributeSet attrs) { 30 | super(context, attrs); 31 | } 32 | 33 | public void addCallback(final DrawCallback callback) { 34 | callbacks.add(callback); 35 | } 36 | 37 | @Override 38 | public synchronized void draw(final Canvas canvas) { 39 | for (final DrawCallback callback : callbacks) { 40 | callback.drawCallback(canvas); 41 | } 42 | } 43 | 44 | /** Interface defining the callback for client classes. */ 45 | public interface DrawCallback { 46 | public void drawCallback(final Canvas canvas); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/RecognitionScoreView.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.customview; 17 | 18 | import android.content.Context; 19 | import android.graphics.Canvas; 20 | import android.graphics.Paint; 21 | import android.util.AttributeSet; 22 | import android.util.TypedValue; 23 | import android.view.View; 24 | import java.util.List; 25 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; 26 | 27 | public class RecognitionScoreView extends View implements ResultsView { 28 | private static final float TEXT_SIZE_DIP = 16; 29 | private final float textSizePx; 30 | private final Paint fgPaint; 31 | private final Paint bgPaint; 32 | private List results; 33 | 34 | public RecognitionScoreView(final Context context, final AttributeSet set) { 35 | super(context, set); 36 | 37 | textSizePx = 38 | TypedValue.applyDimension( 39 | TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); 40 | fgPaint = new Paint(); 41 | fgPaint.setTextSize(textSizePx); 42 | 43 | bgPaint = new Paint(); 44 | bgPaint.setColor(0xcc4285f4); 45 | } 46 | 47 | @Override 48 | public void setResults(final List results) { 49 | this.results = results; 50 | postInvalidate(); 51 | } 52 | 53 | @Override 54 | public void onDraw(final Canvas canvas) { 55 | final int x = 10; 56 | int y = (int) (fgPaint.getTextSize() * 1.5f); 57 | 58 | canvas.drawPaint(bgPaint); 59 | 60 | if (results != null) { 61 | for (final Recognition recog : results) { 62 | canvas.drawText(recog.getTitle() + ": " + recog.getConfidence(), x, y, fgPaint); 63 | y += (int) (fgPaint.getTextSize() * 1.5f); 64 | } 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/customview/ResultsView.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.customview; 17 | 18 | import java.util.List; 19 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Recognition; 20 | 21 | public interface ResultsView { 22 | public void setResults(final List results); 23 | } 24 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/env/BorderedText.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.env; 17 | 18 | import android.graphics.Canvas; 19 | import android.graphics.Color; 20 | import android.graphics.Paint; 21 | import android.graphics.Paint.Align; 22 | import android.graphics.Paint.Style; 23 | import android.graphics.Rect; 24 | import android.graphics.Typeface; 25 | import java.util.Vector; 26 | 27 | /** A class that encapsulates the tedious bits of rendering legible, bordered text onto a canvas. */ 28 | public class BorderedText { 29 | private final Paint interiorPaint; 30 | private final Paint exteriorPaint; 31 | 32 | private final float textSize; 33 | 34 | /** 35 | * Creates a left-aligned bordered text object with a white interior, and a black exterior with 36 | * the specified text size. 37 | * 38 | * @param textSize text size in pixels 39 | */ 40 | public BorderedText(final float textSize) { 41 | this(Color.WHITE, Color.BLACK, textSize); 42 | } 43 | 44 | /** 45 | * Create a bordered text object with the specified interior and exterior colors, text size and 46 | * alignment. 47 | * 48 | * @param interiorColor the interior text color 49 | * @param exteriorColor the exterior text color 50 | * @param textSize text size in pixels 51 | */ 52 | public BorderedText(final int interiorColor, final int exteriorColor, final float textSize) { 53 | interiorPaint = new Paint(); 54 | interiorPaint.setTextSize(textSize); 55 | interiorPaint.setColor(interiorColor); 56 | interiorPaint.setStyle(Style.FILL); 57 | interiorPaint.setAntiAlias(false); 58 | interiorPaint.setAlpha(255); 59 | 60 | exteriorPaint = new Paint(); 61 | exteriorPaint.setTextSize(textSize); 62 | exteriorPaint.setColor(exteriorColor); 63 | exteriorPaint.setStyle(Style.FILL_AND_STROKE); 64 | exteriorPaint.setStrokeWidth(textSize / 8); 65 | exteriorPaint.setAntiAlias(false); 66 | exteriorPaint.setAlpha(255); 67 | 68 | this.textSize = textSize; 69 | } 70 | 71 | public void setTypeface(Typeface typeface) { 72 | interiorPaint.setTypeface(typeface); 73 | exteriorPaint.setTypeface(typeface); 74 | } 75 | 76 | public void drawText(final Canvas canvas, final float posX, final float posY, final String text) { 77 | canvas.drawText(text, posX, posY, exteriorPaint); 78 | canvas.drawText(text, posX, posY, interiorPaint); 79 | } 80 | 81 | public void drawLines(Canvas canvas, final float posX, final float posY, Vector lines) { 82 | int lineNum = 0; 83 | for (final String line : lines) { 84 | drawText(canvas, posX, posY - getTextSize() * (lines.size() - lineNum - 1), line); 85 | ++lineNum; 86 | } 87 | } 88 | 89 | public void setInteriorColor(final int color) { 90 | interiorPaint.setColor(color); 91 | } 92 | 93 | public void setExteriorColor(final int color) { 94 | exteriorPaint.setColor(color); 95 | } 96 | 97 | public float getTextSize() { 98 | return textSize; 99 | } 100 | 101 | public void setAlpha(final int alpha) { 102 | interiorPaint.setAlpha(alpha); 103 | exteriorPaint.setAlpha(alpha); 104 | } 105 | 106 | public void getTextBounds( 107 | final String line, final int index, final int count, final Rect lineBounds) { 108 | interiorPaint.getTextBounds(line, index, count, lineBounds); 109 | } 110 | 111 | public void setTextAlign(final Align align) { 112 | interiorPaint.setTextAlign(align); 113 | exteriorPaint.setTextAlign(align); 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/env/ImageUtils.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.env; 17 | 18 | import android.graphics.Bitmap; 19 | import android.os.Environment; 20 | import java.io.File; 21 | import java.io.FileOutputStream; 22 | 23 | /** Utility class for manipulating images. */ 24 | public class ImageUtils { 25 | // This value is 2 ^ 18 - 1, and is used to clamp the RGB values before their ranges 26 | // are normalized to eight bits. 27 | static final int kMaxChannelValue = 262143; 28 | 29 | @SuppressWarnings("unused") 30 | private static final Logger LOGGER = new Logger(); 31 | 32 | /** 33 | * Utility method to compute the allocated size in bytes of a YUV420SP image of the given 34 | * dimensions. 35 | */ 36 | public static int getYUVByteSize(final int width, final int height) { 37 | // The luminance plane requires 1 byte per pixel. 38 | final int ySize = width * height; 39 | 40 | // The UV plane works on 2x2 blocks, so dimensions with odd size must be rounded up. 41 | // Each 2x2 block takes 2 bytes to encode, one each for U and V. 42 | final int uvSize = ((width + 1) / 2) * ((height + 1) / 2) * 2; 43 | 44 | return ySize + uvSize; 45 | } 46 | 47 | /** 48 | * Saves a Bitmap object to disk for analysis. 49 | * 50 | * @param bitmap The bitmap to save. 51 | */ 52 | public static void saveBitmap(final Bitmap bitmap) { 53 | saveBitmap(bitmap, "preview.png"); 54 | } 55 | 56 | /** 57 | * Saves a Bitmap object to disk for analysis. 58 | * 59 | * @param bitmap The bitmap to save. 60 | * @param filename The location to save the bitmap to. 61 | */ 62 | public static void saveBitmap(final Bitmap bitmap, final String filename) { 63 | final String root = 64 | Environment.getExternalStorageDirectory().getAbsolutePath() + File.separator + "tensorflow"; 65 | LOGGER.i("Saving %dx%d bitmap to %s.", bitmap.getWidth(), bitmap.getHeight(), root); 66 | final File myDir = new File(root); 67 | 68 | if (!myDir.mkdirs()) { 69 | LOGGER.i("Make dir failed"); 70 | } 71 | 72 | final String fname = filename; 73 | final File file = new File(myDir, fname); 74 | if (file.exists()) { 75 | file.delete(); 76 | } 77 | try { 78 | final FileOutputStream out = new FileOutputStream(file); 79 | bitmap.compress(Bitmap.CompressFormat.PNG, 99, out); 80 | out.flush(); 81 | out.close(); 82 | } catch (final Exception e) { 83 | LOGGER.e(e, "Exception!"); 84 | } 85 | } 86 | 87 | public static void convertYUV420SPToARGB8888(byte[] input, int width, int height, int[] output) { 88 | final int frameSize = width * height; 89 | for (int j = 0, yp = 0; j < height; j++) { 90 | int uvp = frameSize + (j >> 1) * width; 91 | int u = 0; 92 | int v = 0; 93 | 94 | for (int i = 0; i < width; i++, yp++) { 95 | int y = 0xff & input[yp]; 96 | if ((i & 1) == 0) { 97 | v = 0xff & input[uvp++]; 98 | u = 0xff & input[uvp++]; 99 | } 100 | 101 | output[yp] = YUV2RGB(y, u, v); 102 | } 103 | } 104 | } 105 | 106 | private static int YUV2RGB(int y, int u, int v) { 107 | // Adjust and check YUV values 108 | y = (y - 16) < 0 ? 0 : (y - 16); 109 | u -= 128; 110 | v -= 128; 111 | 112 | // This is the floating point equivalent. We do the conversion in integer 113 | // because some Android devices do not have floating point in hardware. 114 | // nR = (int)(1.164 * nY + 2.018 * nU); 115 | // nG = (int)(1.164 * nY - 0.813 * nV - 0.391 * nU); 116 | // nB = (int)(1.164 * nY + 1.596 * nV); 117 | int y1192 = 1192 * y; 118 | int r = (y1192 + 1634 * v); 119 | int g = (y1192 - 833 * v - 400 * u); 120 | int b = (y1192 + 2066 * u); 121 | 122 | // Clipping RGB values to be inside boundaries [ 0 , kMaxChannelValue ] 123 | r = r > kMaxChannelValue ? kMaxChannelValue : (r < 0 ? 0 : r); 124 | g = g > kMaxChannelValue ? kMaxChannelValue : (g < 0 ? 0 : g); 125 | b = b > kMaxChannelValue ? kMaxChannelValue : (b < 0 ? 0 : b); 126 | 127 | return 0xff000000 | ((r << 6) & 0xff0000) | ((g >> 2) & 0xff00) | ((b >> 10) & 0xff); 128 | } 129 | 130 | public static void convertYUV420ToARGB8888( 131 | byte[] yData, 132 | byte[] uData, 133 | byte[] vData, 134 | int width, 135 | int height, 136 | int yRowStride, 137 | int uvRowStride, 138 | int uvPixelStride, 139 | int[] out) { 140 | int yp = 0; 141 | for (int j = 0; j < height; j++) { 142 | int pY = yRowStride * j; 143 | int pUV = uvRowStride * (j >> 1); 144 | 145 | for (int i = 0; i < width; i++) { 146 | int uv_offset = pUV + (i >> 1) * uvPixelStride; 147 | 148 | out[yp++] = YUV2RGB(0xff & yData[pY + i], 0xff & uData[uv_offset], 0xff & vData[uv_offset]); 149 | } 150 | } 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/java/org/tensorflow/lite/examples/classification/env/Logger.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.env; 17 | 18 | import android.util.Log; 19 | import java.util.HashSet; 20 | import java.util.Set; 21 | 22 | /** Wrapper for the platform log function, allows convenient message prefixing and log disabling. */ 23 | public final class Logger { 24 | private static final String DEFAULT_TAG = "tensorflow"; 25 | private static final int DEFAULT_MIN_LOG_LEVEL = Log.DEBUG; 26 | 27 | // Classes to be ignored when examining the stack trace 28 | private static final Set IGNORED_CLASS_NAMES; 29 | 30 | static { 31 | IGNORED_CLASS_NAMES = new HashSet(3); 32 | IGNORED_CLASS_NAMES.add("dalvik.system.VMStack"); 33 | IGNORED_CLASS_NAMES.add("java.lang.Thread"); 34 | IGNORED_CLASS_NAMES.add(Logger.class.getCanonicalName()); 35 | } 36 | 37 | private final String tag; 38 | private final String messagePrefix; 39 | private int minLogLevel = DEFAULT_MIN_LOG_LEVEL; 40 | 41 | /** 42 | * Creates a Logger using the class name as the message prefix. 43 | * 44 | * @param clazz the simple name of this class is used as the message prefix. 45 | */ 46 | public Logger(final Class clazz) { 47 | this(clazz.getSimpleName()); 48 | } 49 | 50 | /** 51 | * Creates a Logger using the specified message prefix. 52 | * 53 | * @param messagePrefix is prepended to the text of every message. 54 | */ 55 | public Logger(final String messagePrefix) { 56 | this(DEFAULT_TAG, messagePrefix); 57 | } 58 | 59 | /** 60 | * Creates a Logger with a custom tag and a custom message prefix. If the message prefix is set to 61 | * 62 | *
null
63 | * 64 | * , the caller's class name is used as the prefix. 65 | * 66 | * @param tag identifies the source of a log message. 67 | * @param messagePrefix prepended to every message if non-null. If null, the name of the caller is 68 | * being used 69 | */ 70 | public Logger(final String tag, final String messagePrefix) { 71 | this.tag = tag; 72 | final String prefix = messagePrefix == null ? getCallerSimpleName() : messagePrefix; 73 | this.messagePrefix = (prefix.length() > 0) ? prefix + ": " : prefix; 74 | } 75 | 76 | /** Creates a Logger using the caller's class name as the message prefix. */ 77 | public Logger() { 78 | this(DEFAULT_TAG, null); 79 | } 80 | 81 | /** Creates a Logger using the caller's class name as the message prefix. */ 82 | public Logger(final int minLogLevel) { 83 | this(DEFAULT_TAG, null); 84 | this.minLogLevel = minLogLevel; 85 | } 86 | 87 | /** 88 | * Return caller's simple name. 89 | * 90 | *

Android getStackTrace() returns an array that looks like this: stackTrace[0]: 91 | * dalvik.system.VMStack stackTrace[1]: java.lang.Thread stackTrace[2]: 92 | * com.google.android.apps.unveil.env.UnveilLogger stackTrace[3]: 93 | * com.google.android.apps.unveil.BaseApplication 94 | * 95 | *

This function returns the simple version of the first non-filtered name. 96 | * 97 | * @return caller's simple name 98 | */ 99 | private static String getCallerSimpleName() { 100 | // Get the current callstack so we can pull the class of the caller off of it. 101 | final StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace(); 102 | 103 | for (final StackTraceElement elem : stackTrace) { 104 | final String className = elem.getClassName(); 105 | if (!IGNORED_CLASS_NAMES.contains(className)) { 106 | // We're only interested in the simple name of the class, not the complete package. 107 | final String[] classParts = className.split("\\."); 108 | return classParts[classParts.length - 1]; 109 | } 110 | } 111 | 112 | return Logger.class.getSimpleName(); 113 | } 114 | 115 | public void setMinLogLevel(final int minLogLevel) { 116 | this.minLogLevel = minLogLevel; 117 | } 118 | 119 | public boolean isLoggable(final int logLevel) { 120 | return logLevel >= minLogLevel || Log.isLoggable(tag, logLevel); 121 | } 122 | 123 | private String toMessage(final String format, final Object... args) { 124 | return messagePrefix + (args.length > 0 ? String.format(format, args) : format); 125 | } 126 | 127 | public void v(final String format, final Object... args) { 128 | if (isLoggable(Log.VERBOSE)) { 129 | Log.v(tag, toMessage(format, args)); 130 | } 131 | } 132 | 133 | public void v(final Throwable t, final String format, final Object... args) { 134 | if (isLoggable(Log.VERBOSE)) { 135 | Log.v(tag, toMessage(format, args), t); 136 | } 137 | } 138 | 139 | public void d(final String format, final Object... args) { 140 | if (isLoggable(Log.DEBUG)) { 141 | Log.d(tag, toMessage(format, args)); 142 | } 143 | } 144 | 145 | public void d(final Throwable t, final String format, final Object... args) { 146 | if (isLoggable(Log.DEBUG)) { 147 | Log.d(tag, toMessage(format, args), t); 148 | } 149 | } 150 | 151 | public void i(final String format, final Object... args) { 152 | if (isLoggable(Log.INFO)) { 153 | Log.i(tag, toMessage(format, args)); 154 | } 155 | } 156 | 157 | public void i(final Throwable t, final String format, final Object... args) { 158 | if (isLoggable(Log.INFO)) { 159 | Log.i(tag, toMessage(format, args), t); 160 | } 161 | } 162 | 163 | public void w(final String format, final Object... args) { 164 | if (isLoggable(Log.WARN)) { 165 | Log.w(tag, toMessage(format, args)); 166 | } 167 | } 168 | 169 | public void w(final Throwable t, final String format, final Object... args) { 170 | if (isLoggable(Log.WARN)) { 171 | Log.w(tag, toMessage(format, args), t); 172 | } 173 | } 174 | 175 | public void e(final String format, final Object... args) { 176 | if (isLoggable(Log.ERROR)) { 177 | Log.e(tag, toMessage(format, args)); 178 | } 179 | } 180 | 181 | public void e(final Throwable t, final String format, final Object... args) { 182 | if (isLoggable(Log.ERROR)) { 183 | Log.e(tag, toMessage(format, args), t); 184 | } 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/drawable-hdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/drawable-hdpi/ic_launcher.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/drawable-mdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/drawable-mdpi/ic_launcher.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/drawable-v24/ic_launcher_foreground.xml: -------------------------------------------------------------------------------- 1 | 7 | 12 | 13 | 19 | 22 | 25 | 26 | 27 | 28 | 34 | 35 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/drawable-xxhdpi/icn_chevron_down.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/drawable-xxhdpi/icn_chevron_down.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/drawable-xxhdpi/icn_chevron_up.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/drawable-xxhdpi/icn_chevron_up.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/drawable-xxhdpi/tfl2_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/drawable-xxhdpi/tfl2_logo.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/drawable-xxhdpi/tfl2_logo_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/drawable-xxhdpi/tfl2_logo_dark.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/drawable/bottom_sheet_bg.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/drawable/ic_baseline_add.xml: -------------------------------------------------------------------------------- 1 | 6 | 9 | 10 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/drawable/ic_baseline_remove.xml: -------------------------------------------------------------------------------- 1 | 6 | 9 | 10 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/drawable/ic_launcher_background.xml: -------------------------------------------------------------------------------- 1 | 2 | 7 | 10 | 15 | 20 | 25 | 30 | 35 | 40 | 45 | 50 | 55 | 60 | 65 | 70 | 75 | 80 | 85 | 90 | 95 | 100 | 105 | 110 | 115 | 120 | 125 | 130 | 135 | 140 | 145 | 150 | 155 | 160 | 165 | 170 | 171 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/drawable/rectangle.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 7 | 12 | 13 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/layout/tfe_ic_activity_camera.xml: -------------------------------------------------------------------------------- 1 | 16 | 21 | 22 | 27 | 28 | 29 | 36 | 37 | 38 | 44 | 45 | 49 | 50 | 51 | 52 | 53 | 56 | 57 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/layout/tfe_ic_camera_connection_fragment.xml: -------------------------------------------------------------------------------- 1 | 16 | 19 | 20 | 25 | 26 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-hdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-hdpi/ic_launcher.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-hdpi/ic_launcher_foreground.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-hdpi/ic_launcher_foreground.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-hdpi/ic_launcher_round.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-mdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-mdpi/ic_launcher.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-mdpi/ic_launcher_foreground.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-mdpi/ic_launcher_foreground.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-mdpi/ic_launcher_round.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-xhdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-xhdpi/ic_launcher.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-xhdpi/ic_launcher_foreground.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-xhdpi/ic_launcher_foreground.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-xxhdpi/ic_launcher.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_foreground.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_foreground.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_foreground.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_foreground.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.png -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/values/colors.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ffa800 4 | #ff6f00 5 | #425066 6 | 7 | #66000000 8 | 9 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/values/dimens.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 15dp 4 | 8dp 5 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/values/strings.xml: -------------------------------------------------------------------------------- 1 | 2 | Midas 3 | This device doesn\'t support Camera2 API. 4 | GPU does not yet supported quantized models. 5 | Model: 6 | 7 | Float_EfficientNet 8 | 13 | 14 | 15 | Device: 16 | 17 | GPU 18 | CPU 19 | NNAPI 20 | 21 | 22 | -------------------------------------------------------------------------------- /mobile/android/app/src/main/res/values/styles.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /mobile/android/build.gradle: -------------------------------------------------------------------------------- 1 | // Top-level build file where you can add configuration options common to all sub-projects/modules. 2 | 3 | buildscript { 4 | 5 | repositories { 6 | google() 7 | jcenter() 8 | } 9 | dependencies { 10 | classpath 'com.android.tools.build:gradle:4.0.0' 11 | classpath 'de.undercouch:gradle-download-task:4.0.2' 12 | // NOTE: Do not place your application dependencies here; they belong 13 | // in the individual module build.gradle files 14 | } 15 | } 16 | 17 | allprojects { 18 | repositories { 19 | google() 20 | jcenter() 21 | } 22 | } 23 | 24 | task clean(type: Delete) { 25 | delete rootProject.buildDir 26 | } 27 | 28 | -------------------------------------------------------------------------------- /mobile/android/gradle.properties: -------------------------------------------------------------------------------- 1 | # Project-wide Gradle settings. 2 | # IDE (e.g. Android Studio) users: 3 | # Gradle settings configured through the IDE *will override* 4 | # any settings specified in this file. 5 | # For more details on how to configure your build environment visit 6 | # http://www.gradle.org/docs/current/userguide/build_environment.html 7 | # Specifies the JVM arguments used for the daemon process. 8 | # The setting is particularly useful for tweaking memory settings. 9 | org.gradle.jvmargs=-Xmx1536m 10 | # When configured, Gradle will run in incubating parallel mode. 11 | # This option should only be used with decoupled projects. More details, visit 12 | # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects 13 | # org.gradle.parallel=true 14 | android.useAndroidX=true 15 | android.enableJetifier=true 16 | -------------------------------------------------------------------------------- /mobile/android/gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/android/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /mobile/android/gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionBase=GRADLE_USER_HOME 2 | distributionPath=wrapper/dists 3 | distributionUrl=https\://services.gradle.org/distributions/gradle-6.1.1-bin.zip 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | -------------------------------------------------------------------------------- /mobile/android/gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | # 4 | # Copyright 2015 the original author or authors. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | ############################################################################## 20 | ## 21 | ## Gradle start up script for UN*X 22 | ## 23 | ############################################################################## 24 | 25 | # Attempt to set APP_HOME 26 | # Resolve links: $0 may be a link 27 | PRG="$0" 28 | # Need this for relative symlinks. 29 | while [ -h "$PRG" ] ; do 30 | ls=`ls -ld "$PRG"` 31 | link=`expr "$ls" : '.*-> \(.*\)$'` 32 | if expr "$link" : '/.*' > /dev/null; then 33 | PRG="$link" 34 | else 35 | PRG=`dirname "$PRG"`"/$link" 36 | fi 37 | done 38 | SAVED="`pwd`" 39 | cd "`dirname \"$PRG\"`/" >/dev/null 40 | APP_HOME="`pwd -P`" 41 | cd "$SAVED" >/dev/null 42 | 43 | APP_NAME="Gradle" 44 | APP_BASE_NAME=`basename "$0"` 45 | 46 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 47 | DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' 48 | 49 | # Use the maximum available, or set MAX_FD != -1 to use that value. 50 | MAX_FD="maximum" 51 | 52 | warn () { 53 | echo "$*" 54 | } 55 | 56 | die () { 57 | echo 58 | echo "$*" 59 | echo 60 | exit 1 61 | } 62 | 63 | # OS specific support (must be 'true' or 'false'). 64 | cygwin=false 65 | msys=false 66 | darwin=false 67 | nonstop=false 68 | case "`uname`" in 69 | CYGWIN* ) 70 | cygwin=true 71 | ;; 72 | Darwin* ) 73 | darwin=true 74 | ;; 75 | MINGW* ) 76 | msys=true 77 | ;; 78 | NONSTOP* ) 79 | nonstop=true 80 | ;; 81 | esac 82 | 83 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 84 | 85 | # Determine the Java command to use to start the JVM. 86 | if [ -n "$JAVA_HOME" ] ; then 87 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 88 | # IBM's JDK on AIX uses strange locations for the executables 89 | JAVACMD="$JAVA_HOME/jre/sh/java" 90 | else 91 | JAVACMD="$JAVA_HOME/bin/java" 92 | fi 93 | if [ ! -x "$JAVACMD" ] ; then 94 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 95 | 96 | Please set the JAVA_HOME variable in your environment to match the 97 | location of your Java installation." 98 | fi 99 | else 100 | JAVACMD="java" 101 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 102 | 103 | Please set the JAVA_HOME variable in your environment to match the 104 | location of your Java installation." 105 | fi 106 | 107 | # Increase the maximum file descriptors if we can. 108 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then 109 | MAX_FD_LIMIT=`ulimit -H -n` 110 | if [ $? -eq 0 ] ; then 111 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 112 | MAX_FD="$MAX_FD_LIMIT" 113 | fi 114 | ulimit -n $MAX_FD 115 | if [ $? -ne 0 ] ; then 116 | warn "Could not set maximum file descriptor limit: $MAX_FD" 117 | fi 118 | else 119 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 120 | fi 121 | fi 122 | 123 | # For Darwin, add options to specify how the application appears in the dock 124 | if $darwin; then 125 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 126 | fi 127 | 128 | # For Cygwin or MSYS, switch paths to Windows format before running java 129 | if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then 130 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 131 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 132 | JAVACMD=`cygpath --unix "$JAVACMD"` 133 | 134 | # We build the pattern for arguments to be converted via cygpath 135 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 136 | SEP="" 137 | for dir in $ROOTDIRSRAW ; do 138 | ROOTDIRS="$ROOTDIRS$SEP$dir" 139 | SEP="|" 140 | done 141 | OURCYGPATTERN="(^($ROOTDIRS))" 142 | # Add a user-defined pattern to the cygpath arguments 143 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 144 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 145 | fi 146 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 147 | i=0 148 | for arg in "$@" ; do 149 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 150 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 151 | 152 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 153 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 154 | else 155 | eval `echo args$i`="\"$arg\"" 156 | fi 157 | i=`expr $i + 1` 158 | done 159 | case $i in 160 | 0) set -- ;; 161 | 1) set -- "$args0" ;; 162 | 2) set -- "$args0" "$args1" ;; 163 | 3) set -- "$args0" "$args1" "$args2" ;; 164 | 4) set -- "$args0" "$args1" "$args2" "$args3" ;; 165 | 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 166 | 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 167 | 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 168 | 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 169 | 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 170 | esac 171 | fi 172 | 173 | # Escape application args 174 | save () { 175 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done 176 | echo " " 177 | } 178 | APP_ARGS=`save "$@"` 179 | 180 | # Collect all arguments for the java command, following the shell quoting and substitution rules 181 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" 182 | 183 | exec "$JAVACMD" "$@" 184 | -------------------------------------------------------------------------------- /mobile/android/gradlew.bat: -------------------------------------------------------------------------------- 1 | @rem 2 | @rem Copyright 2015 the original author or authors. 3 | @rem 4 | @rem Licensed under the Apache License, Version 2.0 (the "License"); 5 | @rem you may not use this file except in compliance with the License. 6 | @rem You may obtain a copy of the License at 7 | @rem 8 | @rem https://www.apache.org/licenses/LICENSE-2.0 9 | @rem 10 | @rem Unless required by applicable law or agreed to in writing, software 11 | @rem distributed under the License is distributed on an "AS IS" BASIS, 12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | @rem See the License for the specific language governing permissions and 14 | @rem limitations under the License. 15 | @rem 16 | 17 | @if "%DEBUG%" == "" @echo off 18 | @rem ########################################################################## 19 | @rem 20 | @rem Gradle startup script for Windows 21 | @rem 22 | @rem ########################################################################## 23 | 24 | @rem Set local scope for the variables with windows NT shell 25 | if "%OS%"=="Windows_NT" setlocal 26 | 27 | set DIRNAME=%~dp0 28 | if "%DIRNAME%" == "" set DIRNAME=. 29 | set APP_BASE_NAME=%~n0 30 | set APP_HOME=%DIRNAME% 31 | 32 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 33 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" 34 | 35 | @rem Find java.exe 36 | if defined JAVA_HOME goto findJavaFromJavaHome 37 | 38 | set JAVA_EXE=java.exe 39 | %JAVA_EXE% -version >NUL 2>&1 40 | if "%ERRORLEVEL%" == "0" goto init 41 | 42 | echo. 43 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 44 | echo. 45 | echo Please set the JAVA_HOME variable in your environment to match the 46 | echo location of your Java installation. 47 | 48 | goto fail 49 | 50 | :findJavaFromJavaHome 51 | set JAVA_HOME=%JAVA_HOME:"=% 52 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 53 | 54 | if exist "%JAVA_EXE%" goto init 55 | 56 | echo. 57 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 58 | echo. 59 | echo Please set the JAVA_HOME variable in your environment to match the 60 | echo location of your Java installation. 61 | 62 | goto fail 63 | 64 | :init 65 | @rem Get command-line arguments, handling Windows variants 66 | 67 | if not "%OS%" == "Windows_NT" goto win9xME_args 68 | 69 | :win9xME_args 70 | @rem Slurp the command line arguments. 71 | set CMD_LINE_ARGS= 72 | set _SKIP=2 73 | 74 | :win9xME_args_slurp 75 | if "x%~1" == "x" goto execute 76 | 77 | set CMD_LINE_ARGS=%* 78 | 79 | :execute 80 | @rem Setup the command line 81 | 82 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 83 | 84 | @rem Execute Gradle 85 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 86 | 87 | :end 88 | @rem End local scope for the variables with windows NT shell 89 | if "%ERRORLEVEL%"=="0" goto mainEnd 90 | 91 | :fail 92 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 93 | rem the _cmd.exe /c_ return code! 94 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 95 | exit /b 1 96 | 97 | :mainEnd 98 | if "%OS%"=="Windows_NT" endlocal 99 | 100 | :omega 101 | -------------------------------------------------------------------------------- /mobile/android/lib_support/build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: 'com.android.library' 2 | 3 | android { 4 | compileSdkVersion 28 5 | buildToolsVersion "28.0.0" 6 | 7 | defaultConfig { 8 | minSdkVersion 21 9 | targetSdkVersion 28 10 | versionCode 1 11 | versionName "1.0" 12 | 13 | testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" 14 | 15 | } 16 | 17 | buildTypes { 18 | release { 19 | minifyEnabled false 20 | proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' 21 | } 22 | } 23 | 24 | aaptOptions { 25 | noCompress "tflite" 26 | } 27 | 28 | lintOptions { 29 | checkReleaseBuilds false 30 | // Or, if you prefer, you can continue to check for errors in release builds, 31 | // but continue the build even when errors are found: 32 | abortOnError false 33 | } 34 | } 35 | 36 | dependencies { 37 | implementation fileTree(dir: 'libs', include: ['*.jar']) 38 | implementation project(":models") 39 | implementation 'androidx.appcompat:appcompat:1.1.0' 40 | 41 | // Build off of nightly TensorFlow Lite 42 | implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true } 43 | implementation('org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly') { changing = true } 44 | implementation('org.tensorflow:tensorflow-lite-support:0.0.0-nightly') { changing = true } 45 | // Use local TensorFlow library 46 | // implementation 'org.tensorflow:tensorflow-lite-local:0.0.0' 47 | } 48 | -------------------------------------------------------------------------------- /mobile/android/lib_support/proguard-rules.pro: -------------------------------------------------------------------------------- 1 | # Add project specific ProGuard rules here. 2 | # You can control the set of applied configuration files using the 3 | # proguardFiles setting in build.gradle. 4 | # 5 | # For more details, see 6 | # http://developer.android.com/guide/developing/tools/proguard.html 7 | 8 | # If your project uses WebView with JS, uncomment the following 9 | # and specify the fully qualified class name to the JavaScript interface 10 | # class: 11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview { 12 | # public *; 13 | #} 14 | 15 | # Uncomment this to preserve the line number information for 16 | # debugging stack traces. 17 | #-keepattributes SourceFile,LineNumberTable 18 | 19 | # If you keep the line number information, uncomment this to 20 | # hide the original source file name. 21 | #-renamesourcefileattribute SourceFile 22 | -------------------------------------------------------------------------------- /mobile/android/lib_support/src/main/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 3 | 4 | -------------------------------------------------------------------------------- /mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.tflite; 17 | 18 | import android.app.Activity; 19 | import java.io.IOException; 20 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; 21 | import org.tensorflow.lite.support.common.TensorOperator; 22 | import org.tensorflow.lite.support.common.ops.NormalizeOp; 23 | 24 | /** This TensorFlowLite classifier works with the float EfficientNet model. */ 25 | public class ClassifierFloatEfficientNet extends Classifier { 26 | 27 | private static final float IMAGE_MEAN = 115.0f; //127.0f; 28 | private static final float IMAGE_STD = 58.0f; //128.0f; 29 | 30 | /** 31 | * Float model does not need dequantization in the post-processing. Setting mean and std as 0.0f 32 | * and 1.0f, repectively, to bypass the normalization. 33 | */ 34 | private static final float PROBABILITY_MEAN = 0.0f; 35 | 36 | private static final float PROBABILITY_STD = 1.0f; 37 | 38 | /** 39 | * Initializes a {@code ClassifierFloatMobileNet}. 40 | * 41 | * @param activity 42 | */ 43 | public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads) 44 | throws IOException { 45 | super(activity, device, numThreads); 46 | } 47 | 48 | @Override 49 | protected String getModelPath() { 50 | // you can download this file from 51 | // see build.gradle for where to obtain this file. It should be auto 52 | // downloaded into assets. 53 | //return "efficientnet-lite0-fp32.tflite"; 54 | return "model_opt.tflite"; 55 | } 56 | 57 | @Override 58 | protected String getLabelPath() { 59 | return "labels_without_background.txt"; 60 | } 61 | 62 | @Override 63 | protected TensorOperator getPreprocessNormalizeOp() { 64 | return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); 65 | } 66 | 67 | @Override 68 | protected TensorOperator getPostprocessNormalizeOp() { 69 | return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.tflite; 17 | 18 | import android.app.Activity; 19 | import java.io.IOException; 20 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; 21 | import org.tensorflow.lite.support.common.TensorOperator; 22 | import org.tensorflow.lite.support.common.ops.NormalizeOp; 23 | 24 | /** This TensorFlowLite classifier works with the float MobileNet model. */ 25 | public class ClassifierFloatMobileNet extends Classifier { 26 | 27 | /** Float MobileNet requires additional normalization of the used input. */ 28 | private static final float IMAGE_MEAN = 127.5f; 29 | 30 | private static final float IMAGE_STD = 127.5f; 31 | 32 | /** 33 | * Float model does not need dequantization in the post-processing. Setting mean and std as 0.0f 34 | * and 1.0f, repectively, to bypass the normalization. 35 | */ 36 | private static final float PROBABILITY_MEAN = 0.0f; 37 | 38 | private static final float PROBABILITY_STD = 1.0f; 39 | 40 | /** 41 | * Initializes a {@code ClassifierFloatMobileNet}. 42 | * 43 | * @param activity 44 | */ 45 | public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads) 46 | throws IOException { 47 | super(activity, device, numThreads); 48 | } 49 | 50 | @Override 51 | protected String getModelPath() { 52 | // you can download this file from 53 | // see build.gradle for where to obtain this file. It should be auto 54 | // downloaded into assets. 55 | return "model_0.tflite"; 56 | } 57 | 58 | @Override 59 | protected String getLabelPath() { 60 | return "labels.txt"; 61 | } 62 | 63 | @Override 64 | protected TensorOperator getPreprocessNormalizeOp() { 65 | return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); 66 | } 67 | 68 | @Override 69 | protected TensorOperator getPostprocessNormalizeOp() { 70 | return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.tflite; 17 | 18 | import android.app.Activity; 19 | import java.io.IOException; 20 | import org.tensorflow.lite.support.common.TensorOperator; 21 | import org.tensorflow.lite.support.common.ops.NormalizeOp; 22 | 23 | /** This TensorFlow Lite classifier works with the quantized EfficientNet model. */ 24 | public class ClassifierQuantizedEfficientNet extends Classifier { 25 | 26 | /** 27 | * The quantized model does not require normalization, thus set mean as 0.0f, and std as 1.0f to 28 | * bypass the normalization. 29 | */ 30 | private static final float IMAGE_MEAN = 0.0f; 31 | 32 | private static final float IMAGE_STD = 1.0f; 33 | 34 | /** Quantized MobileNet requires additional dequantization to the output probability. */ 35 | private static final float PROBABILITY_MEAN = 0.0f; 36 | 37 | private static final float PROBABILITY_STD = 255.0f; 38 | 39 | /** 40 | * Initializes a {@code ClassifierQuantizedMobileNet}. 41 | * 42 | * @param activity 43 | */ 44 | public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads) 45 | throws IOException { 46 | super(activity, device, numThreads); 47 | } 48 | 49 | @Override 50 | protected String getModelPath() { 51 | // you can download this file from 52 | // see build.gradle for where to obtain this file. It should be auto 53 | // downloaded into assets. 54 | return "model_quant.tflite"; 55 | } 56 | 57 | @Override 58 | protected String getLabelPath() { 59 | return "labels_without_background.txt"; 60 | } 61 | 62 | @Override 63 | protected TensorOperator getPreprocessNormalizeOp() { 64 | return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); 65 | } 66 | 67 | @Override 68 | protected TensorOperator getPostprocessNormalizeOp() { 69 | return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /mobile/android/lib_support/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.tflite; 17 | 18 | import android.app.Activity; 19 | import java.io.IOException; 20 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; 21 | import org.tensorflow.lite.support.common.TensorOperator; 22 | import org.tensorflow.lite.support.common.ops.NormalizeOp; 23 | 24 | /** This TensorFlow Lite classifier works with the quantized MobileNet model. */ 25 | public class ClassifierQuantizedMobileNet extends Classifier { 26 | 27 | /** 28 | * The quantized model does not require normalization, thus set mean as 0.0f, and std as 1.0f to 29 | * bypass the normalization. 30 | */ 31 | private static final float IMAGE_MEAN = 0.0f; 32 | 33 | private static final float IMAGE_STD = 1.0f; 34 | 35 | /** Quantized MobileNet requires additional dequantization to the output probability. */ 36 | private static final float PROBABILITY_MEAN = 0.0f; 37 | 38 | private static final float PROBABILITY_STD = 255.0f; 39 | 40 | /** 41 | * Initializes a {@code ClassifierQuantizedMobileNet}. 42 | * 43 | * @param activity 44 | */ 45 | public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads) 46 | throws IOException { 47 | super(activity, device, numThreads); 48 | } 49 | 50 | @Override 51 | protected String getModelPath() { 52 | // you can download this file from 53 | // see build.gradle for where to obtain this file. It should be auto 54 | // downloaded into assets. 55 | return "model_quant_0.tflite"; 56 | } 57 | 58 | @Override 59 | protected String getLabelPath() { 60 | return "labels.txt"; 61 | } 62 | 63 | @Override 64 | protected TensorOperator getPreprocessNormalizeOp() { 65 | return new NormalizeOp(IMAGE_MEAN, IMAGE_STD); 66 | } 67 | 68 | @Override 69 | protected TensorOperator getPostprocessNormalizeOp() { 70 | return new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD); 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /mobile/android/lib_task_api/build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: 'com.android.library' 2 | 3 | android { 4 | compileSdkVersion 28 5 | buildToolsVersion "28.0.0" 6 | 7 | defaultConfig { 8 | minSdkVersion 21 9 | targetSdkVersion 28 10 | versionCode 1 11 | versionName "1.0" 12 | 13 | testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" 14 | 15 | } 16 | 17 | buildTypes { 18 | release { 19 | minifyEnabled false 20 | proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' 21 | } 22 | } 23 | compileOptions { 24 | sourceCompatibility = '1.8' 25 | targetCompatibility = '1.8' 26 | } 27 | aaptOptions { 28 | noCompress "tflite" 29 | } 30 | 31 | lintOptions { 32 | checkReleaseBuilds false 33 | // Or, if you prefer, you can continue to check for errors in release builds, 34 | // but continue the build even when errors are found: 35 | abortOnError false 36 | } 37 | } 38 | 39 | dependencies { 40 | implementation fileTree(dir: 'libs', include: ['*.jar']) 41 | implementation project(":models") 42 | implementation 'androidx.appcompat:appcompat:1.1.0' 43 | 44 | // Build off of nightly TensorFlow Lite Task Library 45 | implementation('org.tensorflow:tensorflow-lite-task-vision:0.0.0-nightly') { changing = true } 46 | implementation('org.tensorflow:tensorflow-lite-metadata:0.0.0-nightly') { changing = true } 47 | } 48 | -------------------------------------------------------------------------------- /mobile/android/lib_task_api/proguard-rules.pro: -------------------------------------------------------------------------------- 1 | # Add project specific ProGuard rules here. 2 | # You can control the set of applied configuration files using the 3 | # proguardFiles setting in build.gradle. 4 | # 5 | # For more details, see 6 | # http://developer.android.com/guide/developing/tools/proguard.html 7 | 8 | # If your project uses WebView with JS, uncomment the following 9 | # and specify the fully qualified class name to the JavaScript interface 10 | # class: 11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview { 12 | # public *; 13 | #} 14 | 15 | # Uncomment this to preserve the line number information for 16 | # debugging stack traces. 17 | #-keepattributes SourceFile,LineNumberTable 18 | 19 | # If you keep the line number information, uncomment this to 20 | # hide the original source file name. 21 | #-renamesourcefileattribute SourceFile 22 | -------------------------------------------------------------------------------- /mobile/android/lib_task_api/src/main/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 3 | 4 | -------------------------------------------------------------------------------- /mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatEfficientNet.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.tflite; 17 | 18 | import android.app.Activity; 19 | import java.io.IOException; 20 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; 21 | 22 | /** This TensorFlowLite classifier works with the float EfficientNet model. */ 23 | public class ClassifierFloatEfficientNet extends Classifier { 24 | 25 | /** 26 | * Initializes a {@code ClassifierFloatMobileNet}. 27 | * 28 | * @param device a {@link Device} object to configure the hardware accelerator 29 | * @param numThreads the number of threads during the inference 30 | * @throws IOException if the model is not loaded correctly 31 | */ 32 | public ClassifierFloatEfficientNet(Activity activity, Device device, int numThreads) 33 | throws IOException { 34 | super(activity, device, numThreads); 35 | } 36 | 37 | @Override 38 | protected String getModelPath() { 39 | // you can download this file from 40 | // see build.gradle for where to obtain this file. It should be auto 41 | // downloaded into assets. 42 | //return "efficientnet-lite0-fp32.tflite"; 43 | return "model.tflite"; 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierFloatMobileNet.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.tflite; 17 | 18 | import android.app.Activity; 19 | import java.io.IOException; 20 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; 21 | 22 | /** This TensorFlowLite classifier works with the float MobileNet model. */ 23 | public class ClassifierFloatMobileNet extends Classifier { 24 | /** 25 | * Initializes a {@code ClassifierFloatMobileNet}. 26 | * 27 | * @param device a {@link Device} object to configure the hardware accelerator 28 | * @param numThreads the number of threads during the inference 29 | * @throws IOException if the model is not loaded correctly 30 | */ 31 | public ClassifierFloatMobileNet(Activity activity, Device device, int numThreads) 32 | throws IOException { 33 | super(activity, device, numThreads); 34 | } 35 | 36 | @Override 37 | protected String getModelPath() { 38 | // you can download this file from 39 | // see build.gradle for where to obtain this file. It should be auto 40 | // downloaded into assets. 41 | return "mobilenet_v1_1.0_224.tflite"; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedEfficientNet.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.tflite; 17 | 18 | import android.app.Activity; 19 | import java.io.IOException; 20 | 21 | /** This TensorFlow Lite classifier works with the quantized EfficientNet model. */ 22 | public class ClassifierQuantizedEfficientNet extends Classifier { 23 | 24 | /** 25 | * Initializes a {@code ClassifierQuantizedMobileNet}. 26 | * 27 | * @param device a {@link Device} object to configure the hardware accelerator 28 | * @param numThreads the number of threads during the inference 29 | * @throws IOException if the model is not loaded correctly 30 | */ 31 | public ClassifierQuantizedEfficientNet(Activity activity, Device device, int numThreads) 32 | throws IOException { 33 | super(activity, device, numThreads); 34 | } 35 | 36 | @Override 37 | protected String getModelPath() { 38 | // you can download this file from 39 | // see build.gradle for where to obtain this file. It should be auto 40 | // downloaded into assets. 41 | return "efficientnet-lite0-int8.tflite"; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /mobile/android/lib_task_api/src/main/java/org/tensorflow/lite/examples/classification/tflite/ClassifierQuantizedMobileNet.java: -------------------------------------------------------------------------------- 1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | package org.tensorflow.lite.examples.classification.tflite; 17 | 18 | import android.app.Activity; 19 | import java.io.IOException; 20 | import org.tensorflow.lite.examples.classification.tflite.Classifier.Device; 21 | 22 | /** This TensorFlow Lite classifier works with the quantized MobileNet model. */ 23 | public class ClassifierQuantizedMobileNet extends Classifier { 24 | 25 | /** 26 | * Initializes a {@code ClassifierQuantizedMobileNet}. 27 | * 28 | * @param device a {@link Device} object to configure the hardware accelerator 29 | * @param numThreads the number of threads during the inference 30 | * @throws IOException if the model is not loaded correctly 31 | */ 32 | public ClassifierQuantizedMobileNet(Activity activity, Device device, int numThreads) 33 | throws IOException { 34 | super(activity, device, numThreads); 35 | } 36 | 37 | @Override 38 | protected String getModelPath() { 39 | // you can download this file from 40 | // see build.gradle for where to obtain this file. It should be auto 41 | // downloaded into assets. 42 | return "mobilenet_v1_1.0_224_quant.tflite"; 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /mobile/android/models/build.gradle: -------------------------------------------------------------------------------- 1 | apply plugin: 'com.android.library' 2 | apply plugin: 'de.undercouch.download' 3 | 4 | android { 5 | compileSdkVersion 28 6 | buildToolsVersion "28.0.0" 7 | 8 | defaultConfig { 9 | minSdkVersion 21 10 | targetSdkVersion 28 11 | versionCode 1 12 | versionName "1.0" 13 | 14 | testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" 15 | 16 | } 17 | 18 | buildTypes { 19 | release { 20 | minifyEnabled false 21 | proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' 22 | } 23 | } 24 | 25 | aaptOptions { 26 | noCompress "tflite" 27 | } 28 | 29 | lintOptions { 30 | checkReleaseBuilds false 31 | // Or, if you prefer, you can continue to check for errors in release builds, 32 | // but continue the build even when errors are found: 33 | abortOnError false 34 | } 35 | } 36 | 37 | // Download default models; if you wish to use your own models then 38 | // place them in the "assets" directory and comment out this line. 39 | project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' 40 | apply from:'download.gradle' 41 | -------------------------------------------------------------------------------- /mobile/android/models/download.gradle: -------------------------------------------------------------------------------- 1 | def modelFloatDownloadUrl = "https://github.com/isl-org/MiDaS/releases/download/v2_1/model_opt.tflite" 2 | def modelFloatFile = "model_opt.tflite" 3 | 4 | task downloadModelFloat(type: Download) { 5 | src "${modelFloatDownloadUrl}" 6 | dest project.ext.ASSET_DIR + "/${modelFloatFile}" 7 | overwrite false 8 | } 9 | 10 | preBuild.dependsOn downloadModelFloat 11 | -------------------------------------------------------------------------------- /mobile/android/models/proguard-rules.pro: -------------------------------------------------------------------------------- 1 | # Add project specific ProGuard rules here. 2 | # You can control the set of applied configuration files using the 3 | # proguardFiles setting in build.gradle. 4 | # 5 | # For more details, see 6 | # http://developer.android.com/guide/developing/tools/proguard.html 7 | 8 | # If your project uses WebView with JS, uncomment the following 9 | # and specify the fully qualified class name to the JavaScript interface 10 | # class: 11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview { 12 | # public *; 13 | #} 14 | 15 | # Uncomment this to preserve the line number information for 16 | # debugging stack traces. 17 | #-keepattributes SourceFile,LineNumberTable 18 | 19 | # If you keep the line number information, uncomment this to 20 | # hide the original source file name. 21 | #-renamesourcefileattribute SourceFile 22 | -------------------------------------------------------------------------------- /mobile/android/models/src/main/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 3 | 4 | -------------------------------------------------------------------------------- /mobile/android/models/src/main/assets/run_tflite.py: -------------------------------------------------------------------------------- 1 | # Flex ops are included in the nightly build of the TensorFlow Python package. You can use TFLite models containing Flex ops by the same Python API as normal TFLite models. The nightly TensorFlow build can be installed with this command: 2 | # Flex ops will be added to the TensorFlow Python package's and the tflite_runtime package from version 2.3 for Linux and 2.4 for other environments. 3 | # https://www.tensorflow.org/lite/guide/ops_select#running_the_model 4 | 5 | # You must use: tf-nightly 6 | # pip install tf-nightly 7 | 8 | import os 9 | import glob 10 | import cv2 11 | import numpy as np 12 | 13 | import tensorflow as tf 14 | 15 | width=256 16 | height=256 17 | model_name="model.tflite" 18 | #model_name="model_quant.tflite" 19 | image_name="dog.jpg" 20 | 21 | # input 22 | img = cv2.imread(image_name) 23 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 24 | 25 | mean=[0.485, 0.456, 0.406] 26 | std=[0.229, 0.224, 0.225] 27 | img = (img - mean) / std 28 | 29 | img_resized = tf.image.resize(img, [width,height], method='bicubic', preserve_aspect_ratio=False) 30 | #img_resized = tf.transpose(img_resized, [2, 0, 1]) 31 | img_input = img_resized.numpy() 32 | reshape_img = img_input.reshape(1,width,height,3) 33 | tensor = tf.convert_to_tensor(reshape_img, dtype=tf.float32) 34 | 35 | # load model 36 | print("Load model...") 37 | interpreter = tf.lite.Interpreter(model_path=model_name) 38 | print("Allocate tensor...") 39 | interpreter.allocate_tensors() 40 | print("Get input/output details...") 41 | input_details = interpreter.get_input_details() 42 | output_details = interpreter.get_output_details() 43 | print("Get input shape...") 44 | input_shape = input_details[0]['shape'] 45 | print(input_shape) 46 | print(input_details) 47 | print(output_details) 48 | #input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) 49 | print("Set input tensor...") 50 | interpreter.set_tensor(input_details[0]['index'], tensor) 51 | 52 | print("invoke()...") 53 | interpreter.invoke() 54 | 55 | # The function `get_tensor()` returns a copy of the tensor data. 56 | # Use `tensor()` in order to get a pointer to the tensor. 57 | print("get output tensor...") 58 | output = interpreter.get_tensor(output_details[0]['index']) 59 | #output = np.squeeze(output) 60 | output = output.reshape(width, height) 61 | #print(output) 62 | prediction = np.array(output) 63 | print("reshape prediction...") 64 | prediction = prediction.reshape(width, height) 65 | 66 | # output file 67 | #prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) 68 | print(" Write image to: output.png") 69 | depth_min = prediction.min() 70 | depth_max = prediction.max() 71 | img_out = (255 * (prediction - depth_min) / (depth_max - depth_min)).astype("uint8") 72 | print("save output image...") 73 | cv2.imwrite("output.png", img_out) 74 | 75 | print("finished") -------------------------------------------------------------------------------- /mobile/android/settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'TFLite Image Classification Demo App' 2 | include ':app', ':lib_support', ':lib_task_api', ':models' -------------------------------------------------------------------------------- /mobile/ios/.gitignore: -------------------------------------------------------------------------------- 1 | # ignore model file 2 | #*.tflite 3 | -------------------------------------------------------------------------------- /mobile/ios/Midas.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /mobile/ios/Midas.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | IDEDidComputeMac32BitWarning 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas.xcodeproj/project.xcworkspace/xcuserdata/admin.xcuserdatad/UserInterfaceState.xcuserstate -------------------------------------------------------------------------------- /mobile/ios/Midas.xcodeproj/xcuserdata/admin.xcuserdatad/xcschemes/xcschememanagement.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | SchemeUserState 6 | 7 | PoseNet.xcscheme_^#shared#^_ 8 | 9 | orderHint 10 | 3 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /mobile/ios/Midas/AppDelegate.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | import UIKit 16 | 17 | @UIApplicationMain 18 | class AppDelegate: UIResponder, UIApplicationDelegate { 19 | 20 | var window: UIWindow? 21 | 22 | func application(_ application: UIApplication, didFinishLaunchingWithOptions launchOptions: [UIApplication.LaunchOptionsKey: Any]?) -> Bool { 23 | return true 24 | } 25 | 26 | func applicationWillResignActive(_ application: UIApplication) { 27 | } 28 | 29 | func applicationDidEnterBackground(_ application: UIApplication) { 30 | } 31 | 32 | func applicationWillEnterForeground(_ application: UIApplication) { 33 | } 34 | 35 | func applicationDidBecomeActive(_ application: UIApplication) { 36 | } 37 | 38 | func applicationWillTerminate(_ application: UIApplication) { 39 | } 40 | } 41 | 42 | -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/100.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/1024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/1024.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/114.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/114.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/120.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/144.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/144.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/152.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/152.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/167.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/167.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/180.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/180.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/20.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/29.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/40.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/50.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/57.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/57.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/58.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/58.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/60.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/60.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/72.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/72.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/76.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/76.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/80.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/80.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/87.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/87.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | {"images":[{"size":"60x60","expected-size":"180","filename":"180.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"40x40","expected-size":"80","filename":"80.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"40x40","expected-size":"120","filename":"120.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"60x60","expected-size":"120","filename":"120.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"57x57","expected-size":"57","filename":"57.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"1x"},{"size":"29x29","expected-size":"58","filename":"58.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"29x29","expected-size":"29","filename":"29.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"1x"},{"size":"29x29","expected-size":"87","filename":"87.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"57x57","expected-size":"114","filename":"114.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"20x20","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"2x"},{"size":"20x20","expected-size":"60","filename":"60.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"iphone","scale":"3x"},{"size":"1024x1024","filename":"1024.png","expected-size":"1024","idiom":"ios-marketing","folder":"Assets.xcassets/AppIcon.appiconset/","scale":"1x"},{"size":"40x40","expected-size":"80","filename":"80.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"72x72","expected-size":"72","filename":"72.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"76x76","expected-size":"152","filename":"152.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"50x50","expected-size":"100","filename":"100.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"29x29","expected-size":"58","filename":"58.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"76x76","expected-size":"76","filename":"76.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"29x29","expected-size":"29","filename":"29.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"50x50","expected-size":"50","filename":"50.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"72x72","expected-size":"144","filename":"144.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"40x40","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"83.5x83.5","expected-size":"167","filename":"167.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"},{"size":"20x20","expected-size":"20","filename":"20.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"1x"},{"size":"20x20","expected-size":"40","filename":"40.png","folder":"Assets.xcassets/AppIcon.appiconset/","idiom":"ipad","scale":"2x"}]} -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "version" : 1, 4 | "author" : "xcode" 5 | } 6 | } -------------------------------------------------------------------------------- /mobile/ios/Midas/Assets.xcassets/tfl_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/mobile/ios/Midas/Assets.xcassets/tfl_logo.png -------------------------------------------------------------------------------- /mobile/ios/Midas/Camera Feed/PreviewView.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | import UIKit 16 | import AVFoundation 17 | 18 | /// The camera frame is displayed on this view. 19 | class PreviewView: UIView { 20 | var previewLayer: AVCaptureVideoPreviewLayer { 21 | guard let layer = layer as? AVCaptureVideoPreviewLayer else { 22 | fatalError("Layer expected is of type VideoPreviewLayer") 23 | } 24 | return layer 25 | } 26 | 27 | var session: AVCaptureSession? { 28 | get { 29 | return previewLayer.session 30 | } 31 | set { 32 | previewLayer.session = newValue 33 | } 34 | } 35 | 36 | override class var layerClass: AnyClass { 37 | return AVCaptureVideoPreviewLayer.self 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /mobile/ios/Midas/Cells/InfoCell.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | import UIKit 16 | 17 | /// Table cell for inference result in bottom view. 18 | class InfoCell: UITableViewCell { 19 | @IBOutlet weak var fieldNameLabel: UILabel! 20 | @IBOutlet weak var infoLabel: UILabel! 21 | } 22 | -------------------------------------------------------------------------------- /mobile/ios/Midas/Constants.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // ============================================================================= 15 | 16 | enum Constants { 17 | // MARK: - Constants related to the image processing 18 | static let bgraPixel = (channels: 4, alphaComponent: 3, lastBgrComponent: 2) 19 | static let rgbPixelChannels = 3 20 | static let maxRGBValue: Float32 = 255.0 21 | 22 | // MARK: - Constants related to the model interperter 23 | static let defaultThreadCount = 2 24 | static let defaultDelegate: Delegates = .CPU 25 | } 26 | -------------------------------------------------------------------------------- /mobile/ios/Midas/Extensions/CGSizeExtension.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // ============================================================================= 15 | 16 | import Accelerate 17 | import Foundation 18 | 19 | extension CGSize { 20 | /// Returns `CGAfineTransform` to resize `self` to fit in destination size, keeping aspect ratio 21 | /// of `self`. `self` image is resized to be inscribe to destination size and located in center of 22 | /// destination. 23 | /// 24 | /// - Parameter toFitIn: destination size to be filled. 25 | /// - Returns: `CGAffineTransform` to transform `self` image to `dest` image. 26 | func transformKeepAspect(toFitIn dest: CGSize) -> CGAffineTransform { 27 | let sourceRatio = self.height / self.width 28 | let destRatio = dest.height / dest.width 29 | 30 | // Calculates ratio `self` to `dest`. 31 | var ratio: CGFloat 32 | var x: CGFloat = 0 33 | var y: CGFloat = 0 34 | if sourceRatio > destRatio { 35 | // Source size is taller than destination. Resized to fit in destination height, and find 36 | // horizontal starting point to be centered. 37 | ratio = dest.height / self.height 38 | x = (dest.width - self.width * ratio) / 2 39 | } else { 40 | ratio = dest.width / self.width 41 | y = (dest.height - self.height * ratio) / 2 42 | } 43 | return CGAffineTransform(a: ratio, b: 0, c: 0, d: ratio, tx: x, ty: y) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /mobile/ios/Midas/Extensions/TFLiteExtension.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | // ============================================================================= 15 | 16 | import Accelerate 17 | import CoreImage 18 | import Foundation 19 | import TensorFlowLite 20 | 21 | // MARK: - Data 22 | extension Data { 23 | /// Creates a new buffer by copying the buffer pointer of the given array. 24 | /// 25 | /// - Warning: The given array's element type `T` must be trivial in that it can be copied bit 26 | /// for bit with no indirection or reference-counting operations; otherwise, reinterpreting 27 | /// data from the resulting buffer has undefined behavior. 28 | /// - Parameter array: An array with elements of type `T`. 29 | init(copyingBufferOf array: [T]) { 30 | self = array.withUnsafeBufferPointer(Data.init) 31 | } 32 | 33 | /// Convert a Data instance to Array representation. 34 | func toArray(type: T.Type) -> [T] where T: AdditiveArithmetic { 35 | var array = [T](repeating: T.zero, count: self.count / MemoryLayout.stride) 36 | _ = array.withUnsafeMutableBytes { self.copyBytes(to: $0) } 37 | return array 38 | } 39 | } 40 | 41 | // MARK: - Wrappers 42 | /// Struct for handling multidimension `Data` in flat `Array`. 43 | struct FlatArray { 44 | private var array: [Element] 45 | var dimensions: [Int] 46 | 47 | init(tensor: Tensor) { 48 | dimensions = tensor.shape.dimensions 49 | array = tensor.data.toArray(type: Element.self) 50 | } 51 | 52 | private func flatIndex(_ index: [Int]) -> Int { 53 | guard index.count == dimensions.count else { 54 | fatalError("Invalid index: got \(index.count) index(es) for \(dimensions.count) index(es).") 55 | } 56 | 57 | var result = 0 58 | for i in 0.. index[i] else { 60 | fatalError("Invalid index: \(index[i]) is bigger than \(dimensions[i])") 61 | } 62 | result = dimensions[i] * result + index[i] 63 | } 64 | return result 65 | } 66 | 67 | subscript(_ index: Int...) -> Element { 68 | get { 69 | return array[flatIndex(index)] 70 | } 71 | set(newValue) { 72 | array[flatIndex(index)] = newValue 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /mobile/ios/Midas/Info.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | CFBundleDevelopmentRegion 6 | $(DEVELOPMENT_LANGUAGE) 7 | CFBundleExecutable 8 | $(EXECUTABLE_NAME) 9 | CFBundleIdentifier 10 | $(PRODUCT_BUNDLE_IDENTIFIER) 11 | CFBundleInfoDictionaryVersion 12 | 6.0 13 | CFBundleName 14 | $(PRODUCT_NAME) 15 | CFBundlePackageType 16 | APPL 17 | CFBundleShortVersionString 18 | 1.0 19 | CFBundleVersion 20 | 1 21 | LSRequiresIPhoneOS 22 | 23 | NSCameraUsageDescription 24 | This app will use camera to continuously estimate the depth map. 25 | UILaunchStoryboardName 26 | LaunchScreen 27 | UIMainStoryboardFile 28 | Main 29 | UIRequiredDeviceCapabilities 30 | 31 | armv7 32 | 33 | UISupportedInterfaceOrientations 34 | 35 | UIInterfaceOrientationPortrait 36 | 37 | UISupportedInterfaceOrientations~ipad 38 | 39 | UIInterfaceOrientationPortrait 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /mobile/ios/Midas/Storyboards/Base.lproj/Launch Screen.storyboard: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 24 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /mobile/ios/Midas/Views/OverlayView.swift: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | import UIKit 16 | 17 | /// UIView for rendering inference output. 18 | class OverlayView: UIView { 19 | 20 | var dots = [CGPoint]() 21 | var lines = [Line]() 22 | 23 | override func draw(_ rect: CGRect) { 24 | for dot in dots { 25 | drawDot(of: dot) 26 | } 27 | for line in lines { 28 | drawLine(of: line) 29 | } 30 | } 31 | 32 | func drawDot(of dot: CGPoint) { 33 | let dotRect = CGRect( 34 | x: dot.x - Traits.dot.radius / 2, y: dot.y - Traits.dot.radius / 2, 35 | width: Traits.dot.radius, height: Traits.dot.radius) 36 | let dotPath = UIBezierPath(ovalIn: dotRect) 37 | 38 | Traits.dot.color.setFill() 39 | dotPath.fill() 40 | } 41 | 42 | func drawLine(of line: Line) { 43 | let linePath = UIBezierPath() 44 | linePath.move(to: CGPoint(x: line.from.x, y: line.from.y)) 45 | linePath.addLine(to: CGPoint(x: line.to.x, y: line.to.y)) 46 | linePath.close() 47 | 48 | linePath.lineWidth = Traits.line.width 49 | Traits.line.color.setStroke() 50 | 51 | linePath.stroke() 52 | } 53 | 54 | func clear() { 55 | self.dots = [] 56 | self.lines = [] 57 | } 58 | } 59 | 60 | private enum Traits { 61 | static let dot = (radius: CGFloat(5), color: UIColor.orange) 62 | static let line = (width: CGFloat(1.0), color: UIColor.orange) 63 | } 64 | -------------------------------------------------------------------------------- /mobile/ios/Podfile: -------------------------------------------------------------------------------- 1 | # Uncomment the next line to define a global platform for your project 2 | platform :ios, '12.0' 3 | 4 | target 'Midas' do 5 | # Comment the next line if you're not using Swift and don't want to use dynamic frameworks 6 | use_frameworks! 7 | 8 | # Pods for Midas 9 | pod 'TensorFlowLiteSwift', '~> 0.0.1-nightly' 10 | pod 'TensorFlowLiteSwift/CoreML', '~> 0.0.1-nightly' 11 | pod 'TensorFlowLiteSwift/Metal', '~> 0.0.1-nightly' 12 | end 13 | -------------------------------------------------------------------------------- /mobile/ios/README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow Lite MiDaS iOS Example 2 | 3 | ### Requirements 4 | 5 | - XCode 11.0 or above 6 | - iOS 12.0 or above, [iOS 14 breaks the NPU Delegate](https://github.com/tensorflow/tensorflow/issues/43339) 7 | - TensorFlow 2.4.0, TensorFlowLiteSwift -> 0.0.1-nightly 8 | 9 | ## Quick Start with a MiDaS Example 10 | 11 | MiDaS is a neural network to compute depth from a single image. It uses TensorFlowLiteSwift / C++ libraries on iOS. The code is written in Swift. 12 | 13 | Paper: https://arxiv.org/abs/1907.01341 14 | 15 | > Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer 16 | > René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun 17 | 18 | ### Install TensorFlow 19 | 20 | Set default python version to python3: 21 | 22 | ``` 23 | echo 'export PATH=/usr/local/opt/python/libexec/bin:$PATH' >> ~/.zshenv 24 | echo 'alias python=python3' >> ~/.zshenv 25 | echo 'alias pip=pip3' >> ~/.zshenv 26 | ``` 27 | 28 | Install TensorFlow 29 | 30 | ```shell 31 | pip install tensorflow 32 | ``` 33 | 34 | ### Install TensorFlowLiteSwift via Cocoapods 35 | 36 | Set required TensorFlowLiteSwift version in the file (`0.0.1-nightly` is recommended): https://github.com/isl-org/MiDaS/blob/master/mobile/ios/Podfile#L9 37 | 38 | Install: brew, ruby, cocoapods 39 | 40 | ``` 41 | ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" 42 | brew install mc rbenv ruby-build 43 | sudo gem install cocoapods 44 | ``` 45 | 46 | 47 | The TensorFlowLiteSwift library is available in [Cocoapods](https://cocoapods.org/), to integrate it to our project, we can run in the root directory of the project: 48 | 49 | ```ruby 50 | pod install 51 | ``` 52 | 53 | Now open the `Midas.xcworkspace` file in XCode, select your iPhone device (XCode->Product->Destination->iPhone) and launch it (cmd + R). If everything works well, you should see a real-time depth map from your camera. 54 | 55 | ### Model 56 | 57 | The TensorFlow (TFlite) model `midas.tflite` is in the folder `/Midas/Model` 58 | 59 | 60 | To use another model, you should convert it from TensorFlow saved-model to TFlite model (so that it can be deployed): 61 | 62 | ```python 63 | saved_model_export_dir = "./saved_model" 64 | converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_export_dir) 65 | tflite_model = converter.convert() 66 | open(model_tflite_name, "wb").write("model.tflite") 67 | ``` 68 | 69 | ### Setup XCode 70 | 71 | * Open directory `.xcworkspace` from the XCode 72 | 73 | * Press on your ProjectName (left-top corner) -> change Bundle Identifier to `com.midas.tflite-npu` or something like this (it should be unique) 74 | 75 | * select your Developer Team (your should be signed-in by using your AppleID) 76 | 77 | * Connect your iPhone (if you want to run it on real device instead of simulator), select your iPhone device (XCode->Product->Destination->iPhone) 78 | 79 | * Click in the XCode: Product -> Run 80 | 81 | * On your iPhone device go to the: Settings -> General -> Device Management (or Profiles) -> Apple Development -> Trust Apple Development 82 | 83 | ---- 84 | 85 | Original repository: https://github.com/isl-org/MiDaS 86 | 87 | 88 | ### Examples: 89 | 90 | | ![photo_2020-09-27_17-43-20](https://user-images.githubusercontent.com/4096485/94367804-9610de80-00e9-11eb-8a23-8b32a6f52d41.jpg) | ![photo_2020-09-27_17-49-22](https://user-images.githubusercontent.com/4096485/94367974-7201cd00-00ea-11eb-8e0a-68eb9ea10f63.jpg) | ![photo_2020-09-27_17-52-30](https://user-images.githubusercontent.com/4096485/94367976-729a6380-00ea-11eb-8ce0-39d3e26dd550.jpg) | ![photo_2020-09-27_17-43-21](https://user-images.githubusercontent.com/4096485/94367807-97420b80-00e9-11eb-9dcd-848ad9e89e03.jpg) | 91 | |---|---|---|---| 92 | 93 | ## LICENSE 94 | 95 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 96 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 97 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 98 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 99 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 100 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 101 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 102 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 103 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 104 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 105 | POSSIBILITY OF SUCH DAMAGE. 106 | -------------------------------------------------------------------------------- /mobile/ios/RunScripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Download TF Lite model from the internet if it does not exist. 3 | 4 | TFLITE_MODEL="model_opt.tflite" 5 | TFLITE_FILE="Midas/Model/${TFLITE_MODEL}" 6 | MODEL_SRC="https://github.com/isl-org/MiDaS/releases/download/v2/${TFLITE_MODEL}" 7 | 8 | if test -f "${TFLITE_FILE}"; then 9 | echo "INFO: TF Lite model already exists. Skip downloading and use the local model." 10 | else 11 | curl --create-dirs -o "${TFLITE_FILE}" -LJO "${MODEL_SRC}" 12 | echo "INFO: Downloaded TensorFlow Lite model to ${TFLITE_FILE}." 13 | fi 14 | 15 | -------------------------------------------------------------------------------- /output/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/output/.placeholder -------------------------------------------------------------------------------- /ros/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Alexey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ros/README.md: -------------------------------------------------------------------------------- 1 | # MiDaS for ROS1 by using LibTorch in C++ 2 | 3 | ### Requirements 4 | 5 | - Ubuntu 17.10 / 18.04 / 20.04, Debian Stretch 6 | - ROS Melodic for Ubuntu (17.10 / 18.04) / Debian Stretch, ROS Noetic for Ubuntu 20.04 7 | - C++11 8 | - LibTorch >= 1.6 9 | 10 | ## Quick Start with a MiDaS Example 11 | 12 | MiDaS is a neural network to compute depth from a single image. 13 | 14 | * input from `image_topic`: `sensor_msgs/Image` - `RGB8` image with any shape 15 | * output to `midas_topic`: `sensor_msgs/Image` - `TYPE_32FC1` inverse relative depth maps in range [0 - 255] with original size and channels=1 16 | 17 | ### Install Dependecies 18 | 19 | * install ROS Melodic for Ubuntu 17.10 / 18.04: 20 | ```bash 21 | wget https://raw.githubusercontent.com/isl-org/MiDaS/master/ros/additions/install_ros_melodic_ubuntu_17_18.sh 22 | ./install_ros_melodic_ubuntu_17_18.sh 23 | ``` 24 | 25 | or Noetic for Ubuntu 20.04: 26 | 27 | ```bash 28 | wget https://raw.githubusercontent.com/isl-org/MiDaS/master/ros/additions/install_ros_noetic_ubuntu_20.sh 29 | ./install_ros_noetic_ubuntu_20.sh 30 | ``` 31 | 32 | 33 | * install LibTorch 1.7 with CUDA 11.0: 34 | 35 | On **Jetson (ARM)**: 36 | ```bash 37 | wget https://nvidia.box.com/shared/static/wa34qwrwtk9njtyarwt5nvo6imenfy26.whl -O torch-1.7.0-cp36-cp36m-linux_aarch64.whl 38 | sudo apt-get install python3-pip libopenblas-base libopenmpi-dev 39 | pip3 install Cython 40 | pip3 install numpy torch-1.7.0-cp36-cp36m-linux_aarch64.whl 41 | ``` 42 | Or compile LibTorch from source: https://github.com/pytorch/pytorch#from-source 43 | 44 | On **Linux (x86_64)**: 45 | ```bash 46 | cd ~/ 47 | wget https://download.pytorch.org/libtorch/cu110/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu110.zip 48 | unzip libtorch-cxx11-abi-shared-with-deps-1.7.0+cu110.zip 49 | ``` 50 | 51 | * create symlink for OpenCV: 52 | 53 | ```bash 54 | sudo ln -s /usr/include/opencv4 /usr/include/opencv 55 | ``` 56 | 57 | * download and install MiDaS: 58 | 59 | ```bash 60 | source ~/.bashrc 61 | cd ~/ 62 | mkdir catkin_ws 63 | cd catkin_ws 64 | git clone https://github.com/isl-org/MiDaS 65 | mkdir src 66 | cp -r MiDaS/ros/* src 67 | 68 | chmod +x src/additions/*.sh 69 | chmod +x src/*.sh 70 | chmod +x src/midas_cpp/scripts/*.py 71 | cp src/additions/do_catkin_make.sh ./do_catkin_make.sh 72 | ./do_catkin_make.sh 73 | ./src/additions/downloads.sh 74 | ``` 75 | 76 | ### Usage 77 | 78 | * run only `midas` node: `~/catkin_ws/src/launch_midas_cpp.sh` 79 | 80 | #### Test 81 | 82 | * Test - capture video and show result in the window: 83 | * place any `test.mp4` video file to the directory `~/catkin_ws/src/` 84 | * run `midas` node: `~/catkin_ws/src/launch_midas_cpp.sh` 85 | * run test nodes in another terminal: `cd ~/catkin_ws/src && ./run_talker_listener_test.sh` and wait 30 seconds 86 | 87 | (to use Python 2, run command `sed -i 's/python3/python2/' ~/catkin_ws/src/midas_cpp/scripts/*.py` ) 88 | 89 | ## Mobile version of MiDaS - Monocular Depth Estimation 90 | 91 | ### Accuracy 92 | 93 | * MiDaS v2 small - ResNet50 default-decoder 384x384 94 | * MiDaS v2.1 small - EfficientNet-Lite3 small-decoder 256x256 95 | 96 | **Zero-shot error** (the lower - the better): 97 | 98 | | Model | DIW WHDR | Eth3d AbsRel | Sintel AbsRel | Kitti δ>1.25 | NyuDepthV2 δ>1.25 | TUM δ>1.25 | 99 | |---|---|---|---|---|---|---| 100 | | MiDaS v2 small 384x384 | **0.1248** | 0.1550 | **0.3300** | **21.81** | 15.73 | 17.00 | 101 | | MiDaS v2.1 small 256x256 | 0.1344 | **0.1344** | 0.3370 | 29.27 | **13.43** | **14.53** | 102 | | Relative improvement, % | -8 % | **+13 %** | -2 % | -34 % | **+15 %** | **+15 %** | 103 | 104 | None of Train/Valid/Test subsets of datasets (DIW, Eth3d, Sintel, Kitti, NyuDepthV2, TUM) were not involved in Training or Fine Tuning. 105 | 106 | ### Inference speed (FPS) on nVidia GPU 107 | 108 | Inference speed excluding pre and post processing, batch=1, **Frames Per Second** (the higher - the better): 109 | 110 | | Model | Jetson Nano, FPS | RTX 2080Ti, FPS | 111 | |---|---|---| 112 | | MiDaS v2 small 384x384 | 1.6 | 117 | 113 | | MiDaS v2.1 small 256x256 | 8.1 | 232 | 114 | | SpeedUp, X times | **5x** | **2x** | 115 | 116 | ### Citation 117 | 118 | This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3): 119 | 120 | >Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer 121 | René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun 122 | 123 | Please cite our paper if you use this code or any of the models: 124 | ``` 125 | @article{Ranftl2020, 126 | author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, 127 | title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, 128 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, 129 | year = {2020}, 130 | } 131 | ``` 132 | -------------------------------------------------------------------------------- /ros/additions/do_catkin_make.sh: -------------------------------------------------------------------------------- 1 | mkdir src 2 | catkin_make 3 | source devel/setup.bash 4 | echo $ROS_PACKAGE_PATH 5 | chmod +x ./devel/setup.bash 6 | -------------------------------------------------------------------------------- /ros/additions/downloads.sh: -------------------------------------------------------------------------------- 1 | mkdir ~/.ros 2 | wget https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small-traced.pt 3 | cp ./model-small-traced.pt ~/.ros/model-small-traced.pt 4 | 5 | 6 | -------------------------------------------------------------------------------- /ros/additions/install_ros_melodic_ubuntu_17_18.sh: -------------------------------------------------------------------------------- 1 | #@title { display-mode: "code" } 2 | 3 | #from http://wiki.ros.org/indigo/Installation/Ubuntu 4 | 5 | #1.2 Setup sources.list 6 | sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' 7 | 8 | # 1.3 Setup keys 9 | sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 10 | sudo apt-key adv --keyserver 'hkp://ha.pool.sks-keyservers.net:80' --recv-key 421C365BD9FF1F717815A3895523BAEEB01FA116 11 | 12 | curl -sSL 'http://keyserver.ubuntu.com/pks/lookup?op=get&search=0xC1CF6E31E6BADE8868B172B4F42ED6FBAB17C654' | sudo apt-key add - 13 | 14 | # 1.4 Installation 15 | sudo apt-get update 16 | sudo apt-get upgrade 17 | 18 | # Desktop-Full Install: 19 | sudo apt-get install ros-melodic-desktop-full 20 | 21 | printf "\nsource /opt/ros/melodic/setup.bash\n" >> ~/.bashrc 22 | 23 | # 1.5 Initialize rosdep 24 | sudo rosdep init 25 | rosdep update 26 | 27 | 28 | # 1.7 Getting rosinstall (python) 29 | sudo apt-get install python-rosinstall 30 | sudo apt-get install python-catkin-tools 31 | sudo apt-get install python-rospy 32 | sudo apt-get install python-rosdep 33 | sudo apt-get install python-roscd 34 | sudo apt-get install python-pip -------------------------------------------------------------------------------- /ros/additions/install_ros_noetic_ubuntu_20.sh: -------------------------------------------------------------------------------- 1 | #@title { display-mode: "code" } 2 | 3 | #from http://wiki.ros.org/indigo/Installation/Ubuntu 4 | 5 | #1.2 Setup sources.list 6 | sudo sh -c 'echo "deb http://packages.ros.org/ros/ubuntu $(lsb_release -sc) main" > /etc/apt/sources.list.d/ros-latest.list' 7 | 8 | # 1.3 Setup keys 9 | sudo apt-key adv --keyserver 'hkp://keyserver.ubuntu.com:80' --recv-key C1CF6E31E6BADE8868B172B4F42ED6FBAB17C654 10 | 11 | curl -sSL 'http://keyserver.ubuntu.com/pks/lookup?op=get&search=0xC1CF6E31E6BADE8868B172B4F42ED6FBAB17C654' | sudo apt-key add - 12 | 13 | # 1.4 Installation 14 | sudo apt-get update 15 | sudo apt-get upgrade 16 | 17 | # Desktop-Full Install: 18 | sudo apt-get install ros-noetic-desktop-full 19 | 20 | printf "\nsource /opt/ros/noetic/setup.bash\n" >> ~/.bashrc 21 | 22 | # 1.5 Initialize rosdep 23 | sudo rosdep init 24 | rosdep update 25 | 26 | 27 | # 1.7 Getting rosinstall (python) 28 | sudo apt-get install python3-rosinstall 29 | sudo apt-get install python3-catkin-tools 30 | sudo apt-get install python3-rospy 31 | sudo apt-get install python3-rosdep 32 | sudo apt-get install python3-roscd 33 | sudo apt-get install python3-pip -------------------------------------------------------------------------------- /ros/additions/make_package_cpp.sh: -------------------------------------------------------------------------------- 1 | cd ~/catkin_ws/src 2 | catkin_create_pkg midas_cpp std_msgs roscpp cv_bridge sensor_msgs image_transport 3 | cd ~/catkin_ws 4 | catkin_make 5 | 6 | chmod +x ~/catkin_ws/devel/setup.bash 7 | printf "\nsource ~/catkin_ws/devel/setup.bash" >> ~/.bashrc 8 | source ~/catkin_ws/devel/setup.bash 9 | 10 | 11 | sudo rosdep init 12 | rosdep update 13 | #rospack depends1 midas_cpp 14 | roscd midas_cpp 15 | #cat package.xml 16 | #rospack depends midas_cpp -------------------------------------------------------------------------------- /ros/launch_midas_cpp.sh: -------------------------------------------------------------------------------- 1 | source ~/catkin_ws/devel/setup.bash 2 | roslaunch midas_cpp midas_cpp.launch model_name:="model-small-traced.pt" input_topic:="image_topic" output_topic:="midas_topic" out_orig_size:="true" -------------------------------------------------------------------------------- /ros/midas_cpp/launch/midas_cpp.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /ros/midas_cpp/launch/midas_talker_listener.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /ros/midas_cpp/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | midas_cpp 4 | 0.1.0 5 | The midas_cpp package 6 | 7 | Alexey Bochkovskiy 8 | MIT 9 | https://github.com/isl-org/MiDaS/tree/master/ros 10 | 11 | 12 | 13 | 14 | 15 | 16 | TODO 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | catkin 52 | cv_bridge 53 | image_transport 54 | roscpp 55 | rospy 56 | sensor_msgs 57 | std_msgs 58 | cv_bridge 59 | image_transport 60 | roscpp 61 | rospy 62 | sensor_msgs 63 | std_msgs 64 | cv_bridge 65 | image_transport 66 | roscpp 67 | rospy 68 | sensor_msgs 69 | std_msgs 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /ros/midas_cpp/scripts/listener.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from __future__ import print_function 3 | 4 | import roslib 5 | #roslib.load_manifest('my_package') 6 | import sys 7 | import rospy 8 | import cv2 9 | import numpy as np 10 | from std_msgs.msg import String 11 | from sensor_msgs.msg import Image 12 | from cv_bridge import CvBridge, CvBridgeError 13 | 14 | class video_show: 15 | 16 | def __init__(self): 17 | self.show_output = rospy.get_param('~show_output', True) 18 | self.save_output = rospy.get_param('~save_output', False) 19 | self.output_video_file = rospy.get_param('~output_video_file','result.mp4') 20 | # rospy.loginfo(f"Listener - params: show_output={self.show_output}, save_output={self.save_output}, output_video_file={self.output_video_file}") 21 | 22 | self.bridge = CvBridge() 23 | self.image_sub = rospy.Subscriber("midas_topic", Image, self.callback) 24 | 25 | def callback(self, data): 26 | try: 27 | cv_image = self.bridge.imgmsg_to_cv2(data) 28 | except CvBridgeError as e: 29 | print(e) 30 | return 31 | 32 | if cv_image.size == 0: 33 | return 34 | 35 | rospy.loginfo("Listener: Received new frame") 36 | cv_image = cv_image.astype("uint8") 37 | 38 | if self.show_output==True: 39 | cv2.imshow("video_show", cv_image) 40 | cv2.waitKey(10) 41 | 42 | if self.save_output==True: 43 | if self.video_writer_init==False: 44 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 45 | self.out = cv2.VideoWriter(self.output_video_file, fourcc, 25, (cv_image.shape[1], cv_image.shape[0])) 46 | 47 | self.out.write(cv_image) 48 | 49 | 50 | 51 | def main(args): 52 | rospy.init_node('listener', anonymous=True) 53 | ic = video_show() 54 | try: 55 | rospy.spin() 56 | except KeyboardInterrupt: 57 | print("Shutting down") 58 | cv2.destroyAllWindows() 59 | 60 | if __name__ == '__main__': 61 | main(sys.argv) -------------------------------------------------------------------------------- /ros/midas_cpp/scripts/listener_original.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from __future__ import print_function 3 | 4 | import roslib 5 | #roslib.load_manifest('my_package') 6 | import sys 7 | import rospy 8 | import cv2 9 | import numpy as np 10 | from std_msgs.msg import String 11 | from sensor_msgs.msg import Image 12 | from cv_bridge import CvBridge, CvBridgeError 13 | 14 | class video_show: 15 | 16 | def __init__(self): 17 | self.show_output = rospy.get_param('~show_output', True) 18 | self.save_output = rospy.get_param('~save_output', False) 19 | self.output_video_file = rospy.get_param('~output_video_file','result.mp4') 20 | # rospy.loginfo(f"Listener original - params: show_output={self.show_output}, save_output={self.save_output}, output_video_file={self.output_video_file}") 21 | 22 | self.bridge = CvBridge() 23 | self.image_sub = rospy.Subscriber("image_topic", Image, self.callback) 24 | 25 | def callback(self, data): 26 | try: 27 | cv_image = self.bridge.imgmsg_to_cv2(data) 28 | except CvBridgeError as e: 29 | print(e) 30 | return 31 | 32 | if cv_image.size == 0: 33 | return 34 | 35 | rospy.loginfo("Listener_original: Received new frame") 36 | cv_image = cv_image.astype("uint8") 37 | 38 | if self.show_output==True: 39 | cv2.imshow("video_show_orig", cv_image) 40 | cv2.waitKey(10) 41 | 42 | if self.save_output==True: 43 | if self.video_writer_init==False: 44 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 45 | self.out = cv2.VideoWriter(self.output_video_file, fourcc, 25, (cv_image.shape[1], cv_image.shape[0])) 46 | 47 | self.out.write(cv_image) 48 | 49 | 50 | 51 | def main(args): 52 | rospy.init_node('listener_original', anonymous=True) 53 | ic = video_show() 54 | try: 55 | rospy.spin() 56 | except KeyboardInterrupt: 57 | print("Shutting down") 58 | cv2.destroyAllWindows() 59 | 60 | if __name__ == '__main__': 61 | main(sys.argv) -------------------------------------------------------------------------------- /ros/midas_cpp/scripts/talker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import roslib 5 | #roslib.load_manifest('my_package') 6 | import sys 7 | import rospy 8 | import cv2 9 | from std_msgs.msg import String 10 | from sensor_msgs.msg import Image 11 | from cv_bridge import CvBridge, CvBridgeError 12 | 13 | 14 | def talker(): 15 | rospy.init_node('talker', anonymous=True) 16 | 17 | use_camera = rospy.get_param('~use_camera', False) 18 | input_video_file = rospy.get_param('~input_video_file','test.mp4') 19 | # rospy.loginfo(f"Talker - params: use_camera={use_camera}, input_video_file={input_video_file}") 20 | 21 | # rospy.loginfo("Talker: Trying to open a video stream") 22 | if use_camera == True: 23 | cap = cv2.VideoCapture(0) 24 | else: 25 | cap = cv2.VideoCapture(input_video_file) 26 | 27 | pub = rospy.Publisher('image_topic', Image, queue_size=1) 28 | rate = rospy.Rate(30) # 30hz 29 | bridge = CvBridge() 30 | 31 | while not rospy.is_shutdown(): 32 | ret, cv_image = cap.read() 33 | if ret==False: 34 | print("Talker: Video is over") 35 | rospy.loginfo("Video is over") 36 | return 37 | 38 | try: 39 | image = bridge.cv2_to_imgmsg(cv_image, "bgr8") 40 | except CvBridgeError as e: 41 | rospy.logerr("Talker: cv2image conversion failed: ", e) 42 | print(e) 43 | continue 44 | 45 | rospy.loginfo("Talker: Publishing frame") 46 | pub.publish(image) 47 | rate.sleep() 48 | 49 | if __name__ == '__main__': 50 | try: 51 | talker() 52 | except rospy.ROSInterruptException: 53 | pass 54 | -------------------------------------------------------------------------------- /ros/run_talker_listener_test.sh: -------------------------------------------------------------------------------- 1 | # place any test.mp4 file near with this file 2 | 3 | # roscore 4 | # rosnode kill -a 5 | 6 | source ~/catkin_ws/devel/setup.bash 7 | 8 | roscore & 9 | P1=$! 10 | rosrun midas_cpp talker.py & 11 | P2=$! 12 | rosrun midas_cpp listener_original.py & 13 | P3=$! 14 | rosrun midas_cpp listener.py & 15 | P4=$! 16 | wait $P1 $P2 $P3 $P4 -------------------------------------------------------------------------------- /tf/README.md: -------------------------------------------------------------------------------- 1 | ## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer 2 | 3 | ### TensorFlow inference using `.pb` and `.onnx` models 4 | 5 | 1. [Run inference on TensorFlow-model by using TensorFlow](#run-inference-on-tensorflow-model-by-using-tensorFlow) 6 | 7 | 2. [Run inference on ONNX-model by using TensorFlow](#run-inference-on-onnx-model-by-using-tensorflow) 8 | 9 | 3. [Make ONNX model from downloaded Pytorch model file](#make-onnx-model-from-downloaded-pytorch-model-file) 10 | 11 | 12 | ### Run inference on TensorFlow-model by using TensorFlow 13 | 14 | 1) Download the model weights [model-f6b98070.pb](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.pb) 15 | and [model-small.pb](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small.pb) and place the 16 | file in the `/tf/` folder. 17 | 18 | 2) Set up dependencies: 19 | 20 | ```shell 21 | # install OpenCV 22 | pip install --upgrade pip 23 | pip install opencv-python 24 | 25 | # install TensorFlow 26 | pip install -I grpcio tensorflow==2.3.0 tensorflow-addons==0.11.2 numpy==1.18.0 27 | ``` 28 | 29 | #### Usage 30 | 31 | 1) Place one or more input images in the folder `tf/input`. 32 | 33 | 2) Run the model: 34 | 35 | ```shell 36 | python tf/run_pb.py 37 | ``` 38 | 39 | Or run the small model: 40 | 41 | ```shell 42 | python tf/run_pb.py --model_weights model-small.pb --model_type small 43 | ``` 44 | 45 | 3) The resulting inverse depth maps are written to the `tf/output` folder. 46 | 47 | 48 | ### Run inference on ONNX-model by using ONNX-Runtime 49 | 50 | 1) Download the model weights [model-f6b98070.onnx](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.onnx) 51 | and [model-small.onnx](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-small.onnx) and place the 52 | file in the `/tf/` folder. 53 | 54 | 2) Set up dependencies: 55 | 56 | ```shell 57 | # install OpenCV 58 | pip install --upgrade pip 59 | pip install opencv-python 60 | 61 | # install ONNX 62 | pip install onnx==1.7.0 63 | 64 | # install ONNX Runtime 65 | pip install onnxruntime==1.5.2 66 | ``` 67 | 68 | #### Usage 69 | 70 | 1) Place one or more input images in the folder `tf/input`. 71 | 72 | 2) Run the model: 73 | 74 | ```shell 75 | python tf/run_onnx.py 76 | ``` 77 | 78 | Or run the small model: 79 | 80 | ```shell 81 | python tf/run_onnx.py --model_weights model-small.onnx --model_type small 82 | ``` 83 | 84 | 3) The resulting inverse depth maps are written to the `tf/output` folder. 85 | 86 | 87 | 88 | ### Make ONNX model from downloaded Pytorch model file 89 | 90 | 1) Download the model weights [model-f6b98070.pt](https://github.com/isl-org/MiDaS/releases/download/v2_1/model-f6b98070.pt) and place the 91 | file in the root folder. 92 | 93 | 2) Set up dependencies: 94 | 95 | ```shell 96 | # install OpenCV 97 | pip install --upgrade pip 98 | pip install opencv-python 99 | 100 | # install PyTorch TorchVision 101 | pip install -I torch==1.7.0 torchvision==0.8.0 102 | 103 | # install TensorFlow 104 | pip install -I grpcio tensorflow==2.3.0 tensorflow-addons==0.11.2 numpy==1.18.0 105 | 106 | # install ONNX 107 | pip install onnx==1.7.0 108 | 109 | # install ONNX-TensorFlow 110 | git clone https://github.com/onnx/onnx-tensorflow.git 111 | cd onnx-tensorflow 112 | git checkout 095b51b88e35c4001d70f15f80f31014b592b81e 113 | pip install -e . 114 | ``` 115 | 116 | #### Usage 117 | 118 | 1) Run the converter: 119 | 120 | ```shell 121 | python tf/make_onnx_model.py 122 | ``` 123 | 124 | 2) The resulting `model-f6b98070.onnx` file is written to the `/tf/` folder. 125 | 126 | 127 | ### Requirements 128 | 129 | The code was tested with Python 3.6.9, PyTorch 1.5.1, TensorFlow 2.2.0, TensorFlow-addons 0.8.3, ONNX 1.7.0, ONNX-TensorFlow (GitHub-master-17.07.2020) and OpenCV 4.3.0. 130 | 131 | ### Citation 132 | 133 | Please cite our paper if you use this code or any of the models: 134 | ``` 135 | @article{Ranftl2019, 136 | author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun}, 137 | title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer}, 138 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, 139 | year = {2020}, 140 | } 141 | ``` 142 | 143 | ### License 144 | 145 | MIT License 146 | 147 | 148 | -------------------------------------------------------------------------------- /tf/input/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/tf/input/.placeholder -------------------------------------------------------------------------------- /tf/make_onnx_model.py: -------------------------------------------------------------------------------- 1 | """Compute depth maps for images in the input folder. 2 | """ 3 | import os 4 | import ntpath 5 | import glob 6 | import torch 7 | import utils 8 | import cv2 9 | import numpy as np 10 | from torchvision.transforms import Compose, Normalize 11 | from torchvision import transforms 12 | 13 | from shutil import copyfile 14 | import fileinput 15 | import sys 16 | sys.path.append(os.getcwd() + '/..') 17 | 18 | def modify_file(): 19 | modify_filename = '../midas/blocks.py' 20 | copyfile(modify_filename, modify_filename+'.bak') 21 | 22 | with open(modify_filename, 'r') as file : 23 | filedata = file.read() 24 | 25 | filedata = filedata.replace('align_corners=True', 'align_corners=False') 26 | filedata = filedata.replace('import torch.nn as nn', 'import torch.nn as nn\nimport torchvision.models as models') 27 | filedata = filedata.replace('torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")', 'models.resnext101_32x8d()') 28 | 29 | with open(modify_filename, 'w') as file: 30 | file.write(filedata) 31 | 32 | def restore_file(): 33 | modify_filename = '../midas/blocks.py' 34 | copyfile(modify_filename+'.bak', modify_filename) 35 | 36 | modify_file() 37 | 38 | from midas.midas_net import MidasNet 39 | from midas.transforms import Resize, NormalizeImage, PrepareForNet 40 | 41 | restore_file() 42 | 43 | 44 | class MidasNet_preprocessing(MidasNet): 45 | """Network for monocular depth estimation. 46 | """ 47 | def forward(self, x): 48 | """Forward pass. 49 | 50 | Args: 51 | x (tensor): input data (image) 52 | 53 | Returns: 54 | tensor: depth 55 | """ 56 | 57 | mean = torch.tensor([0.485, 0.456, 0.406]) 58 | std = torch.tensor([0.229, 0.224, 0.225]) 59 | x.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 60 | 61 | return MidasNet.forward(self, x) 62 | 63 | 64 | def run(model_path): 65 | """Run MonoDepthNN to compute depth maps. 66 | 67 | Args: 68 | model_path (str): path to saved model 69 | """ 70 | print("initialize") 71 | 72 | # select device 73 | 74 | # load network 75 | #model = MidasNet(model_path, non_negative=True) 76 | model = MidasNet_preprocessing(model_path, non_negative=True) 77 | 78 | model.eval() 79 | 80 | print("start processing") 81 | 82 | # input 83 | img_input = np.zeros((3, 384, 384), np.float32) 84 | 85 | # compute 86 | with torch.no_grad(): 87 | sample = torch.from_numpy(img_input).unsqueeze(0) 88 | prediction = model.forward(sample) 89 | prediction = ( 90 | torch.nn.functional.interpolate( 91 | prediction.unsqueeze(1), 92 | size=img_input.shape[:2], 93 | mode="bicubic", 94 | align_corners=False, 95 | ) 96 | .squeeze() 97 | .cpu() 98 | .numpy() 99 | ) 100 | 101 | torch.onnx.export(model, sample, ntpath.basename(model_path).rsplit('.', 1)[0]+'.onnx', opset_version=9) 102 | 103 | print("finished") 104 | 105 | 106 | if __name__ == "__main__": 107 | # set paths 108 | # MODEL_PATH = "model.pt" 109 | MODEL_PATH = "../model-f6b98070.pt" 110 | 111 | # compute depth maps 112 | run(MODEL_PATH) 113 | -------------------------------------------------------------------------------- /tf/output/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/tf/output/.placeholder -------------------------------------------------------------------------------- /tf/run_onnx.py: -------------------------------------------------------------------------------- 1 | """Compute depth maps for images in the input folder. 2 | """ 3 | import os 4 | import glob 5 | import utils 6 | import cv2 7 | import sys 8 | import numpy as np 9 | import argparse 10 | 11 | import onnx 12 | import onnxruntime as rt 13 | 14 | from transforms import Resize, NormalizeImage, PrepareForNet 15 | 16 | 17 | def run(input_path, output_path, model_path, model_type="large"): 18 | """Run MonoDepthNN to compute depth maps. 19 | 20 | Args: 21 | input_path (str): path to input folder 22 | output_path (str): path to output folder 23 | model_path (str): path to saved model 24 | """ 25 | print("initialize") 26 | 27 | # select device 28 | device = "CUDA:0" 29 | #device = "CPU" 30 | print("device: %s" % device) 31 | 32 | # network resolution 33 | if model_type == "large": 34 | net_w, net_h = 384, 384 35 | elif model_type == "small": 36 | net_w, net_h = 256, 256 37 | else: 38 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 39 | assert False 40 | 41 | # load network 42 | print("loading model...") 43 | model = rt.InferenceSession(model_path) 44 | input_name = model.get_inputs()[0].name 45 | output_name = model.get_outputs()[0].name 46 | 47 | resize_image = Resize( 48 | net_w, 49 | net_h, 50 | resize_target=None, 51 | keep_aspect_ratio=False, 52 | ensure_multiple_of=32, 53 | resize_method="upper_bound", 54 | image_interpolation_method=cv2.INTER_CUBIC, 55 | ) 56 | 57 | def compose2(f1, f2): 58 | return lambda x: f2(f1(x)) 59 | 60 | transform = compose2(resize_image, PrepareForNet()) 61 | 62 | # get input 63 | img_names = glob.glob(os.path.join(input_path, "*")) 64 | num_images = len(img_names) 65 | 66 | # create output folder 67 | os.makedirs(output_path, exist_ok=True) 68 | 69 | print("start processing") 70 | 71 | for ind, img_name in enumerate(img_names): 72 | 73 | print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) 74 | 75 | # input 76 | img = utils.read_image(img_name) 77 | img_input = transform({"image": img})["image"] 78 | 79 | # compute 80 | output = model.run([output_name], {input_name: img_input.reshape(1, 3, net_h, net_w).astype(np.float32)})[0] 81 | prediction = np.array(output).reshape(net_h, net_w) 82 | prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) 83 | 84 | # output 85 | filename = os.path.join( 86 | output_path, os.path.splitext(os.path.basename(img_name))[0] 87 | ) 88 | utils.write_depth(filename, prediction, bits=2) 89 | 90 | print("finished") 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | 96 | parser.add_argument('-i', '--input_path', 97 | default='input', 98 | help='folder with input images' 99 | ) 100 | 101 | parser.add_argument('-o', '--output_path', 102 | default='output', 103 | help='folder for output images' 104 | ) 105 | 106 | parser.add_argument('-m', '--model_weights', 107 | default='model-f6b98070.onnx', 108 | help='path to the trained weights of model' 109 | ) 110 | 111 | parser.add_argument('-t', '--model_type', 112 | default='large', 113 | help='model type: large or small' 114 | ) 115 | 116 | args = parser.parse_args() 117 | 118 | # compute depth maps 119 | run(args.input_path, args.output_path, args.model_weights, args.model_type) 120 | -------------------------------------------------------------------------------- /tf/run_pb.py: -------------------------------------------------------------------------------- 1 | """Compute depth maps for images in the input folder. 2 | """ 3 | import os 4 | import glob 5 | import utils 6 | import cv2 7 | import argparse 8 | 9 | import tensorflow as tf 10 | 11 | from transforms import Resize, NormalizeImage, PrepareForNet 12 | 13 | def run(input_path, output_path, model_path, model_type="large"): 14 | """Run MonoDepthNN to compute depth maps. 15 | 16 | Args: 17 | input_path (str): path to input folder 18 | output_path (str): path to output folder 19 | model_path (str): path to saved model 20 | """ 21 | print("initialize") 22 | 23 | # the runtime initialization will not allocate all memory on the device to avoid out of GPU memory 24 | gpus = tf.config.experimental.list_physical_devices('GPU') 25 | if gpus: 26 | try: 27 | for gpu in gpus: 28 | #tf.config.experimental.set_memory_growth(gpu, True) 29 | tf.config.experimental.set_virtual_device_configuration(gpu, 30 | [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4000)]) 31 | except RuntimeError as e: 32 | print(e) 33 | 34 | # network resolution 35 | if model_type == "large": 36 | net_w, net_h = 384, 384 37 | elif model_type == "small": 38 | net_w, net_h = 256, 256 39 | else: 40 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 41 | assert False 42 | 43 | # load network 44 | graph_def = tf.compat.v1.GraphDef() 45 | with tf.io.gfile.GFile(model_path, 'rb') as f: 46 | graph_def.ParseFromString(f.read()) 47 | tf.import_graph_def(graph_def, name='') 48 | 49 | 50 | model_operations = tf.compat.v1.get_default_graph().get_operations() 51 | input_node = '0:0' 52 | output_layer = model_operations[len(model_operations) - 1].name + ':0' 53 | print("Last layer name: ", output_layer) 54 | 55 | resize_image = Resize( 56 | net_w, 57 | net_h, 58 | resize_target=None, 59 | keep_aspect_ratio=False, 60 | ensure_multiple_of=32, 61 | resize_method="upper_bound", 62 | image_interpolation_method=cv2.INTER_CUBIC, 63 | ) 64 | 65 | def compose2(f1, f2): 66 | return lambda x: f2(f1(x)) 67 | 68 | transform = compose2(resize_image, PrepareForNet()) 69 | 70 | # get input 71 | img_names = glob.glob(os.path.join(input_path, "*")) 72 | num_images = len(img_names) 73 | 74 | # create output folder 75 | os.makedirs(output_path, exist_ok=True) 76 | 77 | print("start processing") 78 | 79 | with tf.compat.v1.Session() as sess: 80 | try: 81 | # load images 82 | for ind, img_name in enumerate(img_names): 83 | 84 | print(" processing {} ({}/{})".format(img_name, ind + 1, num_images)) 85 | 86 | # input 87 | img = utils.read_image(img_name) 88 | img_input = transform({"image": img})["image"] 89 | 90 | # compute 91 | prob_tensor = sess.graph.get_tensor_by_name(output_layer) 92 | prediction, = sess.run(prob_tensor, {input_node: [img_input] }) 93 | prediction = prediction.reshape(net_h, net_w) 94 | prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) 95 | 96 | # output 97 | filename = os.path.join( 98 | output_path, os.path.splitext(os.path.basename(img_name))[0] 99 | ) 100 | utils.write_depth(filename, prediction, bits=2) 101 | 102 | except KeyError: 103 | print ("Couldn't find input node: ' + input_node + ' or output layer: " + output_layer + ".") 104 | exit(-1) 105 | 106 | print("finished") 107 | 108 | 109 | if __name__ == "__main__": 110 | parser = argparse.ArgumentParser() 111 | 112 | parser.add_argument('-i', '--input_path', 113 | default='input', 114 | help='folder with input images' 115 | ) 116 | 117 | parser.add_argument('-o', '--output_path', 118 | default='output', 119 | help='folder for output images' 120 | ) 121 | 122 | parser.add_argument('-m', '--model_weights', 123 | default='model-f6b98070.pb', 124 | help='path to the trained weights of model' 125 | ) 126 | 127 | parser.add_argument('-t', '--model_type', 128 | default='large', 129 | help='model type: large or small' 130 | ) 131 | 132 | args = parser.parse_args() 133 | 134 | # compute depth maps 135 | run(args.input_path, args.output_path, args.model_weights, args.model_type) 136 | -------------------------------------------------------------------------------- /tf/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import cv2 4 | 5 | 6 | def write_pfm(path, image, scale=1): 7 | """Write pfm file. 8 | Args: 9 | path (str): pathto file 10 | image (array): data 11 | scale (int, optional): Scale. Defaults to 1. 12 | """ 13 | 14 | with open(path, "wb") as file: 15 | color = None 16 | 17 | if image.dtype.name != "float32": 18 | raise Exception("Image dtype must be float32.") 19 | 20 | image = np.flipud(image) 21 | 22 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 23 | color = True 24 | elif ( 25 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 26 | ): # greyscale 27 | color = False 28 | else: 29 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 30 | 31 | file.write("PF\n" if color else "Pf\n".encode()) 32 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 33 | 34 | endian = image.dtype.byteorder 35 | 36 | if endian == "<" or endian == "=" and sys.byteorder == "little": 37 | scale = -scale 38 | 39 | file.write("%f\n".encode() % scale) 40 | 41 | image.tofile(file) 42 | 43 | def read_image(path): 44 | """Read image and output RGB image (0-1). 45 | Args: 46 | path (str): path to file 47 | Returns: 48 | array: RGB image (0-1) 49 | """ 50 | img = cv2.imread(path) 51 | 52 | if img.ndim == 2: 53 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 54 | 55 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 56 | 57 | return img 58 | 59 | def write_depth(path, depth, bits=1): 60 | """Write depth map to pfm and png file. 61 | Args: 62 | path (str): filepath without extension 63 | depth (array): depth 64 | """ 65 | write_pfm(path + ".pfm", depth.astype(np.float32)) 66 | 67 | depth_min = depth.min() 68 | depth_max = depth.max() 69 | 70 | max_val = (2**(8*bits))-1 71 | 72 | if depth_max - depth_min > np.finfo("float").eps: 73 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 74 | else: 75 | out = 0 76 | 77 | if bits == 1: 78 | cv2.imwrite(path + ".png", out.astype("uint8")) 79 | elif bits == 2: 80 | cv2.imwrite(path + ".png", out.astype("uint16")) 81 | 82 | return -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth. 2 | """ 3 | import sys 4 | import re 5 | import numpy as np 6 | import cv2 7 | import torch 8 | 9 | 10 | def read_pfm(path): 11 | """Read pfm file. 12 | 13 | Args: 14 | path (str): path to file 15 | 16 | Returns: 17 | tuple: (data, scale) 18 | """ 19 | with open(path, "rb") as file: 20 | 21 | color = None 22 | width = None 23 | height = None 24 | scale = None 25 | endian = None 26 | 27 | header = file.readline().rstrip() 28 | if header.decode("ascii") == "PF": 29 | color = True 30 | elif header.decode("ascii") == "Pf": 31 | color = False 32 | else: 33 | raise Exception("Not a PFM file: " + path) 34 | 35 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 36 | if dim_match: 37 | width, height = list(map(int, dim_match.groups())) 38 | else: 39 | raise Exception("Malformed PFM header.") 40 | 41 | scale = float(file.readline().decode("ascii").rstrip()) 42 | if scale < 0: 43 | # little-endian 44 | endian = "<" 45 | scale = -scale 46 | else: 47 | # big-endian 48 | endian = ">" 49 | 50 | data = np.fromfile(file, endian + "f") 51 | shape = (height, width, 3) if color else (height, width) 52 | 53 | data = np.reshape(data, shape) 54 | data = np.flipud(data) 55 | 56 | return data, scale 57 | 58 | 59 | def write_pfm(path, image, scale=1): 60 | """Write pfm file. 61 | 62 | Args: 63 | path (str): pathto file 64 | image (array): data 65 | scale (int, optional): Scale. Defaults to 1. 66 | """ 67 | 68 | with open(path, "wb") as file: 69 | color = None 70 | 71 | if image.dtype.name != "float32": 72 | raise Exception("Image dtype must be float32.") 73 | 74 | image = np.flipud(image) 75 | 76 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 77 | color = True 78 | elif ( 79 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 80 | ): # greyscale 81 | color = False 82 | else: 83 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 84 | 85 | file.write("PF\n" if color else "Pf\n".encode()) 86 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 87 | 88 | endian = image.dtype.byteorder 89 | 90 | if endian == "<" or endian == "=" and sys.byteorder == "little": 91 | scale = -scale 92 | 93 | file.write("%f\n".encode() % scale) 94 | 95 | image.tofile(file) 96 | 97 | 98 | def read_image(path): 99 | """Read image and output RGB image (0-1). 100 | 101 | Args: 102 | path (str): path to file 103 | 104 | Returns: 105 | array: RGB image (0-1) 106 | """ 107 | img = cv2.imread(path) 108 | 109 | if img.ndim == 2: 110 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 111 | 112 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 113 | 114 | return img 115 | 116 | 117 | def resize_image(img): 118 | """Resize image and make it fit for network. 119 | 120 | Args: 121 | img (array): image 122 | 123 | Returns: 124 | tensor: data ready for network 125 | """ 126 | height_orig = img.shape[0] 127 | width_orig = img.shape[1] 128 | 129 | if width_orig > height_orig: 130 | scale = width_orig / 384 131 | else: 132 | scale = height_orig / 384 133 | 134 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 135 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 136 | 137 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 138 | 139 | img_resized = ( 140 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 141 | ) 142 | img_resized = img_resized.unsqueeze(0) 143 | 144 | return img_resized 145 | 146 | 147 | def resize_depth(depth, width, height): 148 | """Resize depth map and bring to CPU (numpy). 149 | 150 | Args: 151 | depth (tensor): depth 152 | width (int): image width 153 | height (int): image height 154 | 155 | Returns: 156 | array: processed depth 157 | """ 158 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 159 | 160 | depth_resized = cv2.resize( 161 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 162 | ) 163 | 164 | return depth_resized 165 | 166 | def write_depth(path, depth, grayscale, bits=1): 167 | """Write depth map to png file. 168 | 169 | Args: 170 | path (str): filepath without extension 171 | depth (array): depth 172 | grayscale (bool): use a grayscale colormap? 173 | """ 174 | if not grayscale: 175 | bits = 1 176 | 177 | if not np.isfinite(depth).all(): 178 | depth=np.nan_to_num(depth, nan=0.0, posinf=0.0, neginf=0.0) 179 | print("WARNING: Non-finite depth values present") 180 | 181 | depth_min = depth.min() 182 | depth_max = depth.max() 183 | 184 | max_val = (2**(8*bits))-1 185 | 186 | if depth_max - depth_min > np.finfo("float").eps: 187 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 188 | else: 189 | out = np.zeros(depth.shape, dtype=depth.dtype) 190 | 191 | if not grayscale: 192 | out = cv2.applyColorMap(np.uint8(out), cv2.COLORMAP_INFERNO) 193 | 194 | if bits == 1: 195 | cv2.imwrite(path + ".png", out.astype("uint8")) 196 | elif bits == 2: 197 | cv2.imwrite(path + ".png", out.astype("uint16")) 198 | 199 | return 200 | -------------------------------------------------------------------------------- /weights/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/MiDaS/454597711a62eabcbf7d1e89f3fb9f569051ac9b/weights/.placeholder --------------------------------------------------------------------------------