├── __init__.py ├── model ├── __init__.py ├── predict_from_folder.py └── predict_from_file.py ├── dataset ├── __init__.py ├── utils.py ├── download_from_file.py ├── export_from_lobe.py └── download_from_flickr.py ├── app ├── components │ ├── __init__.py │ ├── stretch_wrapper.py │ ├── navbar.py │ ├── export.py │ ├── visualize.py │ ├── flickr.py │ ├── dataset.py │ └── model.py ├── assets │ └── icon.ico ├── __init__.py ├── app.spec └── app.py ├── requirements.txt ├── .github └── workflows │ ├── build-mac.yml │ └── build-windows.yml ├── LICENSE ├── .gitignore └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/assets/icon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lobe/image-tools/HEAD/app/assets/icon.ico -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.4.1 2 | xlrd==2.0.1 3 | openpyxl==3.0.9 4 | tqdm==4.62.3 5 | requests==2.27.1 6 | pyinstaller==4.9 7 | PyQT5==5.15.6 8 | pillow==9.0.1 9 | lobe[tf]==0.6.2 10 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | ASSETS_PATH = 'assets' if os.path.exists('assets') else os.path.join('app', 'assets') 5 | 6 | 7 | def resource_path(relative_path): 8 | try: 9 | base_path = sys._MEIPASS 10 | except Exception: 11 | base_path = ASSETS_PATH 12 | 13 | return os.path.join(base_path, relative_path) 14 | -------------------------------------------------------------------------------- /app/components/stretch_wrapper.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import (QHBoxLayout, QFrame) 2 | 3 | 4 | class NoStretch(QFrame): 5 | def __init__(self, widget): 6 | super().__init__() 7 | layout = QHBoxLayout() 8 | layout.setContentsMargins(0, 0, 0, 0) 9 | if isinstance(widget, list) or isinstance(widget, tuple): 10 | for w in widget: 11 | layout.addWidget(w) 12 | else: 13 | layout.addWidget(widget) 14 | layout.addStretch(1) 15 | self.setLayout(layout) 16 | -------------------------------------------------------------------------------- /.github/workflows/build-mac.yml: -------------------------------------------------------------------------------- 1 | name: PyInstaller Mac 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | 7 | jobs: 8 | build: 9 | 10 | runs-on: macOS-latest 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v2 15 | - name: Set up Python 3.9 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.9 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -r requirements.txt 23 | - name: Build with PyInstaller 24 | run: | 25 | pyinstaller --onefile --windowed app/app.spec 26 | - name: Upload app 27 | uses: actions/upload-artifact@v2 28 | with: 29 | name: Image Tools Mac 30 | path: dist/ 31 | -------------------------------------------------------------------------------- /.github/workflows/build-windows.yml: -------------------------------------------------------------------------------- 1 | name: PyInstaller Windows 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | 7 | jobs: 8 | build: 9 | 10 | runs-on: windows-latest 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v2 15 | - name: Set up Python 3.9 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.9 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -r requirements.txt 23 | - name: Build with PyInstaller 24 | run: | 25 | pyinstaller --onefile app/app.spec 26 | - name: Upload exe 27 | uses: actions/upload-artifact@v2 28 | with: 29 | name: Image Tools Windows 30 | path: dist/Image Tools.exe 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020-present Markus Beissinger 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /app/app.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python ; coding: utf-8 -*- 2 | import sys 3 | import os 4 | 5 | block_cipher = None 6 | 7 | spec_path = os.path.realpath(SPECPATH) 8 | 9 | 10 | a = Analysis(['app.py'], 11 | pathex=[os.path.join(spec_path, '..')], 12 | binaries=[], 13 | datas=[('assets/icon.ico', '.')], 14 | hiddenimports=['cmath'], 15 | hookspath=[], 16 | runtime_hooks=[], 17 | excludes=[], 18 | win_no_prefer_redirects=False, 19 | win_private_assemblies=False, 20 | cipher=block_cipher, 21 | noarchive=False) 22 | pyz = PYZ(a.pure, a.zipped_data, 23 | cipher=block_cipher) 24 | exe = EXE(pyz, 25 | a.scripts, 26 | a.binaries, 27 | a.zipfiles, 28 | a.datas, 29 | [], 30 | # exclude_binaries=True, 31 | name='Image Tools', 32 | debug=False, 33 | bootloader_ignore_signals=False, 34 | strip=False, 35 | upx=True, 36 | upx_exclude=[], 37 | runtime_tmpdir=None, 38 | console=False, 39 | icon='assets/icon.ico') 40 | 41 | # Build a .app if on OS X 42 | if sys.platform == 'darwin': 43 | app = BUNDLE(exe, 44 | name='Image Tools.app', 45 | icon='assets/icon.ico') 46 | -------------------------------------------------------------------------------- /app/components/navbar.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import QButtonGroup, QPushButton, QVBoxLayout, QFrame, QLabel 2 | from PyQt5.QtGui import QPixmap 3 | from app import resource_path 4 | 5 | 6 | class NavBar(QFrame): 7 | 8 | def __init__(self, click_callback, tabs): 9 | super().__init__() 10 | # initialize our variables 11 | self.click_callback = click_callback 12 | self.init_ui(tabs) 13 | 14 | def init_ui(self, tabs): 15 | # make our UI 16 | self.buttons = QButtonGroup(self) 17 | self.buttons.buttonClicked.connect(lambda button: self.click_callback(button.text())) 18 | 19 | # logo 20 | label = QLabel(self) 21 | pixmap = QPixmap(resource_path('icon.ico')) 22 | label.setPixmap(pixmap) 23 | label.setScaledContents(True) 24 | label.setObjectName("logo") 25 | 26 | # our tab buttons 27 | buttons = [self.nav_button(tab) for tab in tabs] 28 | # set the first one as checked by default 29 | buttons[0].setChecked(True) 30 | 31 | layout = QVBoxLayout() 32 | layout.addWidget(label) 33 | # add our buttons 34 | for button in buttons: 35 | layout.addWidget(button) 36 | layout.addStretch(1) 37 | layout.setContentsMargins(0, 0, 0, 0) 38 | layout.setSpacing(0) 39 | self.setLayout(layout) 40 | self.setObjectName("navbar") 41 | 42 | def nav_button(self, name): 43 | button = QPushButton(name) 44 | button.setCheckable(True) 45 | button.setObjectName("navbutton") 46 | self.buttons.addButton(button) 47 | return button 48 | -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic download of image files from URLs 3 | """ 4 | import os 5 | from pathlib import Path 6 | import requests 7 | 8 | 9 | def download_image(url, directory, lock, label=None): 10 | filepath = None 11 | try: 12 | # get our image save location 13 | save_dir = os.path.abspath(directory) 14 | if label is not None: 15 | save_dir = os.path.join(save_dir, label) 16 | # make our destination directory if it doesn't exist 17 | Path(save_dir).mkdir(parents=True, exist_ok=True) 18 | with lock: 19 | img_file = _get_filepath(url=url, save_dir=save_dir) 20 | response = requests.get(url, timeout=30) 21 | if response.ok: 22 | # save the image! 23 | with open(img_file, 'wb') as f: 24 | f.write(response.content) 25 | filepath = os.path.abspath(img_file) 26 | success = True 27 | else: 28 | success = False 29 | except Exception: 30 | success = False 31 | if not success: 32 | # with failure, also delete any bit of the temp file we made 33 | try: 34 | os.remove(img_file) 35 | except Exception: 36 | pass 37 | return filepath 38 | 39 | 40 | def _get_filepath(url, save_dir): 41 | # given a url and download folder, return the full filepath to image to save 42 | # get the name from the last url segment 43 | filename = str(url.split('/')[-1]) 44 | # strip out url params from name 45 | filename = filename.split('?')[0] 46 | # if this file already exists in the path, increment its name 47 | # (since different URLs can have the same end filename) 48 | filename = _resolve_filename_conflict(directory=save_dir, filename=filename) 49 | # now that we found the filename, make an empty file with it so that we don't have to wait file to download 50 | # for subsequent name searches 51 | filename = os.path.join(save_dir, filename) 52 | open(filename, 'a').close() 53 | return filename 54 | 55 | 56 | def _resolve_filename_conflict(directory, filename, sep="__"): 57 | # if this file already exists in the path, increment its name 58 | while os.path.exists(os.path.join(directory, filename)): 59 | name, extension = os.path.splitext(filename) 60 | name_parts = name.rsplit(sep, 1) 61 | base_name = name_parts[0] 62 | # get the counter value after the sep 63 | counter = 1 64 | if len(name_parts) > 1: 65 | try: 66 | counter = int(name_parts[-1]) + 1 67 | except ValueError: 68 | base_name = sep.join(name_parts) 69 | filename = f"{base_name}{sep}{counter}{extension}" 70 | return filename 71 | -------------------------------------------------------------------------------- /app/components/export.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PyQt5.QtWidgets import (QPushButton, QVBoxLayout, QHBoxLayout, QFrame, QLabel, QFileDialog, QMessageBox, 3 | QProgressBar, QComboBox, QSizePolicy) 4 | from app.components.stretch_wrapper import NoStretch 5 | from dataset.export_from_lobe import get_projects, export_dataset 6 | 7 | 8 | class Export(QFrame): 9 | export_text = "Export" 10 | exporting_text = "Exporting..." 11 | 12 | def __init__(self, app): 13 | super().__init__() 14 | # initialize our variables 15 | self.app = app 16 | self.export_button = None 17 | self.progress_bar = None 18 | self.projects = get_projects() 19 | self.project_dropdown = None 20 | self.init_ui() 21 | 22 | def init_ui(self): 23 | # make our UI 24 | self.setObjectName("content") 25 | layout = QHBoxLayout() 26 | layout.setContentsMargins(0, 0, 0, 0) 27 | 28 | # our main content area 29 | content = QFrame() 30 | content_layout = QVBoxLayout() 31 | 32 | # some info 33 | title = QLabel("Export") 34 | title.setObjectName("h1") 35 | description = QLabel("Export your labeled dataset from a Lobe project.") 36 | description.setObjectName("h2") 37 | 38 | # project dropdown 39 | project_label = QLabel("Project:") 40 | project_label.setObjectName("separate") 41 | self.project_dropdown = QComboBox() 42 | self.project_dropdown.setSizeAdjustPolicy(QComboBox.AdjustToContents) 43 | self.project_dropdown.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) 44 | project_container = NoStretch(self.project_dropdown) 45 | self.populate_projects() 46 | 47 | # button 48 | self.export_button = QPushButton(self.export_text) 49 | self.export_button.setEnabled(True) 50 | self.export_button.clicked.connect(self.export) 51 | export_container = NoStretch(self.export_button) 52 | export_container.setObjectName("separate") 53 | 54 | self.progress_bar = QProgressBar() 55 | self.progress_bar.hide() 56 | 57 | # make our content layout 58 | content_layout.addWidget(title) 59 | content_layout.addWidget(description) 60 | content_layout.addWidget(project_label) 61 | content_layout.addWidget(project_container) 62 | content_layout.addWidget(export_container) 63 | content_layout.addWidget(self.progress_bar) 64 | content_layout.addStretch(1) 65 | content.setLayout(content_layout) 66 | 67 | layout.addWidget(content) 68 | layout.addStretch(1) 69 | self.setLayout(layout) 70 | 71 | def populate_projects(self): 72 | self.projects = get_projects() 73 | self.project_dropdown.clear() 74 | self.project_dropdown.addItems([name for name, _ in self.projects]) 75 | self.project_dropdown.adjustSize() 76 | 77 | def export(self): 78 | # disable the buttons so we can't click again 79 | self.export_button.setEnabled(False) 80 | self.export_button.setText(self.exporting_text) 81 | self.progress_bar.setValue(0) 82 | self.progress_bar.show() 83 | self.app.processEvents() 84 | destination_directory = QFileDialog.getExistingDirectory(self, "Select Output Directory") 85 | # if they hit cancel, don't download 86 | if not destination_directory: 87 | self.done() 88 | return 89 | # otherwise try exporting to the desired location 90 | try: 91 | project_name, project_id = self.projects[self.project_dropdown.currentIndex()] 92 | export_dir = os.path.join(destination_directory, project_name) 93 | # rename the directory if there is a conflict 94 | rename_idx = 1 95 | while os.path.exists(export_dir): 96 | export_dir = os.path.abspath(os.path.join(destination_directory, f"{project_name} ({rename_idx})")) 97 | rename_idx += 1 98 | export_dataset(project_id=project_id, destination_dir=export_dir, progress_hook=self.progress_hook) 99 | except Exception as e: 100 | QMessageBox.about(self, "Alert", f"Error exporting dataset: {e}") 101 | self.done() 102 | 103 | def progress_hook(self, current, total): 104 | self.progress_bar.setValue(float(current) / total * 100) 105 | if current == total: 106 | self.done() 107 | # make sure to update the UI 108 | self.app.processEvents() 109 | 110 | def done(self): 111 | self.progress_bar.setValue(0) 112 | self.progress_bar.hide() 113 | self.export_button.setEnabled(True) 114 | self.export_button.setText(self.export_text) 115 | self.app.processEvents() 116 | -------------------------------------------------------------------------------- /app/components/visualize.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import QPushButton, QVBoxLayout, QHBoxLayout, QFrame, QLabel, QFileDialog, QMessageBox 2 | from PyQt5.QtGui import QPixmap 3 | from PIL.ImageQt import ImageQt 4 | from PIL import Image 5 | from app.components.stretch_wrapper import NoStretch 6 | from lobe import ImageModel 7 | 8 | 9 | class Visualize(QFrame): 10 | default_model_text = "Please select a TensorFlow model directory.<\i>" 11 | default_file_text = "Please select an image file.<\i>" 12 | loading_text = "Loading..." 13 | 14 | def __init__(self, app): 15 | super().__init__() 16 | # initialize our variables 17 | self.app = app 18 | self.tf_directory = None 19 | self.file = None 20 | self.image = None 21 | self.model = None 22 | self.init_ui() 23 | 24 | def init_ui(self): 25 | # make our UI 26 | self.setObjectName("content") 27 | layout = QHBoxLayout() 28 | layout.setContentsMargins(0, 0, 0, 0) 29 | 30 | # our main content area 31 | content = QFrame() 32 | content_layout = QVBoxLayout() 33 | 34 | # some info 35 | title = QLabel("Visualize") 36 | title.setObjectName("h1") 37 | description = QLabel( 38 | "Visualize the model's prediction as a heatmap on the image.\nThis shows which parts of the image determined the predicted label.") 39 | description.setObjectName("h2") 40 | 41 | # model select button 42 | self.model_button = QPushButton("Select model directory") 43 | self.model_button.clicked.connect(self.select_directory) 44 | model_container = NoStretch(self.model_button) 45 | model_container.setObjectName("separate") 46 | self.model_label = QLabel(self.default_model_text) 47 | 48 | # file selection button 49 | self.file_button = QPushButton("Select image") 50 | self.file_button.clicked.connect(self.select_file) 51 | buttons_container = NoStretch(self.file_button) 52 | buttons_container.setObjectName("separate") 53 | self.path_label = QLabel(self.default_file_text) 54 | 55 | # image display 56 | self.image_label = QLabel() 57 | image_container = NoStretch(self.image_label) 58 | image_container.setObjectName("separate") 59 | self.prediction_label = QLabel() 60 | 61 | # make our content layout 62 | content_layout.addWidget(title) 63 | content_layout.addWidget(description) 64 | content_layout.addWidget(model_container) 65 | content_layout.addWidget(self.model_label) 66 | content_layout.addWidget(buttons_container) 67 | content_layout.addWidget(self.path_label) 68 | content_layout.addWidget(image_container) 69 | content_layout.addWidget(self.prediction_label) 70 | content_layout.addStretch(1) 71 | content.setLayout(content_layout) 72 | 73 | layout.addWidget(content) 74 | layout.addStretch(1) 75 | self.setLayout(layout) 76 | 77 | def select_directory(self): 78 | self.tf_directory = QFileDialog.getExistingDirectory(self, "Select TensorFlow Model Directory") 79 | self.model_label.setText(f"{self.tf_directory}" if self.tf_directory else self.default_model_text) 80 | self.model = None 81 | self.visualize() 82 | 83 | def select_file(self): 84 | self.file = QFileDialog.getOpenFileName(self, 'Select Image File')[0] 85 | self.path_label.setText(f"{self.file}" if self.file else self.default_file_text) 86 | self.visualize() 87 | 88 | def visualize(self): 89 | self.app.processEvents() 90 | if self.tf_directory is not None and self.file is not None: 91 | # disable the buttons so we can't click again 92 | self.model_button.setEnabled(False) 93 | self.file_button.setEnabled(False) 94 | self.image_label.setText(self.loading_text) 95 | self.prediction_label.setText("") 96 | self.app.processEvents() 97 | try: 98 | if self.model is None: 99 | self.model = ImageModel.load(self.tf_directory) 100 | self.image = Image.open(self.file) 101 | prediction = self.model.predict(self.image).prediction 102 | self.prediction_label.setText(f"Predicted label: {prediction}") 103 | viz = self.model.visualize(self.image, prediction) 104 | self.image_label.setPixmap(QPixmap.fromImage(ImageQt(viz))) 105 | except Exception as e: 106 | self.image = None 107 | self.model = None 108 | QMessageBox.about(self, "Alert", f"Error visualizing: {e}") 109 | finally: 110 | self.done() 111 | 112 | def done(self): 113 | self.model_button.setEnabled(True) 114 | self.file_button.setEnabled(True) 115 | if not self.image: 116 | self.image_label.setText("") 117 | self.prediction_label.setText("") 118 | self.app.processEvents() 119 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | .idea 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | #*.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | # pytype static type analyzer 140 | .pytype/ 141 | 142 | # Cython debug symbols 143 | cython_debug/ 144 | 145 | ### JetBrains template 146 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 147 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 148 | 149 | # User-specific stuff 150 | .idea/**/workspace.xml 151 | .idea/**/tasks.xml 152 | .idea/**/usage.statistics.xml 153 | .idea/**/dictionaries 154 | .idea/**/shelf 155 | 156 | # Generated files 157 | .idea/**/contentModel.xml 158 | 159 | # Sensitive or high-churn files 160 | .idea/**/dataSources/ 161 | .idea/**/dataSources.ids 162 | .idea/**/dataSources.local.xml 163 | .idea/**/sqlDataSources.xml 164 | .idea/**/dynamic.xml 165 | .idea/**/uiDesigner.xml 166 | .idea/**/dbnavigator.xml 167 | 168 | # Gradle 169 | .idea/**/gradle.xml 170 | .idea/**/libraries 171 | 172 | # Gradle and Maven with auto-import 173 | # When using Gradle or Maven with auto-import, you should exclude module files, 174 | # since they will be recreated, and may cause churn. Uncomment if using 175 | # auto-import. 176 | # .idea/artifacts 177 | # .idea/compiler.xml 178 | # .idea/jarRepositories.xml 179 | # .idea/modules.xml 180 | # .idea/*.iml 181 | # .idea/modules 182 | # *.iml 183 | # *.ipr 184 | 185 | # CMake 186 | cmake-build-*/ 187 | 188 | # Mongo Explorer plugin 189 | .idea/**/mongoSettings.xml 190 | 191 | # File-based project format 192 | *.iws 193 | 194 | # IntelliJ 195 | out/ 196 | 197 | # mpeltonen/sbt-idea plugin 198 | .idea_modules/ 199 | 200 | # JIRA plugin 201 | atlassian-ide-plugin.xml 202 | 203 | # Cursive Clojure plugin 204 | .idea/replstate.xml 205 | 206 | # Crashlytics plugin (for Android Studio and IntelliJ) 207 | com_crashlytics_export_strings.xml 208 | crashlytics.properties 209 | crashlytics-build.properties 210 | fabric.properties 211 | 212 | # Editor-based Rest Client 213 | .idea/httpRequests 214 | 215 | # Android studio 3.1+ serialized cache file 216 | .idea/caches/build_file_checksums.ser 217 | -------------------------------------------------------------------------------- /model/predict_from_folder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Given a folder of images and a Tensorflow 1.15 SavedModel file, run image classification on the images and move 3 | them into a subdirectory structure where their directory is the predicted label name. 4 | """ 5 | import argparse 6 | import os 7 | import shutil 8 | from tqdm import tqdm 9 | from lobe import ImageModel 10 | from concurrent.futures import ThreadPoolExecutor 11 | from csv import writer as csv_writer 12 | 13 | 14 | def predict_folder(img_dir, model_dir, progress_hook=None, move=True, csv=False): 15 | """ 16 | Run your model on a directory of images. This will also go through any images in existing subdirectories. 17 | Move each image into a subdirectory structure based on the prediction -- the predicted label 18 | becomes the directory name where the image goes. 19 | 20 | :param img_dir: the filepath to your directory of images. 21 | :param model_dir: path to the Lobe Tensorflow SavedModel export. 22 | :param progress_hook: an optional function that will be run with progress_hook(currentProgress, totalProgress) when progress updates. 23 | :param move: a flag for whether you want to physically move the image files into a subfolder structure based on the predicted label 24 | :param csv: a flag for whether you want to create an output csv showing the image filenames and their predictions 25 | """ 26 | print(f"Predicting {img_dir}") 27 | img_dir = os.path.abspath(img_dir) 28 | # if this a .txt file, don't treat the first row as a header. Otherwise, use the first row for header column names. 29 | if not os.path.isdir(img_dir): 30 | raise ValueError(f"Please specify a directory to images. Found {img_dir}") 31 | 32 | num_items = sum(len(files) for _, _, files in os.walk(img_dir)) 33 | print(f"Predicting {num_items} items...") 34 | 35 | # load the model 36 | print("Loading model...") 37 | model = ImageModel.load(model_path=model_dir) 38 | print("Model loaded!") 39 | 40 | # create our output csv 41 | out_csv = os.path.join(img_dir, "predictions.csv") 42 | if csv: 43 | with open(out_csv, 'w', encoding="utf-8", newline='') as f: 44 | writer = csv_writer(f) 45 | writer.writerow(['File', 'Label', 'Confidence']) 46 | 47 | # iterate over the rows and predict the label 48 | curr_progress = 0 49 | no_labels = 0 50 | with tqdm(total=num_items) as pbar: 51 | with ThreadPoolExecutor() as executor: 52 | model_futures = [] 53 | # make our prediction jobs 54 | for root, _, files in os.walk(img_dir): 55 | for filename in files: 56 | image_file = os.path.abspath(os.path.join(root, filename)) 57 | model_futures.append( 58 | (executor.submit(predict_label_from_image_file, image_file=image_file, model=model), image_file) 59 | ) 60 | 61 | for future, img_file in model_futures: 62 | label, confidence = future.result() 63 | if label is None: 64 | no_labels += 1 65 | else: 66 | # move the file 67 | dest_file = img_file 68 | if move: 69 | filename = os.path.split(img_file)[-1] 70 | name, ext = os.path.splitext(filename) 71 | dest_dir = os.path.join(img_dir, label) 72 | os.makedirs(dest_dir, exist_ok=True) 73 | dest_file = os.path.abspath(os.path.join(dest_dir, filename)) 74 | # only move if the destination is different than the file 75 | if dest_file != img_file: 76 | try: 77 | # rename the file if there is a conflict 78 | rename_idx = 0 79 | while os.path.exists(dest_file): 80 | new_name = f'{name}_{rename_idx}{ext}' 81 | dest_file = os.path.abspath(os.path.join(dest_dir, new_name)) 82 | rename_idx += 1 83 | shutil.move(img_file, dest_file) 84 | except Exception as e: 85 | print(f"Problem moving file: {e}") 86 | # write the results to a csv 87 | if csv: 88 | with open(out_csv, 'a', encoding="utf-8", newline='') as f: 89 | writer = csv_writer(f) 90 | writer.writerow( 91 | [dest_file, label, confidence]) 92 | pbar.update(1) 93 | if progress_hook: 94 | curr_progress += 1 95 | progress_hook(curr_progress, num_items) 96 | print(f"Done! Number of images without predicted labels: {no_labels}") 97 | 98 | 99 | def predict_label_from_image_file(image_file, model: ImageModel): 100 | try: 101 | result = model.predict_from_file(path=image_file) 102 | return result.labels[0] 103 | except Exception as e: 104 | print(f"Problem predicting image from file: {e}") 105 | return None, None 106 | 107 | 108 | if __name__ == '__main__': 109 | parser = argparse.ArgumentParser(description='Predict an image dataset from a folder of images.') 110 | parser.add_argument('dir', help='Directory path to your images.') 111 | parser.add_argument('model_dir', help='Path to your SavedModel from Lobe.') 112 | args = parser.parse_args() 113 | predict_folder(img_dir=args.dir, model_dir=args.model_dir, move=True, csv=True) 114 | -------------------------------------------------------------------------------- /model/predict_from_file.py: -------------------------------------------------------------------------------- 1 | """ 2 | Given a csv or txt file and a Tensorflow 1.15 SavedModel file, run image classification on the urls 3 | and write the predicted label and confidence back to the file 4 | """ 5 | import argparse 6 | import os 7 | import pandas as pd 8 | from csv import writer as csv_writer 9 | from tqdm import tqdm 10 | from lobe import ImageModel 11 | from concurrent.futures import ThreadPoolExecutor 12 | 13 | 14 | def predict_dataset(filepath, model_dir, url_col=None, progress_hook=None): 15 | """ 16 | Given a file with urls to images, predict the given SavedModel on the image and write the label 17 | and confidene back to the file. 18 | 19 | :param filepath: path to a valid txt or csv file with image urls to download. 20 | :param model_dir: path to the Lobe Tensorflow SavedModel export. 21 | :param url_col: if this is a csv, the column header name for the urls to download. 22 | :param progress_hook: an optional function that will be run with progress_hook(currentProgress, totalProgress) when progress updates. 23 | """ 24 | print(f"Predicting {filepath}") 25 | filepath = os.path.abspath(filepath) 26 | filename, ext = _name_and_extension(filepath) 27 | # read the file 28 | # if this a .txt file, don't treat the first row as a header. Otherwise, use the first row for header column names. 29 | if ext != '.xlsx': 30 | csv = pd.read_csv(filepath, header=None if ext == '.txt' else 0) 31 | else: 32 | csv = pd.read_excel(filepath, header=0) 33 | if ext in ['.csv', '.xlsx'] and not url_col: 34 | raise ValueError(f"Please specify an image url column for the csv.") 35 | url_col_idx = 0 36 | if url_col: 37 | try: 38 | url_col_idx = list(csv.columns).index(url_col) 39 | except ValueError: 40 | raise ValueError(f"Image url column {url_col} not found in csv headers {csv.columns}") 41 | 42 | num_items = len(csv) 43 | print(f"Predicting {num_items} items...") 44 | 45 | # load the model 46 | print("Loading model...") 47 | model = ImageModel.load(model_path=model_dir) 48 | print("Model loaded!") 49 | 50 | # create our output csv 51 | fname, ext = os.path.splitext(filepath) 52 | out_file = f"{fname}_predictions.csv" 53 | with open(out_file, 'w', encoding="utf-8", newline='') as f: 54 | # our header names from the pandas columns 55 | writer = csv_writer(f) 56 | writer.writerow([*[str(col) if not pd.isna(col) else '' for col in csv.columns], 'label', 'confidence']) 57 | 58 | # iterate over the rows and predict the label 59 | with tqdm(total=len(csv)) as pbar: 60 | with ThreadPoolExecutor() as executor: 61 | model_futures = [] 62 | # make our prediction jobs 63 | for i, row in enumerate(csv.itertuples(index=False)): 64 | url = row[url_col_idx] 65 | model_futures.append(executor.submit(predict_image_url, url=url, model=model, row=row)) 66 | 67 | # write the results from the predict (this should go in order of the futures) 68 | for i, future in enumerate(model_futures): 69 | label, confidence, row = future.result() 70 | with open(out_file, 'a', encoding="utf-8", newline='') as f: 71 | writer = csv_writer(f) 72 | writer.writerow([*[str(col) if not pd.isna(col) else '' for col in row], label, confidence]) 73 | pbar.update(1) 74 | if progress_hook: 75 | progress_hook(i+1, len(csv)) 76 | 77 | 78 | def predict_image_url(url, model: ImageModel, row): 79 | label, confidence = '', '' 80 | try: 81 | result = model.predict_from_url(url=url) 82 | label, confidence = result.labels[0] 83 | except Exception as e: 84 | print(f"Problem predicting image from url: {e}") 85 | return label, confidence, row 86 | 87 | 88 | def _name_and_extension(filepath): 89 | # returns a tuple of the filename and the extension, ignoring any other prefixes in the filepath 90 | # raises if not a file 91 | fpath = os.path.abspath(filepath) 92 | if not os.path.isfile(fpath): 93 | raise ValueError(f"File {filepath} doesn't exist.") 94 | filename = os.path.split(fpath)[-1] 95 | name, ext = os.path.splitext(filename) 96 | return name, str.lower(ext) 97 | 98 | 99 | def _valid_file(filepath): 100 | # file must exist and have a valid extension 101 | valid_extensions = ['.txt', '.csv', '.xlsx'] 102 | _, extension = _name_and_extension(filepath) 103 | if extension not in valid_extensions: 104 | raise ValueError(f"File {filepath} doesn't have one of the valid extensions: {valid_extensions}") 105 | # good to go 106 | return filepath 107 | 108 | 109 | if __name__ == '__main__': 110 | parser = argparse.ArgumentParser(description='Label an image dataset from csv or txt file.') 111 | parser.add_argument('file', help='Path to your csv or txt file.') 112 | parser.add_argument('model_dir', help='Path to your SavedModel from Lobe.') 113 | parser.add_argument('--url', help='If this is a csv with column headers, the column that contains the image urls to download.') 114 | args = parser.parse_args() 115 | predict_dataset(filepath=args.file, model_dir=args.model_dir, url_col=args.url) 116 | -------------------------------------------------------------------------------- /app/components/flickr.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PyQt5.QtWidgets import (QPushButton, QVBoxLayout, QHBoxLayout, QFrame, QLabel, QFileDialog, QMessageBox, 3 | QProgressBar, QLineEdit) 4 | from app.components.stretch_wrapper import NoStretch 5 | from dataset.download_from_flickr import download_flickr 6 | 7 | 8 | class Flickr(QFrame): 9 | download_text = "Download" 10 | downloading_text = "Downloading..." 11 | 12 | def __init__(self, app): 13 | super().__init__() 14 | # initialize our variables 15 | self.app = app 16 | self.api_textbox = None 17 | self.min_lat_textbox = None 18 | self.min_long_textbox = None 19 | self.max_lat_textbox = None 20 | self.max_long_textbox = None 21 | self.search_textbox = None 22 | self.download_button = None 23 | self.progress_bar = None 24 | self.init_ui() 25 | 26 | def init_ui(self): 27 | # make our UI 28 | self.setObjectName("content") 29 | layout = QHBoxLayout() 30 | layout.setContentsMargins(0, 0, 0, 0) 31 | 32 | # our main content area 33 | content = QFrame() 34 | content_layout = QVBoxLayout() 35 | 36 | # some info 37 | title = QLabel("Flickr") 38 | title.setObjectName("h1") 39 | description = QLabel("Download images from Flickr from a geographic bounding box location.") 40 | description.setObjectName("h2") 41 | 42 | # API key 43 | api_label = QLabel("Flickr API Key:") 44 | api_label.setObjectName("separateSmall") 45 | self.api_textbox = QLineEdit() 46 | api_container = NoStretch(self.api_textbox) 47 | 48 | # geo box 49 | bbox_label = QLabel("Bounding Box:") 50 | bbox_label.setObjectName("separateSmall") 51 | 52 | minlat_label = QLabel("Min Latitude:") 53 | self.min_lat_textbox = QLineEdit() 54 | min_lat_container = NoStretch(self.min_lat_textbox) 55 | minlong_label = QLabel("Min Longitude:") 56 | self.min_long_textbox = QLineEdit() 57 | min_long_container = NoStretch(self.min_long_textbox) 58 | 59 | maxlat_label = QLabel("Max Latitude:") 60 | maxlat_label.setObjectName("separateSmall") 61 | self.max_lat_textbox = QLineEdit() 62 | max_lat_container = NoStretch(self.max_lat_textbox) 63 | maxlong_label = QLabel("Max Longitude:") 64 | self.max_long_textbox = QLineEdit() 65 | max_long_container = NoStretch(self.max_long_textbox) 66 | 67 | # search term 68 | search_label = QLabel("Search term:") 69 | search_label.setObjectName("separateSmall") 70 | self.search_textbox = QLineEdit() 71 | search_container = NoStretch(self.search_textbox) 72 | 73 | 74 | # download button 75 | self.download_button = QPushButton(self.download_text) 76 | self.download_button.setEnabled(True) 77 | self.download_button.clicked.connect(self.download) 78 | download_container = NoStretch(self.download_button) 79 | download_container.setObjectName("separate") 80 | 81 | self.progress_bar = QProgressBar() 82 | self.progress_bar.hide() 83 | 84 | # make our content layout 85 | content_layout.addWidget(title) 86 | content_layout.addWidget(description) 87 | content_layout.addWidget(api_label) 88 | content_layout.addWidget(api_container) 89 | content_layout.addWidget(bbox_label) 90 | content_layout.addWidget(minlat_label) 91 | content_layout.addWidget(min_lat_container) 92 | content_layout.addWidget(minlong_label) 93 | content_layout.addWidget(min_long_container) 94 | content_layout.addWidget(maxlat_label) 95 | content_layout.addWidget(max_lat_container) 96 | content_layout.addWidget(maxlong_label) 97 | content_layout.addWidget(max_long_container) 98 | content_layout.addWidget(search_label) 99 | content_layout.addWidget(search_container) 100 | content_layout.addWidget(download_container) 101 | content_layout.addWidget(self.progress_bar) 102 | content_layout.addStretch(1) 103 | content.setLayout(content_layout) 104 | 105 | layout.addWidget(content) 106 | layout.addStretch(1) 107 | self.setLayout(layout) 108 | 109 | def download(self): 110 | # disable the buttons so we can't click again 111 | self.download_button.setEnabled(False) 112 | self.download_button.setText(self.downloading_text) 113 | self.progress_bar.setValue(0) 114 | self.progress_bar.show() 115 | self.app.processEvents() 116 | destination_directory = QFileDialog.getExistingDirectory(self, "Select Output Directory") 117 | # if they hit cancel, don't download 118 | if not destination_directory: 119 | self.done() 120 | return 121 | # otherwise try downloading to the desired location 122 | try: 123 | download_flickr( 124 | api_key=self.api_textbox.text(), 125 | directory=destination_directory, 126 | min_lat=self.min_lat_textbox.text() or None, 127 | min_long=self.min_long_textbox.text() or None, 128 | max_lat=self.max_lat_textbox.text() or None, 129 | max_long=self.max_long_textbox.text() or None, 130 | search=self.search_textbox.text() or None, 131 | progress_hook=self.progress_hook, 132 | ) 133 | except Exception as e: 134 | QMessageBox.about(self, "Alert", f"Error creating dataset: {e}") 135 | self.done() 136 | 137 | def progress_hook(self, current, total): 138 | self.progress_bar.setValue(float(current) / total * 100) 139 | if current == total: 140 | self.done() 141 | # make sure to update the UI 142 | self.app.processEvents() 143 | 144 | def done(self): 145 | self.progress_bar.setValue(0) 146 | self.progress_bar.hide() 147 | self.download_button.setEnabled(True) 148 | self.download_button.setText(self.download_text) 149 | self.app.processEvents() 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![PyInstaller Mac](https://github.com/lobe/image-tools/workflows/PyInstaller%20Mac/badge.svg) ![PyInstaller Windows](https://github.com/lobe/image-tools/workflows/PyInstaller%20Windows/badge.svg) 2 | # Image Tools: creating image datasets 3 | Image Tools helps you form machine learning datasets for image classification. 4 | 5 | ## Download the desktop application on Windows 6 | We use GitHub Actions to build the desktop version of this app. If you would like to download it for 7 | Windows, please click on [Actions](https://github.com/lobe/image-tools/actions) and then 8 | you will see [PyInstaller Windows](https://github.com/lobe/image-tools/actions?query=workflow%3A%22PyInstaller+Windows%22) 9 | on the left under 'All Workflows'. Once you click into the Windows workflow, you will see a list of the builds 10 | in the center of the screen. Click on the topmost item in the results list for the latest version. Once you click 11 | on the latest build, you should see a section titled 'Artifacts' with an item called 'Image Tools Windows'. 12 | When you click on this artifact, it should download the zip containing the app for you! 13 | 14 | *Support for MacOS is still in progress.* 15 | ### Run Desktop application on MacOS 16 | While the compiled Mac app does not work because it is unsigned, you can either create it locally or run the app via 17 | Python command line. 18 | 1. Make sure you have Python 3.8 or Python 3.9. See guide here for installing: https://www.python.org/downloads/mac-osx/ 19 | 2. Download this code repository to your machine. 20 | 3. Open a terminal and navigate to where you downloaded this code with the `cd` command, or open the 21 | terminal directly at the *Image Tools* folder by right clicking on it and selecting *Services > Open Terminal at Folder* 22 | 4. Install dependencies via `pip3 install -r requirements.txt` (make sure your pip3 is pointing to your Python 3.8 or 3.9 installation by checking `pip3 --version`) 23 | 5. Run the app! 24 | * Run from command line: `python3 -m app.app` 25 | * Compile the app yourself: `pyinstaller --windowed --onefile app/app.spec`. This will create a `dist/` folder that contains 26 | the app you can run. 27 | 28 | 29 | 30 | ## Code Setup 31 | Install the required packages. 32 | ```shell script 33 | pip install -r requirements.txt 34 | ``` 35 | If you are on Windows, you will also need to install the latest PyInstaller from GitHub: 36 | ```shell script 37 | pip install git+https://github.com/pyinstaller/pyinstaller.git 38 | ``` 39 | 40 | ## Command Line Usage 41 | ### CSV, XLSX, or TXT files 42 | #### Downloading an image dataset from the urls in a csv or txt file: 43 | ```shell script 44 | python -m dataset.download_from_file your_file.csv --url UrlHeader --label LabelHeader 45 | ``` 46 | This downloader script takes either a csv, xlsx, or txt file and will format an image dataset for you. The resulting images 47 | will be downloaded to a folder with the same name as your input file. If you supplied labels, the images will be 48 | grouped into sub-folders with the label name. 49 | 50 | * csv or xlsx file 51 | * specify the column header for the image urls with the --url flag 52 | * you can optionally give the column header for labels to assign the images if this is a pre-labeled dataset 53 | 54 | * txt file 55 | * separate each image url by a newline 56 | 57 | #### Predicting labels and confidences for images in a csv, xlsx, or txt file: 58 | ```shell script 59 | python -m model.predict_from_file your_file.csv path/to/lobe/savedmodel --url UrlHeader 60 | ``` 61 | This prediction script will take a csv or txt file with urls to images and a Lobe TensorFlow SavedModel export directory, 62 | and create and output csv with the label and confidence as the last two columns. 63 | 64 | * csv or xlsx file 65 | * specify the column header for the image urls with the --url flag 66 | 67 | * txt file 68 | * separate each image url by a newline 69 | 70 | ### Folder of images 71 | ```shell script 72 | python -m model.predict_from_folder path/to/images path/to/lobe/savedmodel 73 | ``` 74 | This prediction script will take a directory of images and a Lobe TensorFlow SavedModel export directory, 75 | and reorganize those images into subdirectories by their predicted label. 76 | 77 | 78 | ### Flickr downloader 79 | Download images from Flickr by latitude and longitude bounding box location and any desired search terms. 80 | ```shell script 81 | python -m dataset.download_from_flickr api_key dest_folder --bbox min_lat,min_long,max_lat,max_long --search searchTerm 82 | ``` 83 | This will create an `images.csv` file in your destination folder that includes the EXIF data for the downloaded photos. 84 | 85 | 86 | ### Export Lobe dataset 87 | Export your project's dataset by giving the project name and desired export directory: 88 | ```shell script 89 | python -m dataset.export_from_lobe 'Project Name' destination/export/folder 90 | ``` 91 | Your images will be copied to the destination folder, and their labels will be the subfolder name. You can take this 92 | exported folder and drag it directly to a new project in Lobe. 93 | 94 | 95 | ## Build Desktop Application 96 | You can create a desktop GUI application using PyInstaller: 97 | 98 | ```shell script 99 | pyinstaller --onefile --windowed app/app.spec 100 | ``` 101 | 102 | This will create a `dist/` folder that will contain the application file `Image Tools.exe` or `Image Tools.app` 103 | depending on your OS. 104 | 105 | ### Running the desktop application for development 106 | ```shell script 107 | python -m app.app 108 | ``` 109 | -------------------------------------------------------------------------------- /dataset/download_from_file.py: -------------------------------------------------------------------------------- 1 | """ 2 | Given a csv or txt file, download the image urls to form the dataset. 3 | """ 4 | import argparse 5 | import os 6 | from csv import writer as csv_writer 7 | import pandas as pd 8 | from tqdm import tqdm 9 | from concurrent.futures import ThreadPoolExecutor, as_completed 10 | from threading import Lock 11 | from dataset.utils import download_image 12 | 13 | 14 | def create_dataset(filepath, url_col=None, label_col=None, progress_hook=None, destination_directory=None): 15 | """ 16 | Given a file with urls to images, downloads those images to a new directory that has the same name 17 | as the file without the extension. If labels are present, further categorizes the directory to have 18 | the labels as sub-directories. 19 | 20 | :param filepath: path to a valid txt or csv file with image urls to download. 21 | :param url_col: if this is a csv, the column header name for the urls to download. 22 | :param label_col: if this is a csv, the column header name for the labels of the images. 23 | :param progress_hook: an optional function that will be run with progress_hook(currentProgress, totalProgress) when progress updates. 24 | :param destination_directory: an optional directory path to download the dataset to. 25 | """ 26 | print(f"Processing {filepath}") 27 | filepath = os.path.abspath(filepath) 28 | filename, ext = _name_and_extension(filepath) 29 | # read the file 30 | # if this a .txt file, don't treat the first row as a header. Otherwise, use the first row for header column names. 31 | if ext != '.xlsx': 32 | csv = pd.read_csv(filepath, header=None if ext == '.txt' else 0) 33 | else: 34 | csv = pd.read_excel(filepath, header=0) 35 | if ext in ['.csv', '.xlsx'] and not url_col: 36 | raise ValueError(f"Please specify an image url column for the csv.") 37 | url_col_idx = 0 38 | if url_col: 39 | try: 40 | url_col_idx = list(csv.columns).index(url_col) 41 | except ValueError: 42 | raise ValueError(f"Image url column {url_col} not found in csv headers {csv.columns}") 43 | label_col_idx = None 44 | if label_col: 45 | try: 46 | label_col_idx = list(csv.columns).index(label_col) 47 | except ValueError: 48 | raise ValueError(f"Label column {label_col} not found in csv headers {csv.columns}") 49 | 50 | total_jobs = len(csv) 51 | print(f"Downloading {total_jobs} items...") 52 | 53 | errors = [] 54 | dest = os.path.join(destination_directory, filename) if destination_directory else filename 55 | 56 | # try/catch for keyboard interrupt 57 | try: 58 | # iterate over the rows and add to our download processing job! 59 | with tqdm(total=total_jobs) as pbar: 60 | with ThreadPoolExecutor() as executor: 61 | # for every image in the row, download it! 62 | download_futures = {} 63 | lock = Lock() 64 | for i, row in enumerate(csv.itertuples(index=False)): 65 | # job is passed to our worker processes 66 | index = i + 1 67 | url = row[url_col_idx] 68 | label = None 69 | if label_col_idx: 70 | label = row[label_col_idx] 71 | label = None if pd.isnull(label) else label 72 | download_futures[ 73 | executor.submit(download_image, url=url, directory=dest, lock=lock, label=label) 74 | ] = (index, url, label) 75 | 76 | 77 | # iterate over the results to update our progress bar and write any errors to the error csv 78 | num_processed = 0 79 | for future in as_completed(download_futures): 80 | index, url, label = download_futures[future] 81 | filename = future.result() 82 | if not filename: 83 | error_row = [index, url] 84 | if label_col_idx: 85 | error_row.append(label) 86 | errors.append(error_row) 87 | # update progress 88 | pbar.update(1) 89 | num_processed += 1 90 | if progress_hook: 91 | progress_hook(num_processed, total_jobs) 92 | 93 | print('Cleaning up...') 94 | # write out the error csv 95 | if len(errors) > 0: 96 | errors.sort() 97 | fname, ext = os.path.splitext(filepath) 98 | error_file = f"{fname}_errors.csv" 99 | with open(error_file, 'w', newline='') as f: 100 | header = f"index,url{',label' if label_col_idx else ''}\n" 101 | f.write(header) 102 | writer = csv_writer(f) 103 | writer.writerows(errors) 104 | 105 | except Exception: 106 | raise 107 | 108 | 109 | def _name_and_extension(filepath): 110 | # returns a tuple of the filename and the extension, ignoring any other prefixes in the filepath 111 | # raises if not a file 112 | fpath = os.path.abspath(filepath) 113 | if not os.path.isfile(fpath): 114 | raise ValueError(f"File {filepath} doesn't exist.") 115 | filename = os.path.split(fpath)[-1] 116 | name, ext = os.path.splitext(filename) 117 | return name, str.lower(ext) 118 | 119 | 120 | def _valid_file(filepath): 121 | # file must exist and have a valid extension 122 | valid_extensions = ['.txt', '.csv', '.xlsx'] 123 | _, extension = _name_and_extension(filepath) 124 | if extension not in valid_extensions: 125 | raise ValueError(f"File {filepath} doesn't have one of the valid extensions: {valid_extensions}") 126 | # good to go 127 | return filepath 128 | 129 | 130 | if __name__ == '__main__': 131 | parser = argparse.ArgumentParser(description='Download an image dataset from csv or txt file.') 132 | parser.add_argument('file', help='Path to your csv or txt file.') 133 | parser.add_argument('--url', help='If this is a csv with column headers, the column that contains the image urls to download.') 134 | parser.add_argument('--label', help='If this is a csv with column headers, the column that contains the labels to assign the images.') 135 | args = parser.parse_args() 136 | create_dataset(filepath=args.file, url_col=args.url, label_col=args.label) 137 | -------------------------------------------------------------------------------- /app/app.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import QtGui, QtCore 2 | from PyQt5.QtWidgets import ( 3 | QApplication, QMainWindow, QHBoxLayout, 4 | QDesktopWidget, QFrame 5 | ) 6 | import sys 7 | from collections import OrderedDict 8 | from multiprocessing import freeze_support 9 | from app.components.navbar import NavBar 10 | from app.components.dataset import Dataset 11 | from app.components.model import Model 12 | from app.components.flickr import Flickr 13 | from app.components.export import Export 14 | from app.components.visualize import Visualize 15 | from app import resource_path 16 | 17 | 18 | try: 19 | # Include in try/except block if you're also targeting Mac/Linux 20 | from PyQt5.QtWinExtras import QtWin 21 | QtWin.setCurrentProcessExplicitAppUserModelID('image-tools.0.1') 22 | except ImportError: 23 | pass 24 | 25 | if hasattr(QtCore.Qt, 'AA_EnableHighDpiScaling'): 26 | QApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling, True) 27 | 28 | if hasattr(QtCore.Qt, 'AA_UseHighDpiPixmaps'): 29 | QApplication.setAttribute(QtCore.Qt.AA_UseHighDpiPixmaps, True) 30 | 31 | 32 | # style variables 33 | DARK_0 = "rgb(18,18,18)" 34 | DARK_1 = "rgb(29,29,29)" 35 | DARK_2 = "rgb(33,33,33)" 36 | DARK_3 = "rgb(39,39,39)" 37 | DARK_4 = "rgb(45,45,45)" 38 | DARK_5 = "rgb(55,55,55)" 39 | 40 | TEXT = "#e0e0e0" 41 | TEXT_DISABLED = "rgb(54,54,54)" 42 | TEXT_MEDIUM = "rgb(70,70,70)" 43 | TEXT_LIGHT = "rgb(182,182,182)" 44 | 45 | 46 | class Tabs: 47 | DATASET = 'Dataset' 48 | MODEL = 'Model' 49 | FLICKR = 'Flickr' 50 | EXPORT = 'Export' 51 | VISUALIZE = 'Visualize' 52 | 53 | 54 | class MainWindow(QMainWindow): 55 | 56 | def __init__(self, app, *args, **kwargs): 57 | super().__init__(*args, **kwargs) 58 | # initialize our variables 59 | self.app = app 60 | self.nav = None 61 | self.pages = OrderedDict() 62 | self.init_ui() 63 | 64 | def init_ui(self): 65 | # make our UI 66 | self.setMinimumSize(900, 650) 67 | self.setWindowTitle("Image Tools") 68 | self.center() 69 | 70 | # our main app consists of two sections -- nav on left and content on right 71 | app_layout = QHBoxLayout() 72 | 73 | self.pages = OrderedDict([ 74 | (Tabs.DATASET, Dataset(self.app)), 75 | (Tabs.MODEL, Model(self.app)), 76 | (Tabs.EXPORT, Export(self.app)), 77 | (Tabs.FLICKR, Flickr(self.app)), 78 | (Tabs.VISUALIZE, Visualize(self.app)), 79 | ]) 80 | 81 | navbar = NavBar(self.nav_click, list(self.pages.keys())) 82 | # we are on the first page by default 83 | self.nav_click(list(self.pages.keys())[0]) 84 | 85 | app_layout.addWidget(navbar) 86 | for page in self.pages.values(): 87 | app_layout.addWidget(page) 88 | app_layout.setContentsMargins(0, 0, 0, 0) 89 | app_layout.setSpacing(0) 90 | 91 | # bind our widget and show 92 | window = QFrame() 93 | window.setObjectName("window") 94 | window.setLayout(app_layout) 95 | self.setCentralWidget(window) 96 | self.show() 97 | 98 | def center(self): 99 | qr = self.frameGeometry() 100 | cp = QDesktopWidget().availableGeometry().center() 101 | qr.moveCenter(cp) 102 | self.move(qr.topLeft()) 103 | 104 | def nav_click(self, button: str): 105 | if button != self.nav: 106 | for name, page in self.pages.items(): 107 | page.show() if button == name else page.hide() 108 | self.nav = button 109 | 110 | 111 | if __name__ == '__main__': 112 | freeze_support() 113 | app = QApplication(sys.argv) 114 | app.setWindowIcon(QtGui.QIcon(resource_path('icon.ico'))) 115 | 116 | w = MainWindow(app) 117 | w.setStyleSheet(f""" 118 | QFrame#window {{ 119 | background: {DARK_0}; 120 | }} 121 | QLabel {{ 122 | color: {TEXT_LIGHT}; 123 | font-size: 16px; 124 | }} 125 | QPushButton {{ 126 | color: {TEXT}; 127 | border: 1px solid {DARK_0}; 128 | background-color: {DARK_5}; 129 | border-radius: 10px; 130 | font-size: 14px; 131 | padding: 10px; 132 | padding-left: 20px; 133 | padding-right: 20px; 134 | }} 135 | QPushButton:pressed {{ 136 | color: {TEXT_LIGHT}; 137 | background-color: {DARK_4}; 138 | }} 139 | QPushButton:disabled {{ 140 | color: {TEXT_MEDIUM}; 141 | background-color: {DARK_3}; 142 | }} 143 | QComboBox {{ 144 | padding: 5px; 145 | padding-left: 10px; 146 | padding-right: 10px; 147 | }} 148 | QProgressBar {{ 149 | color: {TEXT_LIGHT} 150 | }} 151 | 152 | QFrame#navbar {{ 153 | padding-top: 0px; 154 | }} 155 | QLabel#logo {{ 156 | margin: 10px; 157 | }} 158 | QPushButton#navbutton {{ 159 | border: none; 160 | background-color: transparent; 161 | font-size: 18px; 162 | color: {TEXT_LIGHT}; 163 | margin-top: 10px; 164 | margin-left: 15px; 165 | margin-right: 15px; 166 | padding-top: 10px; 167 | padding-bottom: 10px; 168 | border-radius: 10px; 169 | }} 170 | QPushButton#navbutton:checked {{ 171 | background-color: {DARK_2}; 172 | color: {TEXT}; 173 | }} 174 | QPushButton#navbutton:pressed {{ 175 | background-color: {DARK_1}; 176 | }} 177 | 178 | QFrame#content {{ 179 | background: {DARK_1}; 180 | padding-left: 35px; 181 | padding-top: 20px; 182 | }} 183 | QLabel#h1 {{ 184 | font-size: 26px; 185 | font-weight: bold; 186 | color: {TEXT}; 187 | }} 188 | QLabel#h2 {{ 189 | font-size: 20px; 190 | }} 191 | QFrame#separate {{ 192 | margin-top: 30px; 193 | }} 194 | QFrame#separateSmall {{ 195 | margin-top: 15px; 196 | }} 197 | """) 198 | app.exec() 199 | -------------------------------------------------------------------------------- /app/components/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PyQt5.QtWidgets import (QPushButton, QVBoxLayout, QHBoxLayout, QFrame, QLabel, QFileDialog, QMessageBox, QComboBox, 3 | QProgressBar, QSizePolicy) 4 | from app.components.stretch_wrapper import NoStretch 5 | import pandas as pd 6 | from dataset.download_from_file import create_dataset 7 | 8 | 9 | class Dataset(QFrame): 10 | default_text = "Please select a file.<\i>" 11 | download_text = "Download" 12 | downloading_text = "Downloading..." 13 | 14 | def __init__(self, app): 15 | super().__init__() 16 | # initialize our variables 17 | self.app = app 18 | self.file = None 19 | self.init_ui() 20 | 21 | def init_ui(self): 22 | # make our UI 23 | self.setObjectName("content") 24 | layout = QHBoxLayout() 25 | layout.setContentsMargins(0, 0, 0, 0) 26 | 27 | # our main content area 28 | content = QFrame() 29 | content_layout = QVBoxLayout() 30 | 31 | # some info 32 | title = QLabel("Dataset") 33 | title.setObjectName("h1") 34 | description = QLabel( 35 | "Download images from URLs in a .csv or .xlsx file.\nOptionally, supply labels to organize your images into folders by label.") 36 | description.setObjectName("h2") 37 | 38 | # file selection button 39 | self.file_button = QPushButton("Select file") 40 | self.file_button.clicked.connect(self.select_file) 41 | button_container = NoStretch(self.file_button) 42 | button_container.setObjectName("separate") 43 | 44 | # display filepath 45 | self.path_label = QLabel(self.default_text) 46 | 47 | # url column header and optional label column header 48 | self.header_container = QFrame() 49 | self.header_container.setObjectName("separateSmall") 50 | header_layout = QVBoxLayout() 51 | header_layout.setContentsMargins(0, 0, 0, 0) 52 | url_label = QLabel("Column with image URLs:") 53 | self.url_dropdown = QComboBox() 54 | self.url_dropdown.setSizeAdjustPolicy(QComboBox.AdjustToContents) 55 | self.url_dropdown.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) 56 | url_container = NoStretch(self.url_dropdown) 57 | label_label = QLabel("(Optional) column with labels:") 58 | label_label.setObjectName("separateSmall") 59 | self.label_dropdown = QComboBox() 60 | self.label_dropdown.setSizeAdjustPolicy(QComboBox.AdjustToContents) 61 | self.label_dropdown.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) 62 | label_container = NoStretch(self.label_dropdown) 63 | header_layout.addWidget(url_label) 64 | header_layout.addWidget(url_container) 65 | header_layout.addWidget(label_label) 66 | header_layout.addWidget(label_container) 67 | self.header_container.setLayout(header_layout) 68 | self.header_container.hide() 69 | 70 | # download button 71 | self.download_button = QPushButton(self.download_text) 72 | self.download_button.setEnabled(False) 73 | self.download_button.clicked.connect(self.download) 74 | download_container = NoStretch(self.download_button) 75 | download_container.setObjectName("separate") 76 | 77 | self.progress_bar = QProgressBar() 78 | self.progress_bar.hide() 79 | 80 | # make our content layout 81 | content_layout.addWidget(title) 82 | content_layout.addWidget(description) 83 | content_layout.addWidget(button_container) 84 | content_layout.addWidget(self.path_label) 85 | content_layout.addWidget(self.header_container) 86 | content_layout.addWidget(download_container) 87 | content_layout.addWidget(self.progress_bar) 88 | content_layout.addStretch(1) 89 | content.setLayout(content_layout) 90 | 91 | layout.addWidget(content) 92 | layout.addStretch(1) 93 | self.setLayout(layout) 94 | 95 | def select_file(self): 96 | self.file = QFileDialog.getOpenFileName(self, 'Select CSV File', filter="CSV (*.csv *.xlsx)")[0] 97 | self.path_label.setText(f"{self.file}" if self.file else self.default_text) 98 | self.parse_headers() 99 | 100 | def parse_headers(self): 101 | if self.file: 102 | # read the file for its headers and set our dropdown boxes appropriately 103 | try: 104 | if os.path.splitext(self.file)[1] == ".csv": 105 | csv = pd.read_csv(self.file, header=0) 106 | else: 107 | csv = pd.read_excel(self.file, header=0) 108 | self.label_dropdown.clear() 109 | self.url_dropdown.clear() 110 | self.label_dropdown.addItem(None) 111 | for header in list(csv.columns): 112 | self.url_dropdown.addItem(header) 113 | self.label_dropdown.addItem(header) 114 | self.url_dropdown.adjustSize() 115 | self.header_container.show() 116 | self.download_button.setEnabled(True) 117 | except Exception as e: 118 | QMessageBox.about(self, "Alert", f"Error reading csv: {e}") 119 | self.clear_headers() 120 | else: 121 | self.clear_headers() 122 | 123 | def clear_headers(self): 124 | self.header_container.hide() 125 | self.url_dropdown.clear() 126 | self.label_dropdown.clear() 127 | self.download_button.setEnabled(False) 128 | 129 | def download(self): 130 | # disable the buttons so we can't click again 131 | self.download_button.setEnabled(False) 132 | self.download_button.setText(self.downloading_text) 133 | self.file_button.setEnabled(False) 134 | self.progress_bar.setValue(0) 135 | self.progress_bar.show() 136 | self.app.processEvents() 137 | url_col = self.url_dropdown.currentText() 138 | label_col = self.label_dropdown.currentText() 139 | destination_directory = QFileDialog.getExistingDirectory(self, "Select Output Directory") 140 | # if they hit cancel, don't download 141 | if not destination_directory: 142 | self.done() 143 | return 144 | # otherwise try downloading to the desired location 145 | try: 146 | create_dataset( 147 | filepath=self.file, url_col=url_col, label_col=label_col if label_col else None, 148 | progress_hook=self.progress_hook, destination_directory=destination_directory, 149 | ) 150 | except Exception as e: 151 | QMessageBox.about(self, "Alert", f"Error creating dataset: {e}") 152 | self.done() 153 | 154 | def progress_hook(self, current, total): 155 | self.progress_bar.setValue(float(current) / total * 100) 156 | if current == total: 157 | self.done() 158 | # make sure to update the UI 159 | self.app.processEvents() 160 | 161 | def done(self): 162 | self.progress_bar.setValue(0) 163 | self.progress_bar.hide() 164 | self.download_button.setEnabled(True) 165 | self.download_button.setText(self.download_text) 166 | self.file_button.setEnabled(True) 167 | self.app.processEvents() 168 | -------------------------------------------------------------------------------- /app/components/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PyQt5.QtWidgets import (QPushButton, QVBoxLayout, QHBoxLayout, QFrame, QLabel, QFileDialog, QMessageBox, QComboBox, 3 | QProgressBar, QSizePolicy) 4 | from app.components.stretch_wrapper import NoStretch 5 | import pandas as pd 6 | from model.predict_from_file import predict_dataset 7 | from model.predict_from_folder import predict_folder 8 | 9 | 10 | class Model(QFrame): 11 | default_model_text = "Please select a TensorFlow model directory.<\i>" 12 | default_file_text = "Please select a folder of images or a csv of URLs.<\i>" 13 | predict_text = "Predict" 14 | predicting_text = "Predicting..." 15 | 16 | def __init__(self, app): 17 | super().__init__() 18 | # initialize our variables 19 | self.app = app 20 | self.tf_directory = None 21 | self.file = None 22 | self.folder = None 23 | self.init_ui() 24 | 25 | def init_ui(self): 26 | # make our UI 27 | self.setObjectName("content") 28 | layout = QHBoxLayout() 29 | layout.setContentsMargins(0, 0, 0, 0) 30 | 31 | # our main content area 32 | content = QFrame() 33 | content_layout = QVBoxLayout() 34 | 35 | # some info 36 | title = QLabel("Model") 37 | title.setObjectName("h1") 38 | description = QLabel( 39 | "Run your exported TensorFlow model from Lobe \non a folder of images or a .csv/.xlsx file of image URLs.\nThis will produce a new prediction .csv with the image filepath or URL, \nthe model's prediction, and the model's confidence.") 40 | description.setObjectName("h2") 41 | 42 | # model select button 43 | self.model_button = QPushButton("Select model directory") 44 | self.model_button.clicked.connect(self.select_directory) 45 | model_container = NoStretch(self.model_button) 46 | model_container.setObjectName("separate") 47 | self.model_label = QLabel(self.default_model_text) 48 | 49 | # file or folder selection button 50 | self.folder_button = QPushButton("Select folder") 51 | self.folder_button.clicked.connect(self.select_image_folder) 52 | self.file_button = QPushButton("Select file") 53 | self.file_button.clicked.connect(self.select_file) 54 | buttons_container = NoStretch([self.folder_button, self.file_button]) 55 | buttons_container.setObjectName("separate") 56 | self.path_label = QLabel(self.default_file_text) 57 | 58 | # url column header 59 | self.url_label = QLabel("Column with image URLs:") 60 | self.url_label.setObjectName("separateSmall") 61 | self.url_label.hide() 62 | self.url_dropdown = QComboBox() 63 | self.url_dropdown.setSizeAdjustPolicy(QComboBox.AdjustToContents) 64 | self.url_dropdown.setSizePolicy(QSizePolicy.Preferred, QSizePolicy.Expanding) 65 | self.url_container = NoStretch(self.url_dropdown) 66 | self.url_container.hide() 67 | 68 | # predict button 69 | self.predict_button = QPushButton(self.predict_text) 70 | self.predict_button.setEnabled(False) 71 | self.predict_button.clicked.connect(self.predict) 72 | predict_container = NoStretch(self.predict_button) 73 | predict_container.setObjectName("separate") 74 | 75 | self.progress_bar = QProgressBar() 76 | self.progress_bar.hide() 77 | 78 | # make our content layout 79 | content_layout.addWidget(title) 80 | content_layout.addWidget(description) 81 | content_layout.addWidget(model_container) 82 | content_layout.addWidget(self.model_label) 83 | content_layout.addWidget(buttons_container) 84 | content_layout.addWidget(self.path_label) 85 | content_layout.addWidget(self.url_label) 86 | content_layout.addWidget(self.url_container) 87 | content_layout.addWidget(predict_container) 88 | content_layout.addWidget(self.progress_bar) 89 | content_layout.addStretch(1) 90 | content.setLayout(content_layout) 91 | 92 | layout.addWidget(content) 93 | layout.addStretch(1) 94 | self.setLayout(layout) 95 | 96 | def select_directory(self): 97 | self.tf_directory = QFileDialog.getExistingDirectory(self, "Select TensorFlow Model Directory") 98 | self.model_label.setText(f"{self.tf_directory}" if self.tf_directory else self.default_model_text) 99 | self.check_predict_button() 100 | 101 | def select_file(self): 102 | self.file = QFileDialog.getOpenFileName(self, 'Select CSV File', filter="CSV (*.csv *.xlsx)")[0] 103 | self.path_label.setText(f"{self.file}" if self.file else self.default_file_text) 104 | self.folder = None 105 | self.parse_headers() 106 | self.check_predict_button() 107 | 108 | def select_image_folder(self): 109 | self.folder = QFileDialog.getExistingDirectory(self, "Select Images Directory") 110 | self.path_label.setText(f"{self.folder}" if self.folder else self.default_file_text) 111 | self.file = None 112 | self.parse_headers() 113 | self.check_predict_button() 114 | 115 | def check_predict_button(self): 116 | # enable the button when we have both a model and file 117 | if self.tf_directory and (self.file or self.folder): 118 | self.predict_button.setEnabled(True) 119 | else: 120 | self.predict_button.setEnabled(False) 121 | 122 | def parse_headers(self): 123 | if self.file: 124 | # read the file for its headers and set our dropdown boxes appropriately 125 | try: 126 | if os.path.splitext(self.file)[1] == ".csv": 127 | csv = pd.read_csv(self.file, header=0) 128 | else: 129 | csv = pd.read_excel(self.file, header=0) 130 | self.url_dropdown.clear() 131 | for header in list(csv.columns): 132 | self.url_dropdown.addItem(header) 133 | self.url_dropdown.adjustSize() 134 | self.url_label.show() 135 | self.url_container.show() 136 | except Exception as e: 137 | QMessageBox.about(self, "Alert", f"Error reading csv: {e}") 138 | self.clear_headers() 139 | else: 140 | self.clear_headers() 141 | 142 | def clear_headers(self): 143 | self.url_dropdown.clear() 144 | self.url_label.hide() 145 | self.url_container.hide() 146 | 147 | def predict(self): 148 | # disable the buttons so we can't click again 149 | self.predict_button.setEnabled(False) 150 | self.predict_button.setText(self.predicting_text) 151 | self.model_button.setEnabled(False) 152 | self.file_button.setEnabled(False) 153 | self.folder_button.setEnabled(False) 154 | self.progress_bar.setValue(0) 155 | self.progress_bar.show() 156 | self.app.processEvents() 157 | url_col = self.url_dropdown.currentText() 158 | try: 159 | if self.file: 160 | predict_dataset(model_dir=self.tf_directory, filepath=self.file, url_col=url_col, 161 | progress_hook=self.progress_hook) 162 | elif self.folder: 163 | predict_folder(model_dir=self.tf_directory, img_dir=self.folder, move=True, csv=True, 164 | progress_hook=self.progress_hook) 165 | except Exception as e: 166 | QMessageBox.about(self, "Alert", f"Error predicting: {e}") 167 | finally: 168 | self.done() 169 | 170 | def progress_hook(self, current, total): 171 | self.progress_bar.setValue(float(current) / total * 100) 172 | if current == total: 173 | self.done() 174 | # make sure to update the UI 175 | self.app.processEvents() 176 | 177 | def done(self): 178 | self.progress_bar.setValue(0) 179 | self.progress_bar.hide() 180 | self.predict_button.setEnabled(True) 181 | self.predict_button.setText(self.predict_text) 182 | self.model_button.setEnabled(True) 183 | self.file_button.setEnabled(True) 184 | self.folder_button.setEnabled(True) 185 | self.app.processEvents() 186 | -------------------------------------------------------------------------------- /dataset/export_from_lobe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Export your dataset from a Lobe project 3 | """ 4 | import argparse 5 | from sys import platform 6 | import os 7 | import json 8 | import sqlite3 9 | from concurrent.futures import ThreadPoolExecutor, as_completed 10 | from threading import Lock 11 | from tqdm import tqdm 12 | from PIL import Image 13 | from dataset.utils import _resolve_filename_conflict 14 | 15 | if platform == 'darwin': 16 | PROJECTS_DIR_MAC = '~/Library/Application Support/lobe/projects' 17 | PROJECTS_DIR = os.path.realpath(os.path.expanduser(PROJECTS_DIR_MAC)) 18 | else: 19 | PROJECTS_DIR_WINDOWS = os.path.join(os.getenv('APPDATA'), 'lobe', 'projects') 20 | PROJECTS_DIR = os.path.realpath(PROJECTS_DIR_WINDOWS) 21 | 22 | PROJECT_JSON_FILE = 'project.json' 23 | PROJECT_ID_KEY = 'id' 24 | PROJECT_META_KEY = 'meta' 25 | PROJECT_NAME_KEY = 'name' 26 | PROJECT_DB_FILE = 'db.sqlite' 27 | PROJECT_BLOBS = os.path.join('data', 'blobs') 28 | 29 | 30 | def get_projects(): 31 | """ 32 | Returns tuples of (project name, project id) from Lobe's appdata directory, sorted by modified date 33 | """ 34 | projects = [] 35 | for project in os.listdir(PROJECTS_DIR): 36 | project_dir = os.path.join(PROJECTS_DIR, project) 37 | if os.path.isdir(project_dir): 38 | try: 39 | project_json_file = os.path.join(project_dir, PROJECT_JSON_FILE) 40 | with open(project_json_file, 'r') as f: 41 | project_json = json.load(f) 42 | project_id = project_json.get(PROJECT_ID_KEY) 43 | project_name = project_json.get(PROJECT_META_KEY, {}).get(PROJECT_NAME_KEY) 44 | # return the name, the id, and the modified datetime for sorting by recency 45 | projects.append(((project_name, project_id), os.path.getmtime(project_json_file))) 46 | except Exception: 47 | # didn't have the project.json file (old alpha projects) 48 | pass 49 | projects = sorted(projects, key=lambda x: x[1], reverse=True) # sort by the modified date 50 | return [info for info, _ in projects] 51 | 52 | 53 | def export_dataset(project_id, destination_dir, progress_hook=None, batch_size=1000): 54 | """ 55 | Given a project id and a destination export parent directory, copy the images into a subfolder structure 56 | """ 57 | # make the desired destination if it doesn't exist 58 | os.makedirs(destination_dir, exist_ok=True) 59 | # project directory doesn't include the '-' from the project uuid 60 | project_dir = os.path.join(PROJECTS_DIR, project_id.replace('-', '')) 61 | blob_dir = os.path.join(project_dir, PROJECT_BLOBS) 62 | # connect to our project db 63 | db_file = os.path.join(project_dir, PROJECT_DB_FILE) 64 | conn = None 65 | try: 66 | # db connection 67 | conn = sqlite3.connect(db_file) 68 | cursor = conn.cursor() 69 | # go through the data item entries in the db, find the blob filenames, and save to the appropriate location 70 | # first get the total number of images for our progress bar 71 | cursor.execute("SELECT count(*) FROM example_images") 72 | num_images = cursor.fetchone() 73 | if not num_images: 74 | print(f"Didn't find any images for project {project_id}") 75 | else: 76 | num_images = num_images[0] 77 | futures = [] 78 | lock = Lock() 79 | examples_query = """ 80 | SELECT example_images.hash, example_labels.label 81 | FROM example_images LEFT JOIN example_labels 82 | ON example_images.example_id = example_labels.example_id 83 | LIMIT ? 84 | OFFSET ? 85 | """ 86 | with tqdm(total=num_images) as pbar: 87 | with ThreadPoolExecutor() as executor: 88 | for offset in range(0, num_images, batch_size): 89 | cursor.execute(examples_query, [batch_size, offset]) 90 | res = cursor.fetchall() 91 | for row in res: 92 | img_hash, label = row 93 | # get the image filepath from the hash 94 | img_filepath = os.path.join(blob_dir, img_hash) 95 | # if we had a label, make the destination directory the subdirectory with label name 96 | dest_dir = os.path.join(destination_dir, label) if label is not None else destination_dir 97 | futures.append( 98 | executor.submit( 99 | _export_blob, blob_path=img_filepath, destination_dir=dest_dir, lock=lock 100 | ) 101 | ) 102 | 103 | num_processed = 0 104 | # wait for all our futures 105 | for _ in as_completed(futures): 106 | # update our progress bar for the finished image 107 | pbar.update(1) 108 | num_processed += 1 109 | if progress_hook: 110 | progress_hook(num_processed, num_images) 111 | except Exception as e: 112 | print(f"Error exporting project {project_id} to {destination_dir}:\n{e}") 113 | finally: 114 | if conn: 115 | conn.close() 116 | 117 | 118 | def _export_blob(blob_path, destination_dir, lock=None): 119 | """ 120 | Export the image to the destination, resolving names on conflict 121 | """ 122 | os.makedirs(destination_dir, exist_ok=True) 123 | # get our image and save it with the native format in our new directory 124 | # get the blob id from the blob path 125 | blob_id = os.path.basename(blob_path) 126 | img = Image.open(blob_path) 127 | img_filename = f'{blob_id}.{img.format.lower()}' 128 | # look for file name conflict and resolve 129 | if lock: 130 | with lock: 131 | img_filename = _resolve_filename_conflict(directory=destination_dir, filename=img_filename) 132 | # now that we found the filename, make an empty file with it so that we don't have to wait file to download 133 | # for subsequent name searches with threading 134 | open(os.path.join(destination_dir, img_filename), 'a').close() 135 | else: 136 | img_filename = _resolve_filename_conflict(directory=destination_dir, filename=img_filename) 137 | # now save the file 138 | destination_file = os.path.join(destination_dir, img_filename) 139 | img.save(destination_file, quality=100) 140 | 141 | 142 | if __name__ == '__main__': 143 | parser = argparse.ArgumentParser(description='Export an image dataset from Lobe.') 144 | parser.add_argument('project', help='Your project name.', type=str) 145 | parser.add_argument('dest', help='Your destination export directory.', type=str, default='.') 146 | args = parser.parse_args() 147 | project_name, project_id = None, None 148 | for name_, id_ in get_projects(): 149 | if name_ == args.project: 150 | project_name = name_ 151 | project_id = id_ 152 | break 153 | if project_name: 154 | export_dataset(project_id=project_id, destination_dir=os.path.join(os.path.abspath(args.dest), project_name)) 155 | else: 156 | print(f"Couldn't find project with name {args.project}.\nAvailable projects: {[name_ for name_, _ in get_projects()]}") 157 | -------------------------------------------------------------------------------- /dataset/download_from_flickr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download images from latitude and longitude bounding boxes in flickr 3 | """ 4 | import argparse 5 | import os 6 | import csv 7 | import xml.etree.ElementTree as ET 8 | from typing import Optional, Tuple 9 | import requests 10 | from tqdm import tqdm 11 | from concurrent.futures import ThreadPoolExecutor, as_completed 12 | from threading import Lock 13 | from dataset.utils import download_image 14 | 15 | 16 | def download_flickr( 17 | api_key, directory, 18 | min_lat=None, min_long=None, max_lat=None, max_long=None, 19 | search=None, size='z', progress_hook=None 20 | ): 21 | base_url = 'https://www.flickr.com/services/rest/' 22 | search_params = { 23 | 'api_key': api_key, 24 | 'method': 'flickr.photos.search', 25 | 'page': 1, 26 | 'per_page': 250, 27 | 'media': 'photos', 28 | } 29 | if None not in [min_lat, min_long, max_lat, max_long]: 30 | search_params['bbox'] = ','.join([min_long, min_lat, max_long, max_lat]) 31 | if search is not None: 32 | search_params['text'] = search 33 | # everything in try/catch for keyboard interrupt 34 | print(f"Searching Flickr with params {search_params}") 35 | try: 36 | response = requests.get(url=base_url, params=search_params) 37 | duplicates = 0 38 | search_errors = 0 39 | download_errors = 0 40 | downloaded_images = 0 41 | img_urls = [] 42 | csv_lock = Lock() 43 | filesystem_lock = Lock() 44 | num_processed = 0 45 | search_imgs = 0 46 | if response.ok: 47 | root = ET.fromstring(response.content) 48 | page = root.find('photos') 49 | total_images = int(page.get('total')) 50 | pages = int(page.get('pages')) 51 | print(f"Found {total_images} images for location min: ({min_lat}, {min_long}) max: ({max_lat}, {max_long}) and search term '{search}' | {pages} pages") 52 | total_jobs = pages+total_images 53 | with tqdm(total=total_jobs) as pbar: 54 | with ThreadPoolExecutor() as executor: 55 | # run the search page parser 56 | search_futures = [] 57 | for i in range(1, pages+1): 58 | search_futures.append( 59 | executor.submit( 60 | images_from_search, page_index=i, base_url=base_url, search_params=search_params) 61 | ) 62 | 63 | # now for all the search results, start downloading 64 | download_futures = {} 65 | for future in as_completed(search_futures): 66 | try: 67 | for farm_id, server_id, photo_id, secret in future.result(): 68 | search_imgs += 1 69 | # the image download url from the search result info 70 | img_url = f"https://farm{farm_id}.staticflickr.com/{server_id}/{photo_id}_{secret}_{size}.jpg" 71 | # don't download duplicates 72 | if img_url not in img_urls: 73 | img_urls.append(img_url) 74 | # submit job to download the image 75 | download_futures[ 76 | executor.submit(download_image, url=img_url, directory=directory, lock=filesystem_lock) 77 | ] = (photo_id, secret, img_url) 78 | else: 79 | # duplicate found, so don't count this image for jobs 80 | duplicates += 1 81 | total_jobs -= 1 82 | pbar.total = total_jobs 83 | pbar.refresh() 84 | if progress_hook: 85 | progress_hook(num_processed, total_jobs) 86 | # update progress bar for search page 87 | pbar.update(1) 88 | num_processed += 1 89 | if progress_hook: 90 | progress_hook(num_processed, total_jobs) 91 | except Exception: 92 | # search page error, so don't count this page for jobs 93 | search_errors += 1 94 | total_jobs -= 1 95 | pbar.total = total_jobs 96 | pbar.refresh() 97 | if progress_hook: 98 | progress_hook(num_processed, total_jobs) 99 | 100 | # now for all of our downloaded images, write the csv with info if we can 101 | info_futures = [] 102 | for future in as_completed(download_futures): 103 | photo_id, secret, url = download_futures[future] 104 | filename = future.result() 105 | if not filename: 106 | # image download error, so don't count this image for jobs 107 | download_errors += 1 108 | total_jobs -= 1 109 | pbar.total = total_jobs 110 | pbar.refresh() 111 | if progress_hook: 112 | progress_hook(num_processed, total_jobs) 113 | info_futures.append( 114 | executor.submit( 115 | write_photo_csv, 116 | directory=directory, base_url=base_url, api_key=api_key, img_filename=filename, 117 | url=url, photo_id=photo_id, secret=secret, lock=csv_lock 118 | ) 119 | ) 120 | 121 | # wait for all our final csv jobs to finish 122 | for _ in as_completed(info_futures): 123 | # update our progress bar for the finished image download and csv write 124 | pbar.update(1) 125 | downloaded_images += 1 126 | num_processed += 1 127 | if progress_hook: 128 | progress_hook(num_processed, total_jobs) 129 | 130 | # update our progress to be 100% 131 | # (because original number of images reported from flickr api search can be incorrect) 132 | pbar.update(total_jobs-num_processed) 133 | print(f"Downloaded {downloaded_images}\nSearch errors: {search_errors} | Duplicates: {duplicates} | Download errors: {download_errors} | Found {search_imgs} images") 134 | if progress_hook: 135 | progress_hook(total_jobs, total_jobs) 136 | except Exception: 137 | raise 138 | 139 | 140 | def parse_search_xml(xml): 141 | root = ET.fromstring(xml) 142 | page = root.find('photos') 143 | for photo in page: 144 | farm_id = photo.get('farm') 145 | server_id = photo.get('server') 146 | photo_id = photo.get('id') 147 | secret = photo.get('secret') 148 | yield farm_id, server_id, photo_id, secret 149 | 150 | 151 | def images_from_search(page_index, base_url, search_params): 152 | search_params['page'] = page_index 153 | response = requests.get(url=base_url, params=search_params, timeout=30) 154 | if response.ok: 155 | return parse_search_xml(response.content) 156 | return [] 157 | 158 | 159 | def get_photo_location(url, api_key, photo_id) -> Tuple[Optional[float], Optional[float], Optional[float]]: 160 | """ 161 | Given the url and photo details, return the latitude, longitude, and accuracy 162 | """ 163 | try: 164 | exif_params = { 165 | 'api_key': api_key, 166 | 'method': 'flickr.photos.geo.getLocation', 167 | 'photo_id': photo_id, 168 | } 169 | response = requests.get(url=url, params=exif_params) 170 | if response.ok: 171 | root = ET.fromstring(response.content) 172 | photo = root.find('photo') 173 | location = photo.find('location') 174 | latitude = location.get('latitude') 175 | longitude = location.get('longitude') 176 | accuracy = location.get('accuracy') 177 | return latitude, longitude, accuracy 178 | except Exception: 179 | pass 180 | return None, None, None 181 | 182 | 183 | def get_photo_info(url, api_key, photo_id, secret) -> Tuple[Optional[str], Optional[str], Optional[str]]: 184 | """ 185 | Given the url and photo details, return the user id, title, description, and date taken for the photo 186 | """ 187 | try: 188 | info_params = { 189 | 'api_key': api_key, 190 | 'method': 'flickr.photos.getInfo', 191 | 'photo_id': photo_id, 192 | 'secret': secret, 193 | } 194 | response = requests.get(url=url, params=info_params) 195 | if response.ok: 196 | root = ET.fromstring(response.content) 197 | photo = root.find('photo') 198 | owner = photo.find('owner') 199 | user_id = owner.get('nsid') 200 | title = photo.find('title').text 201 | dates = photo.find('dates') 202 | date_taken = dates.get('taken') 203 | return user_id, title, date_taken 204 | except Exception: 205 | pass 206 | return None, None, None 207 | 208 | 209 | def write_photo_csv(directory, base_url, api_key, img_filename, url, photo_id, secret, lock): 210 | out_file = os.path.join(directory, 'images.csv') 211 | latitude, longitude, accuracy = get_photo_location(url=base_url, api_key=api_key, photo_id=photo_id) 212 | user_id, title, date_taken = get_photo_info(url=base_url, api_key=api_key, photo_id=photo_id, secret=secret) 213 | with lock: 214 | make_header = not os.path.isfile(out_file) 215 | with open(out_file, 'a', newline='', encoding='utf-8') as f: 216 | writer = csv.writer(f) 217 | # make this header if not a file already 218 | if make_header: 219 | writer.writerow(['File', 'URL', 'User ID', 'Title', 'Date Taken', 'Latitude', 'Longitude', 'Geo Accuracy']) 220 | # write to csv the filename, url, gps data 221 | try: 222 | writer.writerow([img_filename, url, user_id, title, date_taken, latitude, longitude, accuracy]) 223 | except Exception as e: 224 | # probably has exception on title -- make that None (from non-encodable character) 225 | writer.writerow([img_filename, url, user_id, None, date_taken, latitude, longitude, accuracy]) 226 | 227 | 228 | if __name__ == '__main__': 229 | parser = argparse.ArgumentParser(description='Download images from flickr by geo location.') 230 | parser.add_argument('api', type=str, help='Your Flickr API key.') 231 | parser.add_argument('directory', help='Directory to download the images to.') 232 | parser.add_argument( 233 | '--bbox', type=str, 234 | help='Geographic bounding box to search. Comma separated list of: Min Latitude, Min Longitude, Max Latitude, Max Longitude', 235 | default=None 236 | ) 237 | parser.add_argument('--search', type=str, help='Search term to use.', default=None) 238 | args = parser.parse_args() 239 | if args.bbox is not None: 240 | min_lat, min_long, max_lat, max_long = [float(arg.strip()) for arg in args.bbox.split(',')] 241 | else: 242 | min_lat, min_long, max_lat, max_long = None, None, None, None 243 | download_flickr( 244 | api_key=args.api, directory=args.directory, 245 | min_lat=min_lat, min_long=min_long, max_lat=max_lat, max_long=max_long, 246 | search=args.search, 247 | ) 248 | --------------------------------------------------------------------------------