├── tests ├── __init__.py ├── README.md ├── conftest.py ├── test_base_tab.py ├── test_tab_manager.py ├── test_images_tab.py ├── test_camera_models.py ├── test_image_processing.py └── test_main_app.py ├── utils ├── file_manager.py ├── logger.py ├── datasets │ ├── download_dataset.py │ ├── normalize.py │ └── traj.py └── gsplat_utils │ └── utils.py ├── app ├── config_loader.py ├── tabs │ ├── __init__.py │ ├── gsplat_tab.py │ ├── reconstruct_tab.py │ ├── features_tab.py │ ├── masks_tab.py │ ├── matching_tab.py │ └── images_tab.py ├── tab_manager.py ├── base_tab.py ├── point_cloud_utils.py ├── point_cloud_viewer.py ├── point_cloud_visualizer.py ├── image_processing.py ├── camera_models.py └── mask_manager.py ├── entrypoint.sh ├── main.py ├── .gitmodules ├── .github └── workflows │ └── docker-image.yml ├── models └── README.md ├── README_TESTS.md ├── README.md ├── Dockerfile ├── LICENSE └── config └── config.yaml /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Tests package initialization -------------------------------------------------------------------------------- /utils/file_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def create_directory(path): 4 | if not os.path.exists(path): 5 | os.makedirs(path) -------------------------------------------------------------------------------- /app/config_loader.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | def load_config(config_path): 4 | with open(config_path, 'r') as file: 5 | config = yaml.safe_load(file) 6 | return config 7 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import os 3 | 4 | def setup_logger(log_dir='logs', log_file='app.log'): 5 | os.makedirs(log_dir, exist_ok=True) 6 | logger.add(os.path.join(log_dir, log_file)) 7 | return logger -------------------------------------------------------------------------------- /entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ "$1" == "--dac" ]]; then 4 | echo "Running DAC setup..." 5 | cd /depth_any_camera/dac/models/ops && pip install -e . && cd /source/splat_one && bash 6 | else 7 | echo "Default command executed" 8 | exec "$@" 9 | fi 10 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # main.py 2 | import sys 3 | from PyQt5.QtWidgets import QApplication 4 | from app.main_app import MainApp 5 | 6 | if __name__ == "__main__": 7 | app = QApplication(sys.argv) 8 | main_window = MainApp() 9 | main_window.show() 10 | sys.exit(app.exec_()) 11 | -------------------------------------------------------------------------------- /app/tabs/__init__.py: -------------------------------------------------------------------------------- 1 | # tabs package initialization 2 | from app.tabs.images_tab import ImagesTab 3 | from app.tabs.masks_tab import MasksTab 4 | from app.tabs.depth_tab import DepthTab 5 | from app.tabs.features_tab import FeaturesTab 6 | from app.tabs.matching_tab import MatchingTab 7 | from app.tabs.reconstruct_tab import ReconstructTab 8 | from app.tabs.gsplat_tab import GsplatTab 9 | 10 | # All available tabs 11 | __all__ = [ 12 | 'ImagesTab', 13 | 'MasksTab', 14 | 'DepthTab', 15 | 'FeaturesTab', 16 | 'MatchingTab', 17 | 'ReconstructTab', 18 | 'GsplatTab' 19 | ] -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # SPLAT_ONE Tests 2 | 3 | This directory contains tests for the SPLAT_ONE application. 4 | 5 | ## Running Tests 6 | 7 | To run the tests, make sure you have pytest installed: 8 | 9 | ```bash 10 | pip install pytest pytest-cov 11 | ``` 12 | 13 | Then run the tests from the project root directory: 14 | 15 | ```bash 16 | pytest tests/ 17 | ``` 18 | 19 | For coverage reports: 20 | 21 | ```bash 22 | pytest --cov=app tests/ 23 | ``` 24 | 25 | ## Test Organization 26 | 27 | - `conftest.py`: Common fixtures and test configurations 28 | - `test_*.py`: Test modules for different components 29 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/sam2"] 2 | path = submodules/sam2 3 | url = https://github.com/facebookresearch/sam2.git 4 | [submodule "submodules/opensfm"] 5 | path = submodules/opensfm 6 | url = https://github.com/inuex35/ind-bermuda-opensfm.git 7 | [submodule "submodules/lightglue"] 8 | path = submodules/lightglue 9 | url = https://github.com/cvg/LightGlue.git 10 | [submodule "submodules/aliked"] 11 | path = submodules/aliked 12 | url = https://github.com/Shiaoming/ALIKED.git 13 | [submodule "submodules/gsplat"] 14 | path = submodules/gsplat 15 | url = https://github.com/inuex35/gsplat.git 16 | branch = spherical_render 17 | [submodule "submodules/Depth-Anything-V2"] 18 | path = submodules/Depth-Anything-V2 19 | url = https://github.com/DepthAnything/Depth-Anything-V2.git 20 | [submodule "submodules/depth_any_camera"] 21 | path = submodules/depth_any_camera 22 | url = https://github.com/DepthAnything/Depth-Any-Camera.git 23 | -------------------------------------------------------------------------------- /.github/workflows/docker-image.yml: -------------------------------------------------------------------------------- 1 | # .github/workflows/docker-image.yml 2 | 3 | name: Build and Push Docker Image 4 | 5 | on: 6 | push: 7 | branches: 8 | - main 9 | pull_request: 10 | branches: 11 | - main 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Free Disk Space 18 | uses: jlumbroso/free-disk-space@main 19 | with: 20 | # オプションはプロジェクトの要件に合わせて調整可能です 21 | tool-cache: false 22 | android: true 23 | dotnet: true 24 | haskell: true 25 | large-packages: true 26 | docker-images: true 27 | swap-storage: true 28 | - name: Checkout code 29 | uses: actions/checkout@v3 30 | 31 | - name: Set up Docker Buildx 32 | uses: docker/setup-buildx-action@v2 33 | 34 | - name: Log in to Docker Hub 35 | uses: docker/login-action@v2 36 | with: 37 | username: ${{ secrets.DOCKER_USERNAME }} 38 | password: ${{ secrets.DOCKER_PASSWORD }} 39 | 40 | - name: Build and push Docker image 41 | uses: docker/build-push-action@v5 42 | with: 43 | context: . 44 | push: true 45 | tags: inuex35/splat_one:latest 46 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | # Depth Estimation Models 2 | 3 | This directory should contain the pre-trained models for depth estimation. 4 | 5 | ## Depth Anything V2 Models 6 | 7 | Download the models from the official repository: 8 | - https://github.com/DepthAnything/Depth-Anything-V2 9 | 10 | ### Recommended models: 11 | - **depth_anything_v2_vitl.pth** - Large model (best quality) 12 | - **depth_anything_v2_vits.pth** - Small model (faster inference) 13 | 14 | Download links: 15 | - ViT-L: https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth 16 | - ViT-S: https://huggingface.co/depth-anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth 17 | 18 | ## DAC (Depth Anything Camera) Models 19 | 20 | Download the models from: 21 | - https://github.com/xanderchf/depth_any_camera 22 | 23 | ### Available models: 24 | - **dac_vitl_hypersim.pth** - Trained on HyperSim dataset 25 | - **dac_vitl_vkitti.pth** - Trained on Virtual KITTI dataset 26 | 27 | Choose based on your use case: 28 | - HyperSim: Better for indoor scenes 29 | - VKITTI: Better for outdoor/driving scenes 30 | 31 | ## Installation 32 | 33 | 1. Create this models directory if it doesn't exist 34 | 2. Download the desired model files 35 | 3. Place them in this directory 36 | 4. The depth tab will automatically detect and use them 37 | 38 | ## Note 39 | 40 | The models are large files (several hundred MB each). Make sure you have enough disk space. -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Common test fixtures and configurations 2 | import os 3 | import sys 4 | import pytest 5 | from unittest.mock import MagicMock 6 | import tempfile 7 | import shutil 8 | 9 | # Add the parent directory to path so we can import app modules 10 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 11 | 12 | @pytest.fixture 13 | def mock_qapplication(monkeypatch): 14 | """Mock QApplication to avoid GUI during tests""" 15 | mock_app = MagicMock() 16 | monkeypatch.setattr('PyQt5.QtWidgets.QApplication', MagicMock(return_value=mock_app)) 17 | return mock_app 18 | 19 | @pytest.fixture 20 | def temp_workdir(): 21 | """Create a temporary working directory for tests""" 22 | temp_dir = tempfile.mkdtemp() 23 | 24 | # Create necessary subdirectories 25 | os.makedirs(os.path.join(temp_dir, 'images'), exist_ok=True) 26 | os.makedirs(os.path.join(temp_dir, 'exif'), exist_ok=True) 27 | os.makedirs(os.path.join(temp_dir, 'masks'), exist_ok=True) 28 | 29 | try: 30 | yield temp_dir 31 | finally: 32 | # Cleanup after tests 33 | shutil.rmtree(temp_dir) 34 | 35 | @pytest.fixture 36 | def sample_image_data(): 37 | """Sample image data for testing""" 38 | return [ 39 | {'name': 'image1.jpg', 'width': 1920, 'height': 1080, 'camera': 'Camera1'}, 40 | {'name': 'image2.jpg', 'width': 1920, 'height': 1080, 'camera': 'Camera1'}, 41 | {'name': 'image3.jpg', 'width': 3840, 'height': 2160, 'camera': 'Camera2'} 42 | ] 43 | -------------------------------------------------------------------------------- /README_TESTS.md: -------------------------------------------------------------------------------- 1 | # SPLAT_ONE Testing Documentation 2 | 3 | ## Overview 4 | 5 | This document provides information about the testing system implemented for SPLAT_ONE. The test suite is designed to ensure code reliability, maintainability, and proper functionality across the application. 6 | 7 | ## Changes in Code Structure 8 | 9 | The original codebase has been refactored to improve modularity and testability: 10 | 11 | 1. **Module Separation**: 12 | - Core functionality has been separated into domain-specific modules 13 | - UI components are now organized into a logical hierarchy 14 | - Business logic has been moved from UI classes where possible 15 | 16 | 2. **Object-Oriented Design**: 17 | - Added abstraction through base classes and inheritance 18 | - Improved encapsulation by moving related functionality into specific classes 19 | - Reduced code duplication through shared behaviors 20 | 21 | 3. **Testing Infrastructure**: 22 | - Added comprehensive test suite using pytest 23 | - Created fixtures and test helpers for common testing scenarios 24 | - Implemented mocking for external dependencies 25 | 26 | ## Testing Structure 27 | 28 | The tests are organized as follows: 29 | 30 | - `tests/conftest.py`: Common fixtures and test utilities 31 | - `tests/test_*.py`: Individual test modules for each application component 32 | - `tests/README.md`: Documentation for the test system 33 | 34 | ## Running Tests 35 | 36 | To run the test suite, make sure you have pytest and the required dependencies installed: 37 | 38 | ```bash 39 | pip install pytest pytest-cov pytest-mock 40 | ``` 41 | 42 | From the project root directory, run: 43 | 44 | ```bash 45 | pytest tests/ 46 | ``` 47 | 48 | For a coverage report: 49 | 50 | ```bash 51 | pytest --cov=app tests/ 52 | ``` 53 | 54 | To run specific tests: 55 | 56 | ```bash 57 | pytest tests/test_camera_models.py 58 | ``` 59 | 60 | ## Test Types 61 | 62 | 1. **Unit Tests**: 63 | - Test individual functions and classes in isolation 64 | - Mock dependencies to focus on the unit under test 65 | - Fast execution for quick feedback 66 | 67 | 2. **Integration Tests**: 68 | - Test interactions between components 69 | - Verify proper communication between modules 70 | - Ensure system works as a whole 71 | 72 | ## Mocking Strategy 73 | 74 | To avoid dependencies on GUI components and external libraries during testing, we use a combination of: 75 | 76 | - `unittest.mock` for Python standard library mocking 77 | - `pytest-mock` for pytest-style fixtures 78 | - Custom mock objects for complex dependencies 79 | 80 | ## Future Testing Improvements 81 | 82 | 1. Add more tests for other tabs and components as they are refactored 83 | 2. Implement UI testing for checking GUI interactions 84 | 3. Add performance tests for critical operations 85 | 4. Implement continuous integration for automatic test runs -------------------------------------------------------------------------------- /app/tab_manager.py: -------------------------------------------------------------------------------- 1 | # tab_manager.py 2 | 3 | from PyQt5.QtWidgets import QTabWidget 4 | from app.base_tab import BaseTab 5 | 6 | class TabManager(QTabWidget): 7 | """Manager for application tabs""" 8 | def __init__(self, parent=None): 9 | super().__init__(parent) 10 | self.tab_instances = {} 11 | self.parent_app = parent 12 | self.currentChanged.connect(self.on_tab_changed) 13 | self.active_tab_index = -1 14 | 15 | def register_tab(self, tab_class, tab_name=None, *args, **kwargs): 16 | """Register a tab with the manager""" 17 | # Create the tab instance 18 | tab_instance = tab_class(*args, parent=self.parent_app, **kwargs) 19 | 20 | # Use provided name or get from the tab 21 | if tab_name is None and isinstance(tab_instance, BaseTab): 22 | tab_name = tab_instance.get_tab_name() 23 | elif tab_name is None: 24 | tab_name = tab_class.__name__ 25 | 26 | # Add to tab widget 27 | index = self.addTab(tab_instance, tab_name) 28 | 29 | # Store reference 30 | self.tab_instances[index] = tab_instance 31 | 32 | # If this is the first tab, set it as active 33 | if index == 0: 34 | self.active_tab_index = 0 35 | # Initialize the first tab immediately 36 | if isinstance(tab_instance, BaseTab) and hasattr(tab_instance, 'on_tab_activated'): 37 | tab_instance.on_tab_activated() 38 | 39 | return index 40 | 41 | def get_tab_instance(self, index): 42 | """Get the tab instance at the given index""" 43 | return self.tab_instances.get(index) 44 | 45 | def get_current_tab(self): 46 | """Get the currently active tab instance""" 47 | index = self.currentIndex() 48 | return self.get_tab_instance(index) 49 | 50 | def on_tab_changed(self, index): 51 | """Handle tab changed event""" 52 | # Deactivate previous tab 53 | prev_tab = self.get_tab_instance(self.active_tab_index) 54 | if prev_tab and hasattr(prev_tab, 'on_tab_deactivated'): 55 | prev_tab.on_tab_deactivated() 56 | 57 | # Activate new tab 58 | current_tab = self.get_tab_instance(index) 59 | if current_tab and hasattr(current_tab, 'on_tab_activated'): 60 | current_tab.on_tab_activated() 61 | 62 | # Update active index 63 | self.active_tab_index = index 64 | 65 | def update_all_tabs(self, workdir=None, image_list=None): 66 | """Update all tabs with new workdir and/or image list""" 67 | for tab_instance in self.tab_instances.values(): 68 | if hasattr(tab_instance, 'update_workdir') and workdir is not None: 69 | tab_instance.update_workdir(workdir) 70 | 71 | if hasattr(tab_instance, 'update_image_list') and image_list is not None: 72 | tab_instance.update_image_list(image_list) -------------------------------------------------------------------------------- /app/tabs/gsplat_tab.py: -------------------------------------------------------------------------------- 1 | # gsplat_tab.py 2 | 3 | import os 4 | from PyQt5.QtWidgets import ( 5 | QWidget, QVBoxLayout, QSplitter, QLabel, QTreeWidget, QMessageBox 6 | ) 7 | from PyQt5.QtCore import Qt 8 | 9 | from app.base_tab import BaseTab 10 | 11 | class GsplatTab(BaseTab): 12 | """Gsplat tab implementation""" 13 | def __init__(self, workdir=None, image_list=None, parent=None): 14 | super().__init__(workdir, image_list, parent) 15 | self.gsplat_manager = None 16 | self.gsplat_image_tree = None 17 | 18 | def get_tab_name(self): 19 | return "Gsplat" 20 | 21 | def initialize(self): 22 | """Initialize the Gsplat tab""" 23 | if not self.workdir: 24 | QMessageBox.warning(self, "Error", "Work directory is not set.") 25 | return 26 | 27 | layout = self.create_horizontal_splitter() 28 | 29 | # Left side: Tree of images 30 | self.gsplat_image_tree = QTreeWidget() 31 | self.gsplat_image_tree.setHeaderLabel("Images") 32 | self.gsplat_image_tree.setFixedWidth(250) 33 | layout.addWidget(self.gsplat_image_tree) 34 | 35 | try: 36 | # Import GsplatManager here to avoid circular imports 37 | from app.gsplat_manager import GsplatManager 38 | 39 | # Right side: GsplatManager widget 40 | self.gsplat_manager = GsplatManager(self.workdir) 41 | layout.addWidget(self.gsplat_manager) 42 | 43 | # Set stretch factors 44 | layout.setStretchFactor(0, 1) # Left side (image tree) 45 | layout.setStretchFactor(1, 4) # Right side (gsplat manager) 46 | 47 | # Set layout for gsplat tab 48 | self._layout.addWidget(layout) 49 | 50 | # Populate the tree with camera data 51 | self.setup_camera_image_tree(self.gsplat_image_tree) 52 | 53 | # Connect double-click signal to handler 54 | self.gsplat_image_tree.itemDoubleClicked.connect(self.handle_image_double_click) 55 | 56 | self.is_initialized = True 57 | 58 | except Exception as e: 59 | error_message = f"Failed to initialize GsplatManager: {str(e)}" 60 | QMessageBox.critical(self, "Error", error_message) 61 | placeholder = QLabel("GsplatManager could not be initialized. See error message.") 62 | placeholder.setAlignment(Qt.AlignCenter) 63 | layout.addWidget(placeholder) 64 | self._layout.addWidget(layout) 65 | 66 | def handle_image_double_click(self, item, column): 67 | """Handle double-click event on image tree item""" 68 | if item.childCount() == 0 and item.parent() is not None: 69 | image_name = item.text(0) 70 | if self.gsplat_manager and hasattr(self.gsplat_manager, 'on_camera_image_tree_double_click'): 71 | self.gsplat_manager.on_camera_image_tree_double_click(image_name) 72 | 73 | def refresh(self): 74 | """Refresh the tab content""" 75 | if self.is_initialized: 76 | # Remove old widgets 77 | for i in reversed(range(self._layout.count())): 78 | self._layout.itemAt(i).widget().setParent(None) 79 | 80 | # Reinitialize 81 | self.is_initialized = False 82 | self.initialize() -------------------------------------------------------------------------------- /app/base_tab.py: -------------------------------------------------------------------------------- 1 | # base_tab.py 2 | 3 | from PyQt5.QtWidgets import QWidget, QVBoxLayout, QSplitter, QTreeWidget, QTreeWidgetItem 4 | from PyQt5.QtCore import Qt 5 | 6 | class BaseTab(QWidget): 7 | """Base class for application tabs""" 8 | def __init__(self, workdir=None, image_list=None, parent=None): 9 | super().__init__(parent) 10 | self.workdir = workdir 11 | self.image_list = image_list or [] 12 | self.parent_app = parent 13 | self._layout = QVBoxLayout(self) 14 | self._layout.setContentsMargins(0, 0, 0, 0) 15 | self.is_initialized = False 16 | 17 | def get_tab_name(self): 18 | """Get the name of the tab. To be overridden by child classes.""" 19 | return "Unnamed Tab" 20 | 21 | def initialize(self): 22 | """Initialize the tab. Must be implemented by child classes.""" 23 | raise NotImplementedError("This method should be implemented by child classes.") 24 | 25 | def setup_camera_image_tree(self, tree_widget, select_callback=None): 26 | """Set up a camera-grouped image tree widget""" 27 | tree_widget.clear() 28 | camera_groups = {} 29 | 30 | if not self.workdir or not self.image_list: 31 | return 32 | 33 | import json 34 | import os 35 | exif_dir = os.path.join(self.workdir, "exif") 36 | 37 | for image_name in self.image_list: 38 | exif_file = os.path.join(exif_dir, image_name + '.exif') 39 | if os.path.exists(exif_file): 40 | with open(exif_file, 'r') as f: 41 | exif_data = json.load(f) 42 | camera = exif_data.get('camera', 'Unknown Camera') 43 | else: 44 | camera = 'Unknown Camera' 45 | 46 | if camera not in camera_groups: 47 | camera_groups[camera] = [] 48 | 49 | camera_groups[camera].append(image_name) 50 | 51 | for camera, images in camera_groups.items(): 52 | camera_item = QTreeWidgetItem(tree_widget) 53 | camera_item.setText(0, camera) 54 | for img in images: 55 | img_item = QTreeWidgetItem(camera_item) 56 | img_item.setText(0, img) 57 | 58 | # Connect callback if provided 59 | if select_callback: 60 | tree_widget.itemClicked.connect(select_callback) 61 | 62 | def create_horizontal_splitter(self): 63 | """Create and return a horizontal splitter""" 64 | return QSplitter(Qt.Horizontal) 65 | 66 | def update_workdir(self, workdir): 67 | """Update the working directory""" 68 | self.workdir = workdir 69 | if self.is_initialized: 70 | self.refresh() 71 | 72 | def update_image_list(self, image_list): 73 | """Update the image list""" 74 | self.image_list = image_list 75 | if self.is_initialized: 76 | self.refresh() 77 | 78 | def refresh(self): 79 | """Refresh the tab contents. Can be overridden by child classes.""" 80 | pass 81 | 82 | def on_tab_activated(self): 83 | """Called when the tab is activated. Can be overridden by child classes.""" 84 | if not self.is_initialized: 85 | self.initialize() 86 | self.is_initialized = True 87 | 88 | def on_tab_deactivated(self): 89 | """Called when the tab is deactivated. Can be overridden by child classes.""" 90 | pass -------------------------------------------------------------------------------- /tests/test_base_tab.py: -------------------------------------------------------------------------------- 1 | # test_base_tab.py 2 | 3 | import pytest 4 | from unittest.mock import MagicMock, patch 5 | from PyQt5.QtWidgets import QWidget, QVBoxLayout 6 | from app.base_tab import BaseTab 7 | 8 | class TestTab(BaseTab): 9 | """Test implementation of BaseTab for testing""" 10 | def __init__(self, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.initialize_called = False 13 | self.refresh_called = False 14 | self.tab_activated_called = False 15 | self.tab_deactivated_called = False 16 | 17 | def get_tab_name(self): 18 | return "Test Tab" 19 | 20 | def initialize(self): 21 | self.initialize_called = True 22 | self.is_initialized = True 23 | 24 | def refresh(self): 25 | self.refresh_called = True 26 | 27 | def on_tab_activated(self): 28 | super().on_tab_activated() 29 | self.tab_activated_called = True 30 | 31 | def on_tab_deactivated(self): 32 | super().on_tab_deactivated() 33 | self.tab_deactivated_called = True 34 | 35 | 36 | def test_base_tab_init(): 37 | """Test BaseTab initialization""" 38 | workdir = "/test/workdir" 39 | image_list = ["image1.jpg", "image2.jpg"] 40 | parent = MagicMock() 41 | 42 | tab = TestTab(workdir, image_list, parent) 43 | 44 | assert tab.workdir == workdir 45 | assert tab.image_list == image_list 46 | assert tab.parent_app == parent 47 | assert isinstance(tab._layout, QVBoxLayout) 48 | assert tab.is_initialized is False 49 | 50 | 51 | def test_base_tab_get_tab_name(): 52 | """Test get_tab_name method""" 53 | tab = TestTab() 54 | assert tab.get_tab_name() == "Test Tab" 55 | 56 | 57 | def test_base_tab_on_tab_activated(): 58 | """Test on_tab_activated method""" 59 | tab = TestTab() 60 | tab.on_tab_activated() 61 | 62 | assert tab.initialize_called is True 63 | assert tab.is_initialized is True 64 | assert tab.tab_activated_called is True 65 | 66 | 67 | def test_base_tab_update_workdir(): 68 | """Test update_workdir method""" 69 | tab = TestTab("/old/workdir") 70 | tab.is_initialized = True 71 | 72 | tab.update_workdir("/new/workdir") 73 | 74 | assert tab.workdir == "/new/workdir" 75 | assert tab.refresh_called is True 76 | 77 | 78 | def test_base_tab_update_image_list(): 79 | """Test update_image_list method""" 80 | tab = TestTab(image_list=["old.jpg"]) 81 | tab.is_initialized = True 82 | 83 | new_images = ["new1.jpg", "new2.jpg"] 84 | tab.update_image_list(new_images) 85 | 86 | assert tab.image_list == new_images 87 | assert tab.refresh_called is True 88 | 89 | 90 | @patch('app.base_tab.QTreeWidget') 91 | def test_base_tab_setup_camera_image_tree(mock_tree_widget): 92 | """Test setup_camera_image_tree method""" 93 | # This test is simplified due to complexity of mocking file operations 94 | # A more comprehensive test would use a fixture with actual files 95 | tab = TestTab("/test/workdir", ["image1.jpg"]) 96 | 97 | tree = MagicMock() 98 | callback = MagicMock() 99 | 100 | # Call the method 101 | tab.setup_camera_image_tree(tree, callback) 102 | 103 | # Check that tree was cleared 104 | assert tree.clear.call_count == 1 105 | 106 | # If callback provided, it should be connected 107 | if callback: 108 | assert tree.itemClicked.connect.call_count == 1 -------------------------------------------------------------------------------- /app/point_cloud_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import os 4 | 5 | def rgb_depth_to_pointcloud(rgb_image, depth_image, fx=1000, fy=1000, cx=None, cy=None): 6 | """ 7 | Generate point cloud from RGB image and depth image 8 | 9 | Args: 10 | rgb_image: RGB image (PIL Image or numpy array) 11 | depth_image: Depth image (PIL Image or numpy array) 12 | fx, fy: Camera focal length 13 | cx, cy: Camera optical center (uses image center if None) 14 | 15 | Returns: 16 | points: Point cloud coordinates (N, 3) 17 | colors: Point cloud colors (N, 3) 18 | """ 19 | # Convert images to numpy arrays 20 | if isinstance(rgb_image, Image.Image): 21 | rgb_array = np.array(rgb_image) 22 | else: 23 | rgb_array = rgb_image 24 | 25 | if isinstance(depth_image, Image.Image): 26 | depth_array = np.array(depth_image) 27 | else: 28 | depth_array = depth_image 29 | 30 | # Convert depth image to RGB if grayscale 31 | if len(depth_array.shape) == 2: 32 | depth_array = np.stack([depth_array] * 3, axis=-1) 33 | 34 | # Use first channel if depth image is RGB 35 | if len(depth_array.shape) == 3: 36 | depth_array = depth_array[:, :, 0] 37 | 38 | # Get image dimensions 39 | height, width = depth_array.shape 40 | 41 | # Use image center if optical center is not specified 42 | if cx is None: 43 | cx = width / 2 44 | if cy is None: 45 | cy = height / 2 46 | 47 | # Create mesh grid 48 | y, x = np.meshgrid(np.arange(height), np.arange(width), indexing='ij') 49 | 50 | # Normalize depth values (0-255 to 0-1) 51 | depth_normalized = depth_array.astype(np.float32) / 255.0 52 | 53 | # Calculate 3D coordinates 54 | Z = depth_normalized * 10.0 # Adjust depth scale as needed 55 | X = (x - cx) * Z / fx 56 | Y = (y - cy) * Z / fy 57 | 58 | # Extract only points with valid depth values 59 | valid_mask = (depth_normalized > 0.01) & (depth_normalized < 0.99) 60 | 61 | # Extract point cloud data 62 | points = np.stack([X[valid_mask], Y[valid_mask], Z[valid_mask]], axis=1) 63 | colors = rgb_array[valid_mask] / 255.0 # Normalize colors to 0-1 64 | 65 | return points, colors 66 | 67 | def load_image_and_depth(image_path, depth_path): 68 | """ 69 | Load RGB image and depth image 70 | 71 | Args: 72 | image_path: Path to RGB image 73 | depth_path: Path to depth image 74 | 75 | Returns: 76 | rgb_image: RGB image (PIL Image) 77 | depth_image: Depth image (PIL Image) 78 | """ 79 | if not os.path.exists(image_path): 80 | raise FileNotFoundError(f"Image file not found: {image_path}") 81 | 82 | if not os.path.exists(depth_path): 83 | raise FileNotFoundError(f"Depth file not found: {depth_path}") 84 | 85 | rgb_image = Image.open(image_path).convert('RGB') 86 | depth_image = Image.open(depth_path).convert('L') # Load as grayscale 87 | 88 | return rgb_image, depth_image 89 | 90 | def downsample_pointcloud(points, colors, target_points=10000): 91 | """ 92 | Downsample point cloud 93 | 94 | Args: 95 | points: Point cloud coordinates (N, 3) 96 | colors: Point cloud colors (N, 3) 97 | target_points: Target number of points 98 | 99 | Returns: 100 | downsampled_points: Downsampled point cloud coordinates 101 | downsampled_colors: Downsampled point cloud colors 102 | """ 103 | if len(points) <= target_points: 104 | return points, colors 105 | 106 | # Random sampling 107 | indices = np.random.choice(len(points), target_points, replace=False) 108 | return points[indices], colors[indices] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # splat_one 2 | 3 | splat_one is an integrated application that combines Gaussian Splatting and Structure from Motion (SfM) into a single workflow. It allows users to visualize and process image data through multiple interactive tabs, making it easier to generate 3D point clouds and analyze camera parameters. 4 | 5 | **Note:** This project is currently under development. The Masks tab is not functioning properly at this time. 6 | 7 | ## Tab Descriptions 8 | 9 | - **Images Tab** 10 | Displays information about the imported images, such as file names, resolutions, and metadata. 11 | 12 | - **Masks Tab** 13 | Provides functionality to generate masks for images. 14 | *Currently under development; may not work as expected.* 15 | 16 | - **Features Tab** 17 | Shows the extracted image features, including keypoints and descriptors, for visualization and analysis. 18 | 19 | - **Matching Tab** 20 | Displays the results of feature point matching, allowing you to inspect the quality and accuracy of the correspondences. 21 | 22 | - **Reconstruct Tab** 23 | Visualizes the 3D point cloud reconstructed via SfM, along with camera positions and overall scene structure. 24 | 25 | - **Gsplat Tab** 26 | Presents the output of the Gaussian Splatting process. You can adjust parameters to explore different renderings of the 3D point cloud. 27 | 28 |

29 | 30 | splat_one_resized 31 | 32 |

33 | 34 | ## Data Preparation 35 | 36 | Before running the application, organize your images in the following directory structure inside the `dataset` folder: 37 | 38 | ```bash 39 | splat_one 40 | └─ dataset 41 | └─ your_data 42 | └─ images 43 | ``` 44 | 45 | Place all the images you want to process under the `images` directory. 46 | 47 | ## Docker Deployment 48 | 49 | This application is designed to run via Docker. The Docker image is available as `inuex35/splat_one`. Please refer to the `Dockerfile` for more details on the installation. 50 | 51 | To launch the Docker container with GPU support and proper X11 forwarding (for GUI display), run the following command: 52 | 53 | ```bash 54 | docker run --gpus all -e DISPLAY=host.docker.internal:0.0 -v /tmp/.X11-unix:/tmp/.X11-unix -v ${PWD}/dataset:/source/splat_one/dataset -v C:\Users\$env:USERNAME\.cache:/home/user/.cache/ -p 7007:7007 --rm -it --shm-size=12gb inuex35/splat_one 55 | ``` 56 | 57 | Once inside the container, start the application by executing: 58 | 59 | ```bash 60 | python main.py 61 | ``` 62 | 63 | ### Running with Depth Any Camera (DAC) Mode 64 | 65 | To enable the Depth Any Camera feature for advanced depth estimation, you can run the application with the `--dac` flag: 66 | 67 | ```bash 68 | docker run --gpus all -e DISPLAY=host.docker.internal:0.0 -v /tmp/.X11-unix:/tmp/.X11-unix -v ${PWD}/dataset:/source/splat_one/dataset -v C:\Users\$env:USERNAME\.cache:/home/user/.cache/ -p 7007:7007 --rm -it --shm-size=12gb inuex35/splat_one --dac 69 | ``` 70 | 71 | The `--dac` option enables depth estimation capabilities using the Depth Any Camera model, which provides robust depth prediction for various camera types including perspective, fisheye, and 360-degree cameras. 72 | 73 | ## Dependencies 74 | - [OpenSfM](https://github.com/inuex35/ind-bermuda-opensfm/) 75 | - [Gsplat](https://github.com/inuex35/gsplat/) 76 | 77 | We use code from these repositories. Please note that the license for these components follows their respective original repositories. 78 | 79 | 80 | # Development Status 81 | The project is under active development. 82 | Some features, such as the mask generation in the Masks tab, are not yet fully functional. 83 | Contributions via Issues and Pull Requests are welcome. 84 | License 85 | This project is licensed under the MIT License. See the LICENSE file for details. 86 | -------------------------------------------------------------------------------- /app/tabs/reconstruct_tab.py: -------------------------------------------------------------------------------- 1 | # reconstruct_tab.py 2 | 3 | import os 4 | from PyQt5.QtWidgets import ( 5 | QWidget, QVBoxLayout, QSplitter, QLabel, QTreeWidget, QMessageBox 6 | ) 7 | from PyQt5.QtCore import Qt 8 | 9 | from app.base_tab import BaseTab 10 | 11 | class ReconstructTab(BaseTab): 12 | """Reconstruct tab implementation""" 13 | def __init__(self, workdir=None, image_list=None, parent=None): 14 | super().__init__(workdir, image_list, parent) 15 | self.reconstruction_viewer = None 16 | self.camera_image_tree = None 17 | 18 | def get_tab_name(self): 19 | return "Reconstruct" 20 | 21 | def initialize(self): 22 | """Initialize the Reconstruct tab""" 23 | if not self.workdir: 24 | QMessageBox.warning(self, "Error", "Work directory is not set.") 25 | return 26 | 27 | layout = self.create_horizontal_splitter() 28 | 29 | # Left side: Tree of images grouped by camera 30 | self.camera_image_tree = QTreeWidget() 31 | self.camera_image_tree.setHeaderLabel("Cameras and Images") 32 | self.camera_image_tree.setFixedWidth(250) 33 | layout.addWidget(self.camera_image_tree) 34 | 35 | try: 36 | # Import Reconstruction here to avoid circular imports 37 | from app.point_cloud_visualizer import Reconstruction 38 | 39 | # Right side: Reconstruction (PointCloudVisualizer) widget 40 | self.reconstruction_viewer = Reconstruction(self.workdir) 41 | layout.addWidget(self.reconstruction_viewer) 42 | 43 | # Set layout for reconstruct tab 44 | self._layout.addWidget(layout) 45 | 46 | # Setup camera image tree with click and double click event handlers 47 | self.setup_camera_image_tree(self.camera_image_tree) 48 | self.camera_image_tree.itemClicked.connect(self.handle_camera_image_tree_click) 49 | self.camera_image_tree.itemDoubleClicked.connect(self.handle_camera_image_tree_double_click) 50 | 51 | # Set stretch factors 52 | layout.setStretchFactor(0, 1) # Left side (image tree) 53 | layout.setStretchFactor(1, 4) # Right side (reconstruction viewer) 54 | 55 | self.is_initialized = True 56 | 57 | except Exception as e: 58 | error_message = f"Failed to initialize Reconstruction: {str(e)}" 59 | QMessageBox.critical(self, "Error", error_message) 60 | placeholder = QLabel("Reconstruction could not be initialized. See error message.") 61 | placeholder.setAlignment(Qt.AlignCenter) 62 | layout.addWidget(placeholder) 63 | self._layout.addWidget(layout) 64 | 65 | def handle_camera_image_tree_click(self, item, column): 66 | """Handle single click event for camera_image_tree""" 67 | if item.childCount() == 0 and item.parent() is not None: 68 | image_name = item.text(0) 69 | if self.reconstruction_viewer: 70 | self.reconstruction_viewer.on_camera_image_tree_click(image_name) 71 | 72 | def handle_camera_image_tree_double_click(self, item, column): 73 | """Handle double click event for camera_image_tree""" 74 | if item.childCount() == 0 and item.parent() is not None: 75 | image_name = item.text(0) 76 | if self.reconstruction_viewer: 77 | self.reconstruction_viewer.on_camera_image_tree_double_click(image_name) 78 | 79 | def refresh(self): 80 | """Refresh the tab content""" 81 | if self.is_initialized: 82 | # Remove old widgets 83 | for i in reversed(range(self._layout.count())): 84 | self._layout.itemAt(i).widget().setParent(None) 85 | 86 | # Reinitialize 87 | self.is_initialized = False 88 | self.initialize() 89 | 90 | def on_tab_activated(self): 91 | """Called when tab is activated""" 92 | super().on_tab_activated() 93 | # Update visualization when tab is activated 94 | if self.is_initialized and self.reconstruction_viewer: 95 | self.reconstruction_viewer.update_visualization() -------------------------------------------------------------------------------- /app/tabs/features_tab.py: -------------------------------------------------------------------------------- 1 | # features_tab.py 2 | 3 | import os 4 | from PyQt5.QtWidgets import ( 5 | QWidget, QVBoxLayout, QSplitter, QLabel, QTreeWidget, QMessageBox 6 | ) 7 | from PyQt5.QtCore import Qt 8 | 9 | from app.base_tab import BaseTab 10 | 11 | class FeaturesTab(BaseTab): 12 | """Features tab implementation""" 13 | def __init__(self, workdir=None, image_list=None, parent=None): 14 | super().__init__(workdir, image_list, parent) 15 | self.feature_extractor = None 16 | self.camera_image_tree = None 17 | 18 | # Set up basic UI structure 19 | self.setup_basic_ui() 20 | 21 | def get_tab_name(self): 22 | return "Features" 23 | 24 | def setup_basic_ui(self): 25 | """Set up the basic UI structure""" 26 | layout = self.create_horizontal_splitter() 27 | 28 | # Left side: Tree of images grouped by camera 29 | self.camera_image_tree = QTreeWidget() 30 | self.camera_image_tree.setHeaderLabel("Cameras and Images") 31 | self.camera_image_tree.setFixedWidth(250) 32 | layout.addWidget(self.camera_image_tree) 33 | 34 | # Right side: Placeholder for feature extractor 35 | right_widget = QLabel("Feature Extractor will be displayed here.") 36 | right_widget.setAlignment(Qt.AlignCenter) 37 | layout.addWidget(right_widget) 38 | 39 | # Set stretch factors 40 | layout.setStretchFactor(0, 1) # Left side (image tree) 41 | layout.setStretchFactor(1, 4) # Right side (feature extractor) 42 | 43 | # Set layout for features tab 44 | self._layout.addWidget(layout) 45 | 46 | # Connect signals 47 | self.camera_image_tree.itemClicked.connect(self.display_features_for_image) 48 | 49 | def initialize(self): 50 | """Initialize the Features tab with data""" 51 | if not self.workdir: 52 | QMessageBox.warning(self, "Error", "Work directory is not set.") 53 | return 54 | 55 | try: 56 | # Import FeatureExtractor here to avoid circular imports 57 | from app.feature_extractor import FeatureExtractor 58 | 59 | # Get the main layout and its widgets 60 | main_layout = self._layout.itemAt(0).widget() 61 | right_widget = main_layout.widget(1) # The placeholder 62 | 63 | # Remove the placeholder 64 | right_widget.setParent(None) 65 | 66 | # Create the feature extractor 67 | self.feature_extractor = FeatureExtractor(self.workdir, self.image_list) 68 | main_layout.addWidget(self.feature_extractor) 69 | 70 | # Set stretch factors again 71 | main_layout.setStretchFactor(0, 1) # Left side (image tree) 72 | main_layout.setStretchFactor(1, 4) # Right side (feature extractor) 73 | 74 | # Populate the tree with camera data 75 | self.setup_camera_image_tree(self.camera_image_tree) 76 | 77 | self.is_initialized = True 78 | 79 | except Exception as e: 80 | error_message = f"Failed to initialize FeatureExtractor: {str(e)}" 81 | QMessageBox.critical(self, "Error", error_message) 82 | 83 | def display_features_for_image(self, item, column): 84 | """Display features for the selected image""" 85 | if not self.is_initialized: 86 | self.initialize() 87 | 88 | if item.childCount() == 0 and item.parent() is not None: 89 | image_name = item.text(0) 90 | if self.feature_extractor and hasattr(self.feature_extractor, 'load_image_by_name'): 91 | self.feature_extractor.load_image_by_name(image_name) 92 | 93 | def refresh(self): 94 | """Refresh the tab content""" 95 | if self.is_initialized: 96 | # Remove old widgets 97 | for i in reversed(range(self._layout.count())): 98 | self._layout.itemAt(i).widget().setParent(None) 99 | 100 | # Reinitialize 101 | self.setup_basic_ui() 102 | self.is_initialized = False 103 | self.initialize() -------------------------------------------------------------------------------- /utils/datasets/download_dataset.py: -------------------------------------------------------------------------------- 1 | """Script to download benchmark dataset(s)""" 2 | 3 | import os 4 | import subprocess 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import Literal 8 | 9 | import tyro 10 | 11 | # dataset names 12 | dataset_names = Literal[ 13 | "mipnerf360", 14 | "mipnerf360_extra", 15 | "bilarf_data", 16 | "zipnerf", 17 | "zipnerf_undistorted", 18 | ] 19 | 20 | # dataset urls 21 | urls = { 22 | "mipnerf360": "http://storage.googleapis.com/gresearch/refraw360/360_v2.zip", 23 | "mipnerf360_extra": "https://storage.googleapis.com/gresearch/refraw360/360_extra_scenes.zip", 24 | "bilarf_data": "https://huggingface.co/datasets/Yuehao/bilarf_data/resolve/main/bilarf_data.zip", 25 | "zipnerf": [ 26 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf/berlin.zip", 27 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf/london.zip", 28 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf/nyc.zip", 29 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf/alameda.zip", 30 | ], 31 | "zipnerf_undistorted": [ 32 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf-undistorted/berlin.zip", 33 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf-undistorted/london.zip", 34 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf-undistorted/nyc.zip", 35 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf-undistorted/alameda.zip", 36 | ], 37 | } 38 | 39 | # rename maps 40 | dataset_rename_map = { 41 | "mipnerf360": "360_v2", 42 | "mipnerf360_extra": "360_v2", 43 | "bilarf_data": "bilarf", 44 | "zipnerf": "zipnerf", 45 | "zipnerf_undistorted": "zipnerf_undistorted", 46 | } 47 | 48 | 49 | @dataclass 50 | class DownloadData: 51 | dataset: dataset_names = "mipnerf360" 52 | save_dir: Path = Path(os.getcwd() + "/data") 53 | 54 | def main(self): 55 | self.save_dir.mkdir(parents=True, exist_ok=True) 56 | self.dataset_download(self.dataset) 57 | 58 | def dataset_download(self, dataset: dataset_names): 59 | if isinstance(urls[dataset], list): 60 | for url in urls[dataset]: 61 | url_file_name = Path(url).name 62 | extract_path = self.save_dir / dataset_rename_map[dataset] 63 | download_path = extract_path / url_file_name 64 | download_and_extract(url, download_path, extract_path) 65 | else: 66 | url = urls[dataset] 67 | url_file_name = Path(url).name 68 | extract_path = self.save_dir / dataset_rename_map[dataset] 69 | download_path = extract_path / url_file_name 70 | download_and_extract(url, download_path, extract_path) 71 | 72 | 73 | def download_and_extract(url: str, download_path: Path, extract_path: Path) -> None: 74 | download_path.parent.mkdir(parents=True, exist_ok=True) 75 | extract_path.mkdir(parents=True, exist_ok=True) 76 | 77 | # download 78 | download_command = [ 79 | "curl", 80 | "-L", 81 | "-o", 82 | str(download_path), 83 | url, 84 | ] 85 | try: 86 | subprocess.run(download_command, check=True) 87 | print("File file downloaded succesfully.") 88 | except subprocess.CalledProcessError as e: 89 | print(f"Error downloading file: {e}") 90 | 91 | # if .zip 92 | if Path(url).suffix == ".zip": 93 | if os.name == "nt": # Windows doesn't have 'unzip' but 'tar' works 94 | extract_command = [ 95 | "tar", 96 | "-xvf", 97 | download_path, 98 | "-C", 99 | extract_path, 100 | ] 101 | else: 102 | extract_command = [ 103 | "unzip", 104 | download_path, 105 | "-d", 106 | extract_path, 107 | ] 108 | # if .tar 109 | else: 110 | extract_command = [ 111 | "tar", 112 | "-xvzf", 113 | download_path, 114 | "-C", 115 | extract_path, 116 | ] 117 | 118 | # extract 119 | try: 120 | subprocess.run(extract_command, check=True) 121 | os.remove(download_path) 122 | print("Extraction complete.") 123 | except subprocess.CalledProcessError as e: 124 | print(f"Extraction failed: {e}") 125 | 126 | 127 | if __name__ == "__main__": 128 | tyro.cli(DownloadData).main() 129 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04 2 | 3 | ENV LC_ALL=C.UTF-8 4 | ENV LANG=C.UTF-8 5 | ARG TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6" 6 | ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" 7 | ARG DEBIAN_FRONTEND=noninteractive 8 | 9 | # Install system dependencies 10 | RUN apt-get update && apt-get install -y --no-install-recommends \ 11 | software-properties-common && \ 12 | add-apt-repository ppa:deadsnakes/ppa && \ 13 | apt-get install -y --no-install-recommends \ 14 | build-essential cmake git libeigen3-dev libopencv-dev libceres-dev \ 15 | python3.10 python3.10-dev python3-pip python3.10-distutils \ 16 | python-is-python3 curl ninja-build libglm-dev \ 17 | libboost-all-dev libflann-dev libfreeimage-dev libmetis-dev \ 18 | libgoogle-glog-dev libgtest-dev libgmock-dev libsqlite3-dev \ 19 | libglew-dev qtbase5-dev libqt5opengl5-dev wget libcgal-dev \ 20 | graphviz mesa-utils libgraphviz-dev libgl1 libglib2.0-0 \ 21 | ffmpeg && \ 22 | rm -rf /var/lib/apt/lists/* 23 | 24 | # Set Python 3.10 as default and install pip 25 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 && \ 26 | update-alternatives --config python3 --skip-auto && \ 27 | curl -sS https://bootstrap.pypa.io/get-pip.py | python3 && \ 28 | pip install --no-cache-dir --upgrade pip setuptools "setuptools<68.0.0" 29 | 30 | # Install PyTorch 31 | RUN pip install --no-cache-dir torch==2.5.1+cu121 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 32 | 33 | # Core Python dependencies 34 | RUN pip install --no-cache-dir \ 35 | numpy pandas matplotlib scipy seaborn \ 36 | opencv-python-headless Pillow tqdm \ 37 | flask loguru h5py scikit-learn jupyter jupyterlab \ 38 | pyqtgraph pyyaml packaging pyparsing==3.0.9 networkx==2.5 \ 39 | kornia==0.7.3 torchmetrics[image] imageio[ffmpeg] \ 40 | black coverage mypy pylint pytest flake8 isort tyro>=0.8.8 \ 41 | open3d colour tabulate simplejson parameterized pydegensac \ 42 | tensorboard tensorly xmltodict cloudpickle==0.4.0 \ 43 | fpdf2==2.4.6 python-dateutil Sphinx==4.2.0 \ 44 | wheel viser nerfview \ 45 | jsonschema==4.17.3 jupyter-events==0.6.3 \ 46 | "PyOpenGL==3.1.1a1" "PyQt5" bokeh==2.4.3 \ 47 | splines pyproj \ 48 | mapillary_tools 49 | 50 | ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" 51 | # Install additional GitHub-based packages 52 | RUN pip install --no-cache-dir \ 53 | git+https://github.com/rmbrualla/pycolmap@cc7ea4b7301720ac29287dbe450952511b32125e 54 | 55 | # Clone and patch fused-ssim 56 | RUN git clone https://github.com/rahul-goel/fused-ssim.git /tmp/fused-ssim && \ 57 | sed -i '/compute_100/d' /tmp/fused-ssim/setup.py && \ 58 | sed -i '/compute_101/d' /tmp/fused-ssim/setup.py && \ 59 | pip install /tmp/fused-ssim && \ 60 | rm -rf /tmp/fused-ssim 61 | 62 | # FlashAttention 63 | RUN git clone --depth 1 https://github.com/Dao-AILab/flash-attention.git /flash-attention && \ 64 | cd /flash-attention && pip install --no-cache-dir . 65 | 66 | # LightGlue 67 | RUN git clone --depth 1 https://github.com/cvg/LightGlue.git /LightGlue && \ 68 | cd /LightGlue && pip install --no-cache-dir . 69 | 70 | # Clone and build splat_one and submodules 71 | RUN git clone https://github.com/inuex35/splat_one.git /source/splat_one && \ 72 | cd /source/splat_one && \ 73 | git submodule update --init --recursive 74 | 75 | # Build OpenSfM 76 | RUN cd /source/splat_one/submodules/opensfm && \ 77 | python3 setup.py build && \ 78 | python3 setup.py install 79 | 80 | # Build gsplat 81 | RUN cd /source/splat_one/submodules/gsplat && \ 82 | git checkout b0e978da67fb4364611c6683c5f4e6e6c1d8d8cb && \ 83 | MAX_JOBS=4 pip install -e . 84 | 85 | # Build sam2 86 | RUN sed -i 's/setuptools>=61.0/setuptools>=62.3.8,<75.9/' /source/splat_one/submodules/sam2/pyproject.toml && \ 87 | cd /source/splat_one/submodules/sam2 && \ 88 | pip install -e ".[notebooks]" && \ 89 | cd checkpoints && ./download_ckpts.sh 90 | 91 | # Clone and setup depth_any_camera 92 | RUN git clone https://github.com/yuliangguo/depth_any_camera /depth_any_camera && \ 93 | cd /depth_any_camera && \ 94 | pip install -r requirements.txt 95 | ENV PYTHONPATH="/depth_any_camera:$PYTHONPATH" 96 | 97 | # Pre-download PyTorch model 98 | RUN mkdir -p /root/.cache/torch/hub/checkpoints && \ 99 | wget https://download.pytorch.org/models/alexnet-owt-7be5be79.pth -O /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth 100 | 101 | WORKDIR /source/splat_one 102 | 103 | COPY entrypoint.sh /entrypoint.sh 104 | RUN chmod +x /entrypoint.sh 105 | ENTRYPOINT ["/entrypoint.sh"] 106 | CMD ["bash"] 107 | -------------------------------------------------------------------------------- /app/point_cloud_viewer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PyQt5.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QSlider, QSpinBox 3 | from PyQt5.QtCore import Qt, QTimer 4 | from PyQt5.QtGui import QVector3D 5 | import pyqtgraph.opengl as gl 6 | import pyqtgraph as pg 7 | import os 8 | 9 | from app.point_cloud_utils import rgb_depth_to_pointcloud, load_image_and_depth, downsample_pointcloud 10 | 11 | class PointCloudViewer(QWidget): 12 | """Point cloud viewer widget (pyqtgraph version)""" 13 | 14 | def __init__(self, workdir=None): 15 | super().__init__() 16 | self.workdir = workdir 17 | self.current_points = None 18 | self.current_colors = None 19 | self.point_cloud_item = None 20 | 21 | self.setup_ui() 22 | 23 | def setup_ui(self): 24 | """Setup UI""" 25 | layout = QVBoxLayout(self) 26 | 27 | # pyqtgraph 3D viewer (same method as reconstruction tab) 28 | self.viewer = gl.GLViewWidget() 29 | self.viewer.setFocusPolicy(Qt.NoFocus) 30 | self.viewer.setCameraPosition(distance=10) 31 | layout.addWidget(self.viewer) 32 | 33 | # Control panel 34 | control_layout = QHBoxLayout() 35 | 36 | # Point count adjustment 37 | control_layout.addWidget(QLabel("Point Count:")) 38 | self.point_count_spinbox = QSpinBox() 39 | self.point_count_spinbox.setRange(1000, 100000) 40 | self.point_count_spinbox.setValue(10000) 41 | self.point_count_spinbox.valueChanged.connect(self.update_point_cloud) 42 | control_layout.addWidget(self.point_count_spinbox) 43 | 44 | # Point size adjustment 45 | control_layout.addWidget(QLabel("Point Size:")) 46 | self.point_size_slider = QSlider(Qt.Horizontal) 47 | self.point_size_slider.setRange(1, 10) 48 | self.point_size_slider.setValue(3) 49 | self.point_size_slider.valueChanged.connect(self.update_point_size) 50 | control_layout.addWidget(self.point_size_slider) 51 | 52 | # Reset button 53 | self.reset_button = QPushButton("Reset Camera") 54 | self.reset_button.clicked.connect(self.reset_camera) 55 | control_layout.addWidget(self.reset_button) 56 | 57 | control_layout.addStretch() 58 | layout.addLayout(control_layout) 59 | 60 | def load_point_cloud_from_images(self, image_name): 61 | """Load and display point cloud from images and depth""" 62 | if not self.workdir: 63 | return 64 | 65 | try: 66 | # Image and depth paths 67 | image_path = os.path.join(self.workdir, "images", image_name) 68 | depth_path = os.path.join(self.workdir, "depth", f"{image_name}_depth.png") 69 | 70 | # Load images and depth 71 | rgb_image, depth_image = load_image_and_depth(image_path, depth_path) 72 | 73 | # Generate point cloud 74 | points, colors = rgb_depth_to_pointcloud(rgb_image, depth_image) 75 | 76 | # Save point cloud 77 | self.current_points = points 78 | self.current_colors = colors 79 | 80 | # Display point cloud 81 | self.update_point_cloud() 82 | 83 | except Exception as e: 84 | print(f"Point cloud generation error: {e}") 85 | self.clear_point_cloud() 86 | 87 | def update_point_cloud(self): 88 | """Update and display point cloud""" 89 | if self.current_points is None or self.current_colors is None: 90 | return 91 | 92 | # Downsample point cloud 93 | target_points = self.point_count_spinbox.value() 94 | points, colors = downsample_pointcloud(self.current_points, self.current_colors, target_points) 95 | 96 | # Remove existing point cloud 97 | self.clear_point_cloud() 98 | 99 | # Add new point cloud 100 | if len(points) > 0: 101 | # Create point cloud item (same method as reconstruction tab) 102 | self.point_cloud_item = gl.GLScatterPlotItem( 103 | pos=points, 104 | color=colors, 105 | size=self.point_size_slider.value(), 106 | pxMode=True 107 | ) 108 | self.viewer.addItem(self.point_cloud_item) 109 | 110 | def update_point_size(self): 111 | """Update point size""" 112 | if self.point_cloud_item: 113 | self.point_cloud_item.setData(size=self.point_size_slider.value()) 114 | 115 | def clear_point_cloud(self): 116 | """Clear point cloud""" 117 | if self.point_cloud_item: 118 | self.viewer.removeItem(self.point_cloud_item) 119 | self.point_cloud_item = None 120 | 121 | def reset_camera(self): 122 | """Reset camera position""" 123 | self.viewer.setCameraPosition(distance=10) 124 | 125 | def set_workdir(self, workdir): 126 | """Set working directory""" 127 | self.workdir = workdir -------------------------------------------------------------------------------- /utils/datasets/normalize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"): 5 | """ 6 | reference: nerf-factory 7 | Get a similarity transform to normalize dataset 8 | from c2w (OpenCV convention) cameras 9 | :param c2w: (N, 4) 10 | :return T (4,4) , scale (float) 11 | """ 12 | t = c2w[:, :3, 3] 13 | R = c2w[:, :3, :3] 14 | 15 | # (1) Rotate the world so that z+ is the up axis 16 | # we estimate the up axis by averaging the camera up axes 17 | ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1) 18 | world_up = np.mean(ups, axis=0) 19 | world_up /= np.linalg.norm(world_up) 20 | 21 | up_camspace = np.array([0.0, -1.0, 0.0]) 22 | c = (up_camspace * world_up).sum() 23 | cross = np.cross(world_up, up_camspace) 24 | skew = np.array( 25 | [ 26 | [0.0, -cross[2], cross[1]], 27 | [cross[2], 0.0, -cross[0]], 28 | [-cross[1], cross[0], 0.0], 29 | ] 30 | ) 31 | if c > -1: 32 | R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c) 33 | else: 34 | # In the unlikely case the original data has y+ up axis, 35 | # rotate 180-deg about x axis 36 | R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) 37 | 38 | # R_align = np.eye(3) # DEBUG 39 | R = R_align @ R 40 | fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1) 41 | t = (R_align @ t[..., None])[..., 0] 42 | 43 | # (2) Recenter the scene. 44 | if center_method == "focus": 45 | # find the closest point to the origin for each camera's center ray 46 | nearest = t + (fwds * -t).sum(-1)[:, None] * fwds 47 | translate = -np.median(nearest, axis=0) 48 | elif center_method == "poses": 49 | # use center of the camera positions 50 | translate = -np.median(t, axis=0) 51 | else: 52 | raise ValueError(f"Unknown center_method {center_method}") 53 | 54 | transform = np.eye(4) 55 | transform[:3, 3] = translate 56 | transform[:3, :3] = R_align 57 | 58 | # (3) Rescale the scene using camera distances 59 | scale_fn = np.max if strict_scaling else np.median 60 | scale = 1.0 / scale_fn(np.linalg.norm(t + translate, axis=-1)) 61 | transform[:3, :] *= scale 62 | 63 | return transform 64 | 65 | 66 | def align_principle_axes(point_cloud): 67 | # Compute centroid 68 | centroid = np.median(point_cloud, axis=0) 69 | 70 | # Translate point cloud to centroid 71 | translated_point_cloud = point_cloud - centroid 72 | 73 | # Compute covariance matrix 74 | covariance_matrix = np.cov(translated_point_cloud, rowvar=False) 75 | 76 | # Compute eigenvectors and eigenvalues 77 | eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix) 78 | 79 | # Sort eigenvectors by eigenvalues (descending order) so that the z-axis 80 | # is the principal axis with the smallest eigenvalue. 81 | sort_indices = eigenvalues.argsort()[::-1] 82 | eigenvectors = eigenvectors[:, sort_indices] 83 | 84 | # Check orientation of eigenvectors. If the determinant of the eigenvectors is 85 | # negative, then we need to flip the sign of one of the eigenvectors. 86 | if np.linalg.det(eigenvectors) < 0: 87 | eigenvectors[:, 0] *= -1 88 | 89 | # Create rotation matrix 90 | rotation_matrix = eigenvectors.T 91 | 92 | # Create SE(3) matrix (4x4 transformation matrix) 93 | transform = np.eye(4) 94 | transform[:3, :3] = rotation_matrix 95 | transform[:3, 3] = -rotation_matrix @ centroid 96 | 97 | return transform 98 | 99 | 100 | def transform_points(matrix, points): 101 | """Transform points using an SE(3) matrix. 102 | 103 | Args: 104 | matrix: 4x4 SE(3) matrix 105 | points: Nx3 array of points 106 | 107 | Returns: 108 | Nx3 array of transformed points 109 | """ 110 | assert matrix.shape == (4, 4) 111 | assert len(points.shape) == 2 and points.shape[1] == 3 112 | return points @ matrix[:3, :3].T + matrix[:3, 3] 113 | 114 | 115 | def transform_cameras(matrix, camtoworlds): 116 | """Transform cameras using an SE(3) matrix. 117 | 118 | Args: 119 | matrix: 4x4 SE(3) matrix 120 | camtoworlds: Nx4x4 array of camera-to-world matrices 121 | 122 | Returns: 123 | Nx4x4 array of transformed camera-to-world matrices 124 | """ 125 | assert matrix.shape == (4, 4) 126 | assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4) 127 | camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix) 128 | scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1) 129 | camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None] 130 | return camtoworlds 131 | 132 | 133 | def normalize(camtoworlds, points=None): 134 | T1 = similarity_from_cameras(camtoworlds) 135 | camtoworlds = transform_cameras(T1, camtoworlds) 136 | if points is not None: 137 | points = transform_points(T1, points) 138 | T2 = align_principle_axes(points) 139 | camtoworlds = transform_cameras(T2, camtoworlds) 140 | points = transform_points(T2, points) 141 | return camtoworlds, points, T2 @ T1 142 | else: 143 | return camtoworlds, T1 144 | -------------------------------------------------------------------------------- /app/tabs/masks_tab.py: -------------------------------------------------------------------------------- 1 | # masks_tab.py 2 | 3 | import os 4 | import sys 5 | import importlib.util 6 | from PyQt5.QtWidgets import ( 7 | QWidget, QVBoxLayout, QHBoxLayout, QSplitter, QLabel, 8 | QTreeWidget, QTreeWidgetItem, QMessageBox 9 | ) 10 | from PyQt5.QtCore import Qt 11 | 12 | from app.base_tab import BaseTab 13 | 14 | class MaskManagerWidget(QWidget): 15 | def __init__(self, mask_manager, parent=None): 16 | super().__init__(parent) 17 | self.mask_manager = mask_manager 18 | layout = QVBoxLayout() 19 | layout.addWidget(self.mask_manager) 20 | self.setLayout(layout) 21 | 22 | class MasksTab(BaseTab): 23 | """Masks tab implementation""" 24 | def __init__(self, workdir=None, image_list=None, parent=None): 25 | super().__init__(workdir, image_list, parent) 26 | self.mask_manager = None 27 | self.mask_manager_widget = None 28 | self.camera_image_tree = None 29 | 30 | def get_tab_name(self): 31 | return "Masks" 32 | 33 | def initialize(self): 34 | """Initialize the Masks tab""" 35 | if not self.workdir: 36 | QMessageBox.warning(self, "Error", "Work directory is not set.") 37 | return 38 | 39 | layout = self.create_horizontal_splitter() 40 | 41 | # Left side: Tree of images grouped by camera 42 | self.camera_image_tree = QTreeWidget() 43 | self.camera_image_tree.setHeaderLabel("Cameras and Images") 44 | self.camera_image_tree.setFixedWidth(250) 45 | layout.addWidget(self.camera_image_tree) 46 | 47 | # Initialize MaskManager 48 | try: 49 | sam2_dir = self.get_sam2_directory() 50 | checkpoint_path = os.path.join(sam2_dir, "checkpoints", "sam2.1_hiera_large.pt") 51 | config_path = os.path.join("configs", "sam2.1", "sam2.1_hiera_l.yaml") 52 | mask_dir = os.path.join(self.workdir, "masks") 53 | img_dir = os.path.join(self.workdir, "images") 54 | 55 | # Ensure mask directory exists 56 | os.makedirs(mask_dir, exist_ok=True) 57 | 58 | # Import MaskManager here to avoid importing when not needed 59 | from app.mask_manager import MaskManager 60 | 61 | self.mask_manager = MaskManager( 62 | checkpoint_path, config_path, mask_dir, img_dir, self.image_list 63 | ) 64 | 65 | # Right side: MaskManager widget 66 | self.mask_manager_widget = MaskManagerWidget(self.mask_manager) 67 | layout.addWidget(self.mask_manager_widget) 68 | 69 | # Set stretch factors 70 | layout.setStretchFactor(0, 1) # Left side (image tree) 71 | layout.setStretchFactor(1, 4) # Right side (MaskManager) 72 | 73 | # Set layout for mask tab 74 | self._layout.addWidget(layout) 75 | 76 | # Populate the camera image tree 77 | self.setup_camera_image_tree(self.camera_image_tree, self.display_mask) 78 | 79 | self.is_initialized = True 80 | 81 | except Exception as e: 82 | error_message = f"Failed to initialize MaskManager: {str(e)}" 83 | QMessageBox.critical(self, "Error", error_message) 84 | placeholder = QLabel("MaskManager could not be initialized. See error message.") 85 | placeholder.setAlignment(Qt.AlignCenter) 86 | layout.addWidget(placeholder) 87 | self._layout.addWidget(layout) 88 | 89 | def display_mask(self, item, column): 90 | """Display the mask for the selected image""" 91 | if item.childCount() == 0 and item.parent() is not None: 92 | image_name = item.text(0) 93 | if self.mask_manager is not None: 94 | self.mask_manager.load_image_by_name(image_name) 95 | 96 | def on_tab_activated(self): 97 | """Called when tab is activated""" 98 | super().on_tab_activated() 99 | # Load SAM model when tab is activated 100 | if self.mask_manager and hasattr(self.mask_manager, 'init_sam_model'): 101 | self.mask_manager.init_sam_model() 102 | 103 | def on_tab_deactivated(self): 104 | """Called when tab is deactivated""" 105 | super().on_tab_deactivated() 106 | # Unload SAM model when tab is deactivated to save memory 107 | if self.mask_manager and hasattr(self.mask_manager, 'unload_sam_model'): 108 | self.mask_manager.unload_sam_model() 109 | 110 | def refresh(self): 111 | """Refresh the tab content""" 112 | # Reinitialize the tab if initialized 113 | if self.is_initialized: 114 | # Remove old widgets 115 | for i in reversed(range(self._layout.count())): 116 | self._layout.itemAt(i).widget().setParent(None) 117 | 118 | # Reinitialize 119 | self.is_initialized = False 120 | self.on_tab_activated() 121 | 122 | def get_sam2_directory(self): 123 | """Get sam2 install directory""" 124 | module_name = 'sam2' 125 | spec = importlib.util.find_spec(module_name) 126 | if spec is None: 127 | raise ModuleNotFoundError(f"{module_name} is not installed.") 128 | return os.path.dirname(os.path.dirname(spec.origin)) -------------------------------------------------------------------------------- /app/tabs/matching_tab.py: -------------------------------------------------------------------------------- 1 | # matching_tab.py 2 | 3 | import os 4 | from PyQt5.QtWidgets import ( 5 | QWidget, QVBoxLayout, QHBoxLayout, QSplitter, QLabel, QTreeWidget, QMessageBox 6 | ) 7 | from PyQt5.QtCore import Qt 8 | 9 | from app.base_tab import BaseTab 10 | 11 | class MatchingTab(BaseTab): 12 | """Matching tab implementation""" 13 | def __init__(self, workdir=None, image_list=None, parent=None): 14 | super().__init__(workdir, image_list, parent) 15 | self.matching_viewer = None 16 | self.camera_image_tree_left = None 17 | self.camera_image_tree_right = None 18 | 19 | # Set up basic UI structure 20 | self.setup_basic_ui() 21 | 22 | def get_tab_name(self): 23 | return "Matching" 24 | 25 | def setup_basic_ui(self): 26 | """Set up the basic UI structure""" 27 | # Create a main horizontal splitter 28 | main_layout = self.create_horizontal_splitter() 29 | 30 | # Left side: Combined tree widget container 31 | left_widget = QWidget() 32 | left_layout = QVBoxLayout(left_widget) 33 | left_layout.setContentsMargins(0, 0, 0, 0) 34 | 35 | # Create label for left and right image selection 36 | left_layout.addWidget(QLabel("Left Image Selection:")) 37 | 38 | # Tree for left image 39 | self.camera_image_tree_left = QTreeWidget() 40 | self.camera_image_tree_left.setHeaderLabel("Cameras and Images - Left") 41 | left_layout.addWidget(self.camera_image_tree_left) 42 | 43 | # Add a separator label 44 | left_layout.addWidget(QLabel("Right Image Selection:")) 45 | 46 | # Tree for right image 47 | self.camera_image_tree_right = QTreeWidget() 48 | self.camera_image_tree_right.setHeaderLabel("Cameras and Images - Right") 49 | left_layout.addWidget(self.camera_image_tree_right) 50 | 51 | # Add the left widget to the main layout with fixed width 52 | left_widget.setFixedWidth(300) # A bit wider to accommodate the trees 53 | main_layout.addWidget(left_widget) 54 | 55 | # Right side: Placeholder for matching viewer 56 | right_widget = QLabel("Matching Viewer will be displayed here.") 57 | right_widget.setAlignment(Qt.AlignCenter) 58 | main_layout.addWidget(right_widget) 59 | 60 | # Set stretch factors 61 | main_layout.setStretchFactor(0, 1) # Left side (trees) 62 | main_layout.setStretchFactor(1, 3) # Right side (matching viewer) 63 | 64 | # Set layout for matching tab 65 | self._layout.addWidget(main_layout) 66 | 67 | def initialize(self): 68 | """Initialize the Matching tab with data""" 69 | if not self.workdir: 70 | QMessageBox.warning(self, "Error", "Work directory is not set.") 71 | return 72 | 73 | try: 74 | # Import FeatureMatching here to avoid circular imports 75 | from app.feature_matching import FeatureMatching 76 | 77 | # Get the main layout and its widgets 78 | main_layout = self._layout.itemAt(0).widget() 79 | right_widget = main_layout.widget(1) # The placeholder 80 | 81 | # Remove the placeholder 82 | right_widget.setParent(None) 83 | 84 | # Create the matching viewer 85 | self.matching_viewer = FeatureMatching(workdir=self.workdir, image_list=self.image_list) 86 | main_layout.addWidget(self.matching_viewer) 87 | 88 | # Set stretch factors again (they might be reset after widget changes) 89 | main_layout.setStretchFactor(0, 1) # Left side (trees) 90 | main_layout.setStretchFactor(1, 3) # Right side (matching viewer) 91 | 92 | # Populate the trees with camera data 93 | self.setup_camera_image_tree(self.camera_image_tree_left, self.on_image_selected_left) 94 | self.setup_camera_image_tree(self.camera_image_tree_right, self.on_image_selected_right) 95 | 96 | self.is_initialized = True 97 | 98 | except Exception as e: 99 | error_message = f"Failed to initialize FeatureMatching: {str(e)}" 100 | QMessageBox.critical(self, "Error", error_message) 101 | 102 | def on_image_selected_left(self, item, column): 103 | """Handle image selection in the left camera tree""" 104 | if not self.is_initialized: 105 | self.initialize() 106 | 107 | if item.childCount() == 0 and item.parent() is not None: 108 | image_name = item.text(0) 109 | if self.matching_viewer: 110 | self.matching_viewer.load_image_by_name(image_name, position="left") 111 | 112 | def on_image_selected_right(self, item, column): 113 | """Handle image selection in the right camera tree""" 114 | if not self.is_initialized: 115 | self.initialize() 116 | 117 | if item.childCount() == 0 and item.parent() is not None: 118 | image_name = item.text(0) 119 | if self.matching_viewer: 120 | self.matching_viewer.load_image_by_name(image_name, position="right") 121 | 122 | def refresh(self): 123 | """Refresh the tab content""" 124 | if self.is_initialized: 125 | # Remove old widgets 126 | for i in reversed(range(self._layout.count())): 127 | self._layout.itemAt(i).widget().setParent(None) 128 | 129 | # Reinitialize 130 | self.setup_basic_ui() 131 | self.is_initialized = False 132 | self.initialize() -------------------------------------------------------------------------------- /tests/test_tab_manager.py: -------------------------------------------------------------------------------- 1 | # test_tab_manager.py 2 | 3 | import pytest 4 | from unittest.mock import MagicMock, patch 5 | from PyQt5.QtWidgets import QWidget 6 | from app.tab_manager import TabManager 7 | from app.base_tab import BaseTab 8 | 9 | class MockTab(BaseTab): 10 | """Mock implementation of BaseTab for testing""" 11 | def __init__(self, workdir=None, image_list=None, parent=None, custom_arg=None): 12 | super().__init__(workdir, image_list, parent) 13 | self.custom_arg = custom_arg 14 | self.activated = False 15 | self.deactivated = False 16 | 17 | def get_tab_name(self): 18 | return "Mock Tab" 19 | 20 | def initialize(self): 21 | self.is_initialized = True 22 | 23 | def on_tab_activated(self): 24 | super().on_tab_activated() 25 | self.activated = True 26 | 27 | def on_tab_deactivated(self): 28 | super().on_tab_deactivated() 29 | self.deactivated = True 30 | 31 | 32 | class SimpleWidget(QWidget): 33 | """Simple widget for testing non-BaseTab tabs""" 34 | def __init__(self, parent=None): 35 | super().__init__(parent) 36 | 37 | 38 | @pytest.fixture 39 | def tab_manager(mock_qapplication): 40 | """Create a TabManager for testing""" 41 | return TabManager() 42 | 43 | 44 | def test_tab_manager_init(tab_manager): 45 | """Test TabManager initialization""" 46 | assert isinstance(tab_manager, TabManager) 47 | assert tab_manager.tab_instances == {} 48 | assert tab_manager.active_tab_index == -1 49 | 50 | 51 | def test_register_tab_with_base_tab(tab_manager): 52 | """Test registering a BaseTab-derived tab""" 53 | # Register a tab 54 | index = tab_manager.register_tab(MockTab, workdir="/test/workdir", custom_arg="test") 55 | 56 | # Check that the tab was added 57 | assert index == 0 58 | assert index in tab_manager.tab_instances 59 | assert isinstance(tab_manager.tab_instances[index], MockTab) 60 | assert tab_manager.tab_instances[index].workdir == "/test/workdir" 61 | assert tab_manager.tab_instances[index].custom_arg == "test" 62 | assert tab_manager.tabText(index) == "Mock Tab" 63 | 64 | 65 | def test_register_tab_with_widget(tab_manager): 66 | """Test registering a non-BaseTab widget""" 67 | # Register a widget with explicit name 68 | index = tab_manager.register_tab(SimpleWidget, tab_name="Simple Widget") 69 | 70 | # Check that the tab was added 71 | assert index == 0 72 | assert index in tab_manager.tab_instances 73 | assert isinstance(tab_manager.tab_instances[index], SimpleWidget) 74 | assert tab_manager.tabText(index) == "Simple Widget" 75 | 76 | # Register a widget without explicit name 77 | index = tab_manager.register_tab(SimpleWidget) 78 | 79 | # Check that the tab was added with class name 80 | assert index == 1 81 | assert tab_manager.tabText(index) == "SimpleWidget" 82 | 83 | 84 | def test_get_tab_instance(tab_manager): 85 | """Test get_tab_instance method""" 86 | # Register tabs 87 | index1 = tab_manager.register_tab(MockTab) 88 | index2 = tab_manager.register_tab(SimpleWidget) 89 | 90 | # Get instances 91 | tab1 = tab_manager.get_tab_instance(index1) 92 | tab2 = tab_manager.get_tab_instance(index2) 93 | 94 | # Check correct instances 95 | assert isinstance(tab1, MockTab) 96 | assert isinstance(tab2, SimpleWidget) 97 | 98 | # Test invalid index 99 | assert tab_manager.get_tab_instance(999) is None 100 | 101 | 102 | @patch('app.tab_manager.TabManager.currentIndex') 103 | def test_get_current_tab(mock_current_index, tab_manager): 104 | """Test get_current_tab method""" 105 | # Register tabs 106 | index1 = tab_manager.register_tab(MockTab) 107 | index2 = tab_manager.register_tab(SimpleWidget) 108 | 109 | # Mock currentIndex 110 | mock_current_index.return_value = index1 111 | 112 | # Get current tab 113 | current_tab = tab_manager.get_current_tab() 114 | 115 | # Check correct instance 116 | assert isinstance(current_tab, MockTab) 117 | 118 | # Change current tab 119 | mock_current_index.return_value = index2 120 | current_tab = tab_manager.get_current_tab() 121 | assert isinstance(current_tab, SimpleWidget) 122 | 123 | 124 | def test_on_tab_changed(tab_manager): 125 | """Test on_tab_changed method""" 126 | # Register tabs 127 | index1 = tab_manager.register_tab(MockTab) 128 | index2 = tab_manager.register_tab(MockTab) 129 | 130 | # Get tab instances 131 | tab1 = tab_manager.get_tab_instance(index1) 132 | tab2 = tab_manager.get_tab_instance(index2) 133 | 134 | # Initially, no tab is active 135 | assert tab_manager.active_tab_index == -1 136 | 137 | # Activate first tab 138 | tab_manager.on_tab_changed(index1) 139 | assert tab1.activated is True 140 | assert tab1.deactivated is False 141 | assert tab2.activated is False 142 | assert tab2.deactivated is False 143 | assert tab_manager.active_tab_index == index1 144 | 145 | # Change to second tab 146 | tab_manager.on_tab_changed(index2) 147 | assert tab1.activated is True 148 | assert tab1.deactivated is True 149 | assert tab2.activated is True 150 | assert tab2.deactivated is False 151 | assert tab_manager.active_tab_index == index2 152 | 153 | 154 | def test_update_all_tabs(tab_manager): 155 | """Test update_all_tabs method""" 156 | # Register tabs 157 | index1 = tab_manager.register_tab(MockTab, workdir="/old/workdir", image_list=["old.jpg"]) 158 | index2 = tab_manager.register_tab(SimpleWidget) # Non-BaseTab widget 159 | 160 | # Get first tab instance 161 | tab1 = tab_manager.get_tab_instance(index1) 162 | 163 | # Mock update methods 164 | tab1.update_workdir = MagicMock() 165 | tab1.update_image_list = MagicMock() 166 | 167 | # Call update_all_tabs 168 | new_workdir = "/new/workdir" 169 | new_image_list = ["new1.jpg", "new2.jpg"] 170 | tab_manager.update_all_tabs(workdir=new_workdir, image_list=new_image_list) 171 | 172 | # Check that update methods were called 173 | tab1.update_workdir.assert_called_once_with(new_workdir) 174 | tab1.update_image_list.assert_called_once_with(new_image_list) 175 | 176 | # Test with only workdir 177 | tab1.update_workdir.reset_mock() 178 | tab1.update_image_list.reset_mock() 179 | tab_manager.update_all_tabs(workdir="/another/workdir") 180 | tab1.update_workdir.assert_called_once_with("/another/workdir") 181 | assert tab1.update_image_list.call_count == 0 182 | 183 | # Test with only image_list 184 | tab1.update_workdir.reset_mock() 185 | tab1.update_image_list.reset_mock() 186 | tab_manager.update_all_tabs(image_list=["another.jpg"]) 187 | assert tab1.update_workdir.call_count == 0 188 | tab1.update_image_list.assert_called_once_with(["another.jpg"]) 189 | -------------------------------------------------------------------------------- /tests/test_images_tab.py: -------------------------------------------------------------------------------- 1 | # test_images_tab.py 2 | 3 | import os 4 | import json 5 | import pytest 6 | from unittest.mock import MagicMock, patch 7 | from PyQt5.QtWidgets import QApplication, QTreeWidgetItem, QDialog 8 | from PyQt5.QtGui import QPixmap 9 | from app.tabs.images_tab import ImagesTab 10 | 11 | @pytest.fixture 12 | def mock_camera_model_manager(): 13 | """Mock CameraModelManager for testing""" 14 | manager = MagicMock() 15 | manager.get_camera_models.return_value = { 16 | "Camera1": { 17 | "projection_type": "perspective", 18 | "focal_ratio": 1.5 19 | } 20 | } 21 | return manager 22 | 23 | @pytest.fixture 24 | def mock_image_processor(): 25 | """Mock ImageProcessor for testing""" 26 | processor = MagicMock() 27 | processor.get_sample_image_dimensions.return_value = (1920, 1080) 28 | processor.resize_images.return_value = 2 # Number of images processed 29 | processor.restore_original_images.return_value = True 30 | return processor 31 | 32 | @pytest.fixture 33 | def setup_images_tab(mock_qapplication, setup_image_folders): 34 | """Set up ImagesTab for testing""" 35 | tab = ImagesTab(setup_image_folders) 36 | return tab 37 | 38 | @patch('app.tabs.images_tab.CameraModelManager') 39 | @patch('app.tabs.images_tab.ImageProcessor') 40 | def test_images_tab_initialize(mock_image_processor_class, mock_model_manager_class, setup_images_tab): 41 | """Test ImagesTab initialization""" 42 | # Setup mock returns 43 | mock_model_manager = MagicMock() 44 | mock_model_manager_class.return_value = mock_model_manager 45 | 46 | mock_processor = MagicMock() 47 | mock_image_processor_class.return_value = mock_processor 48 | 49 | # Initialize the tab 50 | setup_images_tab.initialize() 51 | 52 | # Check that managers were created 53 | assert mock_model_manager_class.call_count == 1 54 | assert mock_image_processor_class.call_count == 1 55 | 56 | # Check that UI elements were created 57 | assert setup_images_tab.camera_image_tree is not None 58 | assert setup_images_tab.image_viewer is not None 59 | assert setup_images_tab.exif_table is not None 60 | assert setup_images_tab.is_initialized is True 61 | 62 | @patch('app.tabs.images_tab.QPixmap') 63 | @patch('os.path.exists') 64 | def test_display_image_and_exif(mock_exists, mock_pixmap, setup_images_tab, monkeypatch): 65 | """Test display_image_and_exif method""" 66 | # Setup mocks 67 | mock_exists.return_value = True 68 | mock_pixmap_instance = MagicMock() 69 | mock_pixmap.return_value = mock_pixmap_instance 70 | 71 | # Mock open and json.load 72 | mock_open = MagicMock() 73 | mock_json_load = MagicMock(return_value={"camera": "Camera1", "width": 1920}) 74 | 75 | monkeypatch.setattr('builtins.open', mock_open) 76 | monkeypatch.setattr('json.load', mock_json_load) 77 | 78 | # Create mock item 79 | item = MagicMock() 80 | item.childCount.return_value = 0 81 | item.parent.return_value = MagicMock() # Not None 82 | item.text.return_value = "test_image.jpg" 83 | 84 | # Initialize necessary UI components 85 | setup_images_tab.image_viewer = MagicMock() 86 | setup_images_tab.exif_table = MagicMock() 87 | setup_images_tab.display_exif_data = MagicMock() # Mock the method to avoid testing it here 88 | 89 | # Call the method 90 | setup_images_tab.display_image_and_exif(item, 0) 91 | 92 | # Check that image was displayed 93 | assert setup_images_tab.image_viewer.setPixmap.call_count == 1 94 | assert mock_pixmap.call_count == 1 95 | 96 | # Check that EXIF was loaded 97 | mock_open.assert_called_once() 98 | mock_json_load.assert_called_once() 99 | setup_images_tab.display_exif_data.assert_called_once() 100 | 101 | @patch('json.dumps') 102 | def test_display_exif_data(mock_json_dumps, setup_images_tab, mock_camera_model_manager): 103 | """Test display_exif_data method""" 104 | # Setup 105 | setup_images_tab.exif_table = MagicMock() 106 | setup_images_tab.camera_model_manager = mock_camera_model_manager 107 | mock_json_dumps.return_value = '{"lat": 35.123, "lon": 139.456}' 108 | 109 | # Sample EXIF data 110 | exif_data = { 111 | "camera": "Camera1", 112 | "make": "Test Make", 113 | "model": "Test Model", 114 | "width": 1920, 115 | "height": 1080, 116 | "gps": {"lat": 35.123, "lon": 139.456} 117 | } 118 | 119 | # Call the method 120 | setup_images_tab.display_exif_data(exif_data) 121 | 122 | # Check that table was populated 123 | assert setup_images_tab.exif_table.setRowCount.call_count >= 1 124 | assert setup_images_tab.exif_table.insertRow.call_count > 0 125 | assert setup_images_tab.exif_table.setItem.call_count > 0 126 | 127 | # Check that overrides were applied (should be using focal_ratio from mock_camera_model_manager) 128 | mock_camera_model_manager.get_camera_models.assert_called_once() 129 | 130 | @patch('app.tabs.images_tab.QMessageBox') 131 | def test_open_camera_model_editor(mock_msgbox, setup_images_tab, mock_camera_model_manager): 132 | """Test open_camera_model_editor method""" 133 | # Setup 134 | setup_images_tab.camera_model_manager = mock_camera_model_manager 135 | 136 | # Call the method 137 | setup_images_tab.open_camera_model_editor() 138 | 139 | # Check that editor was opened 140 | mock_camera_model_manager.open_camera_model_editor.assert_called_once_with(parent=setup_images_tab) 141 | 142 | # Test error case 143 | setup_images_tab.camera_model_manager = None 144 | setup_images_tab.open_camera_model_editor() 145 | mock_msgbox.warning.assert_called_once() 146 | 147 | @patch('app.tabs.images_tab.ResolutionDialog') 148 | @patch('app.tabs.images_tab.ExifExtractProgressDialog') 149 | @patch('app.tabs.images_tab.QMessageBox') 150 | def test_resize_images_in_folder(mock_msgbox, mock_progress_dialog, mock_resolution_dialog, 151 | setup_images_tab, mock_image_processor): 152 | """Test resize_images_in_folder method""" 153 | # Setup 154 | setup_images_tab.image_processor = mock_image_processor 155 | 156 | # Mock dialog 157 | mock_dialog = MagicMock() 158 | mock_dialog.exec_.return_value = QDialog.Accepted 159 | mock_dialog.get_values.return_value = ("Percentage (%)", 50) 160 | mock_resolution_dialog.return_value = mock_dialog 161 | 162 | # Mock progress dialog 163 | mock_progress = MagicMock() 164 | mock_progress_dialog.return_value = mock_progress 165 | 166 | # Call the method 167 | setup_images_tab.resize_images_in_folder() 168 | 169 | # Check that dialog was shown and resize was called 170 | mock_resolution_dialog.assert_called_once_with(1920, 1080, parent=setup_images_tab) 171 | mock_image_processor.resize_images.assert_called_once_with("Percentage (%)", 50) 172 | mock_msgbox.information.assert_called_once() 173 | 174 | # Test dialog rejected 175 | mock_dialog.exec_.return_value = QDialog.Rejected 176 | setup_images_tab.resize_images_in_folder() 177 | assert mock_image_processor.resize_images.call_count == 1 # Still just one call 178 | 179 | # Test error case 180 | setup_images_tab.image_processor = None 181 | setup_images_tab.resize_images_in_folder() 182 | assert mock_msgbox.warning.call_count == 1 183 | 184 | @patch('app.tabs.images_tab.QMessageBox') 185 | def test_restore_original_images(mock_msgbox, setup_images_tab, mock_image_processor): 186 | """Test restore_original_images method""" 187 | # Setup 188 | setup_images_tab.image_processor = mock_image_processor 189 | 190 | # Call the method 191 | setup_images_tab.restore_original_images() 192 | 193 | # Check that restore was called 194 | mock_image_processor.restore_original_images.assert_called_once() 195 | mock_msgbox.information.assert_called_once() 196 | 197 | # Test error case 198 | mock_image_processor.restore_original_images.return_value = False 199 | setup_images_tab.restore_original_images() 200 | assert mock_msgbox.warning.call_count == 1 201 | 202 | # Test no image processor 203 | setup_images_tab.image_processor = None 204 | setup_images_tab.restore_original_images() 205 | assert mock_msgbox.warning.call_count == 2 206 | -------------------------------------------------------------------------------- /tests/test_camera_models.py: -------------------------------------------------------------------------------- 1 | # test_camera_models.py 2 | 3 | import os 4 | import json 5 | import pytest 6 | from unittest.mock import MagicMock, patch 7 | from PyQt5.QtWidgets import QTableWidgetItem, QComboBox 8 | from app.camera_models import CameraModelManager, CameraModelEditor 9 | 10 | @pytest.fixture 11 | def sample_camera_models(): 12 | """Sample camera models for testing""" 13 | return { 14 | "Camera1": { 15 | "projection_type": "perspective", 16 | "width": 1920, 17 | "height": 1080, 18 | "focal_ratio": 1.2 19 | }, 20 | "Camera2": { 21 | "projection_type": "spherical", 22 | "width": 3840, 23 | "height": 2160, 24 | "focal_ratio": 1.0 25 | } 26 | } 27 | 28 | @pytest.fixture 29 | def setup_workdir(temp_workdir, sample_camera_models): 30 | """Set up a temporary working directory with camera model files""" 31 | # Create camera_models.json 32 | camera_models_path = os.path.join(temp_workdir, "camera_models.json") 33 | with open(camera_models_path, "w") as f: 34 | json.dump(sample_camera_models, f) 35 | 36 | # Create camera_models_overrides.json with some overrides 37 | overrides = { 38 | "Camera1": { 39 | "focal_ratio": 1.5 # Override just one parameter 40 | } 41 | } 42 | overrides_path = os.path.join(temp_workdir, "camera_models_overrides.json") 43 | with open(overrides_path, "w") as f: 44 | json.dump(overrides, f) 45 | 46 | return temp_workdir 47 | 48 | 49 | def test_camera_model_manager_init(setup_workdir, sample_camera_models): 50 | """Test CameraModelManager initialization and loading""" 51 | manager = CameraModelManager(setup_workdir) 52 | 53 | # Check that models are loaded 54 | assert manager.camera_models is not None 55 | 56 | # Check that Camera1's focal_ratio was overridden 57 | assert manager.camera_models["Camera1"]["focal_ratio"] == 1.5 58 | 59 | # Check that other parameters are preserved 60 | assert manager.camera_models["Camera1"]["projection_type"] == "perspective" 61 | assert manager.camera_models["Camera2"]["projection_type"] == "spherical" 62 | 63 | 64 | def test_camera_model_manager_get_models(setup_workdir): 65 | """Test get_camera_models method""" 66 | manager = CameraModelManager(setup_workdir) 67 | models = manager.get_camera_models() 68 | 69 | assert "Camera1" in models 70 | assert "Camera2" in models 71 | assert models["Camera1"]["focal_ratio"] == 1.5 # Overridden value 72 | 73 | 74 | @patch('app.camera_models.CameraModelEditor') 75 | def test_camera_model_manager_open_editor(mock_editor, setup_workdir): 76 | """Test open_camera_model_editor method""" 77 | # Set up the mock to return True from exec_() 78 | editor_instance = MagicMock() 79 | editor_instance.exec_.return_value = True 80 | mock_editor.return_value = editor_instance 81 | 82 | manager = CameraModelManager(setup_workdir) 83 | parent = MagicMock() 84 | 85 | # Call the method 86 | result = manager.open_camera_model_editor(parent) 87 | 88 | # Check that the editor was created with correct parameters 89 | mock_editor.assert_called_once_with(manager.camera_models, setup_workdir, parent=parent) 90 | assert result is True 91 | 92 | # Now test with no workdir 93 | manager.workdir = None 94 | result = manager.open_camera_model_editor(parent) 95 | assert result is False 96 | # QMessageBox warning should be called, but we can't test that directly here 97 | 98 | 99 | @patch('app.camera_models.QTableWidgetItem') 100 | @patch('app.camera_models.QComboBox') 101 | def test_camera_model_editor_load_models(mock_combo, mock_item, setup_workdir, sample_camera_models): 102 | """Test CameraModelEditor.load_camera_models""" 103 | # Create mock for QTableWidgetItem 104 | mock_item_instance = MagicMock() 105 | mock_item.return_value = mock_item_instance 106 | 107 | # Create mock for QComboBox 108 | mock_combo_instance = MagicMock() 109 | mock_combo.return_value = mock_combo_instance 110 | 111 | # Create mock for the table 112 | mock_table = MagicMock() 113 | 114 | # Create the editor with mocked table 115 | editor = CameraModelEditor(sample_camera_models, setup_workdir) 116 | editor.table = mock_table 117 | 118 | # Call the method 119 | editor.load_camera_models() 120 | 121 | # Check that rows were added to the table 122 | assert mock_table.setRowCount.call_count >= 1 123 | assert mock_table.insertRow.call_count > 0 124 | assert mock_table.setItem.call_count > 0 125 | 126 | # Check that combo boxes were added for projection_type 127 | assert mock_combo.call_count > 0 128 | 129 | 130 | @patch('app.camera_models.json.dump') 131 | @patch('app.camera_models.QMessageBox') 132 | def test_camera_model_editor_save_changes(mock_msgbox, mock_json_dump, setup_workdir, sample_camera_models): 133 | """Test CameraModelEditor.save_changes""" 134 | editor = CameraModelEditor(sample_camera_models, setup_workdir) 135 | 136 | # Create mock table with data 137 | editor.table = MagicMock() 138 | editor.table.rowCount.return_value = 3 139 | 140 | # Set up the table to return different types of cells 141 | def mock_get_item(row, col): 142 | # For the first row 143 | if row == 0: 144 | if col == 0: # Key column 145 | item = MagicMock(spec=QTableWidgetItem) 146 | item.text.return_value = "Camera1" 147 | return item 148 | elif col == 1: # Parameter column 149 | item = MagicMock(spec=QTableWidgetItem) 150 | item.text.return_value = "focal_ratio" 151 | return item 152 | elif col == 2: # Value column 153 | item = MagicMock(spec=QTableWidgetItem) 154 | item.text.return_value = "1.8" # Float value 155 | return item 156 | # For the second row 157 | elif row == 1: 158 | if col == 0: 159 | item = MagicMock(spec=QTableWidgetItem) 160 | item.text.return_value = "Camera1" 161 | return item 162 | elif col == 1: 163 | item = MagicMock(spec=QTableWidgetItem) 164 | item.text.return_value = "width" 165 | return item 166 | elif col == 2: 167 | item = MagicMock(spec=QTableWidgetItem) 168 | item.text.return_value = "1920" # Integer value 169 | return item 170 | # For the third row with ComboBox 171 | elif row == 2: 172 | if col == 0: 173 | item = MagicMock(spec=QTableWidgetItem) 174 | item.text.return_value = "Camera1" 175 | return item 176 | elif col == 1: 177 | item = MagicMock(spec=QTableWidgetItem) 178 | item.text.return_value = "projection_type" 179 | return item 180 | elif col == 2: 181 | # Return None for cell with widget 182 | return None 183 | return None 184 | 185 | editor.table.item.side_effect = mock_get_item 186 | 187 | # Mock cellWidget for the third row, returning a combo box 188 | def mock_cell_widget(row, col): 189 | if row == 2 and col == 2: 190 | combo = MagicMock(spec=QComboBox) 191 | combo.currentText.return_value = "spherical" 192 | return combo 193 | return None 194 | 195 | editor.table.cellWidget.side_effect = mock_cell_widget 196 | 197 | # Call the method 198 | editor.save_changes() 199 | 200 | # Check that json.dump was called with the expected values 201 | assert mock_json_dump.call_count == 1 202 | 203 | # Get the arguments passed to json.dump 204 | args, kwargs = mock_json_dump.call_args 205 | updated_models = args[0] 206 | 207 | # Check the updated models 208 | assert "Camera1" in updated_models 209 | assert updated_models["Camera1"]["focal_ratio"] == 1.8 210 | assert updated_models["Camera1"]["width"] == 1920 211 | assert updated_models["Camera1"]["projection_type"] == "spherical" 212 | 213 | # Check that success message was shown 214 | assert mock_msgbox.information.call_count == 1 215 | -------------------------------------------------------------------------------- /utils/gsplat_utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from sklearn.neighbors import NearestNeighbors 6 | from torch import Tensor 7 | import torch.nn.functional as F 8 | import matplotlib.pyplot as plt 9 | from matplotlib import colormaps 10 | 11 | 12 | class CameraOptModule(torch.nn.Module): 13 | """Camera pose optimization module.""" 14 | 15 | def __init__(self, n: int): 16 | super().__init__() 17 | # Delta positions (3D) + Delta rotations (6D) 18 | self.embeds = torch.nn.Embedding(n, 9) 19 | # Identity rotation in 6D representation 20 | self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0])) 21 | 22 | def zero_init(self): 23 | torch.nn.init.zeros_(self.embeds.weight) 24 | 25 | def random_init(self, std: float): 26 | torch.nn.init.normal_(self.embeds.weight, std=std) 27 | 28 | def forward(self, camtoworlds: Tensor, embed_ids: Tensor) -> Tensor: 29 | """Adjust camera pose based on deltas. 30 | 31 | Args: 32 | camtoworlds: (..., 4, 4) 33 | embed_ids: (...,) 34 | 35 | Returns: 36 | updated camtoworlds: (..., 4, 4) 37 | """ 38 | assert camtoworlds.shape[:-2] == embed_ids.shape 39 | batch_shape = camtoworlds.shape[:-2] 40 | pose_deltas = self.embeds(embed_ids) # (..., 9) 41 | dx, drot = pose_deltas[..., :3], pose_deltas[..., 3:] 42 | rot = rotation_6d_to_matrix( 43 | drot + self.identity.expand(*batch_shape, -1) 44 | ) # (..., 3, 3) 45 | transform = torch.eye(4, device=pose_deltas.device).repeat((*batch_shape, 1, 1)) 46 | transform[..., :3, :3] = rot 47 | transform[..., :3, 3] = dx 48 | return torch.matmul(camtoworlds, transform) 49 | 50 | 51 | class AppearanceOptModule(torch.nn.Module): 52 | """Appearance optimization module.""" 53 | 54 | def __init__( 55 | self, 56 | n: int, 57 | feature_dim: int, 58 | embed_dim: int = 16, 59 | sh_degree: int = 3, 60 | mlp_width: int = 64, 61 | mlp_depth: int = 2, 62 | ): 63 | super().__init__() 64 | self.embed_dim = embed_dim 65 | self.sh_degree = sh_degree 66 | self.embeds = torch.nn.Embedding(n, embed_dim) 67 | layers = [] 68 | layers.append( 69 | torch.nn.Linear(embed_dim + feature_dim + (sh_degree + 1) ** 2, mlp_width) 70 | ) 71 | layers.append(torch.nn.ReLU(inplace=True)) 72 | for _ in range(mlp_depth - 1): 73 | layers.append(torch.nn.Linear(mlp_width, mlp_width)) 74 | layers.append(torch.nn.ReLU(inplace=True)) 75 | layers.append(torch.nn.Linear(mlp_width, 3)) 76 | self.color_head = torch.nn.Sequential(*layers) 77 | 78 | def forward( 79 | self, features: Tensor, embed_ids: Tensor, dirs: Tensor, sh_degree: int 80 | ) -> Tensor: 81 | """Adjust appearance based on embeddings. 82 | 83 | Args: 84 | features: (N, feature_dim) 85 | embed_ids: (C,) 86 | dirs: (C, N, 3) 87 | 88 | Returns: 89 | colors: (C, N, 3) 90 | """ 91 | from gsplat.cuda._torch_impl import _eval_sh_bases_fast 92 | 93 | C, N = dirs.shape[:2] 94 | # Camera embeddings 95 | if embed_ids is None: 96 | embeds = torch.zeros(C, self.embed_dim, device=features.device) 97 | else: 98 | embeds = self.embeds(embed_ids) # [C, D2] 99 | embeds = embeds[:, None, :].expand(-1, N, -1) # [C, N, D2] 100 | # GS features 101 | features = features[None, :, :].expand(C, -1, -1) # [C, N, D1] 102 | # View directions 103 | dirs = F.normalize(dirs, dim=-1) # [C, N, 3] 104 | num_bases_to_use = (sh_degree + 1) ** 2 105 | num_bases = (self.sh_degree + 1) ** 2 106 | sh_bases = torch.zeros(C, N, num_bases, device=features.device) # [C, N, K] 107 | sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs) 108 | # Get colors 109 | if self.embed_dim > 0: 110 | h = torch.cat([embeds, features, sh_bases], dim=-1) # [C, N, D1 + D2 + K] 111 | else: 112 | h = torch.cat([features, sh_bases], dim=-1) 113 | colors = self.color_head(h) 114 | return colors 115 | 116 | 117 | def rotation_6d_to_matrix(d6: Tensor) -> Tensor: 118 | """ 119 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix 120 | using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d. 121 | Args: 122 | d6: 6D rotation representation, of size (*, 6) 123 | 124 | Returns: 125 | batch of rotation matrices of size (*, 3, 3) 126 | 127 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. 128 | On the Continuity of Rotation Representations in Neural Networks. 129 | IEEE Conference on Computer Vision and Pattern Recognition, 2019. 130 | Retrieved from http://arxiv.org/abs/1812.07035 131 | """ 132 | 133 | a1, a2 = d6[..., :3], d6[..., 3:] 134 | b1 = F.normalize(a1, dim=-1) 135 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 136 | b2 = F.normalize(b2, dim=-1) 137 | b3 = torch.cross(b1, b2, dim=-1) 138 | return torch.stack((b1, b2, b3), dim=-2) 139 | 140 | 141 | def knn(x: Tensor, K: int = 4) -> Tensor: 142 | x_np = x.cpu().numpy() 143 | model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np) 144 | distances, _ = model.kneighbors(x_np) 145 | return torch.from_numpy(distances).to(x) 146 | 147 | 148 | def rgb_to_sh(rgb: Tensor) -> Tensor: 149 | C0 = 0.28209479177387814 150 | return (rgb - 0.5) / C0 151 | 152 | 153 | def set_random_seed(seed: int): 154 | random.seed(seed) 155 | np.random.seed(seed) 156 | torch.manual_seed(seed) 157 | 158 | 159 | # ref: https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/general_utils.py#L163 160 | def colormap(img, cmap="jet"): 161 | W, H = img.shape[:2] 162 | dpi = 300 163 | fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi) 164 | im = ax.imshow(img, cmap=cmap) 165 | ax.set_axis_off() 166 | fig.colorbar(im, ax=ax) 167 | fig.tight_layout() 168 | fig.canvas.draw() 169 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 170 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 171 | img = torch.from_numpy(data).float().permute(2, 0, 1) 172 | plt.close() 173 | return img 174 | 175 | 176 | def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor: 177 | """Convert single channel to a color img. 178 | 179 | Args: 180 | img (torch.Tensor): (..., 1) float32 single channel image. 181 | colormap (str): Colormap for img. 182 | 183 | Returns: 184 | (..., 3) colored img with colors in [0, 1]. 185 | """ 186 | img = torch.nan_to_num(img, 0) 187 | if colormap == "gray": 188 | return img.repeat(1, 1, 3) 189 | img_long = (img * 255).long() 190 | img_long_min = torch.min(img_long) 191 | img_long_max = torch.max(img_long) 192 | assert img_long_min >= 0, f"the min value is {img_long_min}" 193 | assert img_long_max <= 255, f"the max value is {img_long_max}" 194 | return torch.tensor( 195 | colormaps[colormap].colors, # type: ignore 196 | device=img.device, 197 | )[img_long[..., 0]] 198 | 199 | 200 | def apply_depth_colormap( 201 | depth: torch.Tensor, 202 | acc: torch.Tensor = None, 203 | near_plane: float = None, 204 | far_plane: float = None, 205 | ) -> torch.Tensor: 206 | """Converts a depth image to color for easier analysis. 207 | 208 | Args: 209 | depth (torch.Tensor): (..., 1) float32 depth. 210 | acc (torch.Tensor | None): (..., 1) optional accumulation mask. 211 | near_plane: Closest depth to consider. If None, use min image value. 212 | far_plane: Furthest depth to consider. If None, use max image value. 213 | 214 | Returns: 215 | (..., 3) colored depth image with colors in [0, 1]. 216 | """ 217 | near_plane = near_plane or float(torch.min(depth)) 218 | far_plane = far_plane or float(torch.max(depth)) 219 | depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) 220 | depth = torch.clip(depth, 0.0, 1.0) 221 | img = apply_float_colormap(depth, colormap="turbo") 222 | if acc is not None: 223 | img = img * acc + (1.0 - acc) 224 | return img 225 | -------------------------------------------------------------------------------- /tests/test_image_processing.py: -------------------------------------------------------------------------------- 1 | # test_image_processing.py 2 | 3 | import os 4 | import json 5 | import pytest 6 | import shutil 7 | from unittest.mock import MagicMock, patch 8 | from PIL import Image 9 | from app.image_processing import ImageProcessor, ResolutionDialog, ExifExtractProgressDialog 10 | 11 | @pytest.fixture 12 | def setup_image_folders(temp_workdir): 13 | """Set up image folders in temporary directory""" 14 | # Create necessary directories 15 | images_dir = os.path.join(temp_workdir, "images") 16 | exif_dir = os.path.join(temp_workdir, "exif") 17 | os.makedirs(images_dir, exist_ok=True) 18 | os.makedirs(exif_dir, exist_ok=True) 19 | 20 | # Create a test image 21 | test_img_path = os.path.join(images_dir, "test_image.jpg") 22 | img = Image.new('RGB', (100, 50), color = 'red') 23 | img.save(test_img_path) 24 | 25 | # Create another test image 26 | test_img2_path = os.path.join(images_dir, "test_image2.jpg") 27 | img2 = Image.new('RGB', (200, 150), color = 'blue') 28 | img2.save(test_img2_path) 29 | 30 | return temp_workdir 31 | 32 | 33 | def test_image_processor_init(setup_image_folders): 34 | """Test ImageProcessor initialization""" 35 | processor = ImageProcessor(setup_image_folders) 36 | 37 | assert processor.workdir == setup_image_folders 38 | assert os.path.exists(processor.images_folder) 39 | assert processor.images_folder == os.path.join(setup_image_folders, "images") 40 | assert processor.exif_folder == os.path.join(setup_image_folders, "exif") 41 | 42 | 43 | def test_get_sample_image_dimensions(setup_image_folders): 44 | """Test get_sample_image_dimensions method""" 45 | processor = ImageProcessor(setup_image_folders) 46 | width, height = processor.get_sample_image_dimensions() 47 | 48 | # Could be either of the test images, so check both possibilities 49 | assert (width == 100 and height == 50) or (width == 200 and height == 150) 50 | 51 | # Test with empty folder 52 | images_dir = processor.images_folder 53 | for file in os.listdir(images_dir): 54 | if file.endswith('.jpg'): 55 | os.remove(os.path.join(images_dir, file)) 56 | 57 | # Now no images should be found 58 | width, height = processor.get_sample_image_dimensions() 59 | assert width is None and height is None 60 | 61 | 62 | def test_resize_images(setup_image_folders): 63 | """Test resize_images method""" 64 | processor = ImageProcessor(setup_image_folders) 65 | 66 | # Get original dimensions of test image 67 | test_img_path = os.path.join(processor.images_folder, "test_image.jpg") 68 | with Image.open(test_img_path) as img: 69 | original_width = img.width 70 | original_height = img.height 71 | 72 | # Resize to 50% 73 | mock_callback = MagicMock() 74 | num_processed = processor.resize_images("Percentage (%)", 50, mock_callback) 75 | 76 | # Check that the progress callback was called 77 | assert mock_callback.call_count > 0 78 | 79 | # Check that the image was resized 80 | with Image.open(test_img_path) as img: 81 | assert img.width == original_width // 2 82 | assert img.height == original_height // 2 83 | 84 | # Check that backup folder was created 85 | assert os.path.exists(processor.images_org_folder) 86 | 87 | # Test width-based resizing 88 | processor.resize_images("Width (px)", 75) 89 | with Image.open(test_img_path) as img: 90 | assert img.width == 75 91 | 92 | # Test height-based resizing 93 | processor.resize_images("Height (px)", 30) 94 | with Image.open(test_img_path) as img: 95 | assert img.height == 30 96 | 97 | 98 | def test_restore_original_images(setup_image_folders): 99 | """Test restore_original_images method""" 100 | processor = ImageProcessor(setup_image_folders) 101 | 102 | # First resize to create a backup 103 | processor.resize_images("Percentage (%)", 50) 104 | 105 | # Get current dimensions after resize 106 | test_img_path = os.path.join(processor.images_folder, "test_image.jpg") 107 | with Image.open(test_img_path) as img: 108 | resized_width = img.width 109 | resized_height = img.height 110 | 111 | # Now restore 112 | result = processor.restore_original_images() 113 | assert result is True 114 | 115 | # Check that the image was restored to original dimensions 116 | with Image.open(test_img_path) as img: 117 | assert img.width != resized_width # Should not be the resized dimensions 118 | assert img.height != resized_height 119 | 120 | # Test when no backup exists 121 | shutil.rmtree(processor.images_org_folder) 122 | result = processor.restore_original_images() 123 | assert result is False 124 | 125 | 126 | @patch('app.image_processing.piexif.load') 127 | @patch('app.image_processing.piexif.dump') 128 | @patch('app.image_processing.fractions.Fraction') 129 | @patch('app.image_processing.Image.open') 130 | def test_apply_exif_from_mapillary_json(mock_image_open, mock_fraction, mock_dump, mock_load, setup_image_folders, tmp_path): 131 | """Test apply_exif_from_mapillary_json method""" 132 | # Set up mock Image 133 | mock_img = MagicMock() 134 | mock_img.info = {} 135 | mock_image_open.return_value.__enter__.return_value = mock_img 136 | 137 | # Set up mock fractions.Fraction 138 | mock_fraction.return_value.numerator = 1 139 | mock_fraction.return_value.denominator = 2 140 | 141 | # Set up mock load 142 | mock_load.side_effect = KeyError # Force new exif_dict creation 143 | 144 | # Set up mock dump 145 | mock_dump.return_value = b'fake exif data' 146 | 147 | # Create a test JSON file with Mapillary data 148 | json_path = os.path.join(tmp_path, "mapillary_image_description.json") 149 | images_dir = os.path.join(tmp_path, "video_frames") 150 | os.makedirs(images_dir, exist_ok=True) 151 | 152 | # Create a test image 153 | test_img_path = os.path.join(images_dir, "frame_0001.jpg") 154 | with open(test_img_path, 'w') as f: 155 | f.write("fake image data") 156 | 157 | mapillary_data = [ 158 | { 159 | "filename": "frame_0001.jpg", 160 | "MAPLatitude": 35.6895, 161 | "MAPLongitude": 139.6917, 162 | "MAPAltitude": 10.5, 163 | "MAPCaptureTime": "2023_01_01_12_30_45_000", 164 | "MAPOrientation": 1 165 | } 166 | ] 167 | 168 | with open(json_path, 'w') as f: 169 | json.dump(mapillary_data, f) 170 | 171 | # Create the processor and call the method 172 | processor = ImageProcessor(setup_image_folders) 173 | processed_count = processor.apply_exif_from_mapillary_json(json_path, images_dir) 174 | 175 | # Check results 176 | assert processed_count == 1 177 | assert mock_image_open.call_count == 1 178 | assert mock_dump.call_count == 1 179 | assert not os.path.exists(images_dir) # Should be renamed 180 | assert os.path.exists(os.path.join(tmp_path, "images")) # New folder name 181 | 182 | 183 | @patch('app.image_processing.QDialogButtonBox') 184 | def test_resolution_dialog(mock_dialog_button_box, mock_qapplication): 185 | """Test ResolutionDialog""" 186 | dialog = ResolutionDialog(1920, 1080) 187 | 188 | # Test initial values 189 | assert dialog.original_width == 1920 190 | assert dialog.original_height == 1080 191 | assert dialog.aspect_ratio == 1920 / 1080 192 | 193 | # Test update_label method 194 | dialog.resize_method_combo = MagicMock() 195 | dialog.value_input = MagicMock() 196 | 197 | # Test Percentage method 198 | dialog.resize_method_combo.currentText.return_value = "Percentage (%)" 199 | dialog.update_label() 200 | dialog.value_input.setText.assert_called_with("100") 201 | 202 | # Test Width method 203 | dialog.resize_method_combo.currentText.return_value = "Width (px)" 204 | dialog.update_label() 205 | dialog.value_input.setText.assert_called_with("1920") 206 | 207 | # Test Height method 208 | dialog.resize_method_combo.currentText.return_value = "Height (px)" 209 | dialog.update_label() 210 | dialog.value_input.setText.assert_called_with("1080") 211 | 212 | # Test get_values method 213 | dialog.resize_method_combo.currentText.return_value = "Percentage (%)" 214 | dialog.value_input.text.return_value = "75" 215 | method, value = dialog.get_values() 216 | assert method == "Percentage (%)" 217 | assert value == 75.0 218 | 219 | 220 | def test_exif_extract_progress_dialog(mock_qapplication): 221 | """Test ExifExtractProgressDialog""" 222 | # Just basic instantiation test 223 | dialog = ExifExtractProgressDialog("Processing...") 224 | assert dialog.windowTitle() == "Extracting EXIF Data" 225 | 226 | # Test with default message 227 | dialog = ExifExtractProgressDialog() 228 | assert dialog.isModal() is True 229 | -------------------------------------------------------------------------------- /utils/datasets/traj.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code borrowed from 3 | 4 | https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/camera_utils.py 5 | """ 6 | 7 | import numpy as np 8 | import scipy 9 | 10 | 11 | def normalize(x: np.ndarray) -> np.ndarray: 12 | """Normalization helper function.""" 13 | return x / np.linalg.norm(x) 14 | 15 | 16 | def viewmatrix(lookdir: np.ndarray, up: np.ndarray, position: np.ndarray) -> np.ndarray: 17 | """Construct lookat view matrix.""" 18 | vec2 = normalize(lookdir) 19 | vec0 = normalize(np.cross(up, vec2)) 20 | vec1 = normalize(np.cross(vec2, vec0)) 21 | m = np.stack([vec0, vec1, vec2, position], axis=1) 22 | return m 23 | 24 | 25 | def focus_point_fn(poses: np.ndarray) -> np.ndarray: 26 | """Calculate nearest point to all focal axes in poses.""" 27 | directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] 28 | m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1]) 29 | mt_m = np.transpose(m, [0, 2, 1]) @ m 30 | focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] 31 | return focus_pt 32 | 33 | 34 | def average_pose(poses: np.ndarray) -> np.ndarray: 35 | """New pose using average position, z-axis, and up vector of input poses.""" 36 | position = poses[:, :3, 3].mean(0) 37 | z_axis = poses[:, :3, 2].mean(0) 38 | up = poses[:, :3, 1].mean(0) 39 | cam2world = viewmatrix(z_axis, up, position) 40 | return cam2world 41 | 42 | 43 | def generate_spiral_path( 44 | poses, 45 | bounds, 46 | n_frames=120, 47 | n_rots=2, 48 | zrate=0.5, 49 | spiral_scale_f=1.0, 50 | spiral_scale_r=1.0, 51 | focus_distance=0.75, 52 | ): 53 | """Calculates a forward facing spiral path for rendering.""" 54 | # Find a reasonable 'focus depth' for this dataset as a weighted average 55 | # of conservative near and far bounds in disparity space. 56 | near_bound = bounds.min() 57 | far_bound = bounds.max() 58 | # All cameras will point towards the world space point (0, 0, -focal). 59 | focal = 1 / (((1 - focus_distance) / near_bound + focus_distance / far_bound)) 60 | focal = focal * spiral_scale_f 61 | 62 | # Get radii for spiral path using 90th percentile of camera positions. 63 | positions = poses[:, :3, 3] 64 | radii = np.percentile(np.abs(positions), 90, 0) 65 | radii = radii * spiral_scale_r 66 | radii = np.concatenate([radii, [1.0]]) 67 | 68 | # Generate poses for spiral path. 69 | render_poses = [] 70 | cam2world = average_pose(poses) 71 | up = poses[:, :3, 1].mean(0) 72 | for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_frames, endpoint=False): 73 | t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0] 74 | position = cam2world @ t 75 | lookat = cam2world @ [0, 0, -focal, 1.0] 76 | z_axis = position - lookat 77 | render_poses.append(viewmatrix(z_axis, up, position)) 78 | render_poses = np.stack(render_poses, axis=0) 79 | return render_poses 80 | 81 | 82 | def generate_ellipse_path_z( 83 | poses: np.ndarray, 84 | n_frames: int = 120, 85 | # const_speed: bool = True, 86 | variation: float = 0.0, 87 | phase: float = 0.0, 88 | height: float = 0.0, 89 | ) -> np.ndarray: 90 | """Generate an elliptical render path based on the given poses.""" 91 | # Calculate the focal point for the path (cameras point toward this). 92 | center = focus_point_fn(poses) 93 | # Path height sits at z=height (in middle of zero-mean capture pattern). 94 | offset = np.array([center[0], center[1], height]) 95 | 96 | # Calculate scaling for ellipse axes based on input camera positions. 97 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) 98 | # Use ellipse that is symmetric about the focal point in xy. 99 | low = -sc + offset 100 | high = sc + offset 101 | # Optional height variation need not be symmetric 102 | z_low = np.percentile((poses[:, :3, 3]), 10, axis=0) 103 | z_high = np.percentile((poses[:, :3, 3]), 90, axis=0) 104 | 105 | def get_positions(theta): 106 | # Interpolate between bounds with trig functions to get ellipse in x-y. 107 | # Optionally also interpolate in z to change camera height along path. 108 | return np.stack( 109 | [ 110 | low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5), 111 | low[1] + (high - low)[1] * (np.sin(theta) * 0.5 + 0.5), 112 | variation 113 | * ( 114 | z_low[2] 115 | + (z_high - z_low)[2] 116 | * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5) 117 | ) 118 | + height, 119 | ], 120 | -1, 121 | ) 122 | 123 | theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True) 124 | positions = get_positions(theta) 125 | 126 | # if const_speed: 127 | # # Resample theta angles so that the velocity is closer to constant. 128 | # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) 129 | # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1) 130 | # positions = get_positions(theta) 131 | 132 | # Throw away duplicated last position. 133 | positions = positions[:-1] 134 | 135 | # Set path's up vector to axis closest to average of input pose up vectors. 136 | avg_up = poses[:, :3, 1].mean(0) 137 | avg_up = avg_up / np.linalg.norm(avg_up) 138 | ind_up = np.argmax(np.abs(avg_up)) 139 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) 140 | 141 | return np.stack([viewmatrix(center - p, up, p) for p in positions]) 142 | 143 | 144 | def generate_ellipse_path_y( 145 | poses: np.ndarray, 146 | n_frames: int = 120, 147 | # const_speed: bool = True, 148 | variation: float = 0.0, 149 | phase: float = 0.0, 150 | height: float = 0.0, 151 | ) -> np.ndarray: 152 | """Generate an elliptical render path based on the given poses.""" 153 | # Calculate the focal point for the path (cameras point toward this). 154 | center = focus_point_fn(poses) 155 | # Path height sits at y=height (in middle of zero-mean capture pattern). 156 | offset = np.array([center[0], height, center[2]]) 157 | 158 | # Calculate scaling for ellipse axes based on input camera positions. 159 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) 160 | # Use ellipse that is symmetric about the focal point in xy. 161 | low = -sc + offset 162 | high = sc + offset 163 | # Optional height variation need not be symmetric 164 | y_low = np.percentile((poses[:, :3, 3]), 10, axis=0) 165 | y_high = np.percentile((poses[:, :3, 3]), 90, axis=0) 166 | 167 | def get_positions(theta): 168 | # Interpolate between bounds with trig functions to get ellipse in x-z. 169 | # Optionally also interpolate in y to change camera height along path. 170 | return np.stack( 171 | [ 172 | low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5), 173 | variation 174 | * ( 175 | y_low[1] 176 | + (y_high - y_low)[1] 177 | * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5) 178 | ) 179 | + height, 180 | low[2] + (high - low)[2] * (np.sin(theta) * 0.5 + 0.5), 181 | ], 182 | -1, 183 | ) 184 | 185 | theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True) 186 | positions = get_positions(theta) 187 | 188 | # if const_speed: 189 | # # Resample theta angles so that the velocity is closer to constant. 190 | # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) 191 | # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1) 192 | # positions = get_positions(theta) 193 | 194 | # Throw away duplicated last position. 195 | positions = positions[:-1] 196 | 197 | # Set path's up vector to axis closest to average of input pose up vectors. 198 | avg_up = poses[:, :3, 1].mean(0) 199 | avg_up = avg_up / np.linalg.norm(avg_up) 200 | ind_up = np.argmax(np.abs(avg_up)) 201 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) 202 | 203 | return np.stack([viewmatrix(p - center, up, p) for p in positions]) 204 | 205 | 206 | def generate_interpolated_path( 207 | poses: np.ndarray, 208 | n_interp: int, 209 | spline_degree: int = 5, 210 | smoothness: float = 0.03, 211 | rot_weight: float = 0.1, 212 | ): 213 | """Creates a smooth spline path between input keyframe camera poses. 214 | 215 | Spline is calculated with poses in format (position, lookat-point, up-point). 216 | 217 | Args: 218 | poses: (n, 3, 4) array of input pose keyframes. 219 | n_interp: returned path will have n_interp * (n - 1) total poses. 220 | spline_degree: polynomial degree of B-spline. 221 | smoothness: parameter for spline smoothing, 0 forces exact interpolation. 222 | rot_weight: relative weighting of rotation/translation in spline solve. 223 | 224 | Returns: 225 | Array of new camera poses with shape (n_interp * (n - 1), 3, 4). 226 | """ 227 | 228 | def poses_to_points(poses, dist): 229 | """Converts from pose matrices to (position, lookat, up) format.""" 230 | pos = poses[:, :3, -1] 231 | lookat = poses[:, :3, -1] - dist * poses[:, :3, 2] 232 | up = poses[:, :3, -1] + dist * poses[:, :3, 1] 233 | return np.stack([pos, lookat, up], 1) 234 | 235 | def points_to_poses(points): 236 | """Converts from (position, lookat, up) format to pose matrices.""" 237 | return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points]) 238 | 239 | def interp(points, n, k, s): 240 | """Runs multidimensional B-spline interpolation on the input points.""" 241 | sh = points.shape 242 | pts = np.reshape(points, (sh[0], -1)) 243 | k = min(k, sh[0] - 1) 244 | tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s) 245 | u = np.linspace(0, 1, n, endpoint=False) 246 | new_points = np.array(scipy.interpolate.splev(u, tck)) 247 | new_points = np.reshape(new_points.T, (n, sh[1], sh[2])) 248 | return new_points 249 | 250 | points = poses_to_points(poses, dist=rot_weight) 251 | new_points = interp( 252 | points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness 253 | ) 254 | return points_to_poses(new_points) 255 | -------------------------------------------------------------------------------- /app/point_cloud_visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import json 4 | from PyQt5.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QMessageBox 5 | from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer 6 | from PyQt5.QtGui import QVector3D 7 | import pyqtgraph.opengl as gl 8 | import pyqtgraph as pg 9 | from opensfm import dataset 10 | from opensfm.actions import reconstruct, create_tracks 11 | from opensfm.reconstruction import ReconstructionAlgorithm 12 | from scipy.spatial.transform import Rotation 13 | from utils.logger import setup_logger 14 | import multiprocessing 15 | from multiprocessing import Process, Event 16 | import time 17 | logger = setup_logger() 18 | 19 | def safe_load_reconstruction(file_path: str, retries=3, delay=1): 20 | # If the file doesn't exist yet, just return None without logging errors 21 | if not os.path.exists(file_path): 22 | return None 23 | 24 | for attempt in range(retries): 25 | try: 26 | with open(file_path, 'r') as f: 27 | return json.load(f) 28 | except (FileNotFoundError, json.JSONDecodeError) as e: 29 | logger.warning(f"Attempt {attempt + 1}/{retries}: Error loading {file_path}: {e}") 30 | time.sleep(delay) 31 | logger.error(f"Failed to load reconstruction data from {file_path} after {retries} attempts.") 32 | return None 33 | 34 | class ReconstructionThread(QThread): 35 | finished = pyqtSignal() 36 | stopped = pyqtSignal() 37 | 38 | def __init__(self, dataset): 39 | super().__init__() 40 | self.dataset = dataset 41 | self._running = True 42 | 43 | def run(self): 44 | create_tracks.run_dataset(self.dataset) 45 | 46 | if not self._running: 47 | self.stopped.emit() 48 | return 49 | 50 | reconstruct.run_dataset(self.dataset, ReconstructionAlgorithm.INCREMENTAL) 51 | 52 | if not self._running: 53 | self.stopped.emit() 54 | return 55 | 56 | self.finished.emit() 57 | 58 | def stop(self): 59 | self._running = False 60 | 61 | class Reconstruction(QWidget): 62 | def __init__(self, workdir): 63 | super().__init__() 64 | self.workdir = workdir 65 | self.reconstruction_file = os.path.join(workdir, "reconstruction.json") 66 | self.dataset = dataset.DataSet(workdir) 67 | self.camera_size = 1.0 68 | self.last_mod_time = 0 69 | self.camera_items = [] 70 | self.stop_event = multiprocessing.Event() 71 | 72 | self.setFocusPolicy(Qt.StrongFocus) 73 | self.setFocus() 74 | 75 | main_layout = QVBoxLayout(self) 76 | 77 | self.viewer = gl.GLViewWidget() 78 | self.viewer.setFocusPolicy(Qt.NoFocus) 79 | self.viewer.setCameraPosition(distance=50) 80 | main_layout.addWidget(self.viewer) 81 | 82 | button_layout = QHBoxLayout() 83 | 84 | self.run_button = QPushButton("Run Reconstruction") 85 | self.run_button.clicked.connect(self.start_reconstruction) 86 | button_layout.addWidget(self.run_button) 87 | 88 | self.stop_button = QPushButton("Stop Reconstruction") 89 | self.stop_button.clicked.connect(self.stop_reconstruction) 90 | self.stop_button.setEnabled(False) 91 | button_layout.addWidget(self.stop_button) 92 | 93 | self.config_button = QPushButton("Config") 94 | self.config_button.clicked.connect(self.configure_reconstruction) 95 | button_layout.addWidget(self.config_button) 96 | 97 | main_layout.addLayout(button_layout) 98 | 99 | self.timer = QTimer(self) 100 | self.timer.timeout.connect(self.check_for_updates) 101 | self.timer.start(10000) 102 | 103 | self.update_visualization() 104 | 105 | def start_reconstruction(self): 106 | reply = QMessageBox.question( 107 | self, 'Confirm', 'Start Reconstruction?', 108 | QMessageBox.Yes | QMessageBox.No 109 | ) 110 | if reply == QMessageBox.Yes: 111 | self.run_button.setEnabled(False) 112 | self.stop_button.setEnabled(True) 113 | self.stop_event.clear() 114 | self.process = Process(target=self.run_reconstruction, args=(self.stop_event,)) 115 | self.process.start() 116 | 117 | def run_reconstruction(self, stop_event): 118 | create_tracks.run_dataset(self.dataset) 119 | if stop_event.is_set(): 120 | return 121 | reconstruct.run_dataset(self.dataset, ReconstructionAlgorithm.INCREMENTAL) 122 | 123 | def stop_reconstruction(self): 124 | if self.process and self.process.is_alive(): 125 | reply = QMessageBox.question( 126 | self, 'Confirm', 'Stop Reconstruction?', 127 | QMessageBox.Yes | QMessageBox.No 128 | ) 129 | if reply == QMessageBox.Yes: 130 | self.stop_event.set() 131 | self.process.join(timeout=10) 132 | if self.process.is_alive(): 133 | self.process.terminate() 134 | self.process.join() 135 | QMessageBox.information(self, "Stopped", "Reconstruction stopped successfully.") 136 | self.run_button.setEnabled(True) 137 | self.stop_button.setEnabled(False) 138 | 139 | def on_reconstruction_finished(self): 140 | self.run_button.setEnabled(True) 141 | self.stop_button.setEnabled(False) 142 | self.update_visualization() 143 | QMessageBox.information(self, "Done", "Reconstruction completed successfully.") 144 | 145 | def on_reconstruction_stopped(self): 146 | self.run_button.setEnabled(True) 147 | self.stop_button.setEnabled(False) 148 | QMessageBox.information(self, "Stopped", "Reconstruction was stopped.") 149 | 150 | def configure_reconstruction(self): 151 | QMessageBox.information(self, "Config", "Configuration dialog placeholder.") 152 | 153 | def check_for_updates(self): 154 | # Check if the file exists before attempting to get its modification time 155 | if not os.path.exists(self.reconstruction_file): 156 | return 157 | 158 | try: 159 | mod_time = os.path.getmtime(self.reconstruction_file) 160 | if mod_time != getattr(self, 'last_mod_time', None): 161 | self.last_mod_time = mod_time 162 | self.update_visualization() 163 | except (FileNotFoundError, OSError) as e: 164 | # Handle any errors that might occur when checking the file 165 | logger.warning(f"Error checking reconstruction file: {e}") 166 | 167 | def update_visualization(self): 168 | self.viewer.clear() 169 | self.camera_items.clear() 170 | 171 | data = safe_load_reconstruction(self.reconstruction_file) 172 | if not data: 173 | self.show_placeholder() 174 | return 175 | 176 | for reconstruction in data: 177 | points = np.array([p["coordinates"] for p in reconstruction.get("points", {}).values()], dtype=float) 178 | colors = np.array([p["color"] for p in reconstruction.get("points", {}).values()], dtype=float) / 255.0 179 | if points.size: 180 | scatter = gl.GLScatterPlotItem(pos=points, color=colors, size=2) 181 | scatter.setGLOptions('translucent') 182 | self.viewer.addItem(scatter) 183 | 184 | for shot_name, shot in reconstruction.get("shots", {}).items(): 185 | rotation = Rotation.from_rotvec(np.array(shot["rotation"])).as_matrix() 186 | translation = np.array(shot["translation"]) 187 | position = -rotation.T @ translation 188 | cam_type = reconstruction["cameras"][shot["camera"]]["projection_type"] 189 | self.add_camera_visualization(shot_name, rotation, position, cam_type, self.camera_size) 190 | 191 | def closeEvent(self, event): 192 | self.stop_reconstruction() 193 | super().closeEvent(event) 194 | 195 | def add_camera_visualization(self, cam_name, R, t, cam_model, size=1.0): 196 | if cam_model in ["spherical", "equirectangular"]: 197 | sphere_mesh = pg.opengl.MeshData.sphere(rows=10, cols=20, radius=size) 198 | sphere = gl.GLMeshItem(meshdata=sphere_mesh, color=(1, 1, 1, 0.3), smooth=True) 199 | sphere.translate(*t) 200 | self.viewer.addItem(sphere) 201 | self.camera_items.append((cam_name, sphere, t, R, cam_model)) 202 | else: 203 | frustum_size = size * 5 204 | vertices = np.array([ 205 | [0, 0, 0], 206 | [1, 1, -2], 207 | [1, -1, -2], 208 | [-1, 1, -2], 209 | [-1, -1, -2], 210 | ]) * frustum_size 211 | vertices = vertices @ -R + t 212 | 213 | edges = [ 214 | (0, 1), (0, 2), (0, 3), (0, 4), 215 | (1, 2), (2, 4), (4, 3), (3, 1) 216 | ] 217 | 218 | lines = [] 219 | for s, e in edges: 220 | line = gl.GLLinePlotItem(pos=np.array([vertices[s], vertices[e]]), color=(1, 1, 1, 0.5), width=1) 221 | self.viewer.addItem(line) 222 | lines.append(line) 223 | 224 | self.camera_items.append((cam_name, lines, t, R, cam_model)) 225 | 226 | def show_placeholder(self): 227 | sphere_mesh = pg.opengl.MeshData.sphere(rows=10, cols=20, radius=10) 228 | sphere = gl.GLMeshItem(meshdata=sphere_mesh, color=(0.5, 0.5, 0.5, 0.3), smooth=True) 229 | self.viewer.addItem(sphere) 230 | 231 | def move_to_camera(self, image_name): 232 | for name, _, pos, R, _ in self.camera_items: 233 | if name == image_name: 234 | elevation = np.degrees(np.arcsin(-R[2, 0])) 235 | azimuth = np.degrees(np.arctan2(R[1, 0], R[0, 0])) 236 | self.viewer.setCameraPosition(pos=QVector3D(*pos), distance=20, elevation=elevation, azimuth=azimuth) 237 | break 238 | 239 | def highlight_camera(self, cam_name): 240 | for name, item, _, _, model in self.camera_items: 241 | color = (1, 0, 0, 0.8) if name == cam_name else (1, 1, 1, 0.3) 242 | if model in ["spherical", "equirectangular"]: 243 | item.setColor(color) 244 | else: 245 | for line in item: 246 | line.setData(color=color) 247 | 248 | def on_camera_image_tree_click(self, image_name): 249 | self.highlight_camera(image_name) 250 | 251 | def on_camera_image_tree_double_click(self, image_name): 252 | self.move_to_camera(image_name) 253 | 254 | def keyPressEvent(self, event): 255 | if event.key() == Qt.Key_Plus: 256 | self.camera_size *= 1.1 257 | self.update_visualization() 258 | elif event.key() == Qt.Key_Minus: 259 | self.camera_size /= 1.1 260 | self.update_visualization() 261 | else: 262 | super().keyPressEvent(event) 263 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /app/image_processing.py: -------------------------------------------------------------------------------- 1 | # image_processing.py 2 | 3 | import os 4 | import json 5 | import shutil 6 | import datetime 7 | import fractions 8 | from PIL import Image 9 | import piexif 10 | from PyQt5.QtWidgets import QDialog, QVBoxLayout, QFormLayout, QLabel, QLineEdit, QComboBox, QDialogButtonBox, QCheckBox 11 | from PyQt5.QtCore import Qt 12 | 13 | class ResolutionDialog(QDialog): 14 | """Dialog for resizing images""" 15 | def __init__(self, current_width, current_height, parent=None): 16 | super().__init__(parent) 17 | self.setWindowTitle("Resize Images") 18 | self.setFixedSize(300, 150) 19 | 20 | self.original_width = current_width 21 | self.original_height = current_height 22 | self.aspect_ratio = current_width / current_height 23 | 24 | layout = QVBoxLayout() 25 | 26 | self.resize_method_combo = QComboBox() 27 | self.resize_method_combo.addItems(["Percentage (%)", "Width (px)", "Height (px)"]) 28 | self.resize_method_combo.currentIndexChanged.connect(self.update_label) 29 | layout.addWidget(QLabel("Resize method:")) 30 | layout.addWidget(self.resize_method_combo) 31 | 32 | form_layout = QFormLayout() 33 | self.value_input = QLineEdit("100") 34 | form_layout.addRow(QLabel("Value:"), self.value_input) 35 | 36 | layout.addLayout(form_layout) 37 | 38 | button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) 39 | button_box.accepted.connect(self.accept) 40 | button_box.rejected.connect(self.reject) 41 | 42 | layout.addWidget(button_box) 43 | self.setLayout(layout) 44 | 45 | def update_label(self): 46 | method = self.resize_method_combo.currentText() 47 | if method == "Percentage (%)": 48 | self.value_input.setText("100") 49 | elif method == "Width (px)": 50 | self.value_input.setText(str(self.original_width)) 51 | else: 52 | self.value_input.setText(str(self.original_height)) 53 | 54 | def get_values(self): 55 | method = self.resize_method_combo.currentText() 56 | try: 57 | value = float(self.value_input.text()) 58 | except (ValueError, TypeError): 59 | value = 100.0 if method == "Percentage (%)" else (self.original_width if method == "Width (px)" else self.original_height) 60 | return method, value 61 | 62 | 63 | class ExifExtractProgressDialog(QDialog): 64 | """Progress dialog for EXIF extraction""" 65 | def __init__(self, message="Please Wait...", parent=None): 66 | super().__init__(parent) 67 | self.setWindowTitle("Extracting EXIF Data") 68 | self.setModal(True) 69 | self.setFixedSize(520, 120) 70 | 71 | layout = QVBoxLayout() 72 | label = QLabel(message) 73 | label.setAlignment(Qt.AlignCenter) 74 | label.setStyleSheet("font-size: 14px; padding: 10px;") 75 | layout.addWidget(label) 76 | 77 | self.setLayout(layout) 78 | 79 | 80 | class ImageProcessor: 81 | """Class for handling image processing operations""" 82 | def __init__(self, workdir): 83 | self.workdir = workdir 84 | self.images_folder = os.path.join(workdir, "images") 85 | self.images_org_folder = os.path.join(workdir, "images_org") 86 | self.exif_folder = os.path.join(workdir, "exif") 87 | 88 | # Ensure directories exist 89 | os.makedirs(self.images_folder, exist_ok=True) 90 | os.makedirs(self.exif_folder, exist_ok=True) 91 | 92 | def resize_images(self, method, value, progress_callback=None): 93 | """Resize images in the images folder""" 94 | # Backup original images if not already backed up 95 | if not os.path.exists(self.images_org_folder): 96 | shutil.copytree(self.images_folder, self.images_org_folder) 97 | 98 | # Get list of image files 99 | image_files = [f for f in os.listdir(self.images_folder) 100 | if f.lower().endswith(('.jpg', '.jpeg', '.png'))] 101 | 102 | # Process each image 103 | for i, image_file in enumerate(image_files): 104 | image_path = os.path.join(self.images_folder, image_file) 105 | 106 | try: 107 | # Open image 108 | with Image.open(image_path) as img: 109 | # Calculate new dimensions based on method 110 | if method == "Percentage (%)": 111 | scale = value / 100 112 | new_width = int(img.width * scale) 113 | new_height = int(img.height * scale) 114 | elif method == "Width (px)": 115 | new_width = int(value) 116 | new_height = int(new_width * img.height / img.width) 117 | else: # Height (px) 118 | new_height = int(value) 119 | new_width = int(new_height * img.width / img.height) 120 | 121 | # Ensure minimum dimensions 122 | new_width = max(1, new_width) 123 | new_height = max(1, new_height) 124 | 125 | # Resize image 126 | img_resized = img.resize((new_width, new_height), Image.LANCZOS) 127 | 128 | # Save resized image 129 | img_resized.save(image_path) 130 | except Exception as e: 131 | print(f"Error resizing image {image_file}: {e}") 132 | 133 | # Call progress callback if provided 134 | if progress_callback: 135 | progress = (i + 1) / len(image_files) * 100 136 | progress_callback(int(progress)) 137 | 138 | return len(image_files) 139 | 140 | def restore_original_images(self): 141 | """Restore original images from backup""" 142 | if os.path.exists(self.images_org_folder): 143 | try: 144 | shutil.rmtree(self.images_folder) 145 | shutil.copytree(self.images_org_folder, self.images_folder) 146 | return True 147 | except Exception as e: 148 | print(f"Error restoring original images: {e}") 149 | return False 150 | return False 151 | 152 | def get_sample_image_dimensions(self): 153 | """Get dimensions of a sample image in the folder""" 154 | try: 155 | image_files = [f for f in os.listdir(self.images_folder) 156 | if f.lower().endswith(('.jpg', '.jpeg', '.png'))] 157 | 158 | if not image_files: 159 | return None, None 160 | 161 | sample_image_path = os.path.join(self.images_folder, image_files[0]) 162 | with Image.open(sample_image_path) as img: 163 | return img.width, img.height 164 | except Exception as e: 165 | print(f"Error getting sample image dimensions: {e}") 166 | return None, None 167 | 168 | @staticmethod 169 | def convert_to_degrees(value): 170 | """Convert decimal GPS coordinates to degrees, minutes, seconds""" 171 | d = int(value) 172 | m = int((value - d) * 60) 173 | s = (value - d - m / 60) * 3600 174 | return d, m, s 175 | 176 | @staticmethod 177 | def convert_to_rational(number): 178 | """Convert a number to a rational for EXIF data""" 179 | f = fractions.Fraction(str(number)).limit_denominator() 180 | return f.numerator, f.denominator 181 | 182 | def apply_exif_from_mapillary_json(self, json_path, images_dir): 183 | """Apply EXIF data from Mapillary JSON to images""" 184 | if not os.path.exists(json_path): 185 | print(f"JSON file not found: {json_path}") 186 | return 0 187 | 188 | if not os.path.exists(images_dir): 189 | print(f"Images directory not found: {images_dir}") 190 | return 0 191 | 192 | try: 193 | with open(json_path, 'r') as file: 194 | metadata_list = json.load(file) 195 | except Exception as e: 196 | print(f"Error loading JSON file: {e}") 197 | return 0 198 | 199 | processed_count = 0 200 | for metadata in metadata_list: 201 | image_filename = metadata.get('filename') 202 | if not image_filename: 203 | continue 204 | 205 | image_path = os.path.join(images_dir, os.path.basename(image_filename)) 206 | if not os.path.exists(image_path): 207 | print(f"Image not found: {image_path}") 208 | continue 209 | 210 | try: 211 | img = Image.open(image_path) 212 | try: 213 | exif_dict = piexif.load(img.info['exif']) 214 | except (KeyError, piexif.InvalidImageDataError): 215 | exif_dict = {"0th": {}, "Exif": {}, "GPS": {}, "1st": {}} 216 | 217 | # Check that required keys exist in metadata 218 | if all(key in metadata for key in ['MAPLatitude', 'MAPLongitude', 'MAPAltitude', 'MAPCaptureTime']): 219 | lat_deg = self.convert_to_degrees(metadata['MAPLatitude']) 220 | lon_deg = self.convert_to_degrees(metadata['MAPLongitude']) 221 | 222 | gps_ifd = { 223 | piexif.GPSIFD.GPSLatitudeRef: 'N' if metadata['MAPLatitude'] >= 0 else 'S', 224 | piexif.GPSIFD.GPSLongitudeRef: 'E' if metadata['MAPLongitude'] >= 0 else 'W', 225 | piexif.GPSIFD.GPSLatitude: [ 226 | self.convert_to_rational(lat_deg[0]), 227 | self.convert_to_rational(lat_deg[1]), 228 | self.convert_to_rational(lat_deg[2]) 229 | ], 230 | piexif.GPSIFD.GPSLongitude: [ 231 | self.convert_to_rational(lon_deg[0]), 232 | self.convert_to_rational(lon_deg[1]), 233 | self.convert_to_rational(lon_deg[2]) 234 | ], 235 | piexif.GPSIFD.GPSAltitude: self.convert_to_rational(metadata['MAPAltitude']), 236 | piexif.GPSIFD.GPSAltitudeRef: 0 if metadata['MAPAltitude'] >= 0 else 1, 237 | } 238 | 239 | exif_dict['GPS'] = gps_ifd 240 | 241 | capture_time = datetime.datetime.strptime(metadata['MAPCaptureTime'], '%Y_%m_%d_%H_%M_%S_%f') 242 | exif_dict['Exif'][piexif.ExifIFD.DateTimeOriginal] = capture_time.strftime('%Y:%m:%d %H:%M:%S') 243 | exif_dict['0th'][piexif.ImageIFD.Orientation] = metadata.get('MAPOrientation', 1) 244 | 245 | exif_bytes = piexif.dump(exif_dict) 246 | img.save(image_path, "jpeg", exif=exif_bytes) 247 | 248 | processed_count += 1 249 | else: 250 | print(f"Missing required metadata for {image_filename}") 251 | 252 | img.close() 253 | 254 | except Exception as e: 255 | print(f"Failed to update EXIF for {image_path}: {e}") 256 | 257 | # After writing EXIF, rename the folder to 'images' 258 | try: 259 | images_parent_dir = os.path.dirname(images_dir) 260 | final_images_dir = os.path.join(images_parent_dir, "images") 261 | 262 | # If 'images' folder already exists, remove it and replace 263 | if os.path.exists(final_images_dir): 264 | shutil.rmtree(final_images_dir) 265 | os.rename(images_dir, final_images_dir) 266 | print(f"Renamed folder to {final_images_dir}") 267 | except Exception as e: 268 | print(f"Error renaming directory: {e}") 269 | 270 | return processed_count -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # Metadata 2 | use_exif_size: yes 3 | unknown_camera_models_are_different: no # Treat images from unknown camera models as comming from different cameras 4 | default_focal_prior: 0.85 5 | 6 | # Params for features 7 | feature_type: ALIKED # Feature type (AKAZE, SURF, SIFT, HAHOG, ORB) 8 | feature_root: 1 # If 1, apply square root mapping to features 9 | feature_min_frames: 4000 # If fewer frames are detected, sift_peak_threshold/surf_hessian_threshold is reduced. 10 | feature_min_frames_panorama: 16000 # Same as above but for panorama images 11 | feature_process_size: 2048 # Resize the image if its size is larger than specified. Set to -1 for original size 12 | feature_process_size_panorama: 4096 # Same as above but for panorama images 13 | feature_use_adaptive_suppression: no 14 | features_bake_segmentation: no # Bake segmentation info (class and instance) in the feature data. Thus it is done once for all at extraction time. 15 | 16 | # Params for SIFT 17 | sift_peak_threshold: 0.1 # Smaller value -> more features 18 | sift_edge_threshold: 10 # See OpenCV doc 19 | 20 | # Params for SURF 21 | surf_hessian_threshold: 3000 # Smaller value -> more features 22 | surf_n_octaves: 4 # See OpenCV doc 23 | surf_n_octavelayers: 2 # See OpenCV doc 24 | surf_upright: 0 # See OpenCV doc 25 | 26 | # Params for AKAZE (See details in lib/src/third_party/akaze/AKAZEConfig.h) 27 | akaze_omax: 4 # Maximum octave evolution of the image 2^sigma (coarsest scale sigma units) 28 | akaze_dthreshold: 0.001 # Detector response threshold to accept point 29 | akaze_descriptor: MSURF # Feature type 30 | akaze_descriptor_size: 0 # Size of the descriptor in bits. 0->Full size 31 | akaze_descriptor_channels: 3 # Number of feature channels (1,2,3) 32 | akaze_kcontrast_percentile: 0.7 33 | akaze_use_isotropic_diffusion: no 34 | 35 | # Params for HAHOG 36 | hahog_peak_threshold: 0.00001 37 | hahog_edge_threshold: 10 38 | hahog_normalize_to_uchar: yes 39 | 40 | # Params for general matching 41 | lowes_ratio: 0.8 # Ratio test for matches 42 | matcher_type: FLANN # FLANN, BRUTEFORCE, or WORDS 43 | symmetric_matching: yes # Match symmetricly or one-way 44 | 45 | # Params for FLANN matching 46 | flann_algorithm: KMEANS # Algorithm type (KMEANS, KDTREE) 47 | flann_branching: 8 # See OpenCV doc 48 | flann_iterations: 10 # See OpenCV doc 49 | flann_tree: 8 # See OpenCV doc 50 | flann_checks: 20 # Smaller -> Faster (but might lose good matches) 51 | 52 | # Params for BoW matching 53 | bow_file: bow_hahog_root_uchar_10000.npz 54 | bow_words_to_match: 50 # Number of words to explore per feature. 55 | bow_num_checks: 20 # Number of matching features to check. 56 | bow_matcher_type: FLANN # Matcher type to assign words to features 57 | 58 | # Params for VLAD matching 59 | vlad_file: bow_hahog_root_uchar_64.npz 60 | 61 | # Params for matching 62 | matching_gps_distance: 150 # Maximum gps distance between two images for matching 63 | matching_gps_neighbors: 0 # Number of images to match selected by GPS distance. Set to 0 to use no limit (or disable if matching_gps_distance is also 0) 64 | matching_time_neighbors: 0 # Number of images to match selected by time taken. Set to 0 to disable 65 | matching_order_neighbors: 0 # Number of images to match selected by image name. Set to 0 to disable 66 | matching_bow_neighbors: 0 # Number of images to match selected by BoW distance. Set to 0 to disable 67 | matching_bow_gps_distance: 0 # Maximum GPS distance for preempting images before using selection by BoW distance. Set to 0 to disable 68 | matching_bow_gps_neighbors: 0 # Number of images (selected by GPS distance) to preempt before using selection by BoW distance. Set to 0 to use no limit (or disable if matching_bow_gps_distance is also 0) 69 | matching_bow_other_cameras: False # If True, BoW image selection will use N neighbors from the same camera + N neighbors from any different camera. If False, the selection will take the nearest neighbors from all cameras. 70 | matching_vlad_neighbors: 0 # Number of images to match selected by VLAD distance. Set to 0 to disable 71 | matching_vlad_gps_distance: 0 # Maximum GPS distance for preempting images before using selection by VLAD distance. Set to 0 to disable 72 | matching_vlad_gps_neighbors: 0 # Number of images (selected by GPS distance) to preempt before using selection by VLAD distance. Set to 0 to use no limit (or disable if matching_vlad_gps_distance is also 0) 73 | matching_vlad_other_cameras: False # If True, VLAD image selection will use N neighbors from the same camera + N neighbors from any different camera. If False, the selection will take the nearest neighbors from all cameras. 74 | matching_graph_rounds: 0 # Number of rounds to run when running triangulation-based pair selection 75 | matching_use_filters: False # If True, removes static matches using ad-hoc heuristics 76 | matching_use_segmentation: no # Use segmentation information (if available) to improve matching 77 | 78 | # Params for geometric estimation 79 | robust_matching_threshold: 0.004 # Outlier threshold for fundamental matrix estimation as portion of image width 80 | robust_matching_calib_threshold: 0.004 # Outlier threshold for essential matrix estimation during matching in radians 81 | robust_matching_min_match: 20 # Minimum number of matches to accept matches between two images 82 | five_point_algo_threshold: 0.004 # Outlier threshold for essential matrix estimation during incremental reconstruction in radians 83 | five_point_algo_min_inliers: 20 # Minimum number of inliers for considering a two view reconstruction valid 84 | five_point_refine_match_iterations: 10 # Number of LM iterations to run when refining relative pose during matching 85 | five_point_refine_rec_iterations: 1000 # Number of LM iterations to run when refining relative pose during reconstruction 86 | triangulation_threshold: 0.006 # Outlier threshold for accepting a triangulated point in radians 87 | triangulation_min_ray_angle: 1.0 # Minimum angle between views to accept a triangulated point 88 | triangulation_type: FULL # Triangulation type : either considering all rays (FULL), or sing a RANSAC variant (ROBUST) 89 | resection_threshold: 0.004 # Outlier threshold for resection in radians 90 | resection_min_inliers: 10 # Minimum number of resection inliers to accept it 91 | 92 | # Params for track creation 93 | min_track_length: 2 # Minimum number of features/images per track 94 | 95 | # Params for bundle adjustment 96 | loss_function: SoftLOneLoss # Loss function for the ceres problem (see: http://ceres-solver.org/modeling.html#lossfunction) 97 | loss_function_threshold: 1 # Threshold on the squared residuals. Usually cost is quadratic for smaller residuals and sub-quadratic above. 98 | reprojection_error_sd: 0.004 # The standard deviation of the reprojection error 99 | exif_focal_sd: 0.01 # The standard deviation of the exif focal length in log-scale 100 | principal_point_sd: 0.01 # The standard deviation of the principal point coordinates 101 | radial_distortion_k1_sd: 0.01 # The standard deviation of the first radial distortion parameter 102 | radial_distortion_k2_sd: 0.01 # The standard deviation of the second radial distortion parameter 103 | radial_distortion_k3_sd: 0.01 # The standard deviation of the third radial distortion parameter 104 | radial_distortion_k4_sd: 0.01 # The standard deviation of the fourth radial distortion parameter 105 | tangential_distortion_p1_sd: 0.01 # The standard deviation of the first tangential distortion parameter 106 | tangential_distortion_p2_sd: 0.01 # The standard deviation of the second tangential distortion parameter 107 | gcp_horizontal_sd: 0.01 # The default horizontal standard deviation of the GCPs (in meters) 108 | gcp_vertical_sd: 0.1 # The default vertical standard deviation of the GCPs (in meters) 109 | rig_translation_sd: 0.1 # The standard deviation of the rig translation 110 | rig_rotation_sd: 0.1 # The standard deviation of the rig rotation 111 | bundle_outlier_filtering_type: FIXED # Type of threshold for filtering outlier : either fixed value (FIXED) or based on actual distribution (AUTO) 112 | bundle_outlier_auto_ratio: 3.0 # For AUTO filtering type, projections with larger reprojection than ratio-times-mean, are removed 113 | bundle_outlier_fixed_threshold: 0.006 # For FIXED filtering type, projections with larger reprojection error after bundle adjustment are removed 114 | optimize_camera_parameters: yes # Optimize internal camera parameters during bundle 115 | bundle_max_iterations: 100 # Maximum optimizer iterations. 116 | 117 | retriangulation: yes # Retriangulate all points from time to time 118 | retriangulation_ratio: 1.2 # Retriangulate when the number of points grows by this ratio 119 | bundle_analytic_derivatives: yes # Use analytic derivatives or auto-differentiated ones during bundle adjustment 120 | bundle_interval: 999999 # Bundle after adding 'bundle_interval' cameras 121 | bundle_new_points_ratio: 1.2 # Bundle when the number of points grows by this ratio 122 | local_bundle_radius: 3 # Max image graph distance for images to be included in local bundle adjustment 123 | local_bundle_min_common_points: 20 # Minimum number of common points betwenn images to be considered neighbors 124 | local_bundle_max_shots: 30 # Max number of shots to optimize during local bundle adjustment 125 | 126 | save_partial_reconstructions: yes # Save reconstructions at every iteration 127 | 128 | # Params for GPS alignment 129 | use_altitude_tag: no # Use or ignore EXIF altitude tag 130 | align_method: auto # orientation_prior or naive 131 | align_orientation_prior: horizontal # horizontal, vertical or no_roll 132 | bundle_use_gps: yes # Enforce GPS position in bundle adjustment 133 | bundle_use_gcp: no # Enforce Ground Control Point position in bundle adjustment 134 | bundle_compensate_gps_bias: yes # Compensate GPS with a per-camera similarity transform 135 | 136 | 137 | # Params for rigs 138 | rig_calibration_subset_size: 15 # Number of rig instances to use when calibration rigs 139 | rig_calibration_completeness: 0.85 # Ratio of reconstructed images needed to consider a reconstruction for rig calibration 140 | rig_calibration_max_rounds: 10 # Number of SfM tentatives to run until we get a satisfying reconstruction 141 | 142 | # Params for image undistortion 143 | undistorted_image_format: jpg # Format in which to save the undistorted images 144 | undistorted_image_max_size: 100000 # Max width and height of the undistorted image 145 | 146 | # Params for depth estimation 147 | depthmap_method: PATCH_MATCH_SAMPLE # Raw depthmap computation algorithm (PATCH_MATCH, BRUTE_FORCE, PATCH_MATCH_SAMPLE) 148 | depthmap_resolution: 640 # Resolution of the depth maps 149 | depthmap_num_neighbors: 10 # Number of neighboring views 150 | depthmap_num_matching_views: 6 # Number of neighboring views used for each depthmaps 151 | depthmap_min_depth: 0 # Minimum depth in meters. Set to 0 to auto-infer from the reconstruction. 152 | depthmap_max_depth: 0 # Maximum depth in meters. Set to 0 to auto-infer from the reconstruction. 153 | depthmap_patchmatch_iterations: 3 # Number of PatchMatch iterations to run 154 | depthmap_patch_size: 7 # Size of the correlation patch 155 | depthmap_min_patch_sd: 1.0 # Patches with lower standard deviation are ignored 156 | depthmap_min_correlation_score: 0.1 # Minimum correlation score to accept a depth value 157 | depthmap_same_depth_threshold: 0.01 # Threshold to measure depth closeness 158 | depthmap_min_consistent_views: 3 # Min number of views that should reconstruct a point for it to be valid 159 | depthmap_save_debug_files: no # Save debug files with partial reconstruction results 160 | 161 | # Other params 162 | processes: 1 # Number of threads to use 163 | read_processes: 1 # When processes > 1, number of threads used for reading images 164 | 165 | # Params for submodel split and merge 166 | submodel_size: 80 # Average number of images per submodel 167 | submodel_overlap: 30.0 # Radius of the overlapping region between submodels 168 | submodels_relpath: "submodels" # Relative path to the submodels directory 169 | submodel_relpath_template: "submodels/submodel_%04d" # Template to generate the relative path to a submodel directory 170 | submodel_images_relpath_template: "submodels/submodel_%04d/images" # Template to generate the relative path to a submodel images directory -------------------------------------------------------------------------------- /tests/test_main_app.py: -------------------------------------------------------------------------------- 1 | # test_main_app.py 2 | 3 | import os 4 | import sys 5 | import pytest 6 | import shutil 7 | from unittest.mock import MagicMock, patch 8 | from PyQt5.QtWidgets import QApplication, QDialog, QMessageBox, QFileDialog 9 | from app.main_app import MainApp, VideoProcessDialog 10 | 11 | @pytest.fixture 12 | def setup_main_app(mock_qapplication): 13 | """Set up MainApp for testing with mocked dependencies""" 14 | with patch('app.main_app.TabManager'), \ 15 | patch('app.main_app.QTimer'), \ 16 | patch('app.main_app.QToolBar'): 17 | app = MainApp() 18 | app.show_start_dialog = MagicMock() # Prevent actual dialog on init 19 | return app 20 | 21 | 22 | def test_main_app_init(setup_main_app): 23 | """Test MainApp initialization""" 24 | app = setup_main_app 25 | 26 | # Check attributes 27 | assert app.workdir is None 28 | assert app.image_list == [] 29 | assert app.tab_manager is not None 30 | assert app.camera_model_manager is None 31 | assert app.image_processor is None 32 | 33 | # Check that QTimer was set to call show_start_dialog 34 | from app.main_app import QTimer 35 | assert QTimer.singleShot.call_count == 1 36 | 37 | 38 | def test_setup_ui(setup_main_app): 39 | """Test setup_ui method""" 40 | app = setup_main_app 41 | 42 | # Reset mocks 43 | app.tab_manager.reset_mock() 44 | 45 | # Call method again 46 | app.setup_ui() 47 | 48 | # Check that TabManager was set as central widget 49 | assert app.centralWidget() == app.tab_manager 50 | 51 | 52 | def test_register_tabs(setup_main_app): 53 | """Test register_tabs method""" 54 | app = setup_main_app 55 | 56 | # Reset mock 57 | app.tab_manager.reset_mock() 58 | 59 | # Call method 60 | app.register_tabs() 61 | 62 | # Check that register_tab was called 63 | assert app.tab_manager.register_tab.call_count >= 1 64 | 65 | 66 | @patch('app.main_app.QMessageBox.question') 67 | def test_show_start_dialog_video(mock_question, setup_main_app): 68 | """Test show_start_dialog selecting video processing""" 69 | app = setup_main_app 70 | app.process_video = MagicMock() 71 | app.select_image_folder = MagicMock() 72 | 73 | # Choose video processing 74 | mock_question.return_value = QMessageBox.Yes 75 | 76 | # Call method 77 | app.show_start_dialog() 78 | 79 | # Check that process_video was called 80 | app.process_video.assert_called_once() 81 | assert app.select_image_folder.call_count == 0 82 | 83 | 84 | @patch('app.main_app.QMessageBox.question') 85 | def test_show_start_dialog_image_folder(mock_question, setup_main_app): 86 | """Test show_start_dialog selecting image folder""" 87 | app = setup_main_app 88 | app.process_video = MagicMock() 89 | app.select_image_folder = MagicMock() 90 | 91 | # Choose image folder 92 | mock_question.return_value = QMessageBox.No 93 | 94 | # Call method 95 | app.show_start_dialog() 96 | 97 | # Check that select_image_folder was called 98 | app.select_image_folder.assert_called_once() 99 | assert app.process_video.call_count == 0 100 | 101 | 102 | @patch('app.main_app.QFileDialog.getExistingDirectory') 103 | @patch('os.makedirs') 104 | @patch('os.listdir') 105 | @patch('os.path.exists') 106 | @patch('shutil.copy') 107 | def test_select_image_folder(mock_copy, mock_exists, mock_listdir, mock_makedirs, 108 | mock_get_dir, setup_main_app): 109 | """Test select_image_folder method""" 110 | app = setup_main_app 111 | app.load_workdir = MagicMock() 112 | 113 | # Set up mocks 114 | test_dir = "/test/workdir" 115 | mock_get_dir.return_value = test_dir 116 | mock_listdir.return_value = ["image1.jpg", "image2.jpg", "file.txt"] 117 | mock_exists.return_value = False # Images don't exist in target dir 118 | 119 | # Call method 120 | app.select_image_folder() 121 | 122 | # Check that workdir was set 123 | assert app.workdir == test_dir 124 | 125 | # Check that images directory was created 126 | mock_makedirs.assert_called_with(os.path.join(test_dir, "images"), exist_ok=True) 127 | 128 | # Check that images were copied 129 | assert mock_copy.call_count == 2 # Two image files 130 | 131 | # Check that load_workdir was called 132 | app.load_workdir.assert_called_once() 133 | 134 | # Test cancel case 135 | mock_get_dir.return_value = "" 136 | with patch('app.main_app.QMessageBox.warning') as mock_warning, \ 137 | patch('app.main_app.sys.exit') as mock_exit: 138 | app.select_image_folder() 139 | mock_warning.assert_called_once() 140 | mock_exit.assert_called_once_with(1) 141 | 142 | 143 | @patch('app.main_app.QFileDialog.getOpenFileName') 144 | @patch('app.main_app.VideoProcessDialog') 145 | @patch('app.main_app.ExifExtractProgressDialog') 146 | @patch('app.main_app.VideoProcessCommand') 147 | @patch('app.main_app.ImageProcessor') 148 | def test_process_video(mock_image_processor, mock_video_command, mock_progress_dialog, 149 | mock_dialog, mock_get_file, setup_main_app): 150 | """Test process_video method""" 151 | app = setup_main_app 152 | app.load_workdir = MagicMock() 153 | 154 | # Set up mocks 155 | test_file = "/test/video.mp4" 156 | mock_get_file.return_value = (test_file, "") 157 | 158 | dialog_instance = MagicMock() 159 | dialog_instance.exec_.return_value = QDialog.Accepted 160 | dialog_instance.get_values.return_value = { 161 | "import_path": "/test/import_path", 162 | "method": "Interval", 163 | "interval": 0.5, 164 | "distance": -1, 165 | "geotag_source": "video", 166 | "geotag_source_path": None, 167 | "offset_time": 0, 168 | "use_gpx": True 169 | } 170 | mock_dialog.return_value = dialog_instance 171 | 172 | progress_instance = MagicMock() 173 | mock_progress_dialog.return_value = progress_instance 174 | 175 | video_command_instance = MagicMock() 176 | mock_video_command.return_value = video_command_instance 177 | 178 | processor_instance = MagicMock() 179 | mock_image_processor.return_value = processor_instance 180 | 181 | # Mock file existence checks 182 | with patch('os.makedirs') as mock_makedirs, \ 183 | patch('os.path.exists') as mock_exists, \ 184 | patch('app.main_app.QApplication.processEvents') as mock_process_events: 185 | 186 | mock_exists.return_value = True 187 | 188 | # Call method 189 | app.process_video() 190 | 191 | # Check that dialog was shown 192 | mock_dialog.assert_called_once() 193 | dialog_instance.exec_.assert_called_once() 194 | 195 | # Check that progress dialog was shown 196 | mock_progress_dialog.assert_called_once() 197 | progress_instance.show.assert_called_once() 198 | 199 | # Check that VideoProcessCommand was used 200 | mock_video_command.assert_called_once() 201 | video_command_instance.run.assert_called_once() 202 | 203 | # Check that ImageProcessor was used 204 | mock_image_processor.assert_called_once() 205 | processor_instance.apply_exif_from_mapillary_json.assert_called_once() 206 | 207 | # Check that workdir was set and load_workdir called 208 | assert app.workdir == "/test/import_path" 209 | app.load_workdir.assert_called_once() 210 | 211 | # Test cancel case 212 | dialog_instance.exec_.return_value = QDialog.Rejected 213 | with patch('app.main_app.QMessageBox.warning') as mock_warning, \ 214 | patch('app.main_app.sys.exit') as mock_exit: 215 | app.process_video() 216 | mock_warning.assert_called_once() 217 | mock_exit.assert_called_once_with(1) 218 | 219 | 220 | @patch('os.path.exists') 221 | @patch('os.listdir') 222 | def test_load_workdir_with_exif(mock_listdir, mock_exists, setup_main_app): 223 | """Test load_workdir method when EXIF data exists""" 224 | app = setup_main_app 225 | app.workdir = "/test/workdir" 226 | 227 | # Mock file operations 228 | mock_exists.return_value = True # All paths exist 229 | mock_listdir.return_value = ["image1.jpg", "image2.jpg"] 230 | 231 | # Mock camera model manager 232 | with patch('app.main_app.CameraModelManager') as mock_manager, \ 233 | patch('app.main_app.ImageProcessor') as mock_processor: 234 | 235 | manager_instance = MagicMock() 236 | mock_manager.return_value = manager_instance 237 | 238 | processor_instance = MagicMock() 239 | mock_processor.return_value = processor_instance 240 | 241 | # Call method 242 | app.load_workdir() 243 | 244 | # Check that image list was populated 245 | assert app.image_list == ["image1.jpg", "image2.jpg"] 246 | 247 | # Check that managers were created 248 | mock_manager.assert_called_once_with(app.workdir) 249 | mock_processor.assert_called_once_with(app.workdir) 250 | 251 | # Check that tab manager was updated 252 | if app.tab_manager: 253 | app.tab_manager.update_all_tabs.assert_called_once_with( 254 | workdir=app.workdir, image_list=app.image_list 255 | ) 256 | 257 | 258 | @patch('os.path.exists') 259 | @patch('os.listdir') 260 | def test_load_workdir_without_exif(mock_listdir, mock_exists, setup_main_app): 261 | """Test load_workdir method when EXIF data doesn't exist""" 262 | app = setup_main_app 263 | app.workdir = "/test/workdir" 264 | 265 | # Mock file operations to simulate missing EXIF 266 | def mock_exists_side_effect(path): 267 | return "images" in path # Only images directory exists 268 | 269 | mock_exists.side_effect = mock_exists_side_effect 270 | mock_listdir.return_value = ["image1.jpg", "image2.jpg"] 271 | 272 | # Mock progress dialog 273 | with patch('app.main_app.ExifExtractProgressDialog') as mock_progress, \ 274 | patch('app.main_app.CameraModelManager') as mock_manager, \ 275 | patch('app.main_app.ImageProcessor') as mock_processor, \ 276 | patch('app.main_app.QApplication.processEvents') as mock_process, \ 277 | patch('shutil.copy') as mock_copy, \ 278 | patch('app.main_app.DataSet') as mock_dataset, \ 279 | patch('app.main_app.extract_metadata') as mock_extract: 280 | 281 | progress_instance = MagicMock() 282 | mock_progress.return_value = progress_instance 283 | 284 | manager_instance = MagicMock() 285 | mock_manager.return_value = manager_instance 286 | 287 | processor_instance = MagicMock() 288 | mock_processor.return_value = processor_instance 289 | 290 | dataset_instance = MagicMock() 291 | mock_dataset.return_value = dataset_instance 292 | 293 | # Call method 294 | app.load_workdir() 295 | 296 | # Check that progress dialog was shown 297 | mock_progress.assert_called_once() 298 | progress_instance.show.assert_called_once() 299 | 300 | # Check that metadata extraction was called 301 | mock_copy.assert_called_once() # Config file copy 302 | mock_dataset.assert_called_once_with(app.workdir) 303 | mock_extract.run_dataset.assert_called_once_with(dataset_instance) 304 | 305 | # Check that camera model editor was opened 306 | manager_instance.open_camera_model_editor.assert_called_once_with(parent=app) 307 | 308 | 309 | def test_video_process_dialog(mock_qapplication): 310 | """Test VideoProcessDialog""" 311 | dialog = VideoProcessDialog() 312 | 313 | # Test toggle_sampling_inputs 314 | dialog.sampling_method_combo = MagicMock() 315 | dialog.distance_input = MagicMock() 316 | dialog.interval_input = MagicMock() 317 | 318 | # Test Interval method 319 | dialog.sampling_method_combo.currentText.return_value = "Interval" 320 | dialog.toggle_sampling_inputs() 321 | dialog.distance_input.setDisabled.assert_called_with(True) 322 | dialog.interval_input.setDisabled.assert_called_with(False) 323 | 324 | # Test Distance method 325 | dialog.sampling_method_combo.currentText.return_value = "Distance" 326 | dialog.toggle_sampling_inputs() 327 | dialog.distance_input.setDisabled.assert_called_with(False) 328 | dialog.interval_input.setDisabled.assert_called_with(True) 329 | 330 | # Test get_sampling_values 331 | dialog.distance_input.text.return_value = "5" 332 | dialog.interval_input.text.return_value = "0.5" 333 | 334 | # With Distance method 335 | dialog.sampling_method_combo.currentText.return_value = "Distance" 336 | method, interval, distance = dialog.get_sampling_values() 337 | assert method == "Distance" 338 | assert interval == -1 339 | assert distance == 5.0 340 | 341 | # With Interval method 342 | dialog.sampling_method_combo.currentText.return_value = "Interval" 343 | method, interval, distance = dialog.get_sampling_values() 344 | assert method == "Interval" 345 | assert interval == 0.5 346 | assert distance == -1 347 | 348 | # Test get_values 349 | dialog.import_path_input = MagicMock() 350 | dialog.import_path_input.text.return_value = "/test/import" 351 | 352 | dialog.geotag_source_combo = MagicMock() 353 | dialog.geotag_source_combo.currentText.return_value = "video" 354 | 355 | dialog.geotag_source_path_input = MagicMock() 356 | dialog.geotag_source_path_input.text.return_value = "" 357 | 358 | dialog.interpolation_offset_input = MagicMock() 359 | dialog.interpolation_offset_input.text.return_value = "1.5" 360 | 361 | dialog.interpolation_use_gpx_checkbox = MagicMock() 362 | dialog.interpolation_use_gpx_checkbox.isChecked.return_value = True 363 | 364 | # Mock get_sampling_values 365 | dialog.get_sampling_values = MagicMock(return_value=("Interval", 0.5, -1)) 366 | 367 | values = dialog.get_values() 368 | assert values["import_path"] == "/test/import" 369 | assert values["method"] == "Interval" 370 | assert values["interval"] == 0.5 371 | assert values["distance"] == -1 372 | assert values["geotag_source"] == "video" 373 | assert values["geotag_source_path"] is None 374 | assert values["offset_time"] == 1.5 375 | assert values["use_gpx"] is True -------------------------------------------------------------------------------- /app/camera_models.py: -------------------------------------------------------------------------------- 1 | # camera_models.py 2 | 3 | import os 4 | import json 5 | from PyQt5.QtWidgets import ( 6 | QDialog, QVBoxLayout, QTableWidget, QTableWidgetItem, 7 | QPushButton, QLabel, QComboBox, QDialogButtonBox, QMessageBox 8 | ) 9 | 10 | class CameraModelEditor(QDialog): 11 | """Camera model editor dialog""" 12 | def __init__(self, camera_models, workdir, parent=None): 13 | super().__init__(parent) 14 | self.setWindowTitle("Edit Camera Models") 15 | self.setFixedSize(600, 400) 16 | 17 | self.camera_models = camera_models 18 | self.workdir = workdir 19 | 20 | layout = QVBoxLayout() 21 | layout.addWidget(QLabel("Camera Model Overrides")) 22 | 23 | self.table = QTableWidget() 24 | self.table.setColumnCount(3) 25 | self.table.setHorizontalHeaderLabels(["Key", "Parameter", "Value"]) 26 | layout.addWidget(self.table) 27 | 28 | save_button = QPushButton("Save Changes") 29 | save_button.clicked.connect(self.save_changes) 30 | layout.addWidget(save_button) 31 | 32 | self.setLayout(layout) 33 | 34 | self.load_camera_models() 35 | 36 | def load_camera_models(self): 37 | """Load camera models into the table""" 38 | self.table.setRowCount(0) 39 | row = 0 40 | for key, params in self.camera_models.items(): 41 | # key = 'Perspective' etc. (e.g., "Spherical") 42 | # params = { 'projection_type': 'perspective', 'width': ..., etc. } 43 | for param, value in params.items(): 44 | self.table.insertRow(row) 45 | # Column 1 (Key) 46 | self.table.setItem(row, 0, QTableWidgetItem(key)) 47 | 48 | # Column 2 (Parameter) 49 | self.table.setItem(row, 1, QTableWidgetItem(param)) 50 | 51 | # Column 3 (Value) - Use ComboBox for projection_type 52 | if param == "projection_type": 53 | combo = QComboBox() 54 | combo.addItems(["perspective", "spherical"]) 55 | # Set selection based on current value 56 | if str(value) in ["perspective", "spherical"]: 57 | combo.setCurrentText(str(value)) 58 | else: 59 | # If the value isn't registered, add it to the front and select it 60 | combo.insertItem(0, str(value)) 61 | combo.setCurrentIndex(0) 62 | self.table.setCellWidget(row, 2, combo) 63 | else: 64 | # For other params, show as text 65 | self.table.setItem(row, 2, QTableWidgetItem(str(value))) 66 | 67 | row += 1 68 | 69 | def save_changes(self): 70 | """Save changes to camera_models_overrides.json""" 71 | updated_models = {} 72 | for row in range(self.table.rowCount()): 73 | key_item = self.table.item(row, 0) 74 | param_item = self.table.item(row, 1) 75 | 76 | if not key_item or not param_item: 77 | continue 78 | 79 | key = key_item.text() 80 | param = param_item.text() 81 | 82 | # Check if Value is a ComboBox 83 | cell_widget = self.table.cellWidget(row, 2) 84 | if isinstance(cell_widget, QComboBox): 85 | # If ComboBox, get current text 86 | value = cell_widget.currentText() 87 | else: 88 | # If QTableWidgetItem, get text 89 | value_item = self.table.item(row, 2) 90 | if value_item: 91 | value = value_item.text() 92 | else: 93 | value = "" 94 | 95 | # Convert to float/int if it's a number 96 | # projection_type is usually a string, so it normally won't be converted 97 | try: 98 | # If '.' is included, convert to float, otherwise int 99 | if '.' in value: 100 | num = float(value) 101 | value = num 102 | else: 103 | num = int(value) 104 | value = num 105 | except ValueError: 106 | # If conversion fails, keep as string 107 | pass 108 | 109 | if key not in updated_models: 110 | updated_models[key] = {} 111 | updated_models[key][param] = value 112 | 113 | # Write JSON 114 | try: 115 | overrides_path = os.path.join(self.workdir, "camera_models_overrides.json") 116 | with open(overrides_path, "w") as f: 117 | json.dump(updated_models, f, indent=4) 118 | 119 | # After saving overrides, update the camera_models.json file as well 120 | # This fixes the issue where camera models aren't updated 121 | self.update_base_camera_models(updated_models) 122 | 123 | # Update the EXIF data in the exifs folder with the new camera model info 124 | self.update_exif_files(updated_models) 125 | 126 | QMessageBox.information(self, "Success", "Camera models saved successfully!") 127 | self.accept() 128 | except Exception as e: 129 | QMessageBox.critical(self, "Error", f"Failed to save camera models: {e}") 130 | 131 | def update_base_camera_models(self, overrides): 132 | """Update the camera_models.json file with the latest overrides 133 | This fixes the issue where camera models aren't updated when overrides is generated first time""" 134 | try: 135 | camera_models_path = os.path.join(self.workdir, "camera_models.json") 136 | 137 | # Load existing camera models if available 138 | if os.path.exists(camera_models_path): 139 | with open(camera_models_path, "r") as f: 140 | base_models = json.load(f) 141 | else: 142 | base_models = {} 143 | 144 | # Merge with overrides 145 | merged_models = base_models.copy() 146 | for key, params in overrides.items(): 147 | if key in merged_models: 148 | merged_models[key].update(params) 149 | else: 150 | merged_models[key] = params 151 | 152 | # Save updated camera models 153 | with open(camera_models_path, "w") as f: 154 | json.dump(merged_models, f, indent=4) 155 | 156 | return True 157 | except Exception as e: 158 | print(f"Error updating camera_models.json: {e}") 159 | return False 160 | 161 | def update_exif_files(self, camera_models_overrides): 162 | """Update EXIF files with new camera model information 163 | This ensures EXIF data is consistent with camera model overrides""" 164 | try: 165 | exif_dir = os.path.join(self.workdir, "exif") 166 | if not os.path.exists(exif_dir): 167 | print(f"EXIF directory does not exist: {exif_dir}") 168 | return False 169 | 170 | # Get the merged camera models first 171 | camera_models_path = os.path.join(self.workdir, "camera_models.json") 172 | if os.path.exists(camera_models_path): 173 | with open(camera_models_path, "r") as f: 174 | camera_models = json.load(f) 175 | else: 176 | print(f"Camera models file does not exist: {camera_models_path}") 177 | return False 178 | 179 | # Process each EXIF file 180 | for filename in os.listdir(exif_dir): 181 | if not filename.endswith(".exif"): 182 | continue 183 | 184 | exif_path = os.path.join(exif_dir, filename) 185 | 186 | # Read the EXIF data 187 | with open(exif_path, "r") as f: 188 | exif_data = json.load(f) 189 | 190 | # Get the camera model name 191 | camera_name = exif_data.get("camera", "Unknown Camera") 192 | 193 | # Check if there are overrides for this camera 194 | if camera_name in camera_models: 195 | # Get the parameters that need to be updated 196 | camera_params = camera_models[camera_name] 197 | 198 | # Update EXIF data with the new camera parameters 199 | for param, value in camera_params.items(): 200 | # Update specific fields based on the parameter 201 | if param == "projection_type": 202 | exif_data["projection_type"] = value 203 | elif param == "width": 204 | exif_data["width"] = value 205 | elif param == "height": 206 | exif_data["height"] = value 207 | elif param == "focal_ratio": 208 | exif_data["focal_ratio"] = value 209 | # If focal_x and focal_y are present, update them too 210 | if "width" in camera_params and "focal" in exif_data: 211 | exif_data["focal_x"] = value * camera_params["width"] 212 | if "height" in camera_params and "focal" in exif_data: 213 | exif_data["focal_y"] = value * camera_params["height"] 214 | 215 | # Write the updated EXIF data back to the file 216 | with open(exif_path, "w") as f: 217 | json.dump(exif_data, f, indent=4) 218 | 219 | return True 220 | except Exception as e: 221 | print(f"Error updating EXIF files: {e}") 222 | return False 223 | 224 | 225 | class CameraModelManager: 226 | """Manager class for camera models""" 227 | def __init__(self, workdir): 228 | self.workdir = workdir 229 | self.camera_models = {} 230 | self._default_model = { 231 | "Perspective": { 232 | "projection_type": "perspective", 233 | "width": 1920, 234 | "height": 1080, 235 | "focal_ratio": 1.0 236 | } 237 | } 238 | self.load_camera_models() 239 | 240 | def load_camera_models(self): 241 | """Load camera_models.json and apply camera_models_overrides.json""" 242 | camera_models_path = os.path.join(self.workdir, "camera_models.json") 243 | overrides_path = os.path.join(self.workdir, "camera_models_overrides.json") 244 | 245 | # Load base models, if missing use default 246 | if os.path.exists(camera_models_path): 247 | try: 248 | with open(camera_models_path, "r") as f: 249 | base_models = json.load(f) 250 | except Exception as e: 251 | print(f"Error loading camera_models.json: {e}") 252 | base_models = self._default_model.copy() 253 | # Try to create the file with default model 254 | try: 255 | with open(camera_models_path, "w") as f: 256 | json.dump(base_models, f, indent=4) 257 | except Exception as e: 258 | print(f"Error creating default camera_models.json: {e}") 259 | else: 260 | base_models = self._default_model.copy() 261 | # Try to create the file with default model 262 | try: 263 | with open(camera_models_path, "w") as f: 264 | json.dump(base_models, f, indent=4) 265 | except Exception as e: 266 | print(f"Error creating default camera_models.json: {e}") 267 | 268 | # Load overrides if they exist 269 | overrides = {} 270 | if os.path.exists(overrides_path): 271 | try: 272 | with open(overrides_path, "r") as f: 273 | overrides = json.load(f) 274 | except Exception as e: 275 | print(f"Error loading camera_models_overrides.json: {e}") 276 | 277 | # Merge models with overrides 278 | merged_models = base_models.copy() 279 | for key, params in overrides.items(): 280 | if key in merged_models: 281 | merged_models[key].update(params) 282 | else: 283 | merged_models[key] = params 284 | 285 | # Write the merged models back to camera_models.json to ensure it's always up-to-date 286 | # This fixes the issue where camera models aren't updated when overrides changes 287 | try: 288 | with open(camera_models_path, "w") as f: 289 | json.dump(merged_models, f, indent=4) 290 | except Exception as e: 291 | print(f"Error updating camera_models.json: {e}") 292 | 293 | self.camera_models = merged_models 294 | return merged_models 295 | 296 | def get_camera_models(self): 297 | """Get the current camera models""" 298 | return self.camera_models 299 | 300 | def open_camera_model_editor(self, parent=None): 301 | """Open camera model editor dialog""" 302 | if self.workdir: 303 | # Reload models to get latest changes 304 | self.load_camera_models() 305 | 306 | try: 307 | dialog = CameraModelEditor(self.camera_models, self.workdir, parent=parent) 308 | if dialog.exec_(): 309 | # Reload after saving 310 | self.load_camera_models() 311 | 312 | return True 313 | except Exception as e: 314 | if parent: 315 | QMessageBox.critical(parent, "Error", f"Failed to open camera model editor: {e}") 316 | return False 317 | else: 318 | if parent: 319 | QMessageBox.warning(parent, "Error", "Workdir is not set.") 320 | return False 321 | 322 | def update_exif_with_camera_models(self): 323 | """Update all EXIF files with current camera model information 324 | This can be called explicitly when needed to ensure EXIF data is in sync""" 325 | editor = CameraModelEditor(self.camera_models, self.workdir) 326 | return editor.update_exif_files(self.camera_models) -------------------------------------------------------------------------------- /app/mask_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import torch 5 | from loguru import logger as guru 6 | from PyQt5.QtWidgets import ( 7 | QWidget, QVBoxLayout, QLabel, QPushButton, QHBoxLayout, QMessageBox 8 | ) 9 | from PyQt5.QtGui import QPixmap, QImage 10 | from PyQt5.QtCore import Qt 11 | from sam2.build_sam import build_sam2 12 | from sam2.sam2_image_predictor import SAM2ImagePredictor 13 | 14 | class ClickableImageLabel(QLabel): 15 | def __init__(self, parent=None): 16 | super().__init__(parent) 17 | self.setAlignment(Qt.AlignCenter) 18 | self.click_callback = None # Callback function to be called on click 19 | 20 | def mousePressEvent(self, event): 21 | if event.button() == Qt.LeftButton: 22 | if self.click_callback is not None: 23 | self.click_callback(event) 24 | super().mousePressEvent(event) 25 | 26 | class MaskManager(QWidget): 27 | def __init__(self, checkpoint_path, config_path, mask_dir, img_dir, image_list): 28 | super().__init__() 29 | self.checkpoint_path = checkpoint_path 30 | self.config_path = config_path 31 | self.mask_dir = mask_dir 32 | self.img_dir = img_dir 33 | self.image_list = image_list 34 | self.current_index = 0 35 | self.current_image = None 36 | self.current_mask = None 37 | self.sam_model = None 38 | self.predictor = None 39 | self.input_points = [] 40 | self.input_labels = [] 41 | self.label_toggle = 1 # Start with positive point (1) 42 | self.image_name = None # Name of the current image 43 | 44 | # Initialize UI 45 | self.init_ui() 46 | # Initialize SAM2 model and predictor 47 | self.init_sam_model() 48 | 49 | # Point: Initially, do not load the image, display text instead 50 | # self.load_current_image() # ← Comment this out 51 | self.image_label.setText("Select an image to view masks.") 52 | self.image_label.setStyleSheet("color: gray; font-size: 16px;") # Styling for better visibility 53 | 54 | def init_ui(self): 55 | """Initialize the UI components.""" 56 | layout = QVBoxLayout() 57 | 58 | # Clickable image label 59 | self.image_label = ClickableImageLabel() 60 | # Set the click callback function 61 | self.image_label.click_callback = self.on_image_clicked 62 | layout.addWidget(self.image_label) 63 | 64 | # Navigation and Reset Buttons Layout 65 | button_layout = QHBoxLayout() 66 | 67 | # Previous Image Button 68 | self.prev_button = QPushButton("< Previous Image") 69 | self.prev_button.clicked.connect(self.prev_image) 70 | button_layout.addWidget(self.prev_button) 71 | 72 | # Reset Mask Button 73 | self.reset_button = QPushButton("Reset Mask") 74 | self.reset_button.clicked.connect(self.reset_mask) 75 | button_layout.addWidget(self.reset_button) 76 | 77 | # Next Image Button 78 | self.next_button = QPushButton("Next Image >") 79 | self.next_button.clicked.connect(self.next_image) 80 | button_layout.addWidget(self.next_button) 81 | 82 | layout.addLayout(button_layout) 83 | self.setLayout(layout) 84 | 85 | def init_sam_model(self): 86 | """Initialize the SAM2 model and predictor.""" 87 | if self.sam_model is None: 88 | device = "cuda" if torch.cuda.is_available() else "cpu" 89 | self.sam_model = build_sam2(self.config_path, self.checkpoint_path, device=device) 90 | self.predictor = SAM2ImagePredictor(self.sam_model) 91 | guru.info(f"SAM2 model loaded with checkpoint: {self.checkpoint_path}") 92 | 93 | def unload_sam_model(self): 94 | """Unload the SAM2 model and free resources.""" 95 | if self.sam_model is not None: 96 | del self.sam_model 97 | del self.predictor 98 | self.sam_model = None 99 | self.predictor = None 100 | guru.info("SAM2 model and predictor have been unloaded.") 101 | 102 | # Optionally clear CUDA cache if using GPU 103 | if torch.cuda.is_available(): 104 | torch.cuda.empty_cache() 105 | guru.info("CUDA cache has been cleared.") 106 | 107 | def load_current_image(self): 108 | """Load and display the current image.""" 109 | # If the image list is empty, display a message and exit 110 | if not self.image_list: 111 | self.image_label.setText("Select an image to view masks.") 112 | return 113 | 114 | self.input_points = [] 115 | self.input_labels = [] 116 | self.label_toggle = 1 # Reset label toggle 117 | self.current_mask = None 118 | 119 | if 0 <= self.current_index < len(self.image_list): 120 | self.image_name = self.image_list[self.current_index] 121 | self.load_image_by_name(self.image_name) 122 | else: 123 | QMessageBox.warning(self, "Error", "No images to display.") 124 | 125 | def load_image_by_name(self, image_name): 126 | """Load and display the image specified by image_name.""" 127 | self.image_name = image_name 128 | image_path = os.path.join(self.img_dir, image_name) 129 | self.current_image = cv2.imread(image_path) 130 | 131 | if self.current_image is None: 132 | QMessageBox.warning(self, "Error", f"Failed to load image {self.image_name}") 133 | return 134 | 135 | # Load the mask if it exists 136 | mask_path = os.path.join(self.mask_dir, f"{self.image_name}.png") 137 | if os.path.exists(mask_path): 138 | self.current_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) 139 | guru.info(f"Loaded existing mask from {mask_path}") 140 | self.plot_mask(self.image_name, self.current_mask) # Display the image with mask overlay 141 | else: 142 | # Display original image if no mask exists 143 | self.display_image(self.current_image) 144 | 145 | def display_image(self, image): 146 | """Display the given image in the QLabel, resizing it to fit the QLabel size.""" 147 | rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 148 | self.set_image_to_label(rgb_image) 149 | 150 | def plot_mask(self, image_name, mask_data): 151 | # Load image from the given path 152 | image_path = os.path.join(self.img_dir, image_name) 153 | image = cv2.imread(image_path) 154 | if image is None: 155 | QMessageBox.warning(self, "Error", f"Failed to load image {image_name}") 156 | return 157 | 158 | # Create a binary mask where black regions (value 0) become white (255) 159 | black_region_mask = cv2.threshold(mask_data, 0, 255, cv2.THRESH_BINARY_INV)[1] 160 | # Copy the original image to create an overlay 161 | overlayed_image = image.copy() 162 | # Replace pixels in the overlay where the mask is active with blue (BGR: [255, 0, 0]) 163 | overlayed_image[black_region_mask == 255] = [255, 0, 0] 164 | 165 | # Draw selected points on the overlay (optional) 166 | for idx, point in enumerate(self.input_points): 167 | color = (0, 255, 0) if self.input_labels[idx] == 1 else (0, 0, 255) 168 | cv2.circle(overlayed_image, (point[0], point[1]), radius=5, color=color, thickness=-1) 169 | 170 | # Convert the overlayed image from BGR to RGB and display it in the QLabel 171 | rgb_image = cv2.cvtColor(overlayed_image, cv2.COLOR_BGR2RGB) 172 | self.set_image_to_label(rgb_image) 173 | 174 | def set_image_to_label(self, rgb_image): 175 | """Resize the image to fit the QLabel width while maintaining aspect ratio and display it.""" 176 | label_width = self.image_label.width() 177 | if label_width == 0: 178 | label_width = 400 179 | 180 | h, w, _ = rgb_image.shape 181 | aspect_ratio = h / w # Calculate aspect ratio 182 | 183 | # Calculate new dimensions while maintaining aspect ratio 184 | new_width = label_width 185 | new_height = int(new_width * aspect_ratio) 186 | 187 | # Resize image 188 | resized_image = cv2.resize(rgb_image, (new_width, new_height), interpolation=cv2.INTER_AREA) 189 | 190 | # Convert to QImage and set to QLabel 191 | height, width, channel = resized_image.shape 192 | bytes_per_line = channel * width 193 | q_image = QImage(resized_image.data, width, height, bytes_per_line, QImage.Format_RGB888) 194 | pixmap = QPixmap.fromImage(q_image) 195 | 196 | self.image_label.setPixmap(pixmap) 197 | # Clear the text to remove any previous text 198 | self.image_label.setText("") 199 | 200 | def on_image_clicked(self, event): 201 | """Handle image click events with position correction.""" 202 | if self.current_image is not None: 203 | label_pos = self.image_label.mapFromGlobal(event.globalPos()) 204 | pixmap = self.image_label.pixmap() 205 | if pixmap is None: 206 | return 207 | 208 | pixmap_width = pixmap.width() 209 | pixmap_height = pixmap.height() 210 | label_width = self.image_label.width() 211 | label_height = self.image_label.height() 212 | offset_x = (label_width - pixmap_width) / 2 213 | offset_y = (label_height - pixmap_height) / 2 214 | 215 | if (offset_x <= label_pos.x() <= offset_x + pixmap_width) and \ 216 | (offset_y <= label_pos.y() <= offset_y + pixmap_height): 217 | 218 | scale_x = self.current_image.shape[1] / pixmap_width 219 | scale_y = self.current_image.shape[0] / pixmap_height 220 | 221 | corrected_x = int((label_pos.x() - offset_x) * scale_x) 222 | corrected_y = int((label_pos.y() - offset_y) * scale_y) 223 | corrected_x = max(0, min(corrected_x, self.current_image.shape[1] - 1)) 224 | corrected_y = max(0, min(corrected_y, self.current_image.shape[0] - 1)) 225 | 226 | self.input_points.append([corrected_x, corrected_y]) 227 | self.input_labels.append(self.label_toggle) 228 | 229 | # Toggle the label for the next point (1→0, 0→1) 230 | self.label_toggle = 1 - self.label_toggle 231 | self.generate_mask() 232 | 233 | def process_single_image(self, image, image_name, point_coords, point_labels): 234 | """Generate a mask for a single image and save it to the specified mask directory.""" 235 | self.predictor.set_image(image) 236 | 237 | # Generate mask with SAM2 238 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): 239 | masks, scores, logits = self.predictor.predict( 240 | point_coords=point_coords, 241 | point_labels=point_labels, 242 | multimask_output=False 243 | ) 244 | 245 | mask_output_path = os.path.join(self.mask_dir, f"{image_name}.png") 246 | inverted_mask = 1 - masks[0] 247 | mask_to_save = (inverted_mask * 255).astype(np.uint8) 248 | cv2.imwrite(mask_output_path, mask_to_save) 249 | guru.info(f"Mask saved to {mask_output_path}") 250 | 251 | return inverted_mask * 255 252 | 253 | def generate_mask(self): 254 | """Generate and display the mask.""" 255 | if self.current_image is None or not self.input_points: 256 | return 257 | 258 | point_coords = np.array(self.input_points) 259 | point_labels = np.array(self.input_labels) 260 | mask = self.process_single_image( 261 | self.current_image, self.image_name, point_coords, point_labels 262 | ) 263 | self.current_mask = mask 264 | self.display_image_with_mask() 265 | 266 | def display_image_with_mask(self): 267 | # Check if current image and mask exist 268 | if self.current_image is not None and self.current_mask is not None: 269 | # Create a copy of the original image 270 | overlay = self.current_image.copy() 271 | # Create a binary mask where the mask's black areas become white (255) 272 | black_region_mask = cv2.threshold(self.current_mask, 0, 255, cv2.THRESH_BINARY_INV)[1] 273 | # Set the corresponding pixels in the overlay to blue (BGR: [255, 0, 0]) 274 | overlay[black_region_mask == 255] = [255, 0, 0] 275 | 276 | # Draw selected points on the overlay 277 | for idx, point in enumerate(self.input_points): 278 | color = (0, 255, 0) if self.input_labels[idx] == 1 else (0, 0, 255) 279 | cv2.circle(overlay, (point[0], point[1]), radius=5, color=color, thickness=-1) 280 | 281 | # Display the modified image 282 | self.display_image(overlay) 283 | elif self.current_image is not None: 284 | self.display_image(self.current_image) 285 | 286 | def reset_mask(self): 287 | """Reset the current mask to a blank white mask and update the display.""" 288 | if self.current_image is None: 289 | QMessageBox.warning(self, "Error", "No image loaded to reset the mask.") 290 | return 291 | 292 | self.current_mask = np.ones(self.current_image.shape[:2], dtype=np.uint8) * 255 293 | self.save_current_mask() 294 | self.clear_points() 295 | self.display_image_with_mask() 296 | QMessageBox.information(self, "Reset Mask", "The mask has been reset to a blank state.") 297 | 298 | def save_current_mask(self): 299 | """Save the current mask to a file.""" 300 | if self.current_mask is not None: 301 | mask_output_path = os.path.join(self.mask_dir, f"{self.image_name}.png") 302 | mask_to_save = self.current_mask.astype(np.uint8) 303 | cv2.imwrite(mask_output_path, mask_to_save) 304 | guru.info(f"Mask saved to {mask_output_path}") 305 | 306 | def clear_points(self): 307 | """Clear the list of points and labels.""" 308 | self.input_points.clear() 309 | self.input_labels.clear() 310 | self.label_toggle = 1 311 | 312 | def next_image(self): 313 | """Move to the next image.""" 314 | if self.current_index < len(self.image_list) - 1: 315 | self.current_index += 1 316 | self.load_current_image() 317 | else: 318 | QMessageBox.information(self, "Info", "This is the last image.") 319 | 320 | def prev_image(self): 321 | """Move to the previous image.""" 322 | if self.current_index > 0: 323 | self.current_index -= 1 324 | self.load_current_image() 325 | else: 326 | QMessageBox.information(self, "Info", "This is the first image.") 327 | -------------------------------------------------------------------------------- /app/tabs/images_tab.py: -------------------------------------------------------------------------------- 1 | # images_tab.py 2 | 3 | import os 4 | import json 5 | from PyQt5.QtWidgets import ( 6 | QWidget, QVBoxLayout, QHBoxLayout, QSplitter, 7 | QLabel, QPushButton, QTreeWidget, QTableWidget, QTableWidgetItem, 8 | QSizePolicy, QMessageBox, QApplication, QDialog 9 | ) 10 | from PyQt5.QtCore import Qt 11 | from PyQt5.QtGui import QPixmap 12 | 13 | from app.base_tab import BaseTab 14 | from app.camera_models import CameraModelManager 15 | from app.image_processing import ImageProcessor, ResolutionDialog, ExifExtractProgressDialog 16 | 17 | class ImagesTab(BaseTab): 18 | """Images tab implementation""" 19 | def __init__(self, workdir=None, image_list=None, parent=None): 20 | super().__init__(workdir, image_list, parent) 21 | self.image_viewer = None 22 | self.exif_table = None 23 | self.camera_image_tree = None 24 | self.camera_model_manager = None 25 | self.image_processor = None 26 | 27 | # Always set up the basic UI structure 28 | self.setup_basic_ui() 29 | 30 | # Initialize managers if workdir is available 31 | if self.workdir: 32 | try: 33 | self.camera_model_manager = CameraModelManager(self.workdir) 34 | self.image_processor = ImageProcessor(self.workdir) 35 | except Exception as e: 36 | print(f"Error initializing managers: {e}") 37 | 38 | def get_tab_name(self): 39 | return "Images" 40 | 41 | def setup_basic_ui(self): 42 | """Set up the basic UI structure without data initialization""" 43 | layout = self.create_horizontal_splitter() 44 | 45 | # Left side: Camera and image tree 46 | self.camera_image_tree = QTreeWidget() 47 | self.camera_image_tree.setHeaderLabel("Cameras and Images") 48 | self.camera_image_tree.setFixedWidth(250) 49 | layout.addWidget(self.camera_image_tree) 50 | 51 | # Right side: Image viewer, EXIF data, buttons 52 | right_widget = QWidget() 53 | right_layout = QVBoxLayout() 54 | right_layout.setContentsMargins(0, 0, 0, 0) 55 | 56 | # Image viewer 57 | self.image_viewer = QLabel("Image Viewer") 58 | self.image_viewer.setAlignment(Qt.AlignCenter) 59 | self.image_viewer.setMinimumHeight(300) 60 | right_layout.addWidget(self.image_viewer, stretch=3) 61 | 62 | # EXIF data table 63 | self.exif_table = QTableWidget() 64 | self.exif_table.setColumnCount(2) 65 | self.exif_table.setHorizontalHeaderLabels(["Field", "Value"]) 66 | self.exif_table.horizontalHeader().setStretchLastSection(True) 67 | right_layout.addWidget(self.exif_table, stretch=2) 68 | 69 | # Buttons 70 | button_widget = QWidget() 71 | button_layout = QHBoxLayout() 72 | button_layout.setContentsMargins(5, 5, 5, 5) 73 | button_layout.setSpacing(10) 74 | 75 | self.camera_model_button = QPushButton("Edit Camera Models") 76 | self.camera_model_button.clicked.connect(self.open_camera_model_editor) 77 | 78 | self.resize_button = QPushButton("Change Resolution") 79 | self.resize_button.clicked.connect(self.resize_images_in_folder) 80 | 81 | self.restore_button = QPushButton("Restore Images") 82 | self.restore_button.clicked.connect(self.restore_original_images) 83 | 84 | # Add stretch to center the buttons 85 | button_layout.addStretch(1) 86 | button_layout.addWidget(self.camera_model_button) 87 | button_layout.addWidget(self.resize_button) 88 | button_layout.addWidget(self.restore_button) 89 | button_layout.addStretch(1) 90 | 91 | button_widget.setLayout(button_layout) 92 | 93 | # Make button widget expand horizontally 94 | button_widget.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) 95 | 96 | right_layout.addWidget(button_widget, stretch=0) 97 | 98 | right_widget.setLayout(right_layout) 99 | layout.addWidget(right_widget) 100 | 101 | layout.setStretchFactor(0, 1) 102 | layout.setStretchFactor(1, 4) 103 | 104 | self._layout.addWidget(layout) 105 | 106 | # Connect signals 107 | self.camera_image_tree.itemClicked.connect(self.display_image_and_exif) 108 | 109 | def initialize(self): 110 | """Initialize the Images tab with data""" 111 | # Basic UI is already set up in __init__, so we only need to initialize data 112 | if self.workdir: 113 | # Initialize managers if not already initialized 114 | if self.camera_model_manager is None: 115 | try: 116 | self.camera_model_manager = CameraModelManager(self.workdir) 117 | except Exception as e: 118 | QMessageBox.critical(self, "Error", f"Failed to initialize camera model manager: {e}") 119 | 120 | if self.image_processor is None: 121 | try: 122 | self.image_processor = ImageProcessor(self.workdir) 123 | except Exception as e: 124 | QMessageBox.critical(self, "Error", f"Failed to initialize image processor: {e}") 125 | 126 | # Populate the camera tree 127 | try: 128 | self.setup_camera_image_tree(self.camera_image_tree) 129 | except Exception as e: 130 | QMessageBox.critical(self, "Error", f"Failed to populate camera tree: {e}") 131 | 132 | self.is_initialized = True 133 | 134 | def display_image_and_exif(self, item, column): 135 | """Display the selected image and its EXIF data""" 136 | if not self.is_initialized: 137 | # Make sure we're initialized before attempting to display data 138 | self.initialize() 139 | 140 | if item.childCount() == 0 and item.parent() is not None: 141 | image_name = item.text(0) 142 | image_path = os.path.join(self.workdir, "images", image_name) 143 | exif_path = os.path.join(self.workdir, "exif", image_name + '.exif') 144 | 145 | # Display image 146 | if os.path.exists(image_path): 147 | pixmap = QPixmap(image_path) 148 | self.image_viewer.setPixmap(pixmap.scaled( 149 | self.image_viewer.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation)) 150 | else: 151 | self.image_viewer.setText("Image not found.") 152 | 153 | # Display EXIF data 154 | if os.path.exists(exif_path): 155 | with open(exif_path, 'r') as f: 156 | exif_data = json.load(f) 157 | self.display_exif_data(exif_data) 158 | else: 159 | self.exif_table.setRowCount(0) 160 | self.exif_table.setHorizontalHeaderLabels(["Field", "Value"]) 161 | self.exif_table.clearContents() 162 | self.exif_table.setRowCount(1) 163 | self.exif_table.setItem(0, 0, QTableWidgetItem("Error")) 164 | self.exif_table.setItem(0, 1, QTableWidgetItem("EXIF data not found.")) 165 | 166 | def display_exif_data(self, exif_data): 167 | """Display EXIF data in the table, applying overrides if available""" 168 | self.exif_table.setRowCount(0) 169 | self.exif_table.setHorizontalHeaderLabels(["Field", "Value"]) 170 | self.exif_table.clearContents() 171 | 172 | # Apply camera_models_overrides.json settings if available 173 | camera_name = exif_data.get("camera", "Unknown Camera") 174 | overrides = {} 175 | if self.camera_model_manager: 176 | try: 177 | camera_models = self.camera_model_manager.get_camera_models() 178 | overrides = camera_models.get(camera_name, {}) 179 | except Exception as e: 180 | print(f"Error getting camera models: {e}") 181 | 182 | # Define fields to display and their order 183 | fields = [ 184 | "make", 185 | "model", 186 | "width", 187 | "height", 188 | "projection_type", 189 | "focal_ratio", 190 | "orientation", 191 | "capture_time", 192 | "gps", 193 | "camera" 194 | ] 195 | 196 | for i, field in enumerate(fields): 197 | self.exif_table.insertRow(i) 198 | key_item = QTableWidgetItem(field) 199 | 200 | # Get value from EXIF data or overrides 201 | value = overrides.get(field, exif_data.get(field, "N/A")) 202 | 203 | # Convert dictionary to string 204 | if isinstance(value, dict): 205 | value = json.dumps(value) 206 | elif isinstance(value, float): 207 | value = f"{value:.2f}" 208 | 209 | value_item = QTableWidgetItem(str(value)) 210 | self.exif_table.setItem(i, 0, key_item) 211 | self.exif_table.setItem(i, 1, value_item) 212 | 213 | def open_camera_model_editor(self): 214 | """Open the camera model editor dialog""" 215 | if not self.is_initialized: 216 | self.initialize() 217 | 218 | # Check if camera model manager exists, if not, try to create it 219 | if self.camera_model_manager is None and self.workdir: 220 | try: 221 | self.camera_model_manager = CameraModelManager(self.workdir) 222 | except Exception as e: 223 | QMessageBox.critical(self, "Error", f"Failed to initialize camera model manager: {e}") 224 | return 225 | 226 | if self.camera_model_manager: 227 | # Ensure camera_models.json exists 228 | camera_models_path = os.path.join(self.workdir, "camera_models.json") 229 | if not os.path.exists(camera_models_path): 230 | # Create a default camera model if needed 231 | try: 232 | with open(camera_models_path, 'w') as f: 233 | default_model = { 234 | "Perspective": { 235 | "projection_type": "perspective", 236 | "width": 1920, 237 | "height": 1080, 238 | "focal_ratio": 1.0 239 | } 240 | } 241 | json.dump(default_model, f, indent=4) 242 | except Exception as e: 243 | QMessageBox.critical(self, "Error", f"Failed to create default camera model: {e}") 244 | return 245 | 246 | # Reload the camera model manager 247 | try: 248 | self.camera_model_manager = CameraModelManager(self.workdir) 249 | except Exception as e: 250 | QMessageBox.critical(self, "Error", f"Failed to reinitialize camera model manager: {e}") 251 | return 252 | 253 | # Now open the editor 254 | try: 255 | self.camera_model_manager.open_camera_model_editor(parent=self) 256 | except Exception as e: 257 | QMessageBox.critical(self, "Error", f"Failed to open camera model editor: {e}") 258 | else: 259 | QMessageBox.warning(self, "Error", "Camera Model Manager is not initialized and cannot be created.") 260 | 261 | def resize_images_in_folder(self): 262 | """Open dialog to resize images""" 263 | if not self.is_initialized: 264 | self.initialize() 265 | 266 | # Check if image processor exists, if not, try to create it 267 | if self.image_processor is None and self.workdir: 268 | try: 269 | self.image_processor = ImageProcessor(self.workdir) 270 | except Exception as e: 271 | QMessageBox.critical(self, "Error", f"Failed to initialize image processor: {e}") 272 | return 273 | 274 | if not self.image_processor: 275 | QMessageBox.warning(self, "Error", "Image Processor is not initialized and cannot be created.") 276 | return 277 | 278 | # Get dimensions of a sample image 279 | width, height = self.image_processor.get_sample_image_dimensions() 280 | if width is None or height is None: 281 | QMessageBox.warning(self, "Error", "No images found to determine default resolution.") 282 | return 283 | 284 | # Show resolution dialog 285 | dialog = ResolutionDialog(width, height, parent=self) 286 | if dialog.exec_() != QDialog.Accepted: 287 | return 288 | 289 | method, value = dialog.get_values() 290 | 291 | # Show progress dialog 292 | progress_dialog = ExifExtractProgressDialog("Resizing images...", self) 293 | progress_dialog.show() 294 | QApplication.processEvents() 295 | 296 | try: 297 | # Resize images 298 | self.image_processor.resize_images(method, value) 299 | except Exception as e: 300 | QMessageBox.warning(self, "Error", f"Error during resizing: {e}") 301 | finally: 302 | progress_dialog.close() 303 | 304 | QMessageBox.information(self, "Completed", "Images resized successfully!") 305 | 306 | def restore_original_images(self): 307 | """Restore original images from backup""" 308 | if not self.is_initialized: 309 | self.initialize() 310 | 311 | # Check if image processor exists, if not, try to create it 312 | if self.image_processor is None and self.workdir: 313 | try: 314 | self.image_processor = ImageProcessor(self.workdir) 315 | except Exception as e: 316 | QMessageBox.critical(self, "Error", f"Failed to initialize image processor: {e}") 317 | return 318 | 319 | if not self.image_processor: 320 | QMessageBox.warning(self, "Error", "Image Processor is not initialized and cannot be created.") 321 | return 322 | 323 | result = self.image_processor.restore_original_images() 324 | if result: 325 | QMessageBox.information(self, "Restored", "Original images restored successfully!") 326 | else: 327 | QMessageBox.warning(self, "Error", "No original backup images found.") 328 | 329 | def refresh(self): 330 | """Refresh the tab contents""" 331 | if self.is_initialized: 332 | if self.camera_model_manager: 333 | try: 334 | self.camera_model_manager = CameraModelManager(self.workdir) 335 | except Exception as e: 336 | print(f"Error refreshing camera model manager: {e}") 337 | 338 | if self.image_processor: 339 | try: 340 | self.image_processor = ImageProcessor(self.workdir) 341 | except Exception as e: 342 | print(f"Error refreshing image processor: {e}") 343 | 344 | # Refresh camera image tree 345 | if self.camera_image_tree: 346 | try: 347 | self.setup_camera_image_tree(self.camera_image_tree) 348 | except Exception as e: 349 | print(f"Error refreshing camera image tree: {e}") 350 | --------------------------------------------------------------------------------