├── tests
├── __init__.py
├── README.md
├── conftest.py
├── test_base_tab.py
├── test_tab_manager.py
├── test_images_tab.py
├── test_camera_models.py
├── test_image_processing.py
└── test_main_app.py
├── utils
├── file_manager.py
├── logger.py
├── datasets
│ ├── download_dataset.py
│ ├── normalize.py
│ └── traj.py
└── gsplat_utils
│ └── utils.py
├── app
├── config_loader.py
├── tabs
│ ├── __init__.py
│ ├── gsplat_tab.py
│ ├── reconstruct_tab.py
│ ├── features_tab.py
│ ├── masks_tab.py
│ ├── matching_tab.py
│ └── images_tab.py
├── tab_manager.py
├── base_tab.py
├── point_cloud_utils.py
├── point_cloud_viewer.py
├── point_cloud_visualizer.py
├── image_processing.py
├── camera_models.py
└── mask_manager.py
├── entrypoint.sh
├── main.py
├── .gitmodules
├── .github
└── workflows
│ └── docker-image.yml
├── models
└── README.md
├── README_TESTS.md
├── README.md
├── Dockerfile
├── LICENSE
└── config
└── config.yaml
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Tests package initialization
--------------------------------------------------------------------------------
/utils/file_manager.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def create_directory(path):
4 | if not os.path.exists(path):
5 | os.makedirs(path)
--------------------------------------------------------------------------------
/app/config_loader.py:
--------------------------------------------------------------------------------
1 | import yaml
2 |
3 | def load_config(config_path):
4 | with open(config_path, 'r') as file:
5 | config = yaml.safe_load(file)
6 | return config
7 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 | import os
3 |
4 | def setup_logger(log_dir='logs', log_file='app.log'):
5 | os.makedirs(log_dir, exist_ok=True)
6 | logger.add(os.path.join(log_dir, log_file))
7 | return logger
--------------------------------------------------------------------------------
/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [[ "$1" == "--dac" ]]; then
4 | echo "Running DAC setup..."
5 | cd /depth_any_camera/dac/models/ops && pip install -e . && cd /source/splat_one && bash
6 | else
7 | echo "Default command executed"
8 | exec "$@"
9 | fi
10 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # main.py
2 | import sys
3 | from PyQt5.QtWidgets import QApplication
4 | from app.main_app import MainApp
5 |
6 | if __name__ == "__main__":
7 | app = QApplication(sys.argv)
8 | main_window = MainApp()
9 | main_window.show()
10 | sys.exit(app.exec_())
11 |
--------------------------------------------------------------------------------
/app/tabs/__init__.py:
--------------------------------------------------------------------------------
1 | # tabs package initialization
2 | from app.tabs.images_tab import ImagesTab
3 | from app.tabs.masks_tab import MasksTab
4 | from app.tabs.depth_tab import DepthTab
5 | from app.tabs.features_tab import FeaturesTab
6 | from app.tabs.matching_tab import MatchingTab
7 | from app.tabs.reconstruct_tab import ReconstructTab
8 | from app.tabs.gsplat_tab import GsplatTab
9 |
10 | # All available tabs
11 | __all__ = [
12 | 'ImagesTab',
13 | 'MasksTab',
14 | 'DepthTab',
15 | 'FeaturesTab',
16 | 'MatchingTab',
17 | 'ReconstructTab',
18 | 'GsplatTab'
19 | ]
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | # SPLAT_ONE Tests
2 |
3 | This directory contains tests for the SPLAT_ONE application.
4 |
5 | ## Running Tests
6 |
7 | To run the tests, make sure you have pytest installed:
8 |
9 | ```bash
10 | pip install pytest pytest-cov
11 | ```
12 |
13 | Then run the tests from the project root directory:
14 |
15 | ```bash
16 | pytest tests/
17 | ```
18 |
19 | For coverage reports:
20 |
21 | ```bash
22 | pytest --cov=app tests/
23 | ```
24 |
25 | ## Test Organization
26 |
27 | - `conftest.py`: Common fixtures and test configurations
28 | - `test_*.py`: Test modules for different components
29 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "submodules/sam2"]
2 | path = submodules/sam2
3 | url = https://github.com/facebookresearch/sam2.git
4 | [submodule "submodules/opensfm"]
5 | path = submodules/opensfm
6 | url = https://github.com/inuex35/ind-bermuda-opensfm.git
7 | [submodule "submodules/lightglue"]
8 | path = submodules/lightglue
9 | url = https://github.com/cvg/LightGlue.git
10 | [submodule "submodules/aliked"]
11 | path = submodules/aliked
12 | url = https://github.com/Shiaoming/ALIKED.git
13 | [submodule "submodules/gsplat"]
14 | path = submodules/gsplat
15 | url = https://github.com/inuex35/gsplat.git
16 | branch = spherical_render
17 | [submodule "submodules/Depth-Anything-V2"]
18 | path = submodules/Depth-Anything-V2
19 | url = https://github.com/DepthAnything/Depth-Anything-V2.git
20 | [submodule "submodules/depth_any_camera"]
21 | path = submodules/depth_any_camera
22 | url = https://github.com/DepthAnything/Depth-Any-Camera.git
23 |
--------------------------------------------------------------------------------
/.github/workflows/docker-image.yml:
--------------------------------------------------------------------------------
1 | # .github/workflows/docker-image.yml
2 |
3 | name: Build and Push Docker Image
4 |
5 | on:
6 | push:
7 | branches:
8 | - main
9 | pull_request:
10 | branches:
11 | - main
12 |
13 | jobs:
14 | build:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - name: Free Disk Space
18 | uses: jlumbroso/free-disk-space@main
19 | with:
20 | # オプションはプロジェクトの要件に合わせて調整可能です
21 | tool-cache: false
22 | android: true
23 | dotnet: true
24 | haskell: true
25 | large-packages: true
26 | docker-images: true
27 | swap-storage: true
28 | - name: Checkout code
29 | uses: actions/checkout@v3
30 |
31 | - name: Set up Docker Buildx
32 | uses: docker/setup-buildx-action@v2
33 |
34 | - name: Log in to Docker Hub
35 | uses: docker/login-action@v2
36 | with:
37 | username: ${{ secrets.DOCKER_USERNAME }}
38 | password: ${{ secrets.DOCKER_PASSWORD }}
39 |
40 | - name: Build and push Docker image
41 | uses: docker/build-push-action@v5
42 | with:
43 | context: .
44 | push: true
45 | tags: inuex35/splat_one:latest
46 |
--------------------------------------------------------------------------------
/models/README.md:
--------------------------------------------------------------------------------
1 | # Depth Estimation Models
2 |
3 | This directory should contain the pre-trained models for depth estimation.
4 |
5 | ## Depth Anything V2 Models
6 |
7 | Download the models from the official repository:
8 | - https://github.com/DepthAnything/Depth-Anything-V2
9 |
10 | ### Recommended models:
11 | - **depth_anything_v2_vitl.pth** - Large model (best quality)
12 | - **depth_anything_v2_vits.pth** - Small model (faster inference)
13 |
14 | Download links:
15 | - ViT-L: https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth
16 | - ViT-S: https://huggingface.co/depth-anything/Depth-Anything-V2-Small/resolve/main/depth_anything_v2_vits.pth
17 |
18 | ## DAC (Depth Anything Camera) Models
19 |
20 | Download the models from:
21 | - https://github.com/xanderchf/depth_any_camera
22 |
23 | ### Available models:
24 | - **dac_vitl_hypersim.pth** - Trained on HyperSim dataset
25 | - **dac_vitl_vkitti.pth** - Trained on Virtual KITTI dataset
26 |
27 | Choose based on your use case:
28 | - HyperSim: Better for indoor scenes
29 | - VKITTI: Better for outdoor/driving scenes
30 |
31 | ## Installation
32 |
33 | 1. Create this models directory if it doesn't exist
34 | 2. Download the desired model files
35 | 3. Place them in this directory
36 | 4. The depth tab will automatically detect and use them
37 |
38 | ## Note
39 |
40 | The models are large files (several hundred MB each). Make sure you have enough disk space.
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # Common test fixtures and configurations
2 | import os
3 | import sys
4 | import pytest
5 | from unittest.mock import MagicMock
6 | import tempfile
7 | import shutil
8 |
9 | # Add the parent directory to path so we can import app modules
10 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
11 |
12 | @pytest.fixture
13 | def mock_qapplication(monkeypatch):
14 | """Mock QApplication to avoid GUI during tests"""
15 | mock_app = MagicMock()
16 | monkeypatch.setattr('PyQt5.QtWidgets.QApplication', MagicMock(return_value=mock_app))
17 | return mock_app
18 |
19 | @pytest.fixture
20 | def temp_workdir():
21 | """Create a temporary working directory for tests"""
22 | temp_dir = tempfile.mkdtemp()
23 |
24 | # Create necessary subdirectories
25 | os.makedirs(os.path.join(temp_dir, 'images'), exist_ok=True)
26 | os.makedirs(os.path.join(temp_dir, 'exif'), exist_ok=True)
27 | os.makedirs(os.path.join(temp_dir, 'masks'), exist_ok=True)
28 |
29 | try:
30 | yield temp_dir
31 | finally:
32 | # Cleanup after tests
33 | shutil.rmtree(temp_dir)
34 |
35 | @pytest.fixture
36 | def sample_image_data():
37 | """Sample image data for testing"""
38 | return [
39 | {'name': 'image1.jpg', 'width': 1920, 'height': 1080, 'camera': 'Camera1'},
40 | {'name': 'image2.jpg', 'width': 1920, 'height': 1080, 'camera': 'Camera1'},
41 | {'name': 'image3.jpg', 'width': 3840, 'height': 2160, 'camera': 'Camera2'}
42 | ]
43 |
--------------------------------------------------------------------------------
/README_TESTS.md:
--------------------------------------------------------------------------------
1 | # SPLAT_ONE Testing Documentation
2 |
3 | ## Overview
4 |
5 | This document provides information about the testing system implemented for SPLAT_ONE. The test suite is designed to ensure code reliability, maintainability, and proper functionality across the application.
6 |
7 | ## Changes in Code Structure
8 |
9 | The original codebase has been refactored to improve modularity and testability:
10 |
11 | 1. **Module Separation**:
12 | - Core functionality has been separated into domain-specific modules
13 | - UI components are now organized into a logical hierarchy
14 | - Business logic has been moved from UI classes where possible
15 |
16 | 2. **Object-Oriented Design**:
17 | - Added abstraction through base classes and inheritance
18 | - Improved encapsulation by moving related functionality into specific classes
19 | - Reduced code duplication through shared behaviors
20 |
21 | 3. **Testing Infrastructure**:
22 | - Added comprehensive test suite using pytest
23 | - Created fixtures and test helpers for common testing scenarios
24 | - Implemented mocking for external dependencies
25 |
26 | ## Testing Structure
27 |
28 | The tests are organized as follows:
29 |
30 | - `tests/conftest.py`: Common fixtures and test utilities
31 | - `tests/test_*.py`: Individual test modules for each application component
32 | - `tests/README.md`: Documentation for the test system
33 |
34 | ## Running Tests
35 |
36 | To run the test suite, make sure you have pytest and the required dependencies installed:
37 |
38 | ```bash
39 | pip install pytest pytest-cov pytest-mock
40 | ```
41 |
42 | From the project root directory, run:
43 |
44 | ```bash
45 | pytest tests/
46 | ```
47 |
48 | For a coverage report:
49 |
50 | ```bash
51 | pytest --cov=app tests/
52 | ```
53 |
54 | To run specific tests:
55 |
56 | ```bash
57 | pytest tests/test_camera_models.py
58 | ```
59 |
60 | ## Test Types
61 |
62 | 1. **Unit Tests**:
63 | - Test individual functions and classes in isolation
64 | - Mock dependencies to focus on the unit under test
65 | - Fast execution for quick feedback
66 |
67 | 2. **Integration Tests**:
68 | - Test interactions between components
69 | - Verify proper communication between modules
70 | - Ensure system works as a whole
71 |
72 | ## Mocking Strategy
73 |
74 | To avoid dependencies on GUI components and external libraries during testing, we use a combination of:
75 |
76 | - `unittest.mock` for Python standard library mocking
77 | - `pytest-mock` for pytest-style fixtures
78 | - Custom mock objects for complex dependencies
79 |
80 | ## Future Testing Improvements
81 |
82 | 1. Add more tests for other tabs and components as they are refactored
83 | 2. Implement UI testing for checking GUI interactions
84 | 3. Add performance tests for critical operations
85 | 4. Implement continuous integration for automatic test runs
--------------------------------------------------------------------------------
/app/tab_manager.py:
--------------------------------------------------------------------------------
1 | # tab_manager.py
2 |
3 | from PyQt5.QtWidgets import QTabWidget
4 | from app.base_tab import BaseTab
5 |
6 | class TabManager(QTabWidget):
7 | """Manager for application tabs"""
8 | def __init__(self, parent=None):
9 | super().__init__(parent)
10 | self.tab_instances = {}
11 | self.parent_app = parent
12 | self.currentChanged.connect(self.on_tab_changed)
13 | self.active_tab_index = -1
14 |
15 | def register_tab(self, tab_class, tab_name=None, *args, **kwargs):
16 | """Register a tab with the manager"""
17 | # Create the tab instance
18 | tab_instance = tab_class(*args, parent=self.parent_app, **kwargs)
19 |
20 | # Use provided name or get from the tab
21 | if tab_name is None and isinstance(tab_instance, BaseTab):
22 | tab_name = tab_instance.get_tab_name()
23 | elif tab_name is None:
24 | tab_name = tab_class.__name__
25 |
26 | # Add to tab widget
27 | index = self.addTab(tab_instance, tab_name)
28 |
29 | # Store reference
30 | self.tab_instances[index] = tab_instance
31 |
32 | # If this is the first tab, set it as active
33 | if index == 0:
34 | self.active_tab_index = 0
35 | # Initialize the first tab immediately
36 | if isinstance(tab_instance, BaseTab) and hasattr(tab_instance, 'on_tab_activated'):
37 | tab_instance.on_tab_activated()
38 |
39 | return index
40 |
41 | def get_tab_instance(self, index):
42 | """Get the tab instance at the given index"""
43 | return self.tab_instances.get(index)
44 |
45 | def get_current_tab(self):
46 | """Get the currently active tab instance"""
47 | index = self.currentIndex()
48 | return self.get_tab_instance(index)
49 |
50 | def on_tab_changed(self, index):
51 | """Handle tab changed event"""
52 | # Deactivate previous tab
53 | prev_tab = self.get_tab_instance(self.active_tab_index)
54 | if prev_tab and hasattr(prev_tab, 'on_tab_deactivated'):
55 | prev_tab.on_tab_deactivated()
56 |
57 | # Activate new tab
58 | current_tab = self.get_tab_instance(index)
59 | if current_tab and hasattr(current_tab, 'on_tab_activated'):
60 | current_tab.on_tab_activated()
61 |
62 | # Update active index
63 | self.active_tab_index = index
64 |
65 | def update_all_tabs(self, workdir=None, image_list=None):
66 | """Update all tabs with new workdir and/or image list"""
67 | for tab_instance in self.tab_instances.values():
68 | if hasattr(tab_instance, 'update_workdir') and workdir is not None:
69 | tab_instance.update_workdir(workdir)
70 |
71 | if hasattr(tab_instance, 'update_image_list') and image_list is not None:
72 | tab_instance.update_image_list(image_list)
--------------------------------------------------------------------------------
/app/tabs/gsplat_tab.py:
--------------------------------------------------------------------------------
1 | # gsplat_tab.py
2 |
3 | import os
4 | from PyQt5.QtWidgets import (
5 | QWidget, QVBoxLayout, QSplitter, QLabel, QTreeWidget, QMessageBox
6 | )
7 | from PyQt5.QtCore import Qt
8 |
9 | from app.base_tab import BaseTab
10 |
11 | class GsplatTab(BaseTab):
12 | """Gsplat tab implementation"""
13 | def __init__(self, workdir=None, image_list=None, parent=None):
14 | super().__init__(workdir, image_list, parent)
15 | self.gsplat_manager = None
16 | self.gsplat_image_tree = None
17 |
18 | def get_tab_name(self):
19 | return "Gsplat"
20 |
21 | def initialize(self):
22 | """Initialize the Gsplat tab"""
23 | if not self.workdir:
24 | QMessageBox.warning(self, "Error", "Work directory is not set.")
25 | return
26 |
27 | layout = self.create_horizontal_splitter()
28 |
29 | # Left side: Tree of images
30 | self.gsplat_image_tree = QTreeWidget()
31 | self.gsplat_image_tree.setHeaderLabel("Images")
32 | self.gsplat_image_tree.setFixedWidth(250)
33 | layout.addWidget(self.gsplat_image_tree)
34 |
35 | try:
36 | # Import GsplatManager here to avoid circular imports
37 | from app.gsplat_manager import GsplatManager
38 |
39 | # Right side: GsplatManager widget
40 | self.gsplat_manager = GsplatManager(self.workdir)
41 | layout.addWidget(self.gsplat_manager)
42 |
43 | # Set stretch factors
44 | layout.setStretchFactor(0, 1) # Left side (image tree)
45 | layout.setStretchFactor(1, 4) # Right side (gsplat manager)
46 |
47 | # Set layout for gsplat tab
48 | self._layout.addWidget(layout)
49 |
50 | # Populate the tree with camera data
51 | self.setup_camera_image_tree(self.gsplat_image_tree)
52 |
53 | # Connect double-click signal to handler
54 | self.gsplat_image_tree.itemDoubleClicked.connect(self.handle_image_double_click)
55 |
56 | self.is_initialized = True
57 |
58 | except Exception as e:
59 | error_message = f"Failed to initialize GsplatManager: {str(e)}"
60 | QMessageBox.critical(self, "Error", error_message)
61 | placeholder = QLabel("GsplatManager could not be initialized. See error message.")
62 | placeholder.setAlignment(Qt.AlignCenter)
63 | layout.addWidget(placeholder)
64 | self._layout.addWidget(layout)
65 |
66 | def handle_image_double_click(self, item, column):
67 | """Handle double-click event on image tree item"""
68 | if item.childCount() == 0 and item.parent() is not None:
69 | image_name = item.text(0)
70 | if self.gsplat_manager and hasattr(self.gsplat_manager, 'on_camera_image_tree_double_click'):
71 | self.gsplat_manager.on_camera_image_tree_double_click(image_name)
72 |
73 | def refresh(self):
74 | """Refresh the tab content"""
75 | if self.is_initialized:
76 | # Remove old widgets
77 | for i in reversed(range(self._layout.count())):
78 | self._layout.itemAt(i).widget().setParent(None)
79 |
80 | # Reinitialize
81 | self.is_initialized = False
82 | self.initialize()
--------------------------------------------------------------------------------
/app/base_tab.py:
--------------------------------------------------------------------------------
1 | # base_tab.py
2 |
3 | from PyQt5.QtWidgets import QWidget, QVBoxLayout, QSplitter, QTreeWidget, QTreeWidgetItem
4 | from PyQt5.QtCore import Qt
5 |
6 | class BaseTab(QWidget):
7 | """Base class for application tabs"""
8 | def __init__(self, workdir=None, image_list=None, parent=None):
9 | super().__init__(parent)
10 | self.workdir = workdir
11 | self.image_list = image_list or []
12 | self.parent_app = parent
13 | self._layout = QVBoxLayout(self)
14 | self._layout.setContentsMargins(0, 0, 0, 0)
15 | self.is_initialized = False
16 |
17 | def get_tab_name(self):
18 | """Get the name of the tab. To be overridden by child classes."""
19 | return "Unnamed Tab"
20 |
21 | def initialize(self):
22 | """Initialize the tab. Must be implemented by child classes."""
23 | raise NotImplementedError("This method should be implemented by child classes.")
24 |
25 | def setup_camera_image_tree(self, tree_widget, select_callback=None):
26 | """Set up a camera-grouped image tree widget"""
27 | tree_widget.clear()
28 | camera_groups = {}
29 |
30 | if not self.workdir or not self.image_list:
31 | return
32 |
33 | import json
34 | import os
35 | exif_dir = os.path.join(self.workdir, "exif")
36 |
37 | for image_name in self.image_list:
38 | exif_file = os.path.join(exif_dir, image_name + '.exif')
39 | if os.path.exists(exif_file):
40 | with open(exif_file, 'r') as f:
41 | exif_data = json.load(f)
42 | camera = exif_data.get('camera', 'Unknown Camera')
43 | else:
44 | camera = 'Unknown Camera'
45 |
46 | if camera not in camera_groups:
47 | camera_groups[camera] = []
48 |
49 | camera_groups[camera].append(image_name)
50 |
51 | for camera, images in camera_groups.items():
52 | camera_item = QTreeWidgetItem(tree_widget)
53 | camera_item.setText(0, camera)
54 | for img in images:
55 | img_item = QTreeWidgetItem(camera_item)
56 | img_item.setText(0, img)
57 |
58 | # Connect callback if provided
59 | if select_callback:
60 | tree_widget.itemClicked.connect(select_callback)
61 |
62 | def create_horizontal_splitter(self):
63 | """Create and return a horizontal splitter"""
64 | return QSplitter(Qt.Horizontal)
65 |
66 | def update_workdir(self, workdir):
67 | """Update the working directory"""
68 | self.workdir = workdir
69 | if self.is_initialized:
70 | self.refresh()
71 |
72 | def update_image_list(self, image_list):
73 | """Update the image list"""
74 | self.image_list = image_list
75 | if self.is_initialized:
76 | self.refresh()
77 |
78 | def refresh(self):
79 | """Refresh the tab contents. Can be overridden by child classes."""
80 | pass
81 |
82 | def on_tab_activated(self):
83 | """Called when the tab is activated. Can be overridden by child classes."""
84 | if not self.is_initialized:
85 | self.initialize()
86 | self.is_initialized = True
87 |
88 | def on_tab_deactivated(self):
89 | """Called when the tab is deactivated. Can be overridden by child classes."""
90 | pass
--------------------------------------------------------------------------------
/tests/test_base_tab.py:
--------------------------------------------------------------------------------
1 | # test_base_tab.py
2 |
3 | import pytest
4 | from unittest.mock import MagicMock, patch
5 | from PyQt5.QtWidgets import QWidget, QVBoxLayout
6 | from app.base_tab import BaseTab
7 |
8 | class TestTab(BaseTab):
9 | """Test implementation of BaseTab for testing"""
10 | def __init__(self, *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.initialize_called = False
13 | self.refresh_called = False
14 | self.tab_activated_called = False
15 | self.tab_deactivated_called = False
16 |
17 | def get_tab_name(self):
18 | return "Test Tab"
19 |
20 | def initialize(self):
21 | self.initialize_called = True
22 | self.is_initialized = True
23 |
24 | def refresh(self):
25 | self.refresh_called = True
26 |
27 | def on_tab_activated(self):
28 | super().on_tab_activated()
29 | self.tab_activated_called = True
30 |
31 | def on_tab_deactivated(self):
32 | super().on_tab_deactivated()
33 | self.tab_deactivated_called = True
34 |
35 |
36 | def test_base_tab_init():
37 | """Test BaseTab initialization"""
38 | workdir = "/test/workdir"
39 | image_list = ["image1.jpg", "image2.jpg"]
40 | parent = MagicMock()
41 |
42 | tab = TestTab(workdir, image_list, parent)
43 |
44 | assert tab.workdir == workdir
45 | assert tab.image_list == image_list
46 | assert tab.parent_app == parent
47 | assert isinstance(tab._layout, QVBoxLayout)
48 | assert tab.is_initialized is False
49 |
50 |
51 | def test_base_tab_get_tab_name():
52 | """Test get_tab_name method"""
53 | tab = TestTab()
54 | assert tab.get_tab_name() == "Test Tab"
55 |
56 |
57 | def test_base_tab_on_tab_activated():
58 | """Test on_tab_activated method"""
59 | tab = TestTab()
60 | tab.on_tab_activated()
61 |
62 | assert tab.initialize_called is True
63 | assert tab.is_initialized is True
64 | assert tab.tab_activated_called is True
65 |
66 |
67 | def test_base_tab_update_workdir():
68 | """Test update_workdir method"""
69 | tab = TestTab("/old/workdir")
70 | tab.is_initialized = True
71 |
72 | tab.update_workdir("/new/workdir")
73 |
74 | assert tab.workdir == "/new/workdir"
75 | assert tab.refresh_called is True
76 |
77 |
78 | def test_base_tab_update_image_list():
79 | """Test update_image_list method"""
80 | tab = TestTab(image_list=["old.jpg"])
81 | tab.is_initialized = True
82 |
83 | new_images = ["new1.jpg", "new2.jpg"]
84 | tab.update_image_list(new_images)
85 |
86 | assert tab.image_list == new_images
87 | assert tab.refresh_called is True
88 |
89 |
90 | @patch('app.base_tab.QTreeWidget')
91 | def test_base_tab_setup_camera_image_tree(mock_tree_widget):
92 | """Test setup_camera_image_tree method"""
93 | # This test is simplified due to complexity of mocking file operations
94 | # A more comprehensive test would use a fixture with actual files
95 | tab = TestTab("/test/workdir", ["image1.jpg"])
96 |
97 | tree = MagicMock()
98 | callback = MagicMock()
99 |
100 | # Call the method
101 | tab.setup_camera_image_tree(tree, callback)
102 |
103 | # Check that tree was cleared
104 | assert tree.clear.call_count == 1
105 |
106 | # If callback provided, it should be connected
107 | if callback:
108 | assert tree.itemClicked.connect.call_count == 1
--------------------------------------------------------------------------------
/app/point_cloud_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import os
4 |
5 | def rgb_depth_to_pointcloud(rgb_image, depth_image, fx=1000, fy=1000, cx=None, cy=None):
6 | """
7 | Generate point cloud from RGB image and depth image
8 |
9 | Args:
10 | rgb_image: RGB image (PIL Image or numpy array)
11 | depth_image: Depth image (PIL Image or numpy array)
12 | fx, fy: Camera focal length
13 | cx, cy: Camera optical center (uses image center if None)
14 |
15 | Returns:
16 | points: Point cloud coordinates (N, 3)
17 | colors: Point cloud colors (N, 3)
18 | """
19 | # Convert images to numpy arrays
20 | if isinstance(rgb_image, Image.Image):
21 | rgb_array = np.array(rgb_image)
22 | else:
23 | rgb_array = rgb_image
24 |
25 | if isinstance(depth_image, Image.Image):
26 | depth_array = np.array(depth_image)
27 | else:
28 | depth_array = depth_image
29 |
30 | # Convert depth image to RGB if grayscale
31 | if len(depth_array.shape) == 2:
32 | depth_array = np.stack([depth_array] * 3, axis=-1)
33 |
34 | # Use first channel if depth image is RGB
35 | if len(depth_array.shape) == 3:
36 | depth_array = depth_array[:, :, 0]
37 |
38 | # Get image dimensions
39 | height, width = depth_array.shape
40 |
41 | # Use image center if optical center is not specified
42 | if cx is None:
43 | cx = width / 2
44 | if cy is None:
45 | cy = height / 2
46 |
47 | # Create mesh grid
48 | y, x = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
49 |
50 | # Normalize depth values (0-255 to 0-1)
51 | depth_normalized = depth_array.astype(np.float32) / 255.0
52 |
53 | # Calculate 3D coordinates
54 | Z = depth_normalized * 10.0 # Adjust depth scale as needed
55 | X = (x - cx) * Z / fx
56 | Y = (y - cy) * Z / fy
57 |
58 | # Extract only points with valid depth values
59 | valid_mask = (depth_normalized > 0.01) & (depth_normalized < 0.99)
60 |
61 | # Extract point cloud data
62 | points = np.stack([X[valid_mask], Y[valid_mask], Z[valid_mask]], axis=1)
63 | colors = rgb_array[valid_mask] / 255.0 # Normalize colors to 0-1
64 |
65 | return points, colors
66 |
67 | def load_image_and_depth(image_path, depth_path):
68 | """
69 | Load RGB image and depth image
70 |
71 | Args:
72 | image_path: Path to RGB image
73 | depth_path: Path to depth image
74 |
75 | Returns:
76 | rgb_image: RGB image (PIL Image)
77 | depth_image: Depth image (PIL Image)
78 | """
79 | if not os.path.exists(image_path):
80 | raise FileNotFoundError(f"Image file not found: {image_path}")
81 |
82 | if not os.path.exists(depth_path):
83 | raise FileNotFoundError(f"Depth file not found: {depth_path}")
84 |
85 | rgb_image = Image.open(image_path).convert('RGB')
86 | depth_image = Image.open(depth_path).convert('L') # Load as grayscale
87 |
88 | return rgb_image, depth_image
89 |
90 | def downsample_pointcloud(points, colors, target_points=10000):
91 | """
92 | Downsample point cloud
93 |
94 | Args:
95 | points: Point cloud coordinates (N, 3)
96 | colors: Point cloud colors (N, 3)
97 | target_points: Target number of points
98 |
99 | Returns:
100 | downsampled_points: Downsampled point cloud coordinates
101 | downsampled_colors: Downsampled point cloud colors
102 | """
103 | if len(points) <= target_points:
104 | return points, colors
105 |
106 | # Random sampling
107 | indices = np.random.choice(len(points), target_points, replace=False)
108 | return points[indices], colors[indices]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # splat_one
2 |
3 | splat_one is an integrated application that combines Gaussian Splatting and Structure from Motion (SfM) into a single workflow. It allows users to visualize and process image data through multiple interactive tabs, making it easier to generate 3D point clouds and analyze camera parameters.
4 |
5 | **Note:** This project is currently under development. The Masks tab is not functioning properly at this time.
6 |
7 | ## Tab Descriptions
8 |
9 | - **Images Tab**
10 | Displays information about the imported images, such as file names, resolutions, and metadata.
11 |
12 | - **Masks Tab**
13 | Provides functionality to generate masks for images.
14 | *Currently under development; may not work as expected.*
15 |
16 | - **Features Tab**
17 | Shows the extracted image features, including keypoints and descriptors, for visualization and analysis.
18 |
19 | - **Matching Tab**
20 | Displays the results of feature point matching, allowing you to inspect the quality and accuracy of the correspondences.
21 |
22 | - **Reconstruct Tab**
23 | Visualizes the 3D point cloud reconstructed via SfM, along with camera positions and overall scene structure.
24 |
25 | - **Gsplat Tab**
26 | Presents the output of the Gaussian Splatting process. You can adjust parameters to explore different renderings of the 3D point cloud.
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | ## Data Preparation
35 |
36 | Before running the application, organize your images in the following directory structure inside the `dataset` folder:
37 |
38 | ```bash
39 | splat_one
40 | └─ dataset
41 | └─ your_data
42 | └─ images
43 | ```
44 |
45 | Place all the images you want to process under the `images` directory.
46 |
47 | ## Docker Deployment
48 |
49 | This application is designed to run via Docker. The Docker image is available as `inuex35/splat_one`. Please refer to the `Dockerfile` for more details on the installation.
50 |
51 | To launch the Docker container with GPU support and proper X11 forwarding (for GUI display), run the following command:
52 |
53 | ```bash
54 | docker run --gpus all -e DISPLAY=host.docker.internal:0.0 -v /tmp/.X11-unix:/tmp/.X11-unix -v ${PWD}/dataset:/source/splat_one/dataset -v C:\Users\$env:USERNAME\.cache:/home/user/.cache/ -p 7007:7007 --rm -it --shm-size=12gb inuex35/splat_one
55 | ```
56 |
57 | Once inside the container, start the application by executing:
58 |
59 | ```bash
60 | python main.py
61 | ```
62 |
63 | ### Running with Depth Any Camera (DAC) Mode
64 |
65 | To enable the Depth Any Camera feature for advanced depth estimation, you can run the application with the `--dac` flag:
66 |
67 | ```bash
68 | docker run --gpus all -e DISPLAY=host.docker.internal:0.0 -v /tmp/.X11-unix:/tmp/.X11-unix -v ${PWD}/dataset:/source/splat_one/dataset -v C:\Users\$env:USERNAME\.cache:/home/user/.cache/ -p 7007:7007 --rm -it --shm-size=12gb inuex35/splat_one --dac
69 | ```
70 |
71 | The `--dac` option enables depth estimation capabilities using the Depth Any Camera model, which provides robust depth prediction for various camera types including perspective, fisheye, and 360-degree cameras.
72 |
73 | ## Dependencies
74 | - [OpenSfM](https://github.com/inuex35/ind-bermuda-opensfm/)
75 | - [Gsplat](https://github.com/inuex35/gsplat/)
76 |
77 | We use code from these repositories. Please note that the license for these components follows their respective original repositories.
78 |
79 |
80 | # Development Status
81 | The project is under active development.
82 | Some features, such as the mask generation in the Masks tab, are not yet fully functional.
83 | Contributions via Issues and Pull Requests are welcome.
84 | License
85 | This project is licensed under the MIT License. See the LICENSE file for details.
86 |
--------------------------------------------------------------------------------
/app/tabs/reconstruct_tab.py:
--------------------------------------------------------------------------------
1 | # reconstruct_tab.py
2 |
3 | import os
4 | from PyQt5.QtWidgets import (
5 | QWidget, QVBoxLayout, QSplitter, QLabel, QTreeWidget, QMessageBox
6 | )
7 | from PyQt5.QtCore import Qt
8 |
9 | from app.base_tab import BaseTab
10 |
11 | class ReconstructTab(BaseTab):
12 | """Reconstruct tab implementation"""
13 | def __init__(self, workdir=None, image_list=None, parent=None):
14 | super().__init__(workdir, image_list, parent)
15 | self.reconstruction_viewer = None
16 | self.camera_image_tree = None
17 |
18 | def get_tab_name(self):
19 | return "Reconstruct"
20 |
21 | def initialize(self):
22 | """Initialize the Reconstruct tab"""
23 | if not self.workdir:
24 | QMessageBox.warning(self, "Error", "Work directory is not set.")
25 | return
26 |
27 | layout = self.create_horizontal_splitter()
28 |
29 | # Left side: Tree of images grouped by camera
30 | self.camera_image_tree = QTreeWidget()
31 | self.camera_image_tree.setHeaderLabel("Cameras and Images")
32 | self.camera_image_tree.setFixedWidth(250)
33 | layout.addWidget(self.camera_image_tree)
34 |
35 | try:
36 | # Import Reconstruction here to avoid circular imports
37 | from app.point_cloud_visualizer import Reconstruction
38 |
39 | # Right side: Reconstruction (PointCloudVisualizer) widget
40 | self.reconstruction_viewer = Reconstruction(self.workdir)
41 | layout.addWidget(self.reconstruction_viewer)
42 |
43 | # Set layout for reconstruct tab
44 | self._layout.addWidget(layout)
45 |
46 | # Setup camera image tree with click and double click event handlers
47 | self.setup_camera_image_tree(self.camera_image_tree)
48 | self.camera_image_tree.itemClicked.connect(self.handle_camera_image_tree_click)
49 | self.camera_image_tree.itemDoubleClicked.connect(self.handle_camera_image_tree_double_click)
50 |
51 | # Set stretch factors
52 | layout.setStretchFactor(0, 1) # Left side (image tree)
53 | layout.setStretchFactor(1, 4) # Right side (reconstruction viewer)
54 |
55 | self.is_initialized = True
56 |
57 | except Exception as e:
58 | error_message = f"Failed to initialize Reconstruction: {str(e)}"
59 | QMessageBox.critical(self, "Error", error_message)
60 | placeholder = QLabel("Reconstruction could not be initialized. See error message.")
61 | placeholder.setAlignment(Qt.AlignCenter)
62 | layout.addWidget(placeholder)
63 | self._layout.addWidget(layout)
64 |
65 | def handle_camera_image_tree_click(self, item, column):
66 | """Handle single click event for camera_image_tree"""
67 | if item.childCount() == 0 and item.parent() is not None:
68 | image_name = item.text(0)
69 | if self.reconstruction_viewer:
70 | self.reconstruction_viewer.on_camera_image_tree_click(image_name)
71 |
72 | def handle_camera_image_tree_double_click(self, item, column):
73 | """Handle double click event for camera_image_tree"""
74 | if item.childCount() == 0 and item.parent() is not None:
75 | image_name = item.text(0)
76 | if self.reconstruction_viewer:
77 | self.reconstruction_viewer.on_camera_image_tree_double_click(image_name)
78 |
79 | def refresh(self):
80 | """Refresh the tab content"""
81 | if self.is_initialized:
82 | # Remove old widgets
83 | for i in reversed(range(self._layout.count())):
84 | self._layout.itemAt(i).widget().setParent(None)
85 |
86 | # Reinitialize
87 | self.is_initialized = False
88 | self.initialize()
89 |
90 | def on_tab_activated(self):
91 | """Called when tab is activated"""
92 | super().on_tab_activated()
93 | # Update visualization when tab is activated
94 | if self.is_initialized and self.reconstruction_viewer:
95 | self.reconstruction_viewer.update_visualization()
--------------------------------------------------------------------------------
/app/tabs/features_tab.py:
--------------------------------------------------------------------------------
1 | # features_tab.py
2 |
3 | import os
4 | from PyQt5.QtWidgets import (
5 | QWidget, QVBoxLayout, QSplitter, QLabel, QTreeWidget, QMessageBox
6 | )
7 | from PyQt5.QtCore import Qt
8 |
9 | from app.base_tab import BaseTab
10 |
11 | class FeaturesTab(BaseTab):
12 | """Features tab implementation"""
13 | def __init__(self, workdir=None, image_list=None, parent=None):
14 | super().__init__(workdir, image_list, parent)
15 | self.feature_extractor = None
16 | self.camera_image_tree = None
17 |
18 | # Set up basic UI structure
19 | self.setup_basic_ui()
20 |
21 | def get_tab_name(self):
22 | return "Features"
23 |
24 | def setup_basic_ui(self):
25 | """Set up the basic UI structure"""
26 | layout = self.create_horizontal_splitter()
27 |
28 | # Left side: Tree of images grouped by camera
29 | self.camera_image_tree = QTreeWidget()
30 | self.camera_image_tree.setHeaderLabel("Cameras and Images")
31 | self.camera_image_tree.setFixedWidth(250)
32 | layout.addWidget(self.camera_image_tree)
33 |
34 | # Right side: Placeholder for feature extractor
35 | right_widget = QLabel("Feature Extractor will be displayed here.")
36 | right_widget.setAlignment(Qt.AlignCenter)
37 | layout.addWidget(right_widget)
38 |
39 | # Set stretch factors
40 | layout.setStretchFactor(0, 1) # Left side (image tree)
41 | layout.setStretchFactor(1, 4) # Right side (feature extractor)
42 |
43 | # Set layout for features tab
44 | self._layout.addWidget(layout)
45 |
46 | # Connect signals
47 | self.camera_image_tree.itemClicked.connect(self.display_features_for_image)
48 |
49 | def initialize(self):
50 | """Initialize the Features tab with data"""
51 | if not self.workdir:
52 | QMessageBox.warning(self, "Error", "Work directory is not set.")
53 | return
54 |
55 | try:
56 | # Import FeatureExtractor here to avoid circular imports
57 | from app.feature_extractor import FeatureExtractor
58 |
59 | # Get the main layout and its widgets
60 | main_layout = self._layout.itemAt(0).widget()
61 | right_widget = main_layout.widget(1) # The placeholder
62 |
63 | # Remove the placeholder
64 | right_widget.setParent(None)
65 |
66 | # Create the feature extractor
67 | self.feature_extractor = FeatureExtractor(self.workdir, self.image_list)
68 | main_layout.addWidget(self.feature_extractor)
69 |
70 | # Set stretch factors again
71 | main_layout.setStretchFactor(0, 1) # Left side (image tree)
72 | main_layout.setStretchFactor(1, 4) # Right side (feature extractor)
73 |
74 | # Populate the tree with camera data
75 | self.setup_camera_image_tree(self.camera_image_tree)
76 |
77 | self.is_initialized = True
78 |
79 | except Exception as e:
80 | error_message = f"Failed to initialize FeatureExtractor: {str(e)}"
81 | QMessageBox.critical(self, "Error", error_message)
82 |
83 | def display_features_for_image(self, item, column):
84 | """Display features for the selected image"""
85 | if not self.is_initialized:
86 | self.initialize()
87 |
88 | if item.childCount() == 0 and item.parent() is not None:
89 | image_name = item.text(0)
90 | if self.feature_extractor and hasattr(self.feature_extractor, 'load_image_by_name'):
91 | self.feature_extractor.load_image_by_name(image_name)
92 |
93 | def refresh(self):
94 | """Refresh the tab content"""
95 | if self.is_initialized:
96 | # Remove old widgets
97 | for i in reversed(range(self._layout.count())):
98 | self._layout.itemAt(i).widget().setParent(None)
99 |
100 | # Reinitialize
101 | self.setup_basic_ui()
102 | self.is_initialized = False
103 | self.initialize()
--------------------------------------------------------------------------------
/utils/datasets/download_dataset.py:
--------------------------------------------------------------------------------
1 | """Script to download benchmark dataset(s)"""
2 |
3 | import os
4 | import subprocess
5 | from dataclasses import dataclass
6 | from pathlib import Path
7 | from typing import Literal
8 |
9 | import tyro
10 |
11 | # dataset names
12 | dataset_names = Literal[
13 | "mipnerf360",
14 | "mipnerf360_extra",
15 | "bilarf_data",
16 | "zipnerf",
17 | "zipnerf_undistorted",
18 | ]
19 |
20 | # dataset urls
21 | urls = {
22 | "mipnerf360": "http://storage.googleapis.com/gresearch/refraw360/360_v2.zip",
23 | "mipnerf360_extra": "https://storage.googleapis.com/gresearch/refraw360/360_extra_scenes.zip",
24 | "bilarf_data": "https://huggingface.co/datasets/Yuehao/bilarf_data/resolve/main/bilarf_data.zip",
25 | "zipnerf": [
26 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf/berlin.zip",
27 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf/london.zip",
28 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf/nyc.zip",
29 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf/alameda.zip",
30 | ],
31 | "zipnerf_undistorted": [
32 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf-undistorted/berlin.zip",
33 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf-undistorted/london.zip",
34 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf-undistorted/nyc.zip",
35 | "https://storage.googleapis.com/gresearch/refraw360/zipnerf-undistorted/alameda.zip",
36 | ],
37 | }
38 |
39 | # rename maps
40 | dataset_rename_map = {
41 | "mipnerf360": "360_v2",
42 | "mipnerf360_extra": "360_v2",
43 | "bilarf_data": "bilarf",
44 | "zipnerf": "zipnerf",
45 | "zipnerf_undistorted": "zipnerf_undistorted",
46 | }
47 |
48 |
49 | @dataclass
50 | class DownloadData:
51 | dataset: dataset_names = "mipnerf360"
52 | save_dir: Path = Path(os.getcwd() + "/data")
53 |
54 | def main(self):
55 | self.save_dir.mkdir(parents=True, exist_ok=True)
56 | self.dataset_download(self.dataset)
57 |
58 | def dataset_download(self, dataset: dataset_names):
59 | if isinstance(urls[dataset], list):
60 | for url in urls[dataset]:
61 | url_file_name = Path(url).name
62 | extract_path = self.save_dir / dataset_rename_map[dataset]
63 | download_path = extract_path / url_file_name
64 | download_and_extract(url, download_path, extract_path)
65 | else:
66 | url = urls[dataset]
67 | url_file_name = Path(url).name
68 | extract_path = self.save_dir / dataset_rename_map[dataset]
69 | download_path = extract_path / url_file_name
70 | download_and_extract(url, download_path, extract_path)
71 |
72 |
73 | def download_and_extract(url: str, download_path: Path, extract_path: Path) -> None:
74 | download_path.parent.mkdir(parents=True, exist_ok=True)
75 | extract_path.mkdir(parents=True, exist_ok=True)
76 |
77 | # download
78 | download_command = [
79 | "curl",
80 | "-L",
81 | "-o",
82 | str(download_path),
83 | url,
84 | ]
85 | try:
86 | subprocess.run(download_command, check=True)
87 | print("File file downloaded succesfully.")
88 | except subprocess.CalledProcessError as e:
89 | print(f"Error downloading file: {e}")
90 |
91 | # if .zip
92 | if Path(url).suffix == ".zip":
93 | if os.name == "nt": # Windows doesn't have 'unzip' but 'tar' works
94 | extract_command = [
95 | "tar",
96 | "-xvf",
97 | download_path,
98 | "-C",
99 | extract_path,
100 | ]
101 | else:
102 | extract_command = [
103 | "unzip",
104 | download_path,
105 | "-d",
106 | extract_path,
107 | ]
108 | # if .tar
109 | else:
110 | extract_command = [
111 | "tar",
112 | "-xvzf",
113 | download_path,
114 | "-C",
115 | extract_path,
116 | ]
117 |
118 | # extract
119 | try:
120 | subprocess.run(extract_command, check=True)
121 | os.remove(download_path)
122 | print("Extraction complete.")
123 | except subprocess.CalledProcessError as e:
124 | print(f"Extraction failed: {e}")
125 |
126 |
127 | if __name__ == "__main__":
128 | tyro.cli(DownloadData).main()
129 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
2 |
3 | ENV LC_ALL=C.UTF-8
4 | ENV LANG=C.UTF-8
5 | ARG TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6"
6 | ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}"
7 | ARG DEBIAN_FRONTEND=noninteractive
8 |
9 | # Install system dependencies
10 | RUN apt-get update && apt-get install -y --no-install-recommends \
11 | software-properties-common && \
12 | add-apt-repository ppa:deadsnakes/ppa && \
13 | apt-get install -y --no-install-recommends \
14 | build-essential cmake git libeigen3-dev libopencv-dev libceres-dev \
15 | python3.10 python3.10-dev python3-pip python3.10-distutils \
16 | python-is-python3 curl ninja-build libglm-dev \
17 | libboost-all-dev libflann-dev libfreeimage-dev libmetis-dev \
18 | libgoogle-glog-dev libgtest-dev libgmock-dev libsqlite3-dev \
19 | libglew-dev qtbase5-dev libqt5opengl5-dev wget libcgal-dev \
20 | graphviz mesa-utils libgraphviz-dev libgl1 libglib2.0-0 \
21 | ffmpeg && \
22 | rm -rf /var/lib/apt/lists/*
23 |
24 | # Set Python 3.10 as default and install pip
25 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 && \
26 | update-alternatives --config python3 --skip-auto && \
27 | curl -sS https://bootstrap.pypa.io/get-pip.py | python3 && \
28 | pip install --no-cache-dir --upgrade pip setuptools "setuptools<68.0.0"
29 |
30 | # Install PyTorch
31 | RUN pip install --no-cache-dir torch==2.5.1+cu121 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
32 |
33 | # Core Python dependencies
34 | RUN pip install --no-cache-dir \
35 | numpy pandas matplotlib scipy seaborn \
36 | opencv-python-headless Pillow tqdm \
37 | flask loguru h5py scikit-learn jupyter jupyterlab \
38 | pyqtgraph pyyaml packaging pyparsing==3.0.9 networkx==2.5 \
39 | kornia==0.7.3 torchmetrics[image] imageio[ffmpeg] \
40 | black coverage mypy pylint pytest flake8 isort tyro>=0.8.8 \
41 | open3d colour tabulate simplejson parameterized pydegensac \
42 | tensorboard tensorly xmltodict cloudpickle==0.4.0 \
43 | fpdf2==2.4.6 python-dateutil Sphinx==4.2.0 \
44 | wheel viser nerfview \
45 | jsonschema==4.17.3 jupyter-events==0.6.3 \
46 | "PyOpenGL==3.1.1a1" "PyQt5" bokeh==2.4.3 \
47 | splines pyproj \
48 | mapillary_tools
49 |
50 | ENV TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}"
51 | # Install additional GitHub-based packages
52 | RUN pip install --no-cache-dir \
53 | git+https://github.com/rmbrualla/pycolmap@cc7ea4b7301720ac29287dbe450952511b32125e
54 |
55 | # Clone and patch fused-ssim
56 | RUN git clone https://github.com/rahul-goel/fused-ssim.git /tmp/fused-ssim && \
57 | sed -i '/compute_100/d' /tmp/fused-ssim/setup.py && \
58 | sed -i '/compute_101/d' /tmp/fused-ssim/setup.py && \
59 | pip install /tmp/fused-ssim && \
60 | rm -rf /tmp/fused-ssim
61 |
62 | # FlashAttention
63 | RUN git clone --depth 1 https://github.com/Dao-AILab/flash-attention.git /flash-attention && \
64 | cd /flash-attention && pip install --no-cache-dir .
65 |
66 | # LightGlue
67 | RUN git clone --depth 1 https://github.com/cvg/LightGlue.git /LightGlue && \
68 | cd /LightGlue && pip install --no-cache-dir .
69 |
70 | # Clone and build splat_one and submodules
71 | RUN git clone https://github.com/inuex35/splat_one.git /source/splat_one && \
72 | cd /source/splat_one && \
73 | git submodule update --init --recursive
74 |
75 | # Build OpenSfM
76 | RUN cd /source/splat_one/submodules/opensfm && \
77 | python3 setup.py build && \
78 | python3 setup.py install
79 |
80 | # Build gsplat
81 | RUN cd /source/splat_one/submodules/gsplat && \
82 | git checkout b0e978da67fb4364611c6683c5f4e6e6c1d8d8cb && \
83 | MAX_JOBS=4 pip install -e .
84 |
85 | # Build sam2
86 | RUN sed -i 's/setuptools>=61.0/setuptools>=62.3.8,<75.9/' /source/splat_one/submodules/sam2/pyproject.toml && \
87 | cd /source/splat_one/submodules/sam2 && \
88 | pip install -e ".[notebooks]" && \
89 | cd checkpoints && ./download_ckpts.sh
90 |
91 | # Clone and setup depth_any_camera
92 | RUN git clone https://github.com/yuliangguo/depth_any_camera /depth_any_camera && \
93 | cd /depth_any_camera && \
94 | pip install -r requirements.txt
95 | ENV PYTHONPATH="/depth_any_camera:$PYTHONPATH"
96 |
97 | # Pre-download PyTorch model
98 | RUN mkdir -p /root/.cache/torch/hub/checkpoints && \
99 | wget https://download.pytorch.org/models/alexnet-owt-7be5be79.pth -O /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100 |
101 | WORKDIR /source/splat_one
102 |
103 | COPY entrypoint.sh /entrypoint.sh
104 | RUN chmod +x /entrypoint.sh
105 | ENTRYPOINT ["/entrypoint.sh"]
106 | CMD ["bash"]
107 |
--------------------------------------------------------------------------------
/app/point_cloud_viewer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PyQt5.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QLabel, QSlider, QSpinBox
3 | from PyQt5.QtCore import Qt, QTimer
4 | from PyQt5.QtGui import QVector3D
5 | import pyqtgraph.opengl as gl
6 | import pyqtgraph as pg
7 | import os
8 |
9 | from app.point_cloud_utils import rgb_depth_to_pointcloud, load_image_and_depth, downsample_pointcloud
10 |
11 | class PointCloudViewer(QWidget):
12 | """Point cloud viewer widget (pyqtgraph version)"""
13 |
14 | def __init__(self, workdir=None):
15 | super().__init__()
16 | self.workdir = workdir
17 | self.current_points = None
18 | self.current_colors = None
19 | self.point_cloud_item = None
20 |
21 | self.setup_ui()
22 |
23 | def setup_ui(self):
24 | """Setup UI"""
25 | layout = QVBoxLayout(self)
26 |
27 | # pyqtgraph 3D viewer (same method as reconstruction tab)
28 | self.viewer = gl.GLViewWidget()
29 | self.viewer.setFocusPolicy(Qt.NoFocus)
30 | self.viewer.setCameraPosition(distance=10)
31 | layout.addWidget(self.viewer)
32 |
33 | # Control panel
34 | control_layout = QHBoxLayout()
35 |
36 | # Point count adjustment
37 | control_layout.addWidget(QLabel("Point Count:"))
38 | self.point_count_spinbox = QSpinBox()
39 | self.point_count_spinbox.setRange(1000, 100000)
40 | self.point_count_spinbox.setValue(10000)
41 | self.point_count_spinbox.valueChanged.connect(self.update_point_cloud)
42 | control_layout.addWidget(self.point_count_spinbox)
43 |
44 | # Point size adjustment
45 | control_layout.addWidget(QLabel("Point Size:"))
46 | self.point_size_slider = QSlider(Qt.Horizontal)
47 | self.point_size_slider.setRange(1, 10)
48 | self.point_size_slider.setValue(3)
49 | self.point_size_slider.valueChanged.connect(self.update_point_size)
50 | control_layout.addWidget(self.point_size_slider)
51 |
52 | # Reset button
53 | self.reset_button = QPushButton("Reset Camera")
54 | self.reset_button.clicked.connect(self.reset_camera)
55 | control_layout.addWidget(self.reset_button)
56 |
57 | control_layout.addStretch()
58 | layout.addLayout(control_layout)
59 |
60 | def load_point_cloud_from_images(self, image_name):
61 | """Load and display point cloud from images and depth"""
62 | if not self.workdir:
63 | return
64 |
65 | try:
66 | # Image and depth paths
67 | image_path = os.path.join(self.workdir, "images", image_name)
68 | depth_path = os.path.join(self.workdir, "depth", f"{image_name}_depth.png")
69 |
70 | # Load images and depth
71 | rgb_image, depth_image = load_image_and_depth(image_path, depth_path)
72 |
73 | # Generate point cloud
74 | points, colors = rgb_depth_to_pointcloud(rgb_image, depth_image)
75 |
76 | # Save point cloud
77 | self.current_points = points
78 | self.current_colors = colors
79 |
80 | # Display point cloud
81 | self.update_point_cloud()
82 |
83 | except Exception as e:
84 | print(f"Point cloud generation error: {e}")
85 | self.clear_point_cloud()
86 |
87 | def update_point_cloud(self):
88 | """Update and display point cloud"""
89 | if self.current_points is None or self.current_colors is None:
90 | return
91 |
92 | # Downsample point cloud
93 | target_points = self.point_count_spinbox.value()
94 | points, colors = downsample_pointcloud(self.current_points, self.current_colors, target_points)
95 |
96 | # Remove existing point cloud
97 | self.clear_point_cloud()
98 |
99 | # Add new point cloud
100 | if len(points) > 0:
101 | # Create point cloud item (same method as reconstruction tab)
102 | self.point_cloud_item = gl.GLScatterPlotItem(
103 | pos=points,
104 | color=colors,
105 | size=self.point_size_slider.value(),
106 | pxMode=True
107 | )
108 | self.viewer.addItem(self.point_cloud_item)
109 |
110 | def update_point_size(self):
111 | """Update point size"""
112 | if self.point_cloud_item:
113 | self.point_cloud_item.setData(size=self.point_size_slider.value())
114 |
115 | def clear_point_cloud(self):
116 | """Clear point cloud"""
117 | if self.point_cloud_item:
118 | self.viewer.removeItem(self.point_cloud_item)
119 | self.point_cloud_item = None
120 |
121 | def reset_camera(self):
122 | """Reset camera position"""
123 | self.viewer.setCameraPosition(distance=10)
124 |
125 | def set_workdir(self, workdir):
126 | """Set working directory"""
127 | self.workdir = workdir
--------------------------------------------------------------------------------
/utils/datasets/normalize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"):
5 | """
6 | reference: nerf-factory
7 | Get a similarity transform to normalize dataset
8 | from c2w (OpenCV convention) cameras
9 | :param c2w: (N, 4)
10 | :return T (4,4) , scale (float)
11 | """
12 | t = c2w[:, :3, 3]
13 | R = c2w[:, :3, :3]
14 |
15 | # (1) Rotate the world so that z+ is the up axis
16 | # we estimate the up axis by averaging the camera up axes
17 | ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1)
18 | world_up = np.mean(ups, axis=0)
19 | world_up /= np.linalg.norm(world_up)
20 |
21 | up_camspace = np.array([0.0, -1.0, 0.0])
22 | c = (up_camspace * world_up).sum()
23 | cross = np.cross(world_up, up_camspace)
24 | skew = np.array(
25 | [
26 | [0.0, -cross[2], cross[1]],
27 | [cross[2], 0.0, -cross[0]],
28 | [-cross[1], cross[0], 0.0],
29 | ]
30 | )
31 | if c > -1:
32 | R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c)
33 | else:
34 | # In the unlikely case the original data has y+ up axis,
35 | # rotate 180-deg about x axis
36 | R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
37 |
38 | # R_align = np.eye(3) # DEBUG
39 | R = R_align @ R
40 | fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1)
41 | t = (R_align @ t[..., None])[..., 0]
42 |
43 | # (2) Recenter the scene.
44 | if center_method == "focus":
45 | # find the closest point to the origin for each camera's center ray
46 | nearest = t + (fwds * -t).sum(-1)[:, None] * fwds
47 | translate = -np.median(nearest, axis=0)
48 | elif center_method == "poses":
49 | # use center of the camera positions
50 | translate = -np.median(t, axis=0)
51 | else:
52 | raise ValueError(f"Unknown center_method {center_method}")
53 |
54 | transform = np.eye(4)
55 | transform[:3, 3] = translate
56 | transform[:3, :3] = R_align
57 |
58 | # (3) Rescale the scene using camera distances
59 | scale_fn = np.max if strict_scaling else np.median
60 | scale = 1.0 / scale_fn(np.linalg.norm(t + translate, axis=-1))
61 | transform[:3, :] *= scale
62 |
63 | return transform
64 |
65 |
66 | def align_principle_axes(point_cloud):
67 | # Compute centroid
68 | centroid = np.median(point_cloud, axis=0)
69 |
70 | # Translate point cloud to centroid
71 | translated_point_cloud = point_cloud - centroid
72 |
73 | # Compute covariance matrix
74 | covariance_matrix = np.cov(translated_point_cloud, rowvar=False)
75 |
76 | # Compute eigenvectors and eigenvalues
77 | eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix)
78 |
79 | # Sort eigenvectors by eigenvalues (descending order) so that the z-axis
80 | # is the principal axis with the smallest eigenvalue.
81 | sort_indices = eigenvalues.argsort()[::-1]
82 | eigenvectors = eigenvectors[:, sort_indices]
83 |
84 | # Check orientation of eigenvectors. If the determinant of the eigenvectors is
85 | # negative, then we need to flip the sign of one of the eigenvectors.
86 | if np.linalg.det(eigenvectors) < 0:
87 | eigenvectors[:, 0] *= -1
88 |
89 | # Create rotation matrix
90 | rotation_matrix = eigenvectors.T
91 |
92 | # Create SE(3) matrix (4x4 transformation matrix)
93 | transform = np.eye(4)
94 | transform[:3, :3] = rotation_matrix
95 | transform[:3, 3] = -rotation_matrix @ centroid
96 |
97 | return transform
98 |
99 |
100 | def transform_points(matrix, points):
101 | """Transform points using an SE(3) matrix.
102 |
103 | Args:
104 | matrix: 4x4 SE(3) matrix
105 | points: Nx3 array of points
106 |
107 | Returns:
108 | Nx3 array of transformed points
109 | """
110 | assert matrix.shape == (4, 4)
111 | assert len(points.shape) == 2 and points.shape[1] == 3
112 | return points @ matrix[:3, :3].T + matrix[:3, 3]
113 |
114 |
115 | def transform_cameras(matrix, camtoworlds):
116 | """Transform cameras using an SE(3) matrix.
117 |
118 | Args:
119 | matrix: 4x4 SE(3) matrix
120 | camtoworlds: Nx4x4 array of camera-to-world matrices
121 |
122 | Returns:
123 | Nx4x4 array of transformed camera-to-world matrices
124 | """
125 | assert matrix.shape == (4, 4)
126 | assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4)
127 | camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix)
128 | scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1)
129 | camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None]
130 | return camtoworlds
131 |
132 |
133 | def normalize(camtoworlds, points=None):
134 | T1 = similarity_from_cameras(camtoworlds)
135 | camtoworlds = transform_cameras(T1, camtoworlds)
136 | if points is not None:
137 | points = transform_points(T1, points)
138 | T2 = align_principle_axes(points)
139 | camtoworlds = transform_cameras(T2, camtoworlds)
140 | points = transform_points(T2, points)
141 | return camtoworlds, points, T2 @ T1
142 | else:
143 | return camtoworlds, T1
144 |
--------------------------------------------------------------------------------
/app/tabs/masks_tab.py:
--------------------------------------------------------------------------------
1 | # masks_tab.py
2 |
3 | import os
4 | import sys
5 | import importlib.util
6 | from PyQt5.QtWidgets import (
7 | QWidget, QVBoxLayout, QHBoxLayout, QSplitter, QLabel,
8 | QTreeWidget, QTreeWidgetItem, QMessageBox
9 | )
10 | from PyQt5.QtCore import Qt
11 |
12 | from app.base_tab import BaseTab
13 |
14 | class MaskManagerWidget(QWidget):
15 | def __init__(self, mask_manager, parent=None):
16 | super().__init__(parent)
17 | self.mask_manager = mask_manager
18 | layout = QVBoxLayout()
19 | layout.addWidget(self.mask_manager)
20 | self.setLayout(layout)
21 |
22 | class MasksTab(BaseTab):
23 | """Masks tab implementation"""
24 | def __init__(self, workdir=None, image_list=None, parent=None):
25 | super().__init__(workdir, image_list, parent)
26 | self.mask_manager = None
27 | self.mask_manager_widget = None
28 | self.camera_image_tree = None
29 |
30 | def get_tab_name(self):
31 | return "Masks"
32 |
33 | def initialize(self):
34 | """Initialize the Masks tab"""
35 | if not self.workdir:
36 | QMessageBox.warning(self, "Error", "Work directory is not set.")
37 | return
38 |
39 | layout = self.create_horizontal_splitter()
40 |
41 | # Left side: Tree of images grouped by camera
42 | self.camera_image_tree = QTreeWidget()
43 | self.camera_image_tree.setHeaderLabel("Cameras and Images")
44 | self.camera_image_tree.setFixedWidth(250)
45 | layout.addWidget(self.camera_image_tree)
46 |
47 | # Initialize MaskManager
48 | try:
49 | sam2_dir = self.get_sam2_directory()
50 | checkpoint_path = os.path.join(sam2_dir, "checkpoints", "sam2.1_hiera_large.pt")
51 | config_path = os.path.join("configs", "sam2.1", "sam2.1_hiera_l.yaml")
52 | mask_dir = os.path.join(self.workdir, "masks")
53 | img_dir = os.path.join(self.workdir, "images")
54 |
55 | # Ensure mask directory exists
56 | os.makedirs(mask_dir, exist_ok=True)
57 |
58 | # Import MaskManager here to avoid importing when not needed
59 | from app.mask_manager import MaskManager
60 |
61 | self.mask_manager = MaskManager(
62 | checkpoint_path, config_path, mask_dir, img_dir, self.image_list
63 | )
64 |
65 | # Right side: MaskManager widget
66 | self.mask_manager_widget = MaskManagerWidget(self.mask_manager)
67 | layout.addWidget(self.mask_manager_widget)
68 |
69 | # Set stretch factors
70 | layout.setStretchFactor(0, 1) # Left side (image tree)
71 | layout.setStretchFactor(1, 4) # Right side (MaskManager)
72 |
73 | # Set layout for mask tab
74 | self._layout.addWidget(layout)
75 |
76 | # Populate the camera image tree
77 | self.setup_camera_image_tree(self.camera_image_tree, self.display_mask)
78 |
79 | self.is_initialized = True
80 |
81 | except Exception as e:
82 | error_message = f"Failed to initialize MaskManager: {str(e)}"
83 | QMessageBox.critical(self, "Error", error_message)
84 | placeholder = QLabel("MaskManager could not be initialized. See error message.")
85 | placeholder.setAlignment(Qt.AlignCenter)
86 | layout.addWidget(placeholder)
87 | self._layout.addWidget(layout)
88 |
89 | def display_mask(self, item, column):
90 | """Display the mask for the selected image"""
91 | if item.childCount() == 0 and item.parent() is not None:
92 | image_name = item.text(0)
93 | if self.mask_manager is not None:
94 | self.mask_manager.load_image_by_name(image_name)
95 |
96 | def on_tab_activated(self):
97 | """Called when tab is activated"""
98 | super().on_tab_activated()
99 | # Load SAM model when tab is activated
100 | if self.mask_manager and hasattr(self.mask_manager, 'init_sam_model'):
101 | self.mask_manager.init_sam_model()
102 |
103 | def on_tab_deactivated(self):
104 | """Called when tab is deactivated"""
105 | super().on_tab_deactivated()
106 | # Unload SAM model when tab is deactivated to save memory
107 | if self.mask_manager and hasattr(self.mask_manager, 'unload_sam_model'):
108 | self.mask_manager.unload_sam_model()
109 |
110 | def refresh(self):
111 | """Refresh the tab content"""
112 | # Reinitialize the tab if initialized
113 | if self.is_initialized:
114 | # Remove old widgets
115 | for i in reversed(range(self._layout.count())):
116 | self._layout.itemAt(i).widget().setParent(None)
117 |
118 | # Reinitialize
119 | self.is_initialized = False
120 | self.on_tab_activated()
121 |
122 | def get_sam2_directory(self):
123 | """Get sam2 install directory"""
124 | module_name = 'sam2'
125 | spec = importlib.util.find_spec(module_name)
126 | if spec is None:
127 | raise ModuleNotFoundError(f"{module_name} is not installed.")
128 | return os.path.dirname(os.path.dirname(spec.origin))
--------------------------------------------------------------------------------
/app/tabs/matching_tab.py:
--------------------------------------------------------------------------------
1 | # matching_tab.py
2 |
3 | import os
4 | from PyQt5.QtWidgets import (
5 | QWidget, QVBoxLayout, QHBoxLayout, QSplitter, QLabel, QTreeWidget, QMessageBox
6 | )
7 | from PyQt5.QtCore import Qt
8 |
9 | from app.base_tab import BaseTab
10 |
11 | class MatchingTab(BaseTab):
12 | """Matching tab implementation"""
13 | def __init__(self, workdir=None, image_list=None, parent=None):
14 | super().__init__(workdir, image_list, parent)
15 | self.matching_viewer = None
16 | self.camera_image_tree_left = None
17 | self.camera_image_tree_right = None
18 |
19 | # Set up basic UI structure
20 | self.setup_basic_ui()
21 |
22 | def get_tab_name(self):
23 | return "Matching"
24 |
25 | def setup_basic_ui(self):
26 | """Set up the basic UI structure"""
27 | # Create a main horizontal splitter
28 | main_layout = self.create_horizontal_splitter()
29 |
30 | # Left side: Combined tree widget container
31 | left_widget = QWidget()
32 | left_layout = QVBoxLayout(left_widget)
33 | left_layout.setContentsMargins(0, 0, 0, 0)
34 |
35 | # Create label for left and right image selection
36 | left_layout.addWidget(QLabel("Left Image Selection:"))
37 |
38 | # Tree for left image
39 | self.camera_image_tree_left = QTreeWidget()
40 | self.camera_image_tree_left.setHeaderLabel("Cameras and Images - Left")
41 | left_layout.addWidget(self.camera_image_tree_left)
42 |
43 | # Add a separator label
44 | left_layout.addWidget(QLabel("Right Image Selection:"))
45 |
46 | # Tree for right image
47 | self.camera_image_tree_right = QTreeWidget()
48 | self.camera_image_tree_right.setHeaderLabel("Cameras and Images - Right")
49 | left_layout.addWidget(self.camera_image_tree_right)
50 |
51 | # Add the left widget to the main layout with fixed width
52 | left_widget.setFixedWidth(300) # A bit wider to accommodate the trees
53 | main_layout.addWidget(left_widget)
54 |
55 | # Right side: Placeholder for matching viewer
56 | right_widget = QLabel("Matching Viewer will be displayed here.")
57 | right_widget.setAlignment(Qt.AlignCenter)
58 | main_layout.addWidget(right_widget)
59 |
60 | # Set stretch factors
61 | main_layout.setStretchFactor(0, 1) # Left side (trees)
62 | main_layout.setStretchFactor(1, 3) # Right side (matching viewer)
63 |
64 | # Set layout for matching tab
65 | self._layout.addWidget(main_layout)
66 |
67 | def initialize(self):
68 | """Initialize the Matching tab with data"""
69 | if not self.workdir:
70 | QMessageBox.warning(self, "Error", "Work directory is not set.")
71 | return
72 |
73 | try:
74 | # Import FeatureMatching here to avoid circular imports
75 | from app.feature_matching import FeatureMatching
76 |
77 | # Get the main layout and its widgets
78 | main_layout = self._layout.itemAt(0).widget()
79 | right_widget = main_layout.widget(1) # The placeholder
80 |
81 | # Remove the placeholder
82 | right_widget.setParent(None)
83 |
84 | # Create the matching viewer
85 | self.matching_viewer = FeatureMatching(workdir=self.workdir, image_list=self.image_list)
86 | main_layout.addWidget(self.matching_viewer)
87 |
88 | # Set stretch factors again (they might be reset after widget changes)
89 | main_layout.setStretchFactor(0, 1) # Left side (trees)
90 | main_layout.setStretchFactor(1, 3) # Right side (matching viewer)
91 |
92 | # Populate the trees with camera data
93 | self.setup_camera_image_tree(self.camera_image_tree_left, self.on_image_selected_left)
94 | self.setup_camera_image_tree(self.camera_image_tree_right, self.on_image_selected_right)
95 |
96 | self.is_initialized = True
97 |
98 | except Exception as e:
99 | error_message = f"Failed to initialize FeatureMatching: {str(e)}"
100 | QMessageBox.critical(self, "Error", error_message)
101 |
102 | def on_image_selected_left(self, item, column):
103 | """Handle image selection in the left camera tree"""
104 | if not self.is_initialized:
105 | self.initialize()
106 |
107 | if item.childCount() == 0 and item.parent() is not None:
108 | image_name = item.text(0)
109 | if self.matching_viewer:
110 | self.matching_viewer.load_image_by_name(image_name, position="left")
111 |
112 | def on_image_selected_right(self, item, column):
113 | """Handle image selection in the right camera tree"""
114 | if not self.is_initialized:
115 | self.initialize()
116 |
117 | if item.childCount() == 0 and item.parent() is not None:
118 | image_name = item.text(0)
119 | if self.matching_viewer:
120 | self.matching_viewer.load_image_by_name(image_name, position="right")
121 |
122 | def refresh(self):
123 | """Refresh the tab content"""
124 | if self.is_initialized:
125 | # Remove old widgets
126 | for i in reversed(range(self._layout.count())):
127 | self._layout.itemAt(i).widget().setParent(None)
128 |
129 | # Reinitialize
130 | self.setup_basic_ui()
131 | self.is_initialized = False
132 | self.initialize()
--------------------------------------------------------------------------------
/tests/test_tab_manager.py:
--------------------------------------------------------------------------------
1 | # test_tab_manager.py
2 |
3 | import pytest
4 | from unittest.mock import MagicMock, patch
5 | from PyQt5.QtWidgets import QWidget
6 | from app.tab_manager import TabManager
7 | from app.base_tab import BaseTab
8 |
9 | class MockTab(BaseTab):
10 | """Mock implementation of BaseTab for testing"""
11 | def __init__(self, workdir=None, image_list=None, parent=None, custom_arg=None):
12 | super().__init__(workdir, image_list, parent)
13 | self.custom_arg = custom_arg
14 | self.activated = False
15 | self.deactivated = False
16 |
17 | def get_tab_name(self):
18 | return "Mock Tab"
19 |
20 | def initialize(self):
21 | self.is_initialized = True
22 |
23 | def on_tab_activated(self):
24 | super().on_tab_activated()
25 | self.activated = True
26 |
27 | def on_tab_deactivated(self):
28 | super().on_tab_deactivated()
29 | self.deactivated = True
30 |
31 |
32 | class SimpleWidget(QWidget):
33 | """Simple widget for testing non-BaseTab tabs"""
34 | def __init__(self, parent=None):
35 | super().__init__(parent)
36 |
37 |
38 | @pytest.fixture
39 | def tab_manager(mock_qapplication):
40 | """Create a TabManager for testing"""
41 | return TabManager()
42 |
43 |
44 | def test_tab_manager_init(tab_manager):
45 | """Test TabManager initialization"""
46 | assert isinstance(tab_manager, TabManager)
47 | assert tab_manager.tab_instances == {}
48 | assert tab_manager.active_tab_index == -1
49 |
50 |
51 | def test_register_tab_with_base_tab(tab_manager):
52 | """Test registering a BaseTab-derived tab"""
53 | # Register a tab
54 | index = tab_manager.register_tab(MockTab, workdir="/test/workdir", custom_arg="test")
55 |
56 | # Check that the tab was added
57 | assert index == 0
58 | assert index in tab_manager.tab_instances
59 | assert isinstance(tab_manager.tab_instances[index], MockTab)
60 | assert tab_manager.tab_instances[index].workdir == "/test/workdir"
61 | assert tab_manager.tab_instances[index].custom_arg == "test"
62 | assert tab_manager.tabText(index) == "Mock Tab"
63 |
64 |
65 | def test_register_tab_with_widget(tab_manager):
66 | """Test registering a non-BaseTab widget"""
67 | # Register a widget with explicit name
68 | index = tab_manager.register_tab(SimpleWidget, tab_name="Simple Widget")
69 |
70 | # Check that the tab was added
71 | assert index == 0
72 | assert index in tab_manager.tab_instances
73 | assert isinstance(tab_manager.tab_instances[index], SimpleWidget)
74 | assert tab_manager.tabText(index) == "Simple Widget"
75 |
76 | # Register a widget without explicit name
77 | index = tab_manager.register_tab(SimpleWidget)
78 |
79 | # Check that the tab was added with class name
80 | assert index == 1
81 | assert tab_manager.tabText(index) == "SimpleWidget"
82 |
83 |
84 | def test_get_tab_instance(tab_manager):
85 | """Test get_tab_instance method"""
86 | # Register tabs
87 | index1 = tab_manager.register_tab(MockTab)
88 | index2 = tab_manager.register_tab(SimpleWidget)
89 |
90 | # Get instances
91 | tab1 = tab_manager.get_tab_instance(index1)
92 | tab2 = tab_manager.get_tab_instance(index2)
93 |
94 | # Check correct instances
95 | assert isinstance(tab1, MockTab)
96 | assert isinstance(tab2, SimpleWidget)
97 |
98 | # Test invalid index
99 | assert tab_manager.get_tab_instance(999) is None
100 |
101 |
102 | @patch('app.tab_manager.TabManager.currentIndex')
103 | def test_get_current_tab(mock_current_index, tab_manager):
104 | """Test get_current_tab method"""
105 | # Register tabs
106 | index1 = tab_manager.register_tab(MockTab)
107 | index2 = tab_manager.register_tab(SimpleWidget)
108 |
109 | # Mock currentIndex
110 | mock_current_index.return_value = index1
111 |
112 | # Get current tab
113 | current_tab = tab_manager.get_current_tab()
114 |
115 | # Check correct instance
116 | assert isinstance(current_tab, MockTab)
117 |
118 | # Change current tab
119 | mock_current_index.return_value = index2
120 | current_tab = tab_manager.get_current_tab()
121 | assert isinstance(current_tab, SimpleWidget)
122 |
123 |
124 | def test_on_tab_changed(tab_manager):
125 | """Test on_tab_changed method"""
126 | # Register tabs
127 | index1 = tab_manager.register_tab(MockTab)
128 | index2 = tab_manager.register_tab(MockTab)
129 |
130 | # Get tab instances
131 | tab1 = tab_manager.get_tab_instance(index1)
132 | tab2 = tab_manager.get_tab_instance(index2)
133 |
134 | # Initially, no tab is active
135 | assert tab_manager.active_tab_index == -1
136 |
137 | # Activate first tab
138 | tab_manager.on_tab_changed(index1)
139 | assert tab1.activated is True
140 | assert tab1.deactivated is False
141 | assert tab2.activated is False
142 | assert tab2.deactivated is False
143 | assert tab_manager.active_tab_index == index1
144 |
145 | # Change to second tab
146 | tab_manager.on_tab_changed(index2)
147 | assert tab1.activated is True
148 | assert tab1.deactivated is True
149 | assert tab2.activated is True
150 | assert tab2.deactivated is False
151 | assert tab_manager.active_tab_index == index2
152 |
153 |
154 | def test_update_all_tabs(tab_manager):
155 | """Test update_all_tabs method"""
156 | # Register tabs
157 | index1 = tab_manager.register_tab(MockTab, workdir="/old/workdir", image_list=["old.jpg"])
158 | index2 = tab_manager.register_tab(SimpleWidget) # Non-BaseTab widget
159 |
160 | # Get first tab instance
161 | tab1 = tab_manager.get_tab_instance(index1)
162 |
163 | # Mock update methods
164 | tab1.update_workdir = MagicMock()
165 | tab1.update_image_list = MagicMock()
166 |
167 | # Call update_all_tabs
168 | new_workdir = "/new/workdir"
169 | new_image_list = ["new1.jpg", "new2.jpg"]
170 | tab_manager.update_all_tabs(workdir=new_workdir, image_list=new_image_list)
171 |
172 | # Check that update methods were called
173 | tab1.update_workdir.assert_called_once_with(new_workdir)
174 | tab1.update_image_list.assert_called_once_with(new_image_list)
175 |
176 | # Test with only workdir
177 | tab1.update_workdir.reset_mock()
178 | tab1.update_image_list.reset_mock()
179 | tab_manager.update_all_tabs(workdir="/another/workdir")
180 | tab1.update_workdir.assert_called_once_with("/another/workdir")
181 | assert tab1.update_image_list.call_count == 0
182 |
183 | # Test with only image_list
184 | tab1.update_workdir.reset_mock()
185 | tab1.update_image_list.reset_mock()
186 | tab_manager.update_all_tabs(image_list=["another.jpg"])
187 | assert tab1.update_workdir.call_count == 0
188 | tab1.update_image_list.assert_called_once_with(["another.jpg"])
189 |
--------------------------------------------------------------------------------
/tests/test_images_tab.py:
--------------------------------------------------------------------------------
1 | # test_images_tab.py
2 |
3 | import os
4 | import json
5 | import pytest
6 | from unittest.mock import MagicMock, patch
7 | from PyQt5.QtWidgets import QApplication, QTreeWidgetItem, QDialog
8 | from PyQt5.QtGui import QPixmap
9 | from app.tabs.images_tab import ImagesTab
10 |
11 | @pytest.fixture
12 | def mock_camera_model_manager():
13 | """Mock CameraModelManager for testing"""
14 | manager = MagicMock()
15 | manager.get_camera_models.return_value = {
16 | "Camera1": {
17 | "projection_type": "perspective",
18 | "focal_ratio": 1.5
19 | }
20 | }
21 | return manager
22 |
23 | @pytest.fixture
24 | def mock_image_processor():
25 | """Mock ImageProcessor for testing"""
26 | processor = MagicMock()
27 | processor.get_sample_image_dimensions.return_value = (1920, 1080)
28 | processor.resize_images.return_value = 2 # Number of images processed
29 | processor.restore_original_images.return_value = True
30 | return processor
31 |
32 | @pytest.fixture
33 | def setup_images_tab(mock_qapplication, setup_image_folders):
34 | """Set up ImagesTab for testing"""
35 | tab = ImagesTab(setup_image_folders)
36 | return tab
37 |
38 | @patch('app.tabs.images_tab.CameraModelManager')
39 | @patch('app.tabs.images_tab.ImageProcessor')
40 | def test_images_tab_initialize(mock_image_processor_class, mock_model_manager_class, setup_images_tab):
41 | """Test ImagesTab initialization"""
42 | # Setup mock returns
43 | mock_model_manager = MagicMock()
44 | mock_model_manager_class.return_value = mock_model_manager
45 |
46 | mock_processor = MagicMock()
47 | mock_image_processor_class.return_value = mock_processor
48 |
49 | # Initialize the tab
50 | setup_images_tab.initialize()
51 |
52 | # Check that managers were created
53 | assert mock_model_manager_class.call_count == 1
54 | assert mock_image_processor_class.call_count == 1
55 |
56 | # Check that UI elements were created
57 | assert setup_images_tab.camera_image_tree is not None
58 | assert setup_images_tab.image_viewer is not None
59 | assert setup_images_tab.exif_table is not None
60 | assert setup_images_tab.is_initialized is True
61 |
62 | @patch('app.tabs.images_tab.QPixmap')
63 | @patch('os.path.exists')
64 | def test_display_image_and_exif(mock_exists, mock_pixmap, setup_images_tab, monkeypatch):
65 | """Test display_image_and_exif method"""
66 | # Setup mocks
67 | mock_exists.return_value = True
68 | mock_pixmap_instance = MagicMock()
69 | mock_pixmap.return_value = mock_pixmap_instance
70 |
71 | # Mock open and json.load
72 | mock_open = MagicMock()
73 | mock_json_load = MagicMock(return_value={"camera": "Camera1", "width": 1920})
74 |
75 | monkeypatch.setattr('builtins.open', mock_open)
76 | monkeypatch.setattr('json.load', mock_json_load)
77 |
78 | # Create mock item
79 | item = MagicMock()
80 | item.childCount.return_value = 0
81 | item.parent.return_value = MagicMock() # Not None
82 | item.text.return_value = "test_image.jpg"
83 |
84 | # Initialize necessary UI components
85 | setup_images_tab.image_viewer = MagicMock()
86 | setup_images_tab.exif_table = MagicMock()
87 | setup_images_tab.display_exif_data = MagicMock() # Mock the method to avoid testing it here
88 |
89 | # Call the method
90 | setup_images_tab.display_image_and_exif(item, 0)
91 |
92 | # Check that image was displayed
93 | assert setup_images_tab.image_viewer.setPixmap.call_count == 1
94 | assert mock_pixmap.call_count == 1
95 |
96 | # Check that EXIF was loaded
97 | mock_open.assert_called_once()
98 | mock_json_load.assert_called_once()
99 | setup_images_tab.display_exif_data.assert_called_once()
100 |
101 | @patch('json.dumps')
102 | def test_display_exif_data(mock_json_dumps, setup_images_tab, mock_camera_model_manager):
103 | """Test display_exif_data method"""
104 | # Setup
105 | setup_images_tab.exif_table = MagicMock()
106 | setup_images_tab.camera_model_manager = mock_camera_model_manager
107 | mock_json_dumps.return_value = '{"lat": 35.123, "lon": 139.456}'
108 |
109 | # Sample EXIF data
110 | exif_data = {
111 | "camera": "Camera1",
112 | "make": "Test Make",
113 | "model": "Test Model",
114 | "width": 1920,
115 | "height": 1080,
116 | "gps": {"lat": 35.123, "lon": 139.456}
117 | }
118 |
119 | # Call the method
120 | setup_images_tab.display_exif_data(exif_data)
121 |
122 | # Check that table was populated
123 | assert setup_images_tab.exif_table.setRowCount.call_count >= 1
124 | assert setup_images_tab.exif_table.insertRow.call_count > 0
125 | assert setup_images_tab.exif_table.setItem.call_count > 0
126 |
127 | # Check that overrides were applied (should be using focal_ratio from mock_camera_model_manager)
128 | mock_camera_model_manager.get_camera_models.assert_called_once()
129 |
130 | @patch('app.tabs.images_tab.QMessageBox')
131 | def test_open_camera_model_editor(mock_msgbox, setup_images_tab, mock_camera_model_manager):
132 | """Test open_camera_model_editor method"""
133 | # Setup
134 | setup_images_tab.camera_model_manager = mock_camera_model_manager
135 |
136 | # Call the method
137 | setup_images_tab.open_camera_model_editor()
138 |
139 | # Check that editor was opened
140 | mock_camera_model_manager.open_camera_model_editor.assert_called_once_with(parent=setup_images_tab)
141 |
142 | # Test error case
143 | setup_images_tab.camera_model_manager = None
144 | setup_images_tab.open_camera_model_editor()
145 | mock_msgbox.warning.assert_called_once()
146 |
147 | @patch('app.tabs.images_tab.ResolutionDialog')
148 | @patch('app.tabs.images_tab.ExifExtractProgressDialog')
149 | @patch('app.tabs.images_tab.QMessageBox')
150 | def test_resize_images_in_folder(mock_msgbox, mock_progress_dialog, mock_resolution_dialog,
151 | setup_images_tab, mock_image_processor):
152 | """Test resize_images_in_folder method"""
153 | # Setup
154 | setup_images_tab.image_processor = mock_image_processor
155 |
156 | # Mock dialog
157 | mock_dialog = MagicMock()
158 | mock_dialog.exec_.return_value = QDialog.Accepted
159 | mock_dialog.get_values.return_value = ("Percentage (%)", 50)
160 | mock_resolution_dialog.return_value = mock_dialog
161 |
162 | # Mock progress dialog
163 | mock_progress = MagicMock()
164 | mock_progress_dialog.return_value = mock_progress
165 |
166 | # Call the method
167 | setup_images_tab.resize_images_in_folder()
168 |
169 | # Check that dialog was shown and resize was called
170 | mock_resolution_dialog.assert_called_once_with(1920, 1080, parent=setup_images_tab)
171 | mock_image_processor.resize_images.assert_called_once_with("Percentage (%)", 50)
172 | mock_msgbox.information.assert_called_once()
173 |
174 | # Test dialog rejected
175 | mock_dialog.exec_.return_value = QDialog.Rejected
176 | setup_images_tab.resize_images_in_folder()
177 | assert mock_image_processor.resize_images.call_count == 1 # Still just one call
178 |
179 | # Test error case
180 | setup_images_tab.image_processor = None
181 | setup_images_tab.resize_images_in_folder()
182 | assert mock_msgbox.warning.call_count == 1
183 |
184 | @patch('app.tabs.images_tab.QMessageBox')
185 | def test_restore_original_images(mock_msgbox, setup_images_tab, mock_image_processor):
186 | """Test restore_original_images method"""
187 | # Setup
188 | setup_images_tab.image_processor = mock_image_processor
189 |
190 | # Call the method
191 | setup_images_tab.restore_original_images()
192 |
193 | # Check that restore was called
194 | mock_image_processor.restore_original_images.assert_called_once()
195 | mock_msgbox.information.assert_called_once()
196 |
197 | # Test error case
198 | mock_image_processor.restore_original_images.return_value = False
199 | setup_images_tab.restore_original_images()
200 | assert mock_msgbox.warning.call_count == 1
201 |
202 | # Test no image processor
203 | setup_images_tab.image_processor = None
204 | setup_images_tab.restore_original_images()
205 | assert mock_msgbox.warning.call_count == 2
206 |
--------------------------------------------------------------------------------
/tests/test_camera_models.py:
--------------------------------------------------------------------------------
1 | # test_camera_models.py
2 |
3 | import os
4 | import json
5 | import pytest
6 | from unittest.mock import MagicMock, patch
7 | from PyQt5.QtWidgets import QTableWidgetItem, QComboBox
8 | from app.camera_models import CameraModelManager, CameraModelEditor
9 |
10 | @pytest.fixture
11 | def sample_camera_models():
12 | """Sample camera models for testing"""
13 | return {
14 | "Camera1": {
15 | "projection_type": "perspective",
16 | "width": 1920,
17 | "height": 1080,
18 | "focal_ratio": 1.2
19 | },
20 | "Camera2": {
21 | "projection_type": "spherical",
22 | "width": 3840,
23 | "height": 2160,
24 | "focal_ratio": 1.0
25 | }
26 | }
27 |
28 | @pytest.fixture
29 | def setup_workdir(temp_workdir, sample_camera_models):
30 | """Set up a temporary working directory with camera model files"""
31 | # Create camera_models.json
32 | camera_models_path = os.path.join(temp_workdir, "camera_models.json")
33 | with open(camera_models_path, "w") as f:
34 | json.dump(sample_camera_models, f)
35 |
36 | # Create camera_models_overrides.json with some overrides
37 | overrides = {
38 | "Camera1": {
39 | "focal_ratio": 1.5 # Override just one parameter
40 | }
41 | }
42 | overrides_path = os.path.join(temp_workdir, "camera_models_overrides.json")
43 | with open(overrides_path, "w") as f:
44 | json.dump(overrides, f)
45 |
46 | return temp_workdir
47 |
48 |
49 | def test_camera_model_manager_init(setup_workdir, sample_camera_models):
50 | """Test CameraModelManager initialization and loading"""
51 | manager = CameraModelManager(setup_workdir)
52 |
53 | # Check that models are loaded
54 | assert manager.camera_models is not None
55 |
56 | # Check that Camera1's focal_ratio was overridden
57 | assert manager.camera_models["Camera1"]["focal_ratio"] == 1.5
58 |
59 | # Check that other parameters are preserved
60 | assert manager.camera_models["Camera1"]["projection_type"] == "perspective"
61 | assert manager.camera_models["Camera2"]["projection_type"] == "spherical"
62 |
63 |
64 | def test_camera_model_manager_get_models(setup_workdir):
65 | """Test get_camera_models method"""
66 | manager = CameraModelManager(setup_workdir)
67 | models = manager.get_camera_models()
68 |
69 | assert "Camera1" in models
70 | assert "Camera2" in models
71 | assert models["Camera1"]["focal_ratio"] == 1.5 # Overridden value
72 |
73 |
74 | @patch('app.camera_models.CameraModelEditor')
75 | def test_camera_model_manager_open_editor(mock_editor, setup_workdir):
76 | """Test open_camera_model_editor method"""
77 | # Set up the mock to return True from exec_()
78 | editor_instance = MagicMock()
79 | editor_instance.exec_.return_value = True
80 | mock_editor.return_value = editor_instance
81 |
82 | manager = CameraModelManager(setup_workdir)
83 | parent = MagicMock()
84 |
85 | # Call the method
86 | result = manager.open_camera_model_editor(parent)
87 |
88 | # Check that the editor was created with correct parameters
89 | mock_editor.assert_called_once_with(manager.camera_models, setup_workdir, parent=parent)
90 | assert result is True
91 |
92 | # Now test with no workdir
93 | manager.workdir = None
94 | result = manager.open_camera_model_editor(parent)
95 | assert result is False
96 | # QMessageBox warning should be called, but we can't test that directly here
97 |
98 |
99 | @patch('app.camera_models.QTableWidgetItem')
100 | @patch('app.camera_models.QComboBox')
101 | def test_camera_model_editor_load_models(mock_combo, mock_item, setup_workdir, sample_camera_models):
102 | """Test CameraModelEditor.load_camera_models"""
103 | # Create mock for QTableWidgetItem
104 | mock_item_instance = MagicMock()
105 | mock_item.return_value = mock_item_instance
106 |
107 | # Create mock for QComboBox
108 | mock_combo_instance = MagicMock()
109 | mock_combo.return_value = mock_combo_instance
110 |
111 | # Create mock for the table
112 | mock_table = MagicMock()
113 |
114 | # Create the editor with mocked table
115 | editor = CameraModelEditor(sample_camera_models, setup_workdir)
116 | editor.table = mock_table
117 |
118 | # Call the method
119 | editor.load_camera_models()
120 |
121 | # Check that rows were added to the table
122 | assert mock_table.setRowCount.call_count >= 1
123 | assert mock_table.insertRow.call_count > 0
124 | assert mock_table.setItem.call_count > 0
125 |
126 | # Check that combo boxes were added for projection_type
127 | assert mock_combo.call_count > 0
128 |
129 |
130 | @patch('app.camera_models.json.dump')
131 | @patch('app.camera_models.QMessageBox')
132 | def test_camera_model_editor_save_changes(mock_msgbox, mock_json_dump, setup_workdir, sample_camera_models):
133 | """Test CameraModelEditor.save_changes"""
134 | editor = CameraModelEditor(sample_camera_models, setup_workdir)
135 |
136 | # Create mock table with data
137 | editor.table = MagicMock()
138 | editor.table.rowCount.return_value = 3
139 |
140 | # Set up the table to return different types of cells
141 | def mock_get_item(row, col):
142 | # For the first row
143 | if row == 0:
144 | if col == 0: # Key column
145 | item = MagicMock(spec=QTableWidgetItem)
146 | item.text.return_value = "Camera1"
147 | return item
148 | elif col == 1: # Parameter column
149 | item = MagicMock(spec=QTableWidgetItem)
150 | item.text.return_value = "focal_ratio"
151 | return item
152 | elif col == 2: # Value column
153 | item = MagicMock(spec=QTableWidgetItem)
154 | item.text.return_value = "1.8" # Float value
155 | return item
156 | # For the second row
157 | elif row == 1:
158 | if col == 0:
159 | item = MagicMock(spec=QTableWidgetItem)
160 | item.text.return_value = "Camera1"
161 | return item
162 | elif col == 1:
163 | item = MagicMock(spec=QTableWidgetItem)
164 | item.text.return_value = "width"
165 | return item
166 | elif col == 2:
167 | item = MagicMock(spec=QTableWidgetItem)
168 | item.text.return_value = "1920" # Integer value
169 | return item
170 | # For the third row with ComboBox
171 | elif row == 2:
172 | if col == 0:
173 | item = MagicMock(spec=QTableWidgetItem)
174 | item.text.return_value = "Camera1"
175 | return item
176 | elif col == 1:
177 | item = MagicMock(spec=QTableWidgetItem)
178 | item.text.return_value = "projection_type"
179 | return item
180 | elif col == 2:
181 | # Return None for cell with widget
182 | return None
183 | return None
184 |
185 | editor.table.item.side_effect = mock_get_item
186 |
187 | # Mock cellWidget for the third row, returning a combo box
188 | def mock_cell_widget(row, col):
189 | if row == 2 and col == 2:
190 | combo = MagicMock(spec=QComboBox)
191 | combo.currentText.return_value = "spherical"
192 | return combo
193 | return None
194 |
195 | editor.table.cellWidget.side_effect = mock_cell_widget
196 |
197 | # Call the method
198 | editor.save_changes()
199 |
200 | # Check that json.dump was called with the expected values
201 | assert mock_json_dump.call_count == 1
202 |
203 | # Get the arguments passed to json.dump
204 | args, kwargs = mock_json_dump.call_args
205 | updated_models = args[0]
206 |
207 | # Check the updated models
208 | assert "Camera1" in updated_models
209 | assert updated_models["Camera1"]["focal_ratio"] == 1.8
210 | assert updated_models["Camera1"]["width"] == 1920
211 | assert updated_models["Camera1"]["projection_type"] == "spherical"
212 |
213 | # Check that success message was shown
214 | assert mock_msgbox.information.call_count == 1
215 |
--------------------------------------------------------------------------------
/utils/gsplat_utils/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch
5 | from sklearn.neighbors import NearestNeighbors
6 | from torch import Tensor
7 | import torch.nn.functional as F
8 | import matplotlib.pyplot as plt
9 | from matplotlib import colormaps
10 |
11 |
12 | class CameraOptModule(torch.nn.Module):
13 | """Camera pose optimization module."""
14 |
15 | def __init__(self, n: int):
16 | super().__init__()
17 | # Delta positions (3D) + Delta rotations (6D)
18 | self.embeds = torch.nn.Embedding(n, 9)
19 | # Identity rotation in 6D representation
20 | self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]))
21 |
22 | def zero_init(self):
23 | torch.nn.init.zeros_(self.embeds.weight)
24 |
25 | def random_init(self, std: float):
26 | torch.nn.init.normal_(self.embeds.weight, std=std)
27 |
28 | def forward(self, camtoworlds: Tensor, embed_ids: Tensor) -> Tensor:
29 | """Adjust camera pose based on deltas.
30 |
31 | Args:
32 | camtoworlds: (..., 4, 4)
33 | embed_ids: (...,)
34 |
35 | Returns:
36 | updated camtoworlds: (..., 4, 4)
37 | """
38 | assert camtoworlds.shape[:-2] == embed_ids.shape
39 | batch_shape = camtoworlds.shape[:-2]
40 | pose_deltas = self.embeds(embed_ids) # (..., 9)
41 | dx, drot = pose_deltas[..., :3], pose_deltas[..., 3:]
42 | rot = rotation_6d_to_matrix(
43 | drot + self.identity.expand(*batch_shape, -1)
44 | ) # (..., 3, 3)
45 | transform = torch.eye(4, device=pose_deltas.device).repeat((*batch_shape, 1, 1))
46 | transform[..., :3, :3] = rot
47 | transform[..., :3, 3] = dx
48 | return torch.matmul(camtoworlds, transform)
49 |
50 |
51 | class AppearanceOptModule(torch.nn.Module):
52 | """Appearance optimization module."""
53 |
54 | def __init__(
55 | self,
56 | n: int,
57 | feature_dim: int,
58 | embed_dim: int = 16,
59 | sh_degree: int = 3,
60 | mlp_width: int = 64,
61 | mlp_depth: int = 2,
62 | ):
63 | super().__init__()
64 | self.embed_dim = embed_dim
65 | self.sh_degree = sh_degree
66 | self.embeds = torch.nn.Embedding(n, embed_dim)
67 | layers = []
68 | layers.append(
69 | torch.nn.Linear(embed_dim + feature_dim + (sh_degree + 1) ** 2, mlp_width)
70 | )
71 | layers.append(torch.nn.ReLU(inplace=True))
72 | for _ in range(mlp_depth - 1):
73 | layers.append(torch.nn.Linear(mlp_width, mlp_width))
74 | layers.append(torch.nn.ReLU(inplace=True))
75 | layers.append(torch.nn.Linear(mlp_width, 3))
76 | self.color_head = torch.nn.Sequential(*layers)
77 |
78 | def forward(
79 | self, features: Tensor, embed_ids: Tensor, dirs: Tensor, sh_degree: int
80 | ) -> Tensor:
81 | """Adjust appearance based on embeddings.
82 |
83 | Args:
84 | features: (N, feature_dim)
85 | embed_ids: (C,)
86 | dirs: (C, N, 3)
87 |
88 | Returns:
89 | colors: (C, N, 3)
90 | """
91 | from gsplat.cuda._torch_impl import _eval_sh_bases_fast
92 |
93 | C, N = dirs.shape[:2]
94 | # Camera embeddings
95 | if embed_ids is None:
96 | embeds = torch.zeros(C, self.embed_dim, device=features.device)
97 | else:
98 | embeds = self.embeds(embed_ids) # [C, D2]
99 | embeds = embeds[:, None, :].expand(-1, N, -1) # [C, N, D2]
100 | # GS features
101 | features = features[None, :, :].expand(C, -1, -1) # [C, N, D1]
102 | # View directions
103 | dirs = F.normalize(dirs, dim=-1) # [C, N, 3]
104 | num_bases_to_use = (sh_degree + 1) ** 2
105 | num_bases = (self.sh_degree + 1) ** 2
106 | sh_bases = torch.zeros(C, N, num_bases, device=features.device) # [C, N, K]
107 | sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs)
108 | # Get colors
109 | if self.embed_dim > 0:
110 | h = torch.cat([embeds, features, sh_bases], dim=-1) # [C, N, D1 + D2 + K]
111 | else:
112 | h = torch.cat([features, sh_bases], dim=-1)
113 | colors = self.color_head(h)
114 | return colors
115 |
116 |
117 | def rotation_6d_to_matrix(d6: Tensor) -> Tensor:
118 | """
119 | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
120 | using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d.
121 | Args:
122 | d6: 6D rotation representation, of size (*, 6)
123 |
124 | Returns:
125 | batch of rotation matrices of size (*, 3, 3)
126 |
127 | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
128 | On the Continuity of Rotation Representations in Neural Networks.
129 | IEEE Conference on Computer Vision and Pattern Recognition, 2019.
130 | Retrieved from http://arxiv.org/abs/1812.07035
131 | """
132 |
133 | a1, a2 = d6[..., :3], d6[..., 3:]
134 | b1 = F.normalize(a1, dim=-1)
135 | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
136 | b2 = F.normalize(b2, dim=-1)
137 | b3 = torch.cross(b1, b2, dim=-1)
138 | return torch.stack((b1, b2, b3), dim=-2)
139 |
140 |
141 | def knn(x: Tensor, K: int = 4) -> Tensor:
142 | x_np = x.cpu().numpy()
143 | model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np)
144 | distances, _ = model.kneighbors(x_np)
145 | return torch.from_numpy(distances).to(x)
146 |
147 |
148 | def rgb_to_sh(rgb: Tensor) -> Tensor:
149 | C0 = 0.28209479177387814
150 | return (rgb - 0.5) / C0
151 |
152 |
153 | def set_random_seed(seed: int):
154 | random.seed(seed)
155 | np.random.seed(seed)
156 | torch.manual_seed(seed)
157 |
158 |
159 | # ref: https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/general_utils.py#L163
160 | def colormap(img, cmap="jet"):
161 | W, H = img.shape[:2]
162 | dpi = 300
163 | fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi)
164 | im = ax.imshow(img, cmap=cmap)
165 | ax.set_axis_off()
166 | fig.colorbar(im, ax=ax)
167 | fig.tight_layout()
168 | fig.canvas.draw()
169 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
170 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
171 | img = torch.from_numpy(data).float().permute(2, 0, 1)
172 | plt.close()
173 | return img
174 |
175 |
176 | def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor:
177 | """Convert single channel to a color img.
178 |
179 | Args:
180 | img (torch.Tensor): (..., 1) float32 single channel image.
181 | colormap (str): Colormap for img.
182 |
183 | Returns:
184 | (..., 3) colored img with colors in [0, 1].
185 | """
186 | img = torch.nan_to_num(img, 0)
187 | if colormap == "gray":
188 | return img.repeat(1, 1, 3)
189 | img_long = (img * 255).long()
190 | img_long_min = torch.min(img_long)
191 | img_long_max = torch.max(img_long)
192 | assert img_long_min >= 0, f"the min value is {img_long_min}"
193 | assert img_long_max <= 255, f"the max value is {img_long_max}"
194 | return torch.tensor(
195 | colormaps[colormap].colors, # type: ignore
196 | device=img.device,
197 | )[img_long[..., 0]]
198 |
199 |
200 | def apply_depth_colormap(
201 | depth: torch.Tensor,
202 | acc: torch.Tensor = None,
203 | near_plane: float = None,
204 | far_plane: float = None,
205 | ) -> torch.Tensor:
206 | """Converts a depth image to color for easier analysis.
207 |
208 | Args:
209 | depth (torch.Tensor): (..., 1) float32 depth.
210 | acc (torch.Tensor | None): (..., 1) optional accumulation mask.
211 | near_plane: Closest depth to consider. If None, use min image value.
212 | far_plane: Furthest depth to consider. If None, use max image value.
213 |
214 | Returns:
215 | (..., 3) colored depth image with colors in [0, 1].
216 | """
217 | near_plane = near_plane or float(torch.min(depth))
218 | far_plane = far_plane or float(torch.max(depth))
219 | depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)
220 | depth = torch.clip(depth, 0.0, 1.0)
221 | img = apply_float_colormap(depth, colormap="turbo")
222 | if acc is not None:
223 | img = img * acc + (1.0 - acc)
224 | return img
225 |
--------------------------------------------------------------------------------
/tests/test_image_processing.py:
--------------------------------------------------------------------------------
1 | # test_image_processing.py
2 |
3 | import os
4 | import json
5 | import pytest
6 | import shutil
7 | from unittest.mock import MagicMock, patch
8 | from PIL import Image
9 | from app.image_processing import ImageProcessor, ResolutionDialog, ExifExtractProgressDialog
10 |
11 | @pytest.fixture
12 | def setup_image_folders(temp_workdir):
13 | """Set up image folders in temporary directory"""
14 | # Create necessary directories
15 | images_dir = os.path.join(temp_workdir, "images")
16 | exif_dir = os.path.join(temp_workdir, "exif")
17 | os.makedirs(images_dir, exist_ok=True)
18 | os.makedirs(exif_dir, exist_ok=True)
19 |
20 | # Create a test image
21 | test_img_path = os.path.join(images_dir, "test_image.jpg")
22 | img = Image.new('RGB', (100, 50), color = 'red')
23 | img.save(test_img_path)
24 |
25 | # Create another test image
26 | test_img2_path = os.path.join(images_dir, "test_image2.jpg")
27 | img2 = Image.new('RGB', (200, 150), color = 'blue')
28 | img2.save(test_img2_path)
29 |
30 | return temp_workdir
31 |
32 |
33 | def test_image_processor_init(setup_image_folders):
34 | """Test ImageProcessor initialization"""
35 | processor = ImageProcessor(setup_image_folders)
36 |
37 | assert processor.workdir == setup_image_folders
38 | assert os.path.exists(processor.images_folder)
39 | assert processor.images_folder == os.path.join(setup_image_folders, "images")
40 | assert processor.exif_folder == os.path.join(setup_image_folders, "exif")
41 |
42 |
43 | def test_get_sample_image_dimensions(setup_image_folders):
44 | """Test get_sample_image_dimensions method"""
45 | processor = ImageProcessor(setup_image_folders)
46 | width, height = processor.get_sample_image_dimensions()
47 |
48 | # Could be either of the test images, so check both possibilities
49 | assert (width == 100 and height == 50) or (width == 200 and height == 150)
50 |
51 | # Test with empty folder
52 | images_dir = processor.images_folder
53 | for file in os.listdir(images_dir):
54 | if file.endswith('.jpg'):
55 | os.remove(os.path.join(images_dir, file))
56 |
57 | # Now no images should be found
58 | width, height = processor.get_sample_image_dimensions()
59 | assert width is None and height is None
60 |
61 |
62 | def test_resize_images(setup_image_folders):
63 | """Test resize_images method"""
64 | processor = ImageProcessor(setup_image_folders)
65 |
66 | # Get original dimensions of test image
67 | test_img_path = os.path.join(processor.images_folder, "test_image.jpg")
68 | with Image.open(test_img_path) as img:
69 | original_width = img.width
70 | original_height = img.height
71 |
72 | # Resize to 50%
73 | mock_callback = MagicMock()
74 | num_processed = processor.resize_images("Percentage (%)", 50, mock_callback)
75 |
76 | # Check that the progress callback was called
77 | assert mock_callback.call_count > 0
78 |
79 | # Check that the image was resized
80 | with Image.open(test_img_path) as img:
81 | assert img.width == original_width // 2
82 | assert img.height == original_height // 2
83 |
84 | # Check that backup folder was created
85 | assert os.path.exists(processor.images_org_folder)
86 |
87 | # Test width-based resizing
88 | processor.resize_images("Width (px)", 75)
89 | with Image.open(test_img_path) as img:
90 | assert img.width == 75
91 |
92 | # Test height-based resizing
93 | processor.resize_images("Height (px)", 30)
94 | with Image.open(test_img_path) as img:
95 | assert img.height == 30
96 |
97 |
98 | def test_restore_original_images(setup_image_folders):
99 | """Test restore_original_images method"""
100 | processor = ImageProcessor(setup_image_folders)
101 |
102 | # First resize to create a backup
103 | processor.resize_images("Percentage (%)", 50)
104 |
105 | # Get current dimensions after resize
106 | test_img_path = os.path.join(processor.images_folder, "test_image.jpg")
107 | with Image.open(test_img_path) as img:
108 | resized_width = img.width
109 | resized_height = img.height
110 |
111 | # Now restore
112 | result = processor.restore_original_images()
113 | assert result is True
114 |
115 | # Check that the image was restored to original dimensions
116 | with Image.open(test_img_path) as img:
117 | assert img.width != resized_width # Should not be the resized dimensions
118 | assert img.height != resized_height
119 |
120 | # Test when no backup exists
121 | shutil.rmtree(processor.images_org_folder)
122 | result = processor.restore_original_images()
123 | assert result is False
124 |
125 |
126 | @patch('app.image_processing.piexif.load')
127 | @patch('app.image_processing.piexif.dump')
128 | @patch('app.image_processing.fractions.Fraction')
129 | @patch('app.image_processing.Image.open')
130 | def test_apply_exif_from_mapillary_json(mock_image_open, mock_fraction, mock_dump, mock_load, setup_image_folders, tmp_path):
131 | """Test apply_exif_from_mapillary_json method"""
132 | # Set up mock Image
133 | mock_img = MagicMock()
134 | mock_img.info = {}
135 | mock_image_open.return_value.__enter__.return_value = mock_img
136 |
137 | # Set up mock fractions.Fraction
138 | mock_fraction.return_value.numerator = 1
139 | mock_fraction.return_value.denominator = 2
140 |
141 | # Set up mock load
142 | mock_load.side_effect = KeyError # Force new exif_dict creation
143 |
144 | # Set up mock dump
145 | mock_dump.return_value = b'fake exif data'
146 |
147 | # Create a test JSON file with Mapillary data
148 | json_path = os.path.join(tmp_path, "mapillary_image_description.json")
149 | images_dir = os.path.join(tmp_path, "video_frames")
150 | os.makedirs(images_dir, exist_ok=True)
151 |
152 | # Create a test image
153 | test_img_path = os.path.join(images_dir, "frame_0001.jpg")
154 | with open(test_img_path, 'w') as f:
155 | f.write("fake image data")
156 |
157 | mapillary_data = [
158 | {
159 | "filename": "frame_0001.jpg",
160 | "MAPLatitude": 35.6895,
161 | "MAPLongitude": 139.6917,
162 | "MAPAltitude": 10.5,
163 | "MAPCaptureTime": "2023_01_01_12_30_45_000",
164 | "MAPOrientation": 1
165 | }
166 | ]
167 |
168 | with open(json_path, 'w') as f:
169 | json.dump(mapillary_data, f)
170 |
171 | # Create the processor and call the method
172 | processor = ImageProcessor(setup_image_folders)
173 | processed_count = processor.apply_exif_from_mapillary_json(json_path, images_dir)
174 |
175 | # Check results
176 | assert processed_count == 1
177 | assert mock_image_open.call_count == 1
178 | assert mock_dump.call_count == 1
179 | assert not os.path.exists(images_dir) # Should be renamed
180 | assert os.path.exists(os.path.join(tmp_path, "images")) # New folder name
181 |
182 |
183 | @patch('app.image_processing.QDialogButtonBox')
184 | def test_resolution_dialog(mock_dialog_button_box, mock_qapplication):
185 | """Test ResolutionDialog"""
186 | dialog = ResolutionDialog(1920, 1080)
187 |
188 | # Test initial values
189 | assert dialog.original_width == 1920
190 | assert dialog.original_height == 1080
191 | assert dialog.aspect_ratio == 1920 / 1080
192 |
193 | # Test update_label method
194 | dialog.resize_method_combo = MagicMock()
195 | dialog.value_input = MagicMock()
196 |
197 | # Test Percentage method
198 | dialog.resize_method_combo.currentText.return_value = "Percentage (%)"
199 | dialog.update_label()
200 | dialog.value_input.setText.assert_called_with("100")
201 |
202 | # Test Width method
203 | dialog.resize_method_combo.currentText.return_value = "Width (px)"
204 | dialog.update_label()
205 | dialog.value_input.setText.assert_called_with("1920")
206 |
207 | # Test Height method
208 | dialog.resize_method_combo.currentText.return_value = "Height (px)"
209 | dialog.update_label()
210 | dialog.value_input.setText.assert_called_with("1080")
211 |
212 | # Test get_values method
213 | dialog.resize_method_combo.currentText.return_value = "Percentage (%)"
214 | dialog.value_input.text.return_value = "75"
215 | method, value = dialog.get_values()
216 | assert method == "Percentage (%)"
217 | assert value == 75.0
218 |
219 |
220 | def test_exif_extract_progress_dialog(mock_qapplication):
221 | """Test ExifExtractProgressDialog"""
222 | # Just basic instantiation test
223 | dialog = ExifExtractProgressDialog("Processing...")
224 | assert dialog.windowTitle() == "Extracting EXIF Data"
225 |
226 | # Test with default message
227 | dialog = ExifExtractProgressDialog()
228 | assert dialog.isModal() is True
229 |
--------------------------------------------------------------------------------
/utils/datasets/traj.py:
--------------------------------------------------------------------------------
1 | """
2 | Code borrowed from
3 |
4 | https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/camera_utils.py
5 | """
6 |
7 | import numpy as np
8 | import scipy
9 |
10 |
11 | def normalize(x: np.ndarray) -> np.ndarray:
12 | """Normalization helper function."""
13 | return x / np.linalg.norm(x)
14 |
15 |
16 | def viewmatrix(lookdir: np.ndarray, up: np.ndarray, position: np.ndarray) -> np.ndarray:
17 | """Construct lookat view matrix."""
18 | vec2 = normalize(lookdir)
19 | vec0 = normalize(np.cross(up, vec2))
20 | vec1 = normalize(np.cross(vec2, vec0))
21 | m = np.stack([vec0, vec1, vec2, position], axis=1)
22 | return m
23 |
24 |
25 | def focus_point_fn(poses: np.ndarray) -> np.ndarray:
26 | """Calculate nearest point to all focal axes in poses."""
27 | directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
28 | m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
29 | mt_m = np.transpose(m, [0, 2, 1]) @ m
30 | focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
31 | return focus_pt
32 |
33 |
34 | def average_pose(poses: np.ndarray) -> np.ndarray:
35 | """New pose using average position, z-axis, and up vector of input poses."""
36 | position = poses[:, :3, 3].mean(0)
37 | z_axis = poses[:, :3, 2].mean(0)
38 | up = poses[:, :3, 1].mean(0)
39 | cam2world = viewmatrix(z_axis, up, position)
40 | return cam2world
41 |
42 |
43 | def generate_spiral_path(
44 | poses,
45 | bounds,
46 | n_frames=120,
47 | n_rots=2,
48 | zrate=0.5,
49 | spiral_scale_f=1.0,
50 | spiral_scale_r=1.0,
51 | focus_distance=0.75,
52 | ):
53 | """Calculates a forward facing spiral path for rendering."""
54 | # Find a reasonable 'focus depth' for this dataset as a weighted average
55 | # of conservative near and far bounds in disparity space.
56 | near_bound = bounds.min()
57 | far_bound = bounds.max()
58 | # All cameras will point towards the world space point (0, 0, -focal).
59 | focal = 1 / (((1 - focus_distance) / near_bound + focus_distance / far_bound))
60 | focal = focal * spiral_scale_f
61 |
62 | # Get radii for spiral path using 90th percentile of camera positions.
63 | positions = poses[:, :3, 3]
64 | radii = np.percentile(np.abs(positions), 90, 0)
65 | radii = radii * spiral_scale_r
66 | radii = np.concatenate([radii, [1.0]])
67 |
68 | # Generate poses for spiral path.
69 | render_poses = []
70 | cam2world = average_pose(poses)
71 | up = poses[:, :3, 1].mean(0)
72 | for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_frames, endpoint=False):
73 | t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]
74 | position = cam2world @ t
75 | lookat = cam2world @ [0, 0, -focal, 1.0]
76 | z_axis = position - lookat
77 | render_poses.append(viewmatrix(z_axis, up, position))
78 | render_poses = np.stack(render_poses, axis=0)
79 | return render_poses
80 |
81 |
82 | def generate_ellipse_path_z(
83 | poses: np.ndarray,
84 | n_frames: int = 120,
85 | # const_speed: bool = True,
86 | variation: float = 0.0,
87 | phase: float = 0.0,
88 | height: float = 0.0,
89 | ) -> np.ndarray:
90 | """Generate an elliptical render path based on the given poses."""
91 | # Calculate the focal point for the path (cameras point toward this).
92 | center = focus_point_fn(poses)
93 | # Path height sits at z=height (in middle of zero-mean capture pattern).
94 | offset = np.array([center[0], center[1], height])
95 |
96 | # Calculate scaling for ellipse axes based on input camera positions.
97 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
98 | # Use ellipse that is symmetric about the focal point in xy.
99 | low = -sc + offset
100 | high = sc + offset
101 | # Optional height variation need not be symmetric
102 | z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
103 | z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
104 |
105 | def get_positions(theta):
106 | # Interpolate between bounds with trig functions to get ellipse in x-y.
107 | # Optionally also interpolate in z to change camera height along path.
108 | return np.stack(
109 | [
110 | low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5),
111 | low[1] + (high - low)[1] * (np.sin(theta) * 0.5 + 0.5),
112 | variation
113 | * (
114 | z_low[2]
115 | + (z_high - z_low)[2]
116 | * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5)
117 | )
118 | + height,
119 | ],
120 | -1,
121 | )
122 |
123 | theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True)
124 | positions = get_positions(theta)
125 |
126 | # if const_speed:
127 | # # Resample theta angles so that the velocity is closer to constant.
128 | # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
129 | # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1)
130 | # positions = get_positions(theta)
131 |
132 | # Throw away duplicated last position.
133 | positions = positions[:-1]
134 |
135 | # Set path's up vector to axis closest to average of input pose up vectors.
136 | avg_up = poses[:, :3, 1].mean(0)
137 | avg_up = avg_up / np.linalg.norm(avg_up)
138 | ind_up = np.argmax(np.abs(avg_up))
139 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
140 |
141 | return np.stack([viewmatrix(center - p, up, p) for p in positions])
142 |
143 |
144 | def generate_ellipse_path_y(
145 | poses: np.ndarray,
146 | n_frames: int = 120,
147 | # const_speed: bool = True,
148 | variation: float = 0.0,
149 | phase: float = 0.0,
150 | height: float = 0.0,
151 | ) -> np.ndarray:
152 | """Generate an elliptical render path based on the given poses."""
153 | # Calculate the focal point for the path (cameras point toward this).
154 | center = focus_point_fn(poses)
155 | # Path height sits at y=height (in middle of zero-mean capture pattern).
156 | offset = np.array([center[0], height, center[2]])
157 |
158 | # Calculate scaling for ellipse axes based on input camera positions.
159 | sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
160 | # Use ellipse that is symmetric about the focal point in xy.
161 | low = -sc + offset
162 | high = sc + offset
163 | # Optional height variation need not be symmetric
164 | y_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
165 | y_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
166 |
167 | def get_positions(theta):
168 | # Interpolate between bounds with trig functions to get ellipse in x-z.
169 | # Optionally also interpolate in y to change camera height along path.
170 | return np.stack(
171 | [
172 | low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5),
173 | variation
174 | * (
175 | y_low[1]
176 | + (y_high - y_low)[1]
177 | * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5)
178 | )
179 | + height,
180 | low[2] + (high - low)[2] * (np.sin(theta) * 0.5 + 0.5),
181 | ],
182 | -1,
183 | )
184 |
185 | theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True)
186 | positions = get_positions(theta)
187 |
188 | # if const_speed:
189 | # # Resample theta angles so that the velocity is closer to constant.
190 | # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
191 | # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1)
192 | # positions = get_positions(theta)
193 |
194 | # Throw away duplicated last position.
195 | positions = positions[:-1]
196 |
197 | # Set path's up vector to axis closest to average of input pose up vectors.
198 | avg_up = poses[:, :3, 1].mean(0)
199 | avg_up = avg_up / np.linalg.norm(avg_up)
200 | ind_up = np.argmax(np.abs(avg_up))
201 | up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
202 |
203 | return np.stack([viewmatrix(p - center, up, p) for p in positions])
204 |
205 |
206 | def generate_interpolated_path(
207 | poses: np.ndarray,
208 | n_interp: int,
209 | spline_degree: int = 5,
210 | smoothness: float = 0.03,
211 | rot_weight: float = 0.1,
212 | ):
213 | """Creates a smooth spline path between input keyframe camera poses.
214 |
215 | Spline is calculated with poses in format (position, lookat-point, up-point).
216 |
217 | Args:
218 | poses: (n, 3, 4) array of input pose keyframes.
219 | n_interp: returned path will have n_interp * (n - 1) total poses.
220 | spline_degree: polynomial degree of B-spline.
221 | smoothness: parameter for spline smoothing, 0 forces exact interpolation.
222 | rot_weight: relative weighting of rotation/translation in spline solve.
223 |
224 | Returns:
225 | Array of new camera poses with shape (n_interp * (n - 1), 3, 4).
226 | """
227 |
228 | def poses_to_points(poses, dist):
229 | """Converts from pose matrices to (position, lookat, up) format."""
230 | pos = poses[:, :3, -1]
231 | lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
232 | up = poses[:, :3, -1] + dist * poses[:, :3, 1]
233 | return np.stack([pos, lookat, up], 1)
234 |
235 | def points_to_poses(points):
236 | """Converts from (position, lookat, up) format to pose matrices."""
237 | return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points])
238 |
239 | def interp(points, n, k, s):
240 | """Runs multidimensional B-spline interpolation on the input points."""
241 | sh = points.shape
242 | pts = np.reshape(points, (sh[0], -1))
243 | k = min(k, sh[0] - 1)
244 | tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s)
245 | u = np.linspace(0, 1, n, endpoint=False)
246 | new_points = np.array(scipy.interpolate.splev(u, tck))
247 | new_points = np.reshape(new_points.T, (n, sh[1], sh[2]))
248 | return new_points
249 |
250 | points = poses_to_points(poses, dist=rot_weight)
251 | new_points = interp(
252 | points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness
253 | )
254 | return points_to_poses(new_points)
255 |
--------------------------------------------------------------------------------
/app/point_cloud_visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import json
4 | from PyQt5.QtWidgets import QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QMessageBox
5 | from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer
6 | from PyQt5.QtGui import QVector3D
7 | import pyqtgraph.opengl as gl
8 | import pyqtgraph as pg
9 | from opensfm import dataset
10 | from opensfm.actions import reconstruct, create_tracks
11 | from opensfm.reconstruction import ReconstructionAlgorithm
12 | from scipy.spatial.transform import Rotation
13 | from utils.logger import setup_logger
14 | import multiprocessing
15 | from multiprocessing import Process, Event
16 | import time
17 | logger = setup_logger()
18 |
19 | def safe_load_reconstruction(file_path: str, retries=3, delay=1):
20 | # If the file doesn't exist yet, just return None without logging errors
21 | if not os.path.exists(file_path):
22 | return None
23 |
24 | for attempt in range(retries):
25 | try:
26 | with open(file_path, 'r') as f:
27 | return json.load(f)
28 | except (FileNotFoundError, json.JSONDecodeError) as e:
29 | logger.warning(f"Attempt {attempt + 1}/{retries}: Error loading {file_path}: {e}")
30 | time.sleep(delay)
31 | logger.error(f"Failed to load reconstruction data from {file_path} after {retries} attempts.")
32 | return None
33 |
34 | class ReconstructionThread(QThread):
35 | finished = pyqtSignal()
36 | stopped = pyqtSignal()
37 |
38 | def __init__(self, dataset):
39 | super().__init__()
40 | self.dataset = dataset
41 | self._running = True
42 |
43 | def run(self):
44 | create_tracks.run_dataset(self.dataset)
45 |
46 | if not self._running:
47 | self.stopped.emit()
48 | return
49 |
50 | reconstruct.run_dataset(self.dataset, ReconstructionAlgorithm.INCREMENTAL)
51 |
52 | if not self._running:
53 | self.stopped.emit()
54 | return
55 |
56 | self.finished.emit()
57 |
58 | def stop(self):
59 | self._running = False
60 |
61 | class Reconstruction(QWidget):
62 | def __init__(self, workdir):
63 | super().__init__()
64 | self.workdir = workdir
65 | self.reconstruction_file = os.path.join(workdir, "reconstruction.json")
66 | self.dataset = dataset.DataSet(workdir)
67 | self.camera_size = 1.0
68 | self.last_mod_time = 0
69 | self.camera_items = []
70 | self.stop_event = multiprocessing.Event()
71 |
72 | self.setFocusPolicy(Qt.StrongFocus)
73 | self.setFocus()
74 |
75 | main_layout = QVBoxLayout(self)
76 |
77 | self.viewer = gl.GLViewWidget()
78 | self.viewer.setFocusPolicy(Qt.NoFocus)
79 | self.viewer.setCameraPosition(distance=50)
80 | main_layout.addWidget(self.viewer)
81 |
82 | button_layout = QHBoxLayout()
83 |
84 | self.run_button = QPushButton("Run Reconstruction")
85 | self.run_button.clicked.connect(self.start_reconstruction)
86 | button_layout.addWidget(self.run_button)
87 |
88 | self.stop_button = QPushButton("Stop Reconstruction")
89 | self.stop_button.clicked.connect(self.stop_reconstruction)
90 | self.stop_button.setEnabled(False)
91 | button_layout.addWidget(self.stop_button)
92 |
93 | self.config_button = QPushButton("Config")
94 | self.config_button.clicked.connect(self.configure_reconstruction)
95 | button_layout.addWidget(self.config_button)
96 |
97 | main_layout.addLayout(button_layout)
98 |
99 | self.timer = QTimer(self)
100 | self.timer.timeout.connect(self.check_for_updates)
101 | self.timer.start(10000)
102 |
103 | self.update_visualization()
104 |
105 | def start_reconstruction(self):
106 | reply = QMessageBox.question(
107 | self, 'Confirm', 'Start Reconstruction?',
108 | QMessageBox.Yes | QMessageBox.No
109 | )
110 | if reply == QMessageBox.Yes:
111 | self.run_button.setEnabled(False)
112 | self.stop_button.setEnabled(True)
113 | self.stop_event.clear()
114 | self.process = Process(target=self.run_reconstruction, args=(self.stop_event,))
115 | self.process.start()
116 |
117 | def run_reconstruction(self, stop_event):
118 | create_tracks.run_dataset(self.dataset)
119 | if stop_event.is_set():
120 | return
121 | reconstruct.run_dataset(self.dataset, ReconstructionAlgorithm.INCREMENTAL)
122 |
123 | def stop_reconstruction(self):
124 | if self.process and self.process.is_alive():
125 | reply = QMessageBox.question(
126 | self, 'Confirm', 'Stop Reconstruction?',
127 | QMessageBox.Yes | QMessageBox.No
128 | )
129 | if reply == QMessageBox.Yes:
130 | self.stop_event.set()
131 | self.process.join(timeout=10)
132 | if self.process.is_alive():
133 | self.process.terminate()
134 | self.process.join()
135 | QMessageBox.information(self, "Stopped", "Reconstruction stopped successfully.")
136 | self.run_button.setEnabled(True)
137 | self.stop_button.setEnabled(False)
138 |
139 | def on_reconstruction_finished(self):
140 | self.run_button.setEnabled(True)
141 | self.stop_button.setEnabled(False)
142 | self.update_visualization()
143 | QMessageBox.information(self, "Done", "Reconstruction completed successfully.")
144 |
145 | def on_reconstruction_stopped(self):
146 | self.run_button.setEnabled(True)
147 | self.stop_button.setEnabled(False)
148 | QMessageBox.information(self, "Stopped", "Reconstruction was stopped.")
149 |
150 | def configure_reconstruction(self):
151 | QMessageBox.information(self, "Config", "Configuration dialog placeholder.")
152 |
153 | def check_for_updates(self):
154 | # Check if the file exists before attempting to get its modification time
155 | if not os.path.exists(self.reconstruction_file):
156 | return
157 |
158 | try:
159 | mod_time = os.path.getmtime(self.reconstruction_file)
160 | if mod_time != getattr(self, 'last_mod_time', None):
161 | self.last_mod_time = mod_time
162 | self.update_visualization()
163 | except (FileNotFoundError, OSError) as e:
164 | # Handle any errors that might occur when checking the file
165 | logger.warning(f"Error checking reconstruction file: {e}")
166 |
167 | def update_visualization(self):
168 | self.viewer.clear()
169 | self.camera_items.clear()
170 |
171 | data = safe_load_reconstruction(self.reconstruction_file)
172 | if not data:
173 | self.show_placeholder()
174 | return
175 |
176 | for reconstruction in data:
177 | points = np.array([p["coordinates"] for p in reconstruction.get("points", {}).values()], dtype=float)
178 | colors = np.array([p["color"] for p in reconstruction.get("points", {}).values()], dtype=float) / 255.0
179 | if points.size:
180 | scatter = gl.GLScatterPlotItem(pos=points, color=colors, size=2)
181 | scatter.setGLOptions('translucent')
182 | self.viewer.addItem(scatter)
183 |
184 | for shot_name, shot in reconstruction.get("shots", {}).items():
185 | rotation = Rotation.from_rotvec(np.array(shot["rotation"])).as_matrix()
186 | translation = np.array(shot["translation"])
187 | position = -rotation.T @ translation
188 | cam_type = reconstruction["cameras"][shot["camera"]]["projection_type"]
189 | self.add_camera_visualization(shot_name, rotation, position, cam_type, self.camera_size)
190 |
191 | def closeEvent(self, event):
192 | self.stop_reconstruction()
193 | super().closeEvent(event)
194 |
195 | def add_camera_visualization(self, cam_name, R, t, cam_model, size=1.0):
196 | if cam_model in ["spherical", "equirectangular"]:
197 | sphere_mesh = pg.opengl.MeshData.sphere(rows=10, cols=20, radius=size)
198 | sphere = gl.GLMeshItem(meshdata=sphere_mesh, color=(1, 1, 1, 0.3), smooth=True)
199 | sphere.translate(*t)
200 | self.viewer.addItem(sphere)
201 | self.camera_items.append((cam_name, sphere, t, R, cam_model))
202 | else:
203 | frustum_size = size * 5
204 | vertices = np.array([
205 | [0, 0, 0],
206 | [1, 1, -2],
207 | [1, -1, -2],
208 | [-1, 1, -2],
209 | [-1, -1, -2],
210 | ]) * frustum_size
211 | vertices = vertices @ -R + t
212 |
213 | edges = [
214 | (0, 1), (0, 2), (0, 3), (0, 4),
215 | (1, 2), (2, 4), (4, 3), (3, 1)
216 | ]
217 |
218 | lines = []
219 | for s, e in edges:
220 | line = gl.GLLinePlotItem(pos=np.array([vertices[s], vertices[e]]), color=(1, 1, 1, 0.5), width=1)
221 | self.viewer.addItem(line)
222 | lines.append(line)
223 |
224 | self.camera_items.append((cam_name, lines, t, R, cam_model))
225 |
226 | def show_placeholder(self):
227 | sphere_mesh = pg.opengl.MeshData.sphere(rows=10, cols=20, radius=10)
228 | sphere = gl.GLMeshItem(meshdata=sphere_mesh, color=(0.5, 0.5, 0.5, 0.3), smooth=True)
229 | self.viewer.addItem(sphere)
230 |
231 | def move_to_camera(self, image_name):
232 | for name, _, pos, R, _ in self.camera_items:
233 | if name == image_name:
234 | elevation = np.degrees(np.arcsin(-R[2, 0]))
235 | azimuth = np.degrees(np.arctan2(R[1, 0], R[0, 0]))
236 | self.viewer.setCameraPosition(pos=QVector3D(*pos), distance=20, elevation=elevation, azimuth=azimuth)
237 | break
238 |
239 | def highlight_camera(self, cam_name):
240 | for name, item, _, _, model in self.camera_items:
241 | color = (1, 0, 0, 0.8) if name == cam_name else (1, 1, 1, 0.3)
242 | if model in ["spherical", "equirectangular"]:
243 | item.setColor(color)
244 | else:
245 | for line in item:
246 | line.setData(color=color)
247 |
248 | def on_camera_image_tree_click(self, image_name):
249 | self.highlight_camera(image_name)
250 |
251 | def on_camera_image_tree_double_click(self, image_name):
252 | self.move_to_camera(image_name)
253 |
254 | def keyPressEvent(self, event):
255 | if event.key() == Qt.Key_Plus:
256 | self.camera_size *= 1.1
257 | self.update_visualization()
258 | elif event.key() == Qt.Key_Minus:
259 | self.camera_size /= 1.1
260 | self.update_visualization()
261 | else:
262 | super().keyPressEvent(event)
263 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/app/image_processing.py:
--------------------------------------------------------------------------------
1 | # image_processing.py
2 |
3 | import os
4 | import json
5 | import shutil
6 | import datetime
7 | import fractions
8 | from PIL import Image
9 | import piexif
10 | from PyQt5.QtWidgets import QDialog, QVBoxLayout, QFormLayout, QLabel, QLineEdit, QComboBox, QDialogButtonBox, QCheckBox
11 | from PyQt5.QtCore import Qt
12 |
13 | class ResolutionDialog(QDialog):
14 | """Dialog for resizing images"""
15 | def __init__(self, current_width, current_height, parent=None):
16 | super().__init__(parent)
17 | self.setWindowTitle("Resize Images")
18 | self.setFixedSize(300, 150)
19 |
20 | self.original_width = current_width
21 | self.original_height = current_height
22 | self.aspect_ratio = current_width / current_height
23 |
24 | layout = QVBoxLayout()
25 |
26 | self.resize_method_combo = QComboBox()
27 | self.resize_method_combo.addItems(["Percentage (%)", "Width (px)", "Height (px)"])
28 | self.resize_method_combo.currentIndexChanged.connect(self.update_label)
29 | layout.addWidget(QLabel("Resize method:"))
30 | layout.addWidget(self.resize_method_combo)
31 |
32 | form_layout = QFormLayout()
33 | self.value_input = QLineEdit("100")
34 | form_layout.addRow(QLabel("Value:"), self.value_input)
35 |
36 | layout.addLayout(form_layout)
37 |
38 | button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
39 | button_box.accepted.connect(self.accept)
40 | button_box.rejected.connect(self.reject)
41 |
42 | layout.addWidget(button_box)
43 | self.setLayout(layout)
44 |
45 | def update_label(self):
46 | method = self.resize_method_combo.currentText()
47 | if method == "Percentage (%)":
48 | self.value_input.setText("100")
49 | elif method == "Width (px)":
50 | self.value_input.setText(str(self.original_width))
51 | else:
52 | self.value_input.setText(str(self.original_height))
53 |
54 | def get_values(self):
55 | method = self.resize_method_combo.currentText()
56 | try:
57 | value = float(self.value_input.text())
58 | except (ValueError, TypeError):
59 | value = 100.0 if method == "Percentage (%)" else (self.original_width if method == "Width (px)" else self.original_height)
60 | return method, value
61 |
62 |
63 | class ExifExtractProgressDialog(QDialog):
64 | """Progress dialog for EXIF extraction"""
65 | def __init__(self, message="Please Wait...", parent=None):
66 | super().__init__(parent)
67 | self.setWindowTitle("Extracting EXIF Data")
68 | self.setModal(True)
69 | self.setFixedSize(520, 120)
70 |
71 | layout = QVBoxLayout()
72 | label = QLabel(message)
73 | label.setAlignment(Qt.AlignCenter)
74 | label.setStyleSheet("font-size: 14px; padding: 10px;")
75 | layout.addWidget(label)
76 |
77 | self.setLayout(layout)
78 |
79 |
80 | class ImageProcessor:
81 | """Class for handling image processing operations"""
82 | def __init__(self, workdir):
83 | self.workdir = workdir
84 | self.images_folder = os.path.join(workdir, "images")
85 | self.images_org_folder = os.path.join(workdir, "images_org")
86 | self.exif_folder = os.path.join(workdir, "exif")
87 |
88 | # Ensure directories exist
89 | os.makedirs(self.images_folder, exist_ok=True)
90 | os.makedirs(self.exif_folder, exist_ok=True)
91 |
92 | def resize_images(self, method, value, progress_callback=None):
93 | """Resize images in the images folder"""
94 | # Backup original images if not already backed up
95 | if not os.path.exists(self.images_org_folder):
96 | shutil.copytree(self.images_folder, self.images_org_folder)
97 |
98 | # Get list of image files
99 | image_files = [f for f in os.listdir(self.images_folder)
100 | if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
101 |
102 | # Process each image
103 | for i, image_file in enumerate(image_files):
104 | image_path = os.path.join(self.images_folder, image_file)
105 |
106 | try:
107 | # Open image
108 | with Image.open(image_path) as img:
109 | # Calculate new dimensions based on method
110 | if method == "Percentage (%)":
111 | scale = value / 100
112 | new_width = int(img.width * scale)
113 | new_height = int(img.height * scale)
114 | elif method == "Width (px)":
115 | new_width = int(value)
116 | new_height = int(new_width * img.height / img.width)
117 | else: # Height (px)
118 | new_height = int(value)
119 | new_width = int(new_height * img.width / img.height)
120 |
121 | # Ensure minimum dimensions
122 | new_width = max(1, new_width)
123 | new_height = max(1, new_height)
124 |
125 | # Resize image
126 | img_resized = img.resize((new_width, new_height), Image.LANCZOS)
127 |
128 | # Save resized image
129 | img_resized.save(image_path)
130 | except Exception as e:
131 | print(f"Error resizing image {image_file}: {e}")
132 |
133 | # Call progress callback if provided
134 | if progress_callback:
135 | progress = (i + 1) / len(image_files) * 100
136 | progress_callback(int(progress))
137 |
138 | return len(image_files)
139 |
140 | def restore_original_images(self):
141 | """Restore original images from backup"""
142 | if os.path.exists(self.images_org_folder):
143 | try:
144 | shutil.rmtree(self.images_folder)
145 | shutil.copytree(self.images_org_folder, self.images_folder)
146 | return True
147 | except Exception as e:
148 | print(f"Error restoring original images: {e}")
149 | return False
150 | return False
151 |
152 | def get_sample_image_dimensions(self):
153 | """Get dimensions of a sample image in the folder"""
154 | try:
155 | image_files = [f for f in os.listdir(self.images_folder)
156 | if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
157 |
158 | if not image_files:
159 | return None, None
160 |
161 | sample_image_path = os.path.join(self.images_folder, image_files[0])
162 | with Image.open(sample_image_path) as img:
163 | return img.width, img.height
164 | except Exception as e:
165 | print(f"Error getting sample image dimensions: {e}")
166 | return None, None
167 |
168 | @staticmethod
169 | def convert_to_degrees(value):
170 | """Convert decimal GPS coordinates to degrees, minutes, seconds"""
171 | d = int(value)
172 | m = int((value - d) * 60)
173 | s = (value - d - m / 60) * 3600
174 | return d, m, s
175 |
176 | @staticmethod
177 | def convert_to_rational(number):
178 | """Convert a number to a rational for EXIF data"""
179 | f = fractions.Fraction(str(number)).limit_denominator()
180 | return f.numerator, f.denominator
181 |
182 | def apply_exif_from_mapillary_json(self, json_path, images_dir):
183 | """Apply EXIF data from Mapillary JSON to images"""
184 | if not os.path.exists(json_path):
185 | print(f"JSON file not found: {json_path}")
186 | return 0
187 |
188 | if not os.path.exists(images_dir):
189 | print(f"Images directory not found: {images_dir}")
190 | return 0
191 |
192 | try:
193 | with open(json_path, 'r') as file:
194 | metadata_list = json.load(file)
195 | except Exception as e:
196 | print(f"Error loading JSON file: {e}")
197 | return 0
198 |
199 | processed_count = 0
200 | for metadata in metadata_list:
201 | image_filename = metadata.get('filename')
202 | if not image_filename:
203 | continue
204 |
205 | image_path = os.path.join(images_dir, os.path.basename(image_filename))
206 | if not os.path.exists(image_path):
207 | print(f"Image not found: {image_path}")
208 | continue
209 |
210 | try:
211 | img = Image.open(image_path)
212 | try:
213 | exif_dict = piexif.load(img.info['exif'])
214 | except (KeyError, piexif.InvalidImageDataError):
215 | exif_dict = {"0th": {}, "Exif": {}, "GPS": {}, "1st": {}}
216 |
217 | # Check that required keys exist in metadata
218 | if all(key in metadata for key in ['MAPLatitude', 'MAPLongitude', 'MAPAltitude', 'MAPCaptureTime']):
219 | lat_deg = self.convert_to_degrees(metadata['MAPLatitude'])
220 | lon_deg = self.convert_to_degrees(metadata['MAPLongitude'])
221 |
222 | gps_ifd = {
223 | piexif.GPSIFD.GPSLatitudeRef: 'N' if metadata['MAPLatitude'] >= 0 else 'S',
224 | piexif.GPSIFD.GPSLongitudeRef: 'E' if metadata['MAPLongitude'] >= 0 else 'W',
225 | piexif.GPSIFD.GPSLatitude: [
226 | self.convert_to_rational(lat_deg[0]),
227 | self.convert_to_rational(lat_deg[1]),
228 | self.convert_to_rational(lat_deg[2])
229 | ],
230 | piexif.GPSIFD.GPSLongitude: [
231 | self.convert_to_rational(lon_deg[0]),
232 | self.convert_to_rational(lon_deg[1]),
233 | self.convert_to_rational(lon_deg[2])
234 | ],
235 | piexif.GPSIFD.GPSAltitude: self.convert_to_rational(metadata['MAPAltitude']),
236 | piexif.GPSIFD.GPSAltitudeRef: 0 if metadata['MAPAltitude'] >= 0 else 1,
237 | }
238 |
239 | exif_dict['GPS'] = gps_ifd
240 |
241 | capture_time = datetime.datetime.strptime(metadata['MAPCaptureTime'], '%Y_%m_%d_%H_%M_%S_%f')
242 | exif_dict['Exif'][piexif.ExifIFD.DateTimeOriginal] = capture_time.strftime('%Y:%m:%d %H:%M:%S')
243 | exif_dict['0th'][piexif.ImageIFD.Orientation] = metadata.get('MAPOrientation', 1)
244 |
245 | exif_bytes = piexif.dump(exif_dict)
246 | img.save(image_path, "jpeg", exif=exif_bytes)
247 |
248 | processed_count += 1
249 | else:
250 | print(f"Missing required metadata for {image_filename}")
251 |
252 | img.close()
253 |
254 | except Exception as e:
255 | print(f"Failed to update EXIF for {image_path}: {e}")
256 |
257 | # After writing EXIF, rename the folder to 'images'
258 | try:
259 | images_parent_dir = os.path.dirname(images_dir)
260 | final_images_dir = os.path.join(images_parent_dir, "images")
261 |
262 | # If 'images' folder already exists, remove it and replace
263 | if os.path.exists(final_images_dir):
264 | shutil.rmtree(final_images_dir)
265 | os.rename(images_dir, final_images_dir)
266 | print(f"Renamed folder to {final_images_dir}")
267 | except Exception as e:
268 | print(f"Error renaming directory: {e}")
269 |
270 | return processed_count
--------------------------------------------------------------------------------
/config/config.yaml:
--------------------------------------------------------------------------------
1 | # Metadata
2 | use_exif_size: yes
3 | unknown_camera_models_are_different: no # Treat images from unknown camera models as comming from different cameras
4 | default_focal_prior: 0.85
5 |
6 | # Params for features
7 | feature_type: ALIKED # Feature type (AKAZE, SURF, SIFT, HAHOG, ORB)
8 | feature_root: 1 # If 1, apply square root mapping to features
9 | feature_min_frames: 4000 # If fewer frames are detected, sift_peak_threshold/surf_hessian_threshold is reduced.
10 | feature_min_frames_panorama: 16000 # Same as above but for panorama images
11 | feature_process_size: 2048 # Resize the image if its size is larger than specified. Set to -1 for original size
12 | feature_process_size_panorama: 4096 # Same as above but for panorama images
13 | feature_use_adaptive_suppression: no
14 | features_bake_segmentation: no # Bake segmentation info (class and instance) in the feature data. Thus it is done once for all at extraction time.
15 |
16 | # Params for SIFT
17 | sift_peak_threshold: 0.1 # Smaller value -> more features
18 | sift_edge_threshold: 10 # See OpenCV doc
19 |
20 | # Params for SURF
21 | surf_hessian_threshold: 3000 # Smaller value -> more features
22 | surf_n_octaves: 4 # See OpenCV doc
23 | surf_n_octavelayers: 2 # See OpenCV doc
24 | surf_upright: 0 # See OpenCV doc
25 |
26 | # Params for AKAZE (See details in lib/src/third_party/akaze/AKAZEConfig.h)
27 | akaze_omax: 4 # Maximum octave evolution of the image 2^sigma (coarsest scale sigma units)
28 | akaze_dthreshold: 0.001 # Detector response threshold to accept point
29 | akaze_descriptor: MSURF # Feature type
30 | akaze_descriptor_size: 0 # Size of the descriptor in bits. 0->Full size
31 | akaze_descriptor_channels: 3 # Number of feature channels (1,2,3)
32 | akaze_kcontrast_percentile: 0.7
33 | akaze_use_isotropic_diffusion: no
34 |
35 | # Params for HAHOG
36 | hahog_peak_threshold: 0.00001
37 | hahog_edge_threshold: 10
38 | hahog_normalize_to_uchar: yes
39 |
40 | # Params for general matching
41 | lowes_ratio: 0.8 # Ratio test for matches
42 | matcher_type: FLANN # FLANN, BRUTEFORCE, or WORDS
43 | symmetric_matching: yes # Match symmetricly or one-way
44 |
45 | # Params for FLANN matching
46 | flann_algorithm: KMEANS # Algorithm type (KMEANS, KDTREE)
47 | flann_branching: 8 # See OpenCV doc
48 | flann_iterations: 10 # See OpenCV doc
49 | flann_tree: 8 # See OpenCV doc
50 | flann_checks: 20 # Smaller -> Faster (but might lose good matches)
51 |
52 | # Params for BoW matching
53 | bow_file: bow_hahog_root_uchar_10000.npz
54 | bow_words_to_match: 50 # Number of words to explore per feature.
55 | bow_num_checks: 20 # Number of matching features to check.
56 | bow_matcher_type: FLANN # Matcher type to assign words to features
57 |
58 | # Params for VLAD matching
59 | vlad_file: bow_hahog_root_uchar_64.npz
60 |
61 | # Params for matching
62 | matching_gps_distance: 150 # Maximum gps distance between two images for matching
63 | matching_gps_neighbors: 0 # Number of images to match selected by GPS distance. Set to 0 to use no limit (or disable if matching_gps_distance is also 0)
64 | matching_time_neighbors: 0 # Number of images to match selected by time taken. Set to 0 to disable
65 | matching_order_neighbors: 0 # Number of images to match selected by image name. Set to 0 to disable
66 | matching_bow_neighbors: 0 # Number of images to match selected by BoW distance. Set to 0 to disable
67 | matching_bow_gps_distance: 0 # Maximum GPS distance for preempting images before using selection by BoW distance. Set to 0 to disable
68 | matching_bow_gps_neighbors: 0 # Number of images (selected by GPS distance) to preempt before using selection by BoW distance. Set to 0 to use no limit (or disable if matching_bow_gps_distance is also 0)
69 | matching_bow_other_cameras: False # If True, BoW image selection will use N neighbors from the same camera + N neighbors from any different camera. If False, the selection will take the nearest neighbors from all cameras.
70 | matching_vlad_neighbors: 0 # Number of images to match selected by VLAD distance. Set to 0 to disable
71 | matching_vlad_gps_distance: 0 # Maximum GPS distance for preempting images before using selection by VLAD distance. Set to 0 to disable
72 | matching_vlad_gps_neighbors: 0 # Number of images (selected by GPS distance) to preempt before using selection by VLAD distance. Set to 0 to use no limit (or disable if matching_vlad_gps_distance is also 0)
73 | matching_vlad_other_cameras: False # If True, VLAD image selection will use N neighbors from the same camera + N neighbors from any different camera. If False, the selection will take the nearest neighbors from all cameras.
74 | matching_graph_rounds: 0 # Number of rounds to run when running triangulation-based pair selection
75 | matching_use_filters: False # If True, removes static matches using ad-hoc heuristics
76 | matching_use_segmentation: no # Use segmentation information (if available) to improve matching
77 |
78 | # Params for geometric estimation
79 | robust_matching_threshold: 0.004 # Outlier threshold for fundamental matrix estimation as portion of image width
80 | robust_matching_calib_threshold: 0.004 # Outlier threshold for essential matrix estimation during matching in radians
81 | robust_matching_min_match: 20 # Minimum number of matches to accept matches between two images
82 | five_point_algo_threshold: 0.004 # Outlier threshold for essential matrix estimation during incremental reconstruction in radians
83 | five_point_algo_min_inliers: 20 # Minimum number of inliers for considering a two view reconstruction valid
84 | five_point_refine_match_iterations: 10 # Number of LM iterations to run when refining relative pose during matching
85 | five_point_refine_rec_iterations: 1000 # Number of LM iterations to run when refining relative pose during reconstruction
86 | triangulation_threshold: 0.006 # Outlier threshold for accepting a triangulated point in radians
87 | triangulation_min_ray_angle: 1.0 # Minimum angle between views to accept a triangulated point
88 | triangulation_type: FULL # Triangulation type : either considering all rays (FULL), or sing a RANSAC variant (ROBUST)
89 | resection_threshold: 0.004 # Outlier threshold for resection in radians
90 | resection_min_inliers: 10 # Minimum number of resection inliers to accept it
91 |
92 | # Params for track creation
93 | min_track_length: 2 # Minimum number of features/images per track
94 |
95 | # Params for bundle adjustment
96 | loss_function: SoftLOneLoss # Loss function for the ceres problem (see: http://ceres-solver.org/modeling.html#lossfunction)
97 | loss_function_threshold: 1 # Threshold on the squared residuals. Usually cost is quadratic for smaller residuals and sub-quadratic above.
98 | reprojection_error_sd: 0.004 # The standard deviation of the reprojection error
99 | exif_focal_sd: 0.01 # The standard deviation of the exif focal length in log-scale
100 | principal_point_sd: 0.01 # The standard deviation of the principal point coordinates
101 | radial_distortion_k1_sd: 0.01 # The standard deviation of the first radial distortion parameter
102 | radial_distortion_k2_sd: 0.01 # The standard deviation of the second radial distortion parameter
103 | radial_distortion_k3_sd: 0.01 # The standard deviation of the third radial distortion parameter
104 | radial_distortion_k4_sd: 0.01 # The standard deviation of the fourth radial distortion parameter
105 | tangential_distortion_p1_sd: 0.01 # The standard deviation of the first tangential distortion parameter
106 | tangential_distortion_p2_sd: 0.01 # The standard deviation of the second tangential distortion parameter
107 | gcp_horizontal_sd: 0.01 # The default horizontal standard deviation of the GCPs (in meters)
108 | gcp_vertical_sd: 0.1 # The default vertical standard deviation of the GCPs (in meters)
109 | rig_translation_sd: 0.1 # The standard deviation of the rig translation
110 | rig_rotation_sd: 0.1 # The standard deviation of the rig rotation
111 | bundle_outlier_filtering_type: FIXED # Type of threshold for filtering outlier : either fixed value (FIXED) or based on actual distribution (AUTO)
112 | bundle_outlier_auto_ratio: 3.0 # For AUTO filtering type, projections with larger reprojection than ratio-times-mean, are removed
113 | bundle_outlier_fixed_threshold: 0.006 # For FIXED filtering type, projections with larger reprojection error after bundle adjustment are removed
114 | optimize_camera_parameters: yes # Optimize internal camera parameters during bundle
115 | bundle_max_iterations: 100 # Maximum optimizer iterations.
116 |
117 | retriangulation: yes # Retriangulate all points from time to time
118 | retriangulation_ratio: 1.2 # Retriangulate when the number of points grows by this ratio
119 | bundle_analytic_derivatives: yes # Use analytic derivatives or auto-differentiated ones during bundle adjustment
120 | bundle_interval: 999999 # Bundle after adding 'bundle_interval' cameras
121 | bundle_new_points_ratio: 1.2 # Bundle when the number of points grows by this ratio
122 | local_bundle_radius: 3 # Max image graph distance for images to be included in local bundle adjustment
123 | local_bundle_min_common_points: 20 # Minimum number of common points betwenn images to be considered neighbors
124 | local_bundle_max_shots: 30 # Max number of shots to optimize during local bundle adjustment
125 |
126 | save_partial_reconstructions: yes # Save reconstructions at every iteration
127 |
128 | # Params for GPS alignment
129 | use_altitude_tag: no # Use or ignore EXIF altitude tag
130 | align_method: auto # orientation_prior or naive
131 | align_orientation_prior: horizontal # horizontal, vertical or no_roll
132 | bundle_use_gps: yes # Enforce GPS position in bundle adjustment
133 | bundle_use_gcp: no # Enforce Ground Control Point position in bundle adjustment
134 | bundle_compensate_gps_bias: yes # Compensate GPS with a per-camera similarity transform
135 |
136 |
137 | # Params for rigs
138 | rig_calibration_subset_size: 15 # Number of rig instances to use when calibration rigs
139 | rig_calibration_completeness: 0.85 # Ratio of reconstructed images needed to consider a reconstruction for rig calibration
140 | rig_calibration_max_rounds: 10 # Number of SfM tentatives to run until we get a satisfying reconstruction
141 |
142 | # Params for image undistortion
143 | undistorted_image_format: jpg # Format in which to save the undistorted images
144 | undistorted_image_max_size: 100000 # Max width and height of the undistorted image
145 |
146 | # Params for depth estimation
147 | depthmap_method: PATCH_MATCH_SAMPLE # Raw depthmap computation algorithm (PATCH_MATCH, BRUTE_FORCE, PATCH_MATCH_SAMPLE)
148 | depthmap_resolution: 640 # Resolution of the depth maps
149 | depthmap_num_neighbors: 10 # Number of neighboring views
150 | depthmap_num_matching_views: 6 # Number of neighboring views used for each depthmaps
151 | depthmap_min_depth: 0 # Minimum depth in meters. Set to 0 to auto-infer from the reconstruction.
152 | depthmap_max_depth: 0 # Maximum depth in meters. Set to 0 to auto-infer from the reconstruction.
153 | depthmap_patchmatch_iterations: 3 # Number of PatchMatch iterations to run
154 | depthmap_patch_size: 7 # Size of the correlation patch
155 | depthmap_min_patch_sd: 1.0 # Patches with lower standard deviation are ignored
156 | depthmap_min_correlation_score: 0.1 # Minimum correlation score to accept a depth value
157 | depthmap_same_depth_threshold: 0.01 # Threshold to measure depth closeness
158 | depthmap_min_consistent_views: 3 # Min number of views that should reconstruct a point for it to be valid
159 | depthmap_save_debug_files: no # Save debug files with partial reconstruction results
160 |
161 | # Other params
162 | processes: 1 # Number of threads to use
163 | read_processes: 1 # When processes > 1, number of threads used for reading images
164 |
165 | # Params for submodel split and merge
166 | submodel_size: 80 # Average number of images per submodel
167 | submodel_overlap: 30.0 # Radius of the overlapping region between submodels
168 | submodels_relpath: "submodels" # Relative path to the submodels directory
169 | submodel_relpath_template: "submodels/submodel_%04d" # Template to generate the relative path to a submodel directory
170 | submodel_images_relpath_template: "submodels/submodel_%04d/images" # Template to generate the relative path to a submodel images directory
--------------------------------------------------------------------------------
/tests/test_main_app.py:
--------------------------------------------------------------------------------
1 | # test_main_app.py
2 |
3 | import os
4 | import sys
5 | import pytest
6 | import shutil
7 | from unittest.mock import MagicMock, patch
8 | from PyQt5.QtWidgets import QApplication, QDialog, QMessageBox, QFileDialog
9 | from app.main_app import MainApp, VideoProcessDialog
10 |
11 | @pytest.fixture
12 | def setup_main_app(mock_qapplication):
13 | """Set up MainApp for testing with mocked dependencies"""
14 | with patch('app.main_app.TabManager'), \
15 | patch('app.main_app.QTimer'), \
16 | patch('app.main_app.QToolBar'):
17 | app = MainApp()
18 | app.show_start_dialog = MagicMock() # Prevent actual dialog on init
19 | return app
20 |
21 |
22 | def test_main_app_init(setup_main_app):
23 | """Test MainApp initialization"""
24 | app = setup_main_app
25 |
26 | # Check attributes
27 | assert app.workdir is None
28 | assert app.image_list == []
29 | assert app.tab_manager is not None
30 | assert app.camera_model_manager is None
31 | assert app.image_processor is None
32 |
33 | # Check that QTimer was set to call show_start_dialog
34 | from app.main_app import QTimer
35 | assert QTimer.singleShot.call_count == 1
36 |
37 |
38 | def test_setup_ui(setup_main_app):
39 | """Test setup_ui method"""
40 | app = setup_main_app
41 |
42 | # Reset mocks
43 | app.tab_manager.reset_mock()
44 |
45 | # Call method again
46 | app.setup_ui()
47 |
48 | # Check that TabManager was set as central widget
49 | assert app.centralWidget() == app.tab_manager
50 |
51 |
52 | def test_register_tabs(setup_main_app):
53 | """Test register_tabs method"""
54 | app = setup_main_app
55 |
56 | # Reset mock
57 | app.tab_manager.reset_mock()
58 |
59 | # Call method
60 | app.register_tabs()
61 |
62 | # Check that register_tab was called
63 | assert app.tab_manager.register_tab.call_count >= 1
64 |
65 |
66 | @patch('app.main_app.QMessageBox.question')
67 | def test_show_start_dialog_video(mock_question, setup_main_app):
68 | """Test show_start_dialog selecting video processing"""
69 | app = setup_main_app
70 | app.process_video = MagicMock()
71 | app.select_image_folder = MagicMock()
72 |
73 | # Choose video processing
74 | mock_question.return_value = QMessageBox.Yes
75 |
76 | # Call method
77 | app.show_start_dialog()
78 |
79 | # Check that process_video was called
80 | app.process_video.assert_called_once()
81 | assert app.select_image_folder.call_count == 0
82 |
83 |
84 | @patch('app.main_app.QMessageBox.question')
85 | def test_show_start_dialog_image_folder(mock_question, setup_main_app):
86 | """Test show_start_dialog selecting image folder"""
87 | app = setup_main_app
88 | app.process_video = MagicMock()
89 | app.select_image_folder = MagicMock()
90 |
91 | # Choose image folder
92 | mock_question.return_value = QMessageBox.No
93 |
94 | # Call method
95 | app.show_start_dialog()
96 |
97 | # Check that select_image_folder was called
98 | app.select_image_folder.assert_called_once()
99 | assert app.process_video.call_count == 0
100 |
101 |
102 | @patch('app.main_app.QFileDialog.getExistingDirectory')
103 | @patch('os.makedirs')
104 | @patch('os.listdir')
105 | @patch('os.path.exists')
106 | @patch('shutil.copy')
107 | def test_select_image_folder(mock_copy, mock_exists, mock_listdir, mock_makedirs,
108 | mock_get_dir, setup_main_app):
109 | """Test select_image_folder method"""
110 | app = setup_main_app
111 | app.load_workdir = MagicMock()
112 |
113 | # Set up mocks
114 | test_dir = "/test/workdir"
115 | mock_get_dir.return_value = test_dir
116 | mock_listdir.return_value = ["image1.jpg", "image2.jpg", "file.txt"]
117 | mock_exists.return_value = False # Images don't exist in target dir
118 |
119 | # Call method
120 | app.select_image_folder()
121 |
122 | # Check that workdir was set
123 | assert app.workdir == test_dir
124 |
125 | # Check that images directory was created
126 | mock_makedirs.assert_called_with(os.path.join(test_dir, "images"), exist_ok=True)
127 |
128 | # Check that images were copied
129 | assert mock_copy.call_count == 2 # Two image files
130 |
131 | # Check that load_workdir was called
132 | app.load_workdir.assert_called_once()
133 |
134 | # Test cancel case
135 | mock_get_dir.return_value = ""
136 | with patch('app.main_app.QMessageBox.warning') as mock_warning, \
137 | patch('app.main_app.sys.exit') as mock_exit:
138 | app.select_image_folder()
139 | mock_warning.assert_called_once()
140 | mock_exit.assert_called_once_with(1)
141 |
142 |
143 | @patch('app.main_app.QFileDialog.getOpenFileName')
144 | @patch('app.main_app.VideoProcessDialog')
145 | @patch('app.main_app.ExifExtractProgressDialog')
146 | @patch('app.main_app.VideoProcessCommand')
147 | @patch('app.main_app.ImageProcessor')
148 | def test_process_video(mock_image_processor, mock_video_command, mock_progress_dialog,
149 | mock_dialog, mock_get_file, setup_main_app):
150 | """Test process_video method"""
151 | app = setup_main_app
152 | app.load_workdir = MagicMock()
153 |
154 | # Set up mocks
155 | test_file = "/test/video.mp4"
156 | mock_get_file.return_value = (test_file, "")
157 |
158 | dialog_instance = MagicMock()
159 | dialog_instance.exec_.return_value = QDialog.Accepted
160 | dialog_instance.get_values.return_value = {
161 | "import_path": "/test/import_path",
162 | "method": "Interval",
163 | "interval": 0.5,
164 | "distance": -1,
165 | "geotag_source": "video",
166 | "geotag_source_path": None,
167 | "offset_time": 0,
168 | "use_gpx": True
169 | }
170 | mock_dialog.return_value = dialog_instance
171 |
172 | progress_instance = MagicMock()
173 | mock_progress_dialog.return_value = progress_instance
174 |
175 | video_command_instance = MagicMock()
176 | mock_video_command.return_value = video_command_instance
177 |
178 | processor_instance = MagicMock()
179 | mock_image_processor.return_value = processor_instance
180 |
181 | # Mock file existence checks
182 | with patch('os.makedirs') as mock_makedirs, \
183 | patch('os.path.exists') as mock_exists, \
184 | patch('app.main_app.QApplication.processEvents') as mock_process_events:
185 |
186 | mock_exists.return_value = True
187 |
188 | # Call method
189 | app.process_video()
190 |
191 | # Check that dialog was shown
192 | mock_dialog.assert_called_once()
193 | dialog_instance.exec_.assert_called_once()
194 |
195 | # Check that progress dialog was shown
196 | mock_progress_dialog.assert_called_once()
197 | progress_instance.show.assert_called_once()
198 |
199 | # Check that VideoProcessCommand was used
200 | mock_video_command.assert_called_once()
201 | video_command_instance.run.assert_called_once()
202 |
203 | # Check that ImageProcessor was used
204 | mock_image_processor.assert_called_once()
205 | processor_instance.apply_exif_from_mapillary_json.assert_called_once()
206 |
207 | # Check that workdir was set and load_workdir called
208 | assert app.workdir == "/test/import_path"
209 | app.load_workdir.assert_called_once()
210 |
211 | # Test cancel case
212 | dialog_instance.exec_.return_value = QDialog.Rejected
213 | with patch('app.main_app.QMessageBox.warning') as mock_warning, \
214 | patch('app.main_app.sys.exit') as mock_exit:
215 | app.process_video()
216 | mock_warning.assert_called_once()
217 | mock_exit.assert_called_once_with(1)
218 |
219 |
220 | @patch('os.path.exists')
221 | @patch('os.listdir')
222 | def test_load_workdir_with_exif(mock_listdir, mock_exists, setup_main_app):
223 | """Test load_workdir method when EXIF data exists"""
224 | app = setup_main_app
225 | app.workdir = "/test/workdir"
226 |
227 | # Mock file operations
228 | mock_exists.return_value = True # All paths exist
229 | mock_listdir.return_value = ["image1.jpg", "image2.jpg"]
230 |
231 | # Mock camera model manager
232 | with patch('app.main_app.CameraModelManager') as mock_manager, \
233 | patch('app.main_app.ImageProcessor') as mock_processor:
234 |
235 | manager_instance = MagicMock()
236 | mock_manager.return_value = manager_instance
237 |
238 | processor_instance = MagicMock()
239 | mock_processor.return_value = processor_instance
240 |
241 | # Call method
242 | app.load_workdir()
243 |
244 | # Check that image list was populated
245 | assert app.image_list == ["image1.jpg", "image2.jpg"]
246 |
247 | # Check that managers were created
248 | mock_manager.assert_called_once_with(app.workdir)
249 | mock_processor.assert_called_once_with(app.workdir)
250 |
251 | # Check that tab manager was updated
252 | if app.tab_manager:
253 | app.tab_manager.update_all_tabs.assert_called_once_with(
254 | workdir=app.workdir, image_list=app.image_list
255 | )
256 |
257 |
258 | @patch('os.path.exists')
259 | @patch('os.listdir')
260 | def test_load_workdir_without_exif(mock_listdir, mock_exists, setup_main_app):
261 | """Test load_workdir method when EXIF data doesn't exist"""
262 | app = setup_main_app
263 | app.workdir = "/test/workdir"
264 |
265 | # Mock file operations to simulate missing EXIF
266 | def mock_exists_side_effect(path):
267 | return "images" in path # Only images directory exists
268 |
269 | mock_exists.side_effect = mock_exists_side_effect
270 | mock_listdir.return_value = ["image1.jpg", "image2.jpg"]
271 |
272 | # Mock progress dialog
273 | with patch('app.main_app.ExifExtractProgressDialog') as mock_progress, \
274 | patch('app.main_app.CameraModelManager') as mock_manager, \
275 | patch('app.main_app.ImageProcessor') as mock_processor, \
276 | patch('app.main_app.QApplication.processEvents') as mock_process, \
277 | patch('shutil.copy') as mock_copy, \
278 | patch('app.main_app.DataSet') as mock_dataset, \
279 | patch('app.main_app.extract_metadata') as mock_extract:
280 |
281 | progress_instance = MagicMock()
282 | mock_progress.return_value = progress_instance
283 |
284 | manager_instance = MagicMock()
285 | mock_manager.return_value = manager_instance
286 |
287 | processor_instance = MagicMock()
288 | mock_processor.return_value = processor_instance
289 |
290 | dataset_instance = MagicMock()
291 | mock_dataset.return_value = dataset_instance
292 |
293 | # Call method
294 | app.load_workdir()
295 |
296 | # Check that progress dialog was shown
297 | mock_progress.assert_called_once()
298 | progress_instance.show.assert_called_once()
299 |
300 | # Check that metadata extraction was called
301 | mock_copy.assert_called_once() # Config file copy
302 | mock_dataset.assert_called_once_with(app.workdir)
303 | mock_extract.run_dataset.assert_called_once_with(dataset_instance)
304 |
305 | # Check that camera model editor was opened
306 | manager_instance.open_camera_model_editor.assert_called_once_with(parent=app)
307 |
308 |
309 | def test_video_process_dialog(mock_qapplication):
310 | """Test VideoProcessDialog"""
311 | dialog = VideoProcessDialog()
312 |
313 | # Test toggle_sampling_inputs
314 | dialog.sampling_method_combo = MagicMock()
315 | dialog.distance_input = MagicMock()
316 | dialog.interval_input = MagicMock()
317 |
318 | # Test Interval method
319 | dialog.sampling_method_combo.currentText.return_value = "Interval"
320 | dialog.toggle_sampling_inputs()
321 | dialog.distance_input.setDisabled.assert_called_with(True)
322 | dialog.interval_input.setDisabled.assert_called_with(False)
323 |
324 | # Test Distance method
325 | dialog.sampling_method_combo.currentText.return_value = "Distance"
326 | dialog.toggle_sampling_inputs()
327 | dialog.distance_input.setDisabled.assert_called_with(False)
328 | dialog.interval_input.setDisabled.assert_called_with(True)
329 |
330 | # Test get_sampling_values
331 | dialog.distance_input.text.return_value = "5"
332 | dialog.interval_input.text.return_value = "0.5"
333 |
334 | # With Distance method
335 | dialog.sampling_method_combo.currentText.return_value = "Distance"
336 | method, interval, distance = dialog.get_sampling_values()
337 | assert method == "Distance"
338 | assert interval == -1
339 | assert distance == 5.0
340 |
341 | # With Interval method
342 | dialog.sampling_method_combo.currentText.return_value = "Interval"
343 | method, interval, distance = dialog.get_sampling_values()
344 | assert method == "Interval"
345 | assert interval == 0.5
346 | assert distance == -1
347 |
348 | # Test get_values
349 | dialog.import_path_input = MagicMock()
350 | dialog.import_path_input.text.return_value = "/test/import"
351 |
352 | dialog.geotag_source_combo = MagicMock()
353 | dialog.geotag_source_combo.currentText.return_value = "video"
354 |
355 | dialog.geotag_source_path_input = MagicMock()
356 | dialog.geotag_source_path_input.text.return_value = ""
357 |
358 | dialog.interpolation_offset_input = MagicMock()
359 | dialog.interpolation_offset_input.text.return_value = "1.5"
360 |
361 | dialog.interpolation_use_gpx_checkbox = MagicMock()
362 | dialog.interpolation_use_gpx_checkbox.isChecked.return_value = True
363 |
364 | # Mock get_sampling_values
365 | dialog.get_sampling_values = MagicMock(return_value=("Interval", 0.5, -1))
366 |
367 | values = dialog.get_values()
368 | assert values["import_path"] == "/test/import"
369 | assert values["method"] == "Interval"
370 | assert values["interval"] == 0.5
371 | assert values["distance"] == -1
372 | assert values["geotag_source"] == "video"
373 | assert values["geotag_source_path"] is None
374 | assert values["offset_time"] == 1.5
375 | assert values["use_gpx"] is True
--------------------------------------------------------------------------------
/app/camera_models.py:
--------------------------------------------------------------------------------
1 | # camera_models.py
2 |
3 | import os
4 | import json
5 | from PyQt5.QtWidgets import (
6 | QDialog, QVBoxLayout, QTableWidget, QTableWidgetItem,
7 | QPushButton, QLabel, QComboBox, QDialogButtonBox, QMessageBox
8 | )
9 |
10 | class CameraModelEditor(QDialog):
11 | """Camera model editor dialog"""
12 | def __init__(self, camera_models, workdir, parent=None):
13 | super().__init__(parent)
14 | self.setWindowTitle("Edit Camera Models")
15 | self.setFixedSize(600, 400)
16 |
17 | self.camera_models = camera_models
18 | self.workdir = workdir
19 |
20 | layout = QVBoxLayout()
21 | layout.addWidget(QLabel("Camera Model Overrides"))
22 |
23 | self.table = QTableWidget()
24 | self.table.setColumnCount(3)
25 | self.table.setHorizontalHeaderLabels(["Key", "Parameter", "Value"])
26 | layout.addWidget(self.table)
27 |
28 | save_button = QPushButton("Save Changes")
29 | save_button.clicked.connect(self.save_changes)
30 | layout.addWidget(save_button)
31 |
32 | self.setLayout(layout)
33 |
34 | self.load_camera_models()
35 |
36 | def load_camera_models(self):
37 | """Load camera models into the table"""
38 | self.table.setRowCount(0)
39 | row = 0
40 | for key, params in self.camera_models.items():
41 | # key = 'Perspective' etc. (e.g., "Spherical")
42 | # params = { 'projection_type': 'perspective', 'width': ..., etc. }
43 | for param, value in params.items():
44 | self.table.insertRow(row)
45 | # Column 1 (Key)
46 | self.table.setItem(row, 0, QTableWidgetItem(key))
47 |
48 | # Column 2 (Parameter)
49 | self.table.setItem(row, 1, QTableWidgetItem(param))
50 |
51 | # Column 3 (Value) - Use ComboBox for projection_type
52 | if param == "projection_type":
53 | combo = QComboBox()
54 | combo.addItems(["perspective", "spherical"])
55 | # Set selection based on current value
56 | if str(value) in ["perspective", "spherical"]:
57 | combo.setCurrentText(str(value))
58 | else:
59 | # If the value isn't registered, add it to the front and select it
60 | combo.insertItem(0, str(value))
61 | combo.setCurrentIndex(0)
62 | self.table.setCellWidget(row, 2, combo)
63 | else:
64 | # For other params, show as text
65 | self.table.setItem(row, 2, QTableWidgetItem(str(value)))
66 |
67 | row += 1
68 |
69 | def save_changes(self):
70 | """Save changes to camera_models_overrides.json"""
71 | updated_models = {}
72 | for row in range(self.table.rowCount()):
73 | key_item = self.table.item(row, 0)
74 | param_item = self.table.item(row, 1)
75 |
76 | if not key_item or not param_item:
77 | continue
78 |
79 | key = key_item.text()
80 | param = param_item.text()
81 |
82 | # Check if Value is a ComboBox
83 | cell_widget = self.table.cellWidget(row, 2)
84 | if isinstance(cell_widget, QComboBox):
85 | # If ComboBox, get current text
86 | value = cell_widget.currentText()
87 | else:
88 | # If QTableWidgetItem, get text
89 | value_item = self.table.item(row, 2)
90 | if value_item:
91 | value = value_item.text()
92 | else:
93 | value = ""
94 |
95 | # Convert to float/int if it's a number
96 | # projection_type is usually a string, so it normally won't be converted
97 | try:
98 | # If '.' is included, convert to float, otherwise int
99 | if '.' in value:
100 | num = float(value)
101 | value = num
102 | else:
103 | num = int(value)
104 | value = num
105 | except ValueError:
106 | # If conversion fails, keep as string
107 | pass
108 |
109 | if key not in updated_models:
110 | updated_models[key] = {}
111 | updated_models[key][param] = value
112 |
113 | # Write JSON
114 | try:
115 | overrides_path = os.path.join(self.workdir, "camera_models_overrides.json")
116 | with open(overrides_path, "w") as f:
117 | json.dump(updated_models, f, indent=4)
118 |
119 | # After saving overrides, update the camera_models.json file as well
120 | # This fixes the issue where camera models aren't updated
121 | self.update_base_camera_models(updated_models)
122 |
123 | # Update the EXIF data in the exifs folder with the new camera model info
124 | self.update_exif_files(updated_models)
125 |
126 | QMessageBox.information(self, "Success", "Camera models saved successfully!")
127 | self.accept()
128 | except Exception as e:
129 | QMessageBox.critical(self, "Error", f"Failed to save camera models: {e}")
130 |
131 | def update_base_camera_models(self, overrides):
132 | """Update the camera_models.json file with the latest overrides
133 | This fixes the issue where camera models aren't updated when overrides is generated first time"""
134 | try:
135 | camera_models_path = os.path.join(self.workdir, "camera_models.json")
136 |
137 | # Load existing camera models if available
138 | if os.path.exists(camera_models_path):
139 | with open(camera_models_path, "r") as f:
140 | base_models = json.load(f)
141 | else:
142 | base_models = {}
143 |
144 | # Merge with overrides
145 | merged_models = base_models.copy()
146 | for key, params in overrides.items():
147 | if key in merged_models:
148 | merged_models[key].update(params)
149 | else:
150 | merged_models[key] = params
151 |
152 | # Save updated camera models
153 | with open(camera_models_path, "w") as f:
154 | json.dump(merged_models, f, indent=4)
155 |
156 | return True
157 | except Exception as e:
158 | print(f"Error updating camera_models.json: {e}")
159 | return False
160 |
161 | def update_exif_files(self, camera_models_overrides):
162 | """Update EXIF files with new camera model information
163 | This ensures EXIF data is consistent with camera model overrides"""
164 | try:
165 | exif_dir = os.path.join(self.workdir, "exif")
166 | if not os.path.exists(exif_dir):
167 | print(f"EXIF directory does not exist: {exif_dir}")
168 | return False
169 |
170 | # Get the merged camera models first
171 | camera_models_path = os.path.join(self.workdir, "camera_models.json")
172 | if os.path.exists(camera_models_path):
173 | with open(camera_models_path, "r") as f:
174 | camera_models = json.load(f)
175 | else:
176 | print(f"Camera models file does not exist: {camera_models_path}")
177 | return False
178 |
179 | # Process each EXIF file
180 | for filename in os.listdir(exif_dir):
181 | if not filename.endswith(".exif"):
182 | continue
183 |
184 | exif_path = os.path.join(exif_dir, filename)
185 |
186 | # Read the EXIF data
187 | with open(exif_path, "r") as f:
188 | exif_data = json.load(f)
189 |
190 | # Get the camera model name
191 | camera_name = exif_data.get("camera", "Unknown Camera")
192 |
193 | # Check if there are overrides for this camera
194 | if camera_name in camera_models:
195 | # Get the parameters that need to be updated
196 | camera_params = camera_models[camera_name]
197 |
198 | # Update EXIF data with the new camera parameters
199 | for param, value in camera_params.items():
200 | # Update specific fields based on the parameter
201 | if param == "projection_type":
202 | exif_data["projection_type"] = value
203 | elif param == "width":
204 | exif_data["width"] = value
205 | elif param == "height":
206 | exif_data["height"] = value
207 | elif param == "focal_ratio":
208 | exif_data["focal_ratio"] = value
209 | # If focal_x and focal_y are present, update them too
210 | if "width" in camera_params and "focal" in exif_data:
211 | exif_data["focal_x"] = value * camera_params["width"]
212 | if "height" in camera_params and "focal" in exif_data:
213 | exif_data["focal_y"] = value * camera_params["height"]
214 |
215 | # Write the updated EXIF data back to the file
216 | with open(exif_path, "w") as f:
217 | json.dump(exif_data, f, indent=4)
218 |
219 | return True
220 | except Exception as e:
221 | print(f"Error updating EXIF files: {e}")
222 | return False
223 |
224 |
225 | class CameraModelManager:
226 | """Manager class for camera models"""
227 | def __init__(self, workdir):
228 | self.workdir = workdir
229 | self.camera_models = {}
230 | self._default_model = {
231 | "Perspective": {
232 | "projection_type": "perspective",
233 | "width": 1920,
234 | "height": 1080,
235 | "focal_ratio": 1.0
236 | }
237 | }
238 | self.load_camera_models()
239 |
240 | def load_camera_models(self):
241 | """Load camera_models.json and apply camera_models_overrides.json"""
242 | camera_models_path = os.path.join(self.workdir, "camera_models.json")
243 | overrides_path = os.path.join(self.workdir, "camera_models_overrides.json")
244 |
245 | # Load base models, if missing use default
246 | if os.path.exists(camera_models_path):
247 | try:
248 | with open(camera_models_path, "r") as f:
249 | base_models = json.load(f)
250 | except Exception as e:
251 | print(f"Error loading camera_models.json: {e}")
252 | base_models = self._default_model.copy()
253 | # Try to create the file with default model
254 | try:
255 | with open(camera_models_path, "w") as f:
256 | json.dump(base_models, f, indent=4)
257 | except Exception as e:
258 | print(f"Error creating default camera_models.json: {e}")
259 | else:
260 | base_models = self._default_model.copy()
261 | # Try to create the file with default model
262 | try:
263 | with open(camera_models_path, "w") as f:
264 | json.dump(base_models, f, indent=4)
265 | except Exception as e:
266 | print(f"Error creating default camera_models.json: {e}")
267 |
268 | # Load overrides if they exist
269 | overrides = {}
270 | if os.path.exists(overrides_path):
271 | try:
272 | with open(overrides_path, "r") as f:
273 | overrides = json.load(f)
274 | except Exception as e:
275 | print(f"Error loading camera_models_overrides.json: {e}")
276 |
277 | # Merge models with overrides
278 | merged_models = base_models.copy()
279 | for key, params in overrides.items():
280 | if key in merged_models:
281 | merged_models[key].update(params)
282 | else:
283 | merged_models[key] = params
284 |
285 | # Write the merged models back to camera_models.json to ensure it's always up-to-date
286 | # This fixes the issue where camera models aren't updated when overrides changes
287 | try:
288 | with open(camera_models_path, "w") as f:
289 | json.dump(merged_models, f, indent=4)
290 | except Exception as e:
291 | print(f"Error updating camera_models.json: {e}")
292 |
293 | self.camera_models = merged_models
294 | return merged_models
295 |
296 | def get_camera_models(self):
297 | """Get the current camera models"""
298 | return self.camera_models
299 |
300 | def open_camera_model_editor(self, parent=None):
301 | """Open camera model editor dialog"""
302 | if self.workdir:
303 | # Reload models to get latest changes
304 | self.load_camera_models()
305 |
306 | try:
307 | dialog = CameraModelEditor(self.camera_models, self.workdir, parent=parent)
308 | if dialog.exec_():
309 | # Reload after saving
310 | self.load_camera_models()
311 |
312 | return True
313 | except Exception as e:
314 | if parent:
315 | QMessageBox.critical(parent, "Error", f"Failed to open camera model editor: {e}")
316 | return False
317 | else:
318 | if parent:
319 | QMessageBox.warning(parent, "Error", "Workdir is not set.")
320 | return False
321 |
322 | def update_exif_with_camera_models(self):
323 | """Update all EXIF files with current camera model information
324 | This can be called explicitly when needed to ensure EXIF data is in sync"""
325 | editor = CameraModelEditor(self.camera_models, self.workdir)
326 | return editor.update_exif_files(self.camera_models)
--------------------------------------------------------------------------------
/app/mask_manager.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import torch
5 | from loguru import logger as guru
6 | from PyQt5.QtWidgets import (
7 | QWidget, QVBoxLayout, QLabel, QPushButton, QHBoxLayout, QMessageBox
8 | )
9 | from PyQt5.QtGui import QPixmap, QImage
10 | from PyQt5.QtCore import Qt
11 | from sam2.build_sam import build_sam2
12 | from sam2.sam2_image_predictor import SAM2ImagePredictor
13 |
14 | class ClickableImageLabel(QLabel):
15 | def __init__(self, parent=None):
16 | super().__init__(parent)
17 | self.setAlignment(Qt.AlignCenter)
18 | self.click_callback = None # Callback function to be called on click
19 |
20 | def mousePressEvent(self, event):
21 | if event.button() == Qt.LeftButton:
22 | if self.click_callback is not None:
23 | self.click_callback(event)
24 | super().mousePressEvent(event)
25 |
26 | class MaskManager(QWidget):
27 | def __init__(self, checkpoint_path, config_path, mask_dir, img_dir, image_list):
28 | super().__init__()
29 | self.checkpoint_path = checkpoint_path
30 | self.config_path = config_path
31 | self.mask_dir = mask_dir
32 | self.img_dir = img_dir
33 | self.image_list = image_list
34 | self.current_index = 0
35 | self.current_image = None
36 | self.current_mask = None
37 | self.sam_model = None
38 | self.predictor = None
39 | self.input_points = []
40 | self.input_labels = []
41 | self.label_toggle = 1 # Start with positive point (1)
42 | self.image_name = None # Name of the current image
43 |
44 | # Initialize UI
45 | self.init_ui()
46 | # Initialize SAM2 model and predictor
47 | self.init_sam_model()
48 |
49 | # Point: Initially, do not load the image, display text instead
50 | # self.load_current_image() # ← Comment this out
51 | self.image_label.setText("Select an image to view masks.")
52 | self.image_label.setStyleSheet("color: gray; font-size: 16px;") # Styling for better visibility
53 |
54 | def init_ui(self):
55 | """Initialize the UI components."""
56 | layout = QVBoxLayout()
57 |
58 | # Clickable image label
59 | self.image_label = ClickableImageLabel()
60 | # Set the click callback function
61 | self.image_label.click_callback = self.on_image_clicked
62 | layout.addWidget(self.image_label)
63 |
64 | # Navigation and Reset Buttons Layout
65 | button_layout = QHBoxLayout()
66 |
67 | # Previous Image Button
68 | self.prev_button = QPushButton("< Previous Image")
69 | self.prev_button.clicked.connect(self.prev_image)
70 | button_layout.addWidget(self.prev_button)
71 |
72 | # Reset Mask Button
73 | self.reset_button = QPushButton("Reset Mask")
74 | self.reset_button.clicked.connect(self.reset_mask)
75 | button_layout.addWidget(self.reset_button)
76 |
77 | # Next Image Button
78 | self.next_button = QPushButton("Next Image >")
79 | self.next_button.clicked.connect(self.next_image)
80 | button_layout.addWidget(self.next_button)
81 |
82 | layout.addLayout(button_layout)
83 | self.setLayout(layout)
84 |
85 | def init_sam_model(self):
86 | """Initialize the SAM2 model and predictor."""
87 | if self.sam_model is None:
88 | device = "cuda" if torch.cuda.is_available() else "cpu"
89 | self.sam_model = build_sam2(self.config_path, self.checkpoint_path, device=device)
90 | self.predictor = SAM2ImagePredictor(self.sam_model)
91 | guru.info(f"SAM2 model loaded with checkpoint: {self.checkpoint_path}")
92 |
93 | def unload_sam_model(self):
94 | """Unload the SAM2 model and free resources."""
95 | if self.sam_model is not None:
96 | del self.sam_model
97 | del self.predictor
98 | self.sam_model = None
99 | self.predictor = None
100 | guru.info("SAM2 model and predictor have been unloaded.")
101 |
102 | # Optionally clear CUDA cache if using GPU
103 | if torch.cuda.is_available():
104 | torch.cuda.empty_cache()
105 | guru.info("CUDA cache has been cleared.")
106 |
107 | def load_current_image(self):
108 | """Load and display the current image."""
109 | # If the image list is empty, display a message and exit
110 | if not self.image_list:
111 | self.image_label.setText("Select an image to view masks.")
112 | return
113 |
114 | self.input_points = []
115 | self.input_labels = []
116 | self.label_toggle = 1 # Reset label toggle
117 | self.current_mask = None
118 |
119 | if 0 <= self.current_index < len(self.image_list):
120 | self.image_name = self.image_list[self.current_index]
121 | self.load_image_by_name(self.image_name)
122 | else:
123 | QMessageBox.warning(self, "Error", "No images to display.")
124 |
125 | def load_image_by_name(self, image_name):
126 | """Load and display the image specified by image_name."""
127 | self.image_name = image_name
128 | image_path = os.path.join(self.img_dir, image_name)
129 | self.current_image = cv2.imread(image_path)
130 |
131 | if self.current_image is None:
132 | QMessageBox.warning(self, "Error", f"Failed to load image {self.image_name}")
133 | return
134 |
135 | # Load the mask if it exists
136 | mask_path = os.path.join(self.mask_dir, f"{self.image_name}.png")
137 | if os.path.exists(mask_path):
138 | self.current_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
139 | guru.info(f"Loaded existing mask from {mask_path}")
140 | self.plot_mask(self.image_name, self.current_mask) # Display the image with mask overlay
141 | else:
142 | # Display original image if no mask exists
143 | self.display_image(self.current_image)
144 |
145 | def display_image(self, image):
146 | """Display the given image in the QLabel, resizing it to fit the QLabel size."""
147 | rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
148 | self.set_image_to_label(rgb_image)
149 |
150 | def plot_mask(self, image_name, mask_data):
151 | # Load image from the given path
152 | image_path = os.path.join(self.img_dir, image_name)
153 | image = cv2.imread(image_path)
154 | if image is None:
155 | QMessageBox.warning(self, "Error", f"Failed to load image {image_name}")
156 | return
157 |
158 | # Create a binary mask where black regions (value 0) become white (255)
159 | black_region_mask = cv2.threshold(mask_data, 0, 255, cv2.THRESH_BINARY_INV)[1]
160 | # Copy the original image to create an overlay
161 | overlayed_image = image.copy()
162 | # Replace pixels in the overlay where the mask is active with blue (BGR: [255, 0, 0])
163 | overlayed_image[black_region_mask == 255] = [255, 0, 0]
164 |
165 | # Draw selected points on the overlay (optional)
166 | for idx, point in enumerate(self.input_points):
167 | color = (0, 255, 0) if self.input_labels[idx] == 1 else (0, 0, 255)
168 | cv2.circle(overlayed_image, (point[0], point[1]), radius=5, color=color, thickness=-1)
169 |
170 | # Convert the overlayed image from BGR to RGB and display it in the QLabel
171 | rgb_image = cv2.cvtColor(overlayed_image, cv2.COLOR_BGR2RGB)
172 | self.set_image_to_label(rgb_image)
173 |
174 | def set_image_to_label(self, rgb_image):
175 | """Resize the image to fit the QLabel width while maintaining aspect ratio and display it."""
176 | label_width = self.image_label.width()
177 | if label_width == 0:
178 | label_width = 400
179 |
180 | h, w, _ = rgb_image.shape
181 | aspect_ratio = h / w # Calculate aspect ratio
182 |
183 | # Calculate new dimensions while maintaining aspect ratio
184 | new_width = label_width
185 | new_height = int(new_width * aspect_ratio)
186 |
187 | # Resize image
188 | resized_image = cv2.resize(rgb_image, (new_width, new_height), interpolation=cv2.INTER_AREA)
189 |
190 | # Convert to QImage and set to QLabel
191 | height, width, channel = resized_image.shape
192 | bytes_per_line = channel * width
193 | q_image = QImage(resized_image.data, width, height, bytes_per_line, QImage.Format_RGB888)
194 | pixmap = QPixmap.fromImage(q_image)
195 |
196 | self.image_label.setPixmap(pixmap)
197 | # Clear the text to remove any previous text
198 | self.image_label.setText("")
199 |
200 | def on_image_clicked(self, event):
201 | """Handle image click events with position correction."""
202 | if self.current_image is not None:
203 | label_pos = self.image_label.mapFromGlobal(event.globalPos())
204 | pixmap = self.image_label.pixmap()
205 | if pixmap is None:
206 | return
207 |
208 | pixmap_width = pixmap.width()
209 | pixmap_height = pixmap.height()
210 | label_width = self.image_label.width()
211 | label_height = self.image_label.height()
212 | offset_x = (label_width - pixmap_width) / 2
213 | offset_y = (label_height - pixmap_height) / 2
214 |
215 | if (offset_x <= label_pos.x() <= offset_x + pixmap_width) and \
216 | (offset_y <= label_pos.y() <= offset_y + pixmap_height):
217 |
218 | scale_x = self.current_image.shape[1] / pixmap_width
219 | scale_y = self.current_image.shape[0] / pixmap_height
220 |
221 | corrected_x = int((label_pos.x() - offset_x) * scale_x)
222 | corrected_y = int((label_pos.y() - offset_y) * scale_y)
223 | corrected_x = max(0, min(corrected_x, self.current_image.shape[1] - 1))
224 | corrected_y = max(0, min(corrected_y, self.current_image.shape[0] - 1))
225 |
226 | self.input_points.append([corrected_x, corrected_y])
227 | self.input_labels.append(self.label_toggle)
228 |
229 | # Toggle the label for the next point (1→0, 0→1)
230 | self.label_toggle = 1 - self.label_toggle
231 | self.generate_mask()
232 |
233 | def process_single_image(self, image, image_name, point_coords, point_labels):
234 | """Generate a mask for a single image and save it to the specified mask directory."""
235 | self.predictor.set_image(image)
236 |
237 | # Generate mask with SAM2
238 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
239 | masks, scores, logits = self.predictor.predict(
240 | point_coords=point_coords,
241 | point_labels=point_labels,
242 | multimask_output=False
243 | )
244 |
245 | mask_output_path = os.path.join(self.mask_dir, f"{image_name}.png")
246 | inverted_mask = 1 - masks[0]
247 | mask_to_save = (inverted_mask * 255).astype(np.uint8)
248 | cv2.imwrite(mask_output_path, mask_to_save)
249 | guru.info(f"Mask saved to {mask_output_path}")
250 |
251 | return inverted_mask * 255
252 |
253 | def generate_mask(self):
254 | """Generate and display the mask."""
255 | if self.current_image is None or not self.input_points:
256 | return
257 |
258 | point_coords = np.array(self.input_points)
259 | point_labels = np.array(self.input_labels)
260 | mask = self.process_single_image(
261 | self.current_image, self.image_name, point_coords, point_labels
262 | )
263 | self.current_mask = mask
264 | self.display_image_with_mask()
265 |
266 | def display_image_with_mask(self):
267 | # Check if current image and mask exist
268 | if self.current_image is not None and self.current_mask is not None:
269 | # Create a copy of the original image
270 | overlay = self.current_image.copy()
271 | # Create a binary mask where the mask's black areas become white (255)
272 | black_region_mask = cv2.threshold(self.current_mask, 0, 255, cv2.THRESH_BINARY_INV)[1]
273 | # Set the corresponding pixels in the overlay to blue (BGR: [255, 0, 0])
274 | overlay[black_region_mask == 255] = [255, 0, 0]
275 |
276 | # Draw selected points on the overlay
277 | for idx, point in enumerate(self.input_points):
278 | color = (0, 255, 0) if self.input_labels[idx] == 1 else (0, 0, 255)
279 | cv2.circle(overlay, (point[0], point[1]), radius=5, color=color, thickness=-1)
280 |
281 | # Display the modified image
282 | self.display_image(overlay)
283 | elif self.current_image is not None:
284 | self.display_image(self.current_image)
285 |
286 | def reset_mask(self):
287 | """Reset the current mask to a blank white mask and update the display."""
288 | if self.current_image is None:
289 | QMessageBox.warning(self, "Error", "No image loaded to reset the mask.")
290 | return
291 |
292 | self.current_mask = np.ones(self.current_image.shape[:2], dtype=np.uint8) * 255
293 | self.save_current_mask()
294 | self.clear_points()
295 | self.display_image_with_mask()
296 | QMessageBox.information(self, "Reset Mask", "The mask has been reset to a blank state.")
297 |
298 | def save_current_mask(self):
299 | """Save the current mask to a file."""
300 | if self.current_mask is not None:
301 | mask_output_path = os.path.join(self.mask_dir, f"{self.image_name}.png")
302 | mask_to_save = self.current_mask.astype(np.uint8)
303 | cv2.imwrite(mask_output_path, mask_to_save)
304 | guru.info(f"Mask saved to {mask_output_path}")
305 |
306 | def clear_points(self):
307 | """Clear the list of points and labels."""
308 | self.input_points.clear()
309 | self.input_labels.clear()
310 | self.label_toggle = 1
311 |
312 | def next_image(self):
313 | """Move to the next image."""
314 | if self.current_index < len(self.image_list) - 1:
315 | self.current_index += 1
316 | self.load_current_image()
317 | else:
318 | QMessageBox.information(self, "Info", "This is the last image.")
319 |
320 | def prev_image(self):
321 | """Move to the previous image."""
322 | if self.current_index > 0:
323 | self.current_index -= 1
324 | self.load_current_image()
325 | else:
326 | QMessageBox.information(self, "Info", "This is the first image.")
327 |
--------------------------------------------------------------------------------
/app/tabs/images_tab.py:
--------------------------------------------------------------------------------
1 | # images_tab.py
2 |
3 | import os
4 | import json
5 | from PyQt5.QtWidgets import (
6 | QWidget, QVBoxLayout, QHBoxLayout, QSplitter,
7 | QLabel, QPushButton, QTreeWidget, QTableWidget, QTableWidgetItem,
8 | QSizePolicy, QMessageBox, QApplication, QDialog
9 | )
10 | from PyQt5.QtCore import Qt
11 | from PyQt5.QtGui import QPixmap
12 |
13 | from app.base_tab import BaseTab
14 | from app.camera_models import CameraModelManager
15 | from app.image_processing import ImageProcessor, ResolutionDialog, ExifExtractProgressDialog
16 |
17 | class ImagesTab(BaseTab):
18 | """Images tab implementation"""
19 | def __init__(self, workdir=None, image_list=None, parent=None):
20 | super().__init__(workdir, image_list, parent)
21 | self.image_viewer = None
22 | self.exif_table = None
23 | self.camera_image_tree = None
24 | self.camera_model_manager = None
25 | self.image_processor = None
26 |
27 | # Always set up the basic UI structure
28 | self.setup_basic_ui()
29 |
30 | # Initialize managers if workdir is available
31 | if self.workdir:
32 | try:
33 | self.camera_model_manager = CameraModelManager(self.workdir)
34 | self.image_processor = ImageProcessor(self.workdir)
35 | except Exception as e:
36 | print(f"Error initializing managers: {e}")
37 |
38 | def get_tab_name(self):
39 | return "Images"
40 |
41 | def setup_basic_ui(self):
42 | """Set up the basic UI structure without data initialization"""
43 | layout = self.create_horizontal_splitter()
44 |
45 | # Left side: Camera and image tree
46 | self.camera_image_tree = QTreeWidget()
47 | self.camera_image_tree.setHeaderLabel("Cameras and Images")
48 | self.camera_image_tree.setFixedWidth(250)
49 | layout.addWidget(self.camera_image_tree)
50 |
51 | # Right side: Image viewer, EXIF data, buttons
52 | right_widget = QWidget()
53 | right_layout = QVBoxLayout()
54 | right_layout.setContentsMargins(0, 0, 0, 0)
55 |
56 | # Image viewer
57 | self.image_viewer = QLabel("Image Viewer")
58 | self.image_viewer.setAlignment(Qt.AlignCenter)
59 | self.image_viewer.setMinimumHeight(300)
60 | right_layout.addWidget(self.image_viewer, stretch=3)
61 |
62 | # EXIF data table
63 | self.exif_table = QTableWidget()
64 | self.exif_table.setColumnCount(2)
65 | self.exif_table.setHorizontalHeaderLabels(["Field", "Value"])
66 | self.exif_table.horizontalHeader().setStretchLastSection(True)
67 | right_layout.addWidget(self.exif_table, stretch=2)
68 |
69 | # Buttons
70 | button_widget = QWidget()
71 | button_layout = QHBoxLayout()
72 | button_layout.setContentsMargins(5, 5, 5, 5)
73 | button_layout.setSpacing(10)
74 |
75 | self.camera_model_button = QPushButton("Edit Camera Models")
76 | self.camera_model_button.clicked.connect(self.open_camera_model_editor)
77 |
78 | self.resize_button = QPushButton("Change Resolution")
79 | self.resize_button.clicked.connect(self.resize_images_in_folder)
80 |
81 | self.restore_button = QPushButton("Restore Images")
82 | self.restore_button.clicked.connect(self.restore_original_images)
83 |
84 | # Add stretch to center the buttons
85 | button_layout.addStretch(1)
86 | button_layout.addWidget(self.camera_model_button)
87 | button_layout.addWidget(self.resize_button)
88 | button_layout.addWidget(self.restore_button)
89 | button_layout.addStretch(1)
90 |
91 | button_widget.setLayout(button_layout)
92 |
93 | # Make button widget expand horizontally
94 | button_widget.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
95 |
96 | right_layout.addWidget(button_widget, stretch=0)
97 |
98 | right_widget.setLayout(right_layout)
99 | layout.addWidget(right_widget)
100 |
101 | layout.setStretchFactor(0, 1)
102 | layout.setStretchFactor(1, 4)
103 |
104 | self._layout.addWidget(layout)
105 |
106 | # Connect signals
107 | self.camera_image_tree.itemClicked.connect(self.display_image_and_exif)
108 |
109 | def initialize(self):
110 | """Initialize the Images tab with data"""
111 | # Basic UI is already set up in __init__, so we only need to initialize data
112 | if self.workdir:
113 | # Initialize managers if not already initialized
114 | if self.camera_model_manager is None:
115 | try:
116 | self.camera_model_manager = CameraModelManager(self.workdir)
117 | except Exception as e:
118 | QMessageBox.critical(self, "Error", f"Failed to initialize camera model manager: {e}")
119 |
120 | if self.image_processor is None:
121 | try:
122 | self.image_processor = ImageProcessor(self.workdir)
123 | except Exception as e:
124 | QMessageBox.critical(self, "Error", f"Failed to initialize image processor: {e}")
125 |
126 | # Populate the camera tree
127 | try:
128 | self.setup_camera_image_tree(self.camera_image_tree)
129 | except Exception as e:
130 | QMessageBox.critical(self, "Error", f"Failed to populate camera tree: {e}")
131 |
132 | self.is_initialized = True
133 |
134 | def display_image_and_exif(self, item, column):
135 | """Display the selected image and its EXIF data"""
136 | if not self.is_initialized:
137 | # Make sure we're initialized before attempting to display data
138 | self.initialize()
139 |
140 | if item.childCount() == 0 and item.parent() is not None:
141 | image_name = item.text(0)
142 | image_path = os.path.join(self.workdir, "images", image_name)
143 | exif_path = os.path.join(self.workdir, "exif", image_name + '.exif')
144 |
145 | # Display image
146 | if os.path.exists(image_path):
147 | pixmap = QPixmap(image_path)
148 | self.image_viewer.setPixmap(pixmap.scaled(
149 | self.image_viewer.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))
150 | else:
151 | self.image_viewer.setText("Image not found.")
152 |
153 | # Display EXIF data
154 | if os.path.exists(exif_path):
155 | with open(exif_path, 'r') as f:
156 | exif_data = json.load(f)
157 | self.display_exif_data(exif_data)
158 | else:
159 | self.exif_table.setRowCount(0)
160 | self.exif_table.setHorizontalHeaderLabels(["Field", "Value"])
161 | self.exif_table.clearContents()
162 | self.exif_table.setRowCount(1)
163 | self.exif_table.setItem(0, 0, QTableWidgetItem("Error"))
164 | self.exif_table.setItem(0, 1, QTableWidgetItem("EXIF data not found."))
165 |
166 | def display_exif_data(self, exif_data):
167 | """Display EXIF data in the table, applying overrides if available"""
168 | self.exif_table.setRowCount(0)
169 | self.exif_table.setHorizontalHeaderLabels(["Field", "Value"])
170 | self.exif_table.clearContents()
171 |
172 | # Apply camera_models_overrides.json settings if available
173 | camera_name = exif_data.get("camera", "Unknown Camera")
174 | overrides = {}
175 | if self.camera_model_manager:
176 | try:
177 | camera_models = self.camera_model_manager.get_camera_models()
178 | overrides = camera_models.get(camera_name, {})
179 | except Exception as e:
180 | print(f"Error getting camera models: {e}")
181 |
182 | # Define fields to display and their order
183 | fields = [
184 | "make",
185 | "model",
186 | "width",
187 | "height",
188 | "projection_type",
189 | "focal_ratio",
190 | "orientation",
191 | "capture_time",
192 | "gps",
193 | "camera"
194 | ]
195 |
196 | for i, field in enumerate(fields):
197 | self.exif_table.insertRow(i)
198 | key_item = QTableWidgetItem(field)
199 |
200 | # Get value from EXIF data or overrides
201 | value = overrides.get(field, exif_data.get(field, "N/A"))
202 |
203 | # Convert dictionary to string
204 | if isinstance(value, dict):
205 | value = json.dumps(value)
206 | elif isinstance(value, float):
207 | value = f"{value:.2f}"
208 |
209 | value_item = QTableWidgetItem(str(value))
210 | self.exif_table.setItem(i, 0, key_item)
211 | self.exif_table.setItem(i, 1, value_item)
212 |
213 | def open_camera_model_editor(self):
214 | """Open the camera model editor dialog"""
215 | if not self.is_initialized:
216 | self.initialize()
217 |
218 | # Check if camera model manager exists, if not, try to create it
219 | if self.camera_model_manager is None and self.workdir:
220 | try:
221 | self.camera_model_manager = CameraModelManager(self.workdir)
222 | except Exception as e:
223 | QMessageBox.critical(self, "Error", f"Failed to initialize camera model manager: {e}")
224 | return
225 |
226 | if self.camera_model_manager:
227 | # Ensure camera_models.json exists
228 | camera_models_path = os.path.join(self.workdir, "camera_models.json")
229 | if not os.path.exists(camera_models_path):
230 | # Create a default camera model if needed
231 | try:
232 | with open(camera_models_path, 'w') as f:
233 | default_model = {
234 | "Perspective": {
235 | "projection_type": "perspective",
236 | "width": 1920,
237 | "height": 1080,
238 | "focal_ratio": 1.0
239 | }
240 | }
241 | json.dump(default_model, f, indent=4)
242 | except Exception as e:
243 | QMessageBox.critical(self, "Error", f"Failed to create default camera model: {e}")
244 | return
245 |
246 | # Reload the camera model manager
247 | try:
248 | self.camera_model_manager = CameraModelManager(self.workdir)
249 | except Exception as e:
250 | QMessageBox.critical(self, "Error", f"Failed to reinitialize camera model manager: {e}")
251 | return
252 |
253 | # Now open the editor
254 | try:
255 | self.camera_model_manager.open_camera_model_editor(parent=self)
256 | except Exception as e:
257 | QMessageBox.critical(self, "Error", f"Failed to open camera model editor: {e}")
258 | else:
259 | QMessageBox.warning(self, "Error", "Camera Model Manager is not initialized and cannot be created.")
260 |
261 | def resize_images_in_folder(self):
262 | """Open dialog to resize images"""
263 | if not self.is_initialized:
264 | self.initialize()
265 |
266 | # Check if image processor exists, if not, try to create it
267 | if self.image_processor is None and self.workdir:
268 | try:
269 | self.image_processor = ImageProcessor(self.workdir)
270 | except Exception as e:
271 | QMessageBox.critical(self, "Error", f"Failed to initialize image processor: {e}")
272 | return
273 |
274 | if not self.image_processor:
275 | QMessageBox.warning(self, "Error", "Image Processor is not initialized and cannot be created.")
276 | return
277 |
278 | # Get dimensions of a sample image
279 | width, height = self.image_processor.get_sample_image_dimensions()
280 | if width is None or height is None:
281 | QMessageBox.warning(self, "Error", "No images found to determine default resolution.")
282 | return
283 |
284 | # Show resolution dialog
285 | dialog = ResolutionDialog(width, height, parent=self)
286 | if dialog.exec_() != QDialog.Accepted:
287 | return
288 |
289 | method, value = dialog.get_values()
290 |
291 | # Show progress dialog
292 | progress_dialog = ExifExtractProgressDialog("Resizing images...", self)
293 | progress_dialog.show()
294 | QApplication.processEvents()
295 |
296 | try:
297 | # Resize images
298 | self.image_processor.resize_images(method, value)
299 | except Exception as e:
300 | QMessageBox.warning(self, "Error", f"Error during resizing: {e}")
301 | finally:
302 | progress_dialog.close()
303 |
304 | QMessageBox.information(self, "Completed", "Images resized successfully!")
305 |
306 | def restore_original_images(self):
307 | """Restore original images from backup"""
308 | if not self.is_initialized:
309 | self.initialize()
310 |
311 | # Check if image processor exists, if not, try to create it
312 | if self.image_processor is None and self.workdir:
313 | try:
314 | self.image_processor = ImageProcessor(self.workdir)
315 | except Exception as e:
316 | QMessageBox.critical(self, "Error", f"Failed to initialize image processor: {e}")
317 | return
318 |
319 | if not self.image_processor:
320 | QMessageBox.warning(self, "Error", "Image Processor is not initialized and cannot be created.")
321 | return
322 |
323 | result = self.image_processor.restore_original_images()
324 | if result:
325 | QMessageBox.information(self, "Restored", "Original images restored successfully!")
326 | else:
327 | QMessageBox.warning(self, "Error", "No original backup images found.")
328 |
329 | def refresh(self):
330 | """Refresh the tab contents"""
331 | if self.is_initialized:
332 | if self.camera_model_manager:
333 | try:
334 | self.camera_model_manager = CameraModelManager(self.workdir)
335 | except Exception as e:
336 | print(f"Error refreshing camera model manager: {e}")
337 |
338 | if self.image_processor:
339 | try:
340 | self.image_processor = ImageProcessor(self.workdir)
341 | except Exception as e:
342 | print(f"Error refreshing image processor: {e}")
343 |
344 | # Refresh camera image tree
345 | if self.camera_image_tree:
346 | try:
347 | self.setup_camera_image_tree(self.camera_image_tree)
348 | except Exception as e:
349 | print(f"Error refreshing camera image tree: {e}")
350 |
--------------------------------------------------------------------------------