├── printguard ├── __init__.py ├── utils │ ├── __init__.py │ ├── backends │ │ ├── __init__.py │ │ ├── protonets │ │ │ ├── __init__.py │ │ │ └── models │ │ │ │ ├── __init__.py │ │ │ │ └── few_shot.py │ │ └── pytorch_engine.py │ ├── printer_services │ │ ├── __init__.py │ │ └── octoprint.py │ ├── model_utils.py │ ├── inference_lib.py │ ├── alert_utils.py │ ├── setup_utils.py │ ├── detection_utils.py │ ├── notification_utils.py │ ├── printer_utils.py │ ├── camera_utils.py │ ├── sse_utils.py │ ├── shared_video_stream.py │ ├── camera_state_manager.py │ ├── inference_engine.py │ └── model_downloader.py ├── routes │ ├── __init__.py │ ├── sse_routes.py │ ├── alert_routes.py │ ├── detection_routes.py │ ├── printer_routes.py │ ├── notification_routes.py │ ├── camera_routes.py │ └── index_routes.py ├── requirements.txt ├── static │ ├── js │ │ ├── utils.js │ │ ├── sw.js │ │ ├── cloudflare_setup.js │ │ ├── notifications.js │ │ └── sse.js │ └── css │ │ └── universal.css ├── templates │ └── cloudflare_setup.html ├── app.py └── models.py ├── docs ├── media │ └── images │ │ ├── interface-index.png │ │ ├── interface-setup-settings.png │ │ ├── data-communication-diagram.png │ │ ├── interface-camera-settings.png │ │ └── interface-alerts-notifications.png └── setup.md ├── .gitignore ├── Dockerfile ├── pyproject.toml ├── .dockerignore ├── .github └── workflows │ ├── python-publish.yml │ └── docker-publish.yml ├── README.md └── scripts └── convert_pytorch_to_onnx.py /printguard/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /printguard/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /printguard/routes/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /printguard/utils/backends/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /printguard/utils/backends/protonets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /printguard/utils/printer_services/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /printguard/utils/backends/protonets/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/media/images/interface-index.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oliverbravery/PrintGuard/HEAD/docs/media/images/interface-index.png -------------------------------------------------------------------------------- /docs/media/images/interface-setup-settings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oliverbravery/PrintGuard/HEAD/docs/media/images/interface-setup-settings.png -------------------------------------------------------------------------------- /docs/media/images/data-communication-diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oliverbravery/PrintGuard/HEAD/docs/media/images/data-communication-diagram.png -------------------------------------------------------------------------------- /docs/media/images/interface-camera-settings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oliverbravery/PrintGuard/HEAD/docs/media/images/interface-camera-settings.png -------------------------------------------------------------------------------- /docs/media/images/interface-alerts-notifications.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oliverbravery/PrintGuard/HEAD/docs/media/images/interface-alerts-notifications.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | __pycache__ 3 | success 4 | failure 5 | build 6 | *.egg-info 7 | !.github 8 | /models 9 | *.pte 10 | model/ 11 | data/ 12 | .vscode/ 13 | .DS_Store -------------------------------------------------------------------------------- /printguard/utils/backends/protonets/models/few_shot.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class Protonet(nn.Module): 4 | def __init__(self, encoder): 5 | super(Protonet, self).__init__() 6 | self.encoder = encoder 7 | -------------------------------------------------------------------------------- /printguard/requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.116.1 2 | uvicorn[standard]==0.35.0 3 | pywebpush==2.0.3 4 | apscheduler==3.11.0 5 | python-dotenv==1.1.1 6 | torch==2.7.0 7 | torchvision==0.22.0 8 | pillow==11.3.0 9 | opencv-python-headless==4.12.0.88 10 | python-multipart==0.0.20 11 | keyring==25.6.0 12 | trustme==1.2.1 13 | platformdirs==4.3.8 14 | sse-starlette==2.4.1 15 | ngrok==1.4.0 16 | cryptography==45.0.5 17 | numpy==2.2.6 18 | onnxruntime==1.22.1 19 | huggingface_hub==0.33.4 -------------------------------------------------------------------------------- /printguard/static/js/utils.js: -------------------------------------------------------------------------------- 1 | export { render_ascii_title }; 2 | 3 | function render_ascii_title(doc_element, text) { 4 | figlet.defaults({ fontPath: '/static/fonts/' }); 5 | figlet.text(text, { 6 | font: 'Big Money-ne', 7 | horizontalLayout: 'default', 8 | verticalLayout: 'default' 9 | }, function(err, data) { 10 | if (err) { 11 | console.error(err); 12 | return; 13 | } 14 | doc_element.textContent = data; 15 | }); 16 | } -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1.4 2 | FROM --platform=$BUILDPLATFORM python:3.11-slim-bookworm AS base 3 | 4 | RUN apt-get update \ 5 | && apt-get install -y --no-install-recommends \ 6 | build-essential python3-dev libffi-dev \ 7 | libjpeg-dev zlib1g-dev libtiff-dev \ 8 | libfreetype6-dev libwebp-dev libopenjp2-7-dev \ 9 | libgomp1 \ 10 | ffmpeg libgl1 \ 11 | && rm -rf /var/lib/apt/lists/* 12 | 13 | WORKDIR /printguard 14 | COPY . /printguard 15 | 16 | RUN pip install --upgrade pip \ 17 | && pip install . 18 | 19 | FROM --platform=$TARGETPLATFORM python:3.11-slim-bookworm AS runtime 20 | 21 | COPY --from=base /usr/local /usr/local 22 | 23 | WORKDIR /printguard 24 | COPY --from=base /printguard /printguard 25 | 26 | EXPOSE 8000 27 | VOLUME ["/data"] 28 | ENTRYPOINT ["printguard"] -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "printguard" 7 | version = "1.0.0b3" 8 | description = "PrintGuard - Real-time Defect Detection on Edge-devices" 9 | authors = [ 10 | { name = "Oliver Bravery", email = "dev@oliverbravery.uk" } 11 | ] 12 | readme = "README.md" 13 | requires-python = ">=3.11" 14 | 15 | dependencies = [ 16 | "fastapi==0.116.1", 17 | "uvicorn[standard]==0.35.0", 18 | "pywebpush==2.0.3", 19 | "apscheduler==3.11.0", 20 | "python-dotenv==1.1.1", 21 | "torch==2.7.0", 22 | "torchvision==0.22.0", 23 | "pillow==11.3.0", 24 | "opencv-python-headless==4.12.0.88", 25 | "python-multipart==0.0.20", 26 | "keyring==25.6.0", 27 | "trustme==1.2.1", 28 | "platformdirs==4.3.8", 29 | "sse-starlette==2.4.1", 30 | "ngrok==1.4.0", 31 | "cryptography==45.0.5", 32 | "numpy==2.2.6", 33 | "onnxruntime==1.22.1", 34 | "huggingface_hub==0.33.4" 35 | ] 36 | 37 | [project.scripts] 38 | printguard = "printguard.app:run" 39 | 40 | [tool.black] 41 | line-length = 88 42 | 43 | [tool.setuptools] 44 | include-package-data = true 45 | 46 | [tool.setuptools.packages.find] 47 | include = ["printguard", "printguard.*"] 48 | 49 | [tool.setuptools.package-data] 50 | "printguard" = [ 51 | "static/**/*", 52 | "templates/**/*", 53 | "model/**/*" 54 | ] -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Python artifacts 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | 22 | # Virtual environments 23 | .env 24 | .venv 25 | env/ 26 | venv/ 27 | ENV/ 28 | env.bak/ 29 | venv.bak/ 30 | 31 | # IDE and editor files 32 | .vscode/ 33 | .idea/ 34 | *.swp 35 | *.swo 36 | *~ 37 | 38 | # OS generated files 39 | .DS_Store 40 | .DS_Store? 41 | ._* 42 | .Spotlight-V100 43 | .Trashes 44 | ehthumbs.db 45 | Thumbs.db 46 | 47 | # Git 48 | .git/ 49 | .gitignore 50 | .gitattributes 51 | 52 | # Docker 53 | .dockerignore 54 | Dockerfile* 55 | docker-compose*.yml 56 | 57 | # Documentation 58 | docs/ 59 | README.md 60 | *.md 61 | 62 | # Scripts and development tools 63 | scripts/ 64 | 65 | # Test directories and files 66 | tests/ 67 | test/ 68 | *_test.py 69 | test_*.py 70 | 71 | # Model files (large binaries that shouldn't be in container) 72 | printguard/model/ 73 | models/ 74 | *.onnx 75 | *.pth 76 | *.pt 77 | *.pte 78 | *.safetensors 79 | 80 | # Cache directories 81 | .cache/ 82 | *.cache/ 83 | cache/ 84 | 85 | # Build artifacts 86 | success/ 87 | failure/ 88 | data/ 89 | 90 | # Temporary files 91 | tmp/ 92 | temp/ 93 | *.tmp 94 | *.log 95 | 96 | # Environment and config files that shouldn't be in container 97 | .env.local 98 | .env.development 99 | .env.test 100 | .env.production -------------------------------------------------------------------------------- /printguard/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any 3 | import logging 4 | 5 | from .config import SENSITIVITY 6 | from .inference_lib import get_inference_engine 7 | 8 | async def _run_inference(model: Any, 9 | batch_tensor: Any, 10 | prototypes: Any, 11 | defect_idx: int, 12 | device: Any) -> Any: 13 | """Run model inference on a batch of image tensors. 14 | 15 | Args: 16 | model (Any): The neural network model to use. 17 | batch_tensor (Any): Batch of preprocessed image tensors. 18 | prototypes (Any): Class prototype tensors for comparison. 19 | defect_idx (int): Index of the defect class. 20 | device (Any): Device to run inference on. 21 | 22 | Returns: 23 | Any: Inference results (typically class predictions). 24 | 25 | Raises: 26 | TypeError: If the model doesn't have required methods. 27 | RuntimeError: If inference execution fails. 28 | """ 29 | inference_engine = get_inference_engine() 30 | loop = asyncio.get_running_loop() 31 | try: 32 | results = await loop.run_in_executor( 33 | None, 34 | inference_engine.predict_batch, 35 | model, 36 | batch_tensor, 37 | prototypes, 38 | defect_idx, 39 | SENSITIVITY, 40 | str(device) 41 | ) 42 | return results 43 | except Exception as e: 44 | logging.error("Error during inference execution: %s", e) 45 | raise RuntimeError(f"Inference execution failed: {e}") from e 46 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package to PyPI when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | workflow_dispatch: 15 | 16 | permissions: 17 | contents: read 18 | 19 | jobs: 20 | release-build: 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | 26 | - uses: actions/setup-python@v5 27 | with: 28 | python-version: "3.11" 29 | 30 | - name: Build release distributions 31 | run: | 32 | python -m pip install build 33 | python -m build 34 | 35 | - name: Upload distributions 36 | uses: actions/upload-artifact@v4 37 | with: 38 | name: release-dists 39 | path: dist/ 40 | 41 | pypi-publish: 42 | runs-on: ubuntu-latest 43 | needs: 44 | - release-build 45 | permissions: 46 | id-token: write 47 | 48 | environment: 49 | name: pypi 50 | url: https://pypi.org/p/PrintGuard 51 | 52 | steps: 53 | - name: Retrieve release distributions 54 | uses: actions/download-artifact@v4 55 | with: 56 | name: release-dists 57 | path: dist/ 58 | 59 | - name: Publish release distributions to PyPI 60 | uses: pypa/gh-action-pypi-publish@release/v1 61 | with: 62 | packages-dir: dist/ 63 | -------------------------------------------------------------------------------- /printguard/utils/inference_lib.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional 3 | from .inference_engine import UniversalInferenceEngine, InferenceBackend 4 | 5 | _inference_engine: Optional[UniversalInferenceEngine] = None 6 | 7 | def _detect_backend() -> InferenceBackend: 8 | """Detect the best available backend based on installed packages.""" 9 | # Check for ONNX Runtime (optimized backend) 10 | try: 11 | import onnxruntime 12 | logging.info("ONNX Runtime detected, using ONNX Runtime backend") 13 | return InferenceBackend.ONNXRUNTIME 14 | except ImportError: 15 | pass 16 | # Check for PyTorch (fallback backend) 17 | try: 18 | import torch 19 | logging.info("PyTorch detected, using PyTorch backend") 20 | return InferenceBackend.PYTORCH 21 | except ImportError: 22 | pass 23 | logging.warning("No specific backend detected, defaulting to PyTorch") 24 | return InferenceBackend.PYTORCH 25 | 26 | 27 | def get_inference_engine() -> UniversalInferenceEngine: 28 | """Get or create the global inference engine instance.""" 29 | # pylint: disable=import-outside-toplevel 30 | from .model_downloader import ensure_model_files 31 | # pylint: disable=global-statement 32 | global _inference_engine 33 | if _inference_engine is None: 34 | backend = _detect_backend() 35 | try: 36 | if not ensure_model_files(backend): 37 | logging.warning("Failed to download model files for %s backend", backend.value) 38 | except ImportError: 39 | logging.warning("Model downloader not available, assuming models are present") 40 | _inference_engine = UniversalInferenceEngine(backend) 41 | logging.info("Created inference engine with %s backend", backend.value) 42 | return _inference_engine 43 | -------------------------------------------------------------------------------- /printguard/routes/sse_routes.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Request, Body 2 | from sse_starlette.sse import EventSourceResponse 3 | from ..utils.sse_utils import outbound_packet_fetch, stop_and_remove_polling_task 4 | from ..utils.printer_utils import start_printer_state_polling 5 | 6 | router = APIRouter() 7 | 8 | @router.get("/sse") 9 | async def sse_connect(request: Request): 10 | """Establish Server-Sent Events connection for real-time updates. 11 | 12 | Args: 13 | request (Request): The FastAPI request object for connection management. 14 | 15 | Returns: 16 | EventSourceResponse: SSE stream for real-time data updates. 17 | """ 18 | async def send_packet(): 19 | async for packet in outbound_packet_fetch(): 20 | if await request.is_disconnected(): 21 | break 22 | yield packet 23 | return EventSourceResponse(send_packet()) 24 | 25 | @router.post("/sse/start-polling") 26 | async def start_polling(request: Request, camera_uuid: str = Body(..., embed=True)): 27 | """Start polling for printer state updates on a specific camera. 28 | 29 | Args: 30 | request (Request): The FastAPI request object. 31 | camera_uuid (str): UUID of the camera to start polling for. 32 | 33 | Returns: 34 | dict: Confirmation message that polling was started. 35 | """ 36 | await start_printer_state_polling(camera_uuid) 37 | return {"message": "Polling started for camera UUID {}".format(camera_uuid)} 38 | 39 | @router.post("/sse/stop-polling") 40 | async def stop_polling(request: Request, camera_uuid: str = Body(..., embed=True)): 41 | """Stop polling for printer state updates on a specific camera. 42 | 43 | Args: 44 | request (Request): The FastAPI request object. 45 | camera_uuid (str): UUID of the camera to stop polling for. 46 | 47 | Returns: 48 | dict: Confirmation message that polling was stopped. 49 | """ 50 | stop_and_remove_polling_task(camera_uuid) 51 | return {"message": "Polling stopped for camera UUID {}".format(camera_uuid)} 52 | -------------------------------------------------------------------------------- /printguard/static/js/sw.js: -------------------------------------------------------------------------------- 1 | // Service worker for PrintGuard Push Notifications 2 | self.addEventListener('install', (event) => { 3 | self.skipWaiting(); 4 | }); 5 | 6 | self.addEventListener('activate', (event) => { 7 | event.waitUntil(self.clients.claim()); 8 | }); 9 | 10 | self.addEventListener('push', event => { 11 | let notificationTitle = 'PrintGuard Alert'; 12 | let notificationBody = 'Print detection alert!'; 13 | if (event.data) { 14 | try { 15 | const jsonData = event.data.json(); 16 | if (typeof jsonData === 'object' && jsonData !== null) { 17 | notificationBody = jsonData.body || jsonData.message || JSON.stringify(jsonData); 18 | if(jsonData.title) notificationTitle = jsonData.title; 19 | } else { 20 | notificationBody = jsonData; 21 | } 22 | } catch (e) { 23 | console.warn('Failed to parse payload directly as JSON, trying as text:', e); 24 | const textData = event.data.text(); 25 | try { 26 | const parsedTextData = JSON.parse(textData); 27 | if (typeof parsedTextData === 'object' && parsedTextData !== null) { 28 | notificationBody = parsedTextData.body || parsedTextData.message || JSON.stringify(parsedTextData); 29 | if(parsedTextData.title) notificationTitle = parsedTextData.title; 30 | } else { 31 | notificationBody = parsedTextData; 32 | } 33 | } catch (e2) { 34 | console.warn('Failed to parse text data as JSON, using text data as body:', e2); 35 | notificationBody = textData; 36 | } 37 | } 38 | } 39 | 40 | event.waitUntil( 41 | self.registration.showNotification(notificationTitle, { 42 | body: notificationBody, 43 | vibrate: [100, 50, 100], 44 | timestamp: Date.now(), 45 | requireInteraction: true 46 | }).catch(err => { 47 | console.error('Error showing notification:', err); 48 | }) 49 | ); 50 | }); 51 | 52 | self.addEventListener('notificationclick', event => { 53 | event.notification.close(); 54 | event.waitUntil( 55 | clients.openWindow('/') 56 | ); 57 | }); 58 | 59 | -------------------------------------------------------------------------------- /printguard/routes/alert_routes.py: -------------------------------------------------------------------------------- 1 | import json 2 | from fastapi import APIRouter, Body, Request 3 | from ..models import AlertAction 4 | from ..utils.alert_utils import (alert_to_response_json, dismiss_alert, 5 | get_alert) 6 | from ..utils.printer_utils import suspend_print_job 7 | 8 | router = APIRouter() 9 | 10 | @router.post("/alert/dismiss") 11 | async def alert_response(request: Request, 12 | alert_id: str = Body(..., embed=True), 13 | action: AlertAction = Body(..., embed=True)): 14 | """Handle alert response actions including dismiss, cancel, and pause. 15 | 16 | Args: 17 | request (Request): The FastAPI request object. 18 | alert_id (str): Unique identifier of the alert to act upon. 19 | action (AlertAction): The action to perform on the alert. 20 | 21 | Returns: 22 | dict: Response containing the result of the action or error message. 23 | """ 24 | alert = get_alert(alert_id) 25 | camera_uuid = alert.camera_uuid if alert else None 26 | if not alert or camera_uuid is None: 27 | return {"message": f"Alert {alert_id} not found."} 28 | response = None 29 | match action: 30 | case AlertAction.DISMISS: 31 | response = await dismiss_alert(alert_id) 32 | case AlertAction.CANCEL_PRINT | AlertAction.PAUSE_PRINT: 33 | suspend_print_job(camera_uuid, action) 34 | return await dismiss_alert(alert_id) 35 | if not response: 36 | response = {"message": f"Alert {alert_id} not found."} 37 | return response 38 | 39 | @router.get("/alert/active") 40 | async def get_active_alerts(request: Request): 41 | """Retrieve all currently active alerts. 42 | 43 | Args: 44 | request (Request): The FastAPI request object containing app state. 45 | 46 | Returns: 47 | dict: Dictionary containing a list of active alerts with their details. 48 | """ 49 | alerts = [] 50 | for alert in request.app.state.alerts.values(): 51 | alerts.append(json.loads(alert_to_response_json(alert))) 52 | return {"active_alerts": alerts} 53 | -------------------------------------------------------------------------------- /printguard/utils/alert_utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import json 4 | 5 | from PIL import Image 6 | 7 | from .camera_utils import update_camera_state 8 | 9 | 10 | def append_new_alert(alert): 11 | """Appends a new alert to the application's state. 12 | 13 | Args: 14 | alert (Alert): The alert object to be added. 15 | The alert object should have the following structure: 16 | { 17 | "id": str, 18 | "snapshot": bytes, 19 | "title": str, 20 | "message": str, 21 | "timestamp": float, 22 | "countdown_time": float, 23 | "camera_uuid": str, 24 | "has_printer": bool, 25 | "countdown_action": str 26 | } 27 | """ 28 | # pylint: disable=import-outside-toplevel 29 | from ..app import app 30 | app.state.alerts[alert.id] = alert 31 | 32 | def get_alert(alert_id): 33 | """Retrieves a single alert by its ID from the application's state. 34 | 35 | Args: 36 | alert_id (str): The ID of the alert to retrieve. 37 | 38 | Returns: 39 | Alert: The alert object if found, otherwise None. 40 | """ 41 | # pylint: disable=import-outside-toplevel 42 | from ..app import app 43 | alert = app.state.alerts.get(alert_id, None) 44 | return alert 45 | 46 | async def dismiss_alert(alert_id): 47 | """Dismisses an alert by its ID, removing it from the application's state. 48 | 49 | Args: 50 | alert_id (str): The ID of the alert to dismiss. 51 | 52 | Returns: 53 | bool: True if the alert was successfully dismissed, False otherwise. 54 | """ 55 | # pylint: disable=import-outside-toplevel 56 | from ..app import app 57 | if alert_id in app.state.alerts: 58 | del app.state.alerts[alert_id] 59 | camera_uuid = alert_id.split('_')[0] 60 | await update_camera_state(camera_uuid, {"current_alert_id": None}) 61 | return True 62 | return False 63 | 64 | def alert_to_response_json(alert): 65 | """Converts an Alert object to a JSON string for API responses. 66 | 67 | The snapshot image is base64 encoded within the JSON. 68 | 69 | Args: 70 | alert (Alert): The alert object to convert. 71 | 72 | Returns: 73 | str: A JSON string representing the alert. 74 | The structure is: 75 | { 76 | "id": str, 77 | "snapshot": str (base64 encoded image), 78 | "title": str, 79 | "message": str, 80 | "timestamp": float, 81 | "countdown_time": float, 82 | "camera_uuid": str, 83 | "has_printer": bool, 84 | "countdown_action": str 85 | } 86 | """ 87 | img_bytes = alert.snapshot 88 | if isinstance(img_bytes, str): 89 | img_bytes = base64.b64decode(img_bytes) 90 | buffer = io.BytesIO() 91 | Image.open(io.BytesIO(img_bytes)).save(buffer, format="JPEG") 92 | base64_snapshot = base64.b64encode(buffer.getvalue()).decode("utf-8") 93 | alert_dict = alert.model_dump() 94 | alert_dict['snapshot'] = base64_snapshot 95 | return json.dumps(alert_dict) 96 | -------------------------------------------------------------------------------- /printguard/routes/detection_routes.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import time 4 | 5 | from fastapi import APIRouter, Body, File, HTTPException, Request, UploadFile 6 | from fastapi.responses import StreamingResponse 7 | 8 | from ..utils.camera_utils import get_camera_state, update_camera_state 9 | from ..utils.detection_utils import _live_detection_loop 10 | 11 | router = APIRouter() 12 | 13 | @router.post("/detect/live/start") 14 | async def start_live_detection(request: Request, camera_uuid: str = Body(..., embed=True)): 15 | """Start continuous live detection on a specified camera. 16 | 17 | Args: 18 | request (Request): The FastAPI request object containing app state. 19 | camera_uuid (str): UUID of the camera to start live detection on. 20 | 21 | Returns: 22 | dict: Message indicating whether live detection was started or already running. 23 | """ 24 | camera_state = await get_camera_state(camera_uuid) 25 | if camera_state.live_detection_running: 26 | return {"message": f"Live detection already running for camera {camera_state.nickname}"} 27 | else: 28 | await update_camera_state(camera_uuid, { 29 | "current_alert_id": None, 30 | "detection_history": [], 31 | "last_result": None, 32 | "last_time": None, 33 | "error": None 34 | }) 35 | await update_camera_state(camera_uuid, {"start_time": time.time(), 36 | "live_detection_running": True, 37 | "live_detection_task": asyncio.create_task( 38 | _live_detection_loop(request.app.state, camera_uuid) 39 | )}) 40 | return {"message": f"Live detection started for camera {camera_state.nickname}"} 41 | 42 | @router.post("/detect/live/stop") 43 | async def stop_live_detection(request: Request, camera_uuid: str = Body(..., embed=True)): 44 | """Stop continuous live detection on a specified camera. 45 | 46 | Args: 47 | request (Request): The FastAPI request object containing app state. 48 | camera_uuid (str): UUID of the camera to stop live detection on. 49 | 50 | Returns: 51 | dict: Message indicating whether live detection was stopped or not running. 52 | """ 53 | camera_state = await get_camera_state(camera_uuid) 54 | if not camera_state.live_detection_running: 55 | return {"message": f"Live detection not running for camera {camera_state.nickname}"} 56 | live_detection_task = camera_state.live_detection_task 57 | if live_detection_task: 58 | try: 59 | await asyncio.wait_for(live_detection_task, timeout=0.25) 60 | logging.debug("Live detection task for camera %s finished successfully.", camera_uuid) 61 | except asyncio.TimeoutError: 62 | logging.debug("Live detection task for camera %s did not finish in time.", camera_uuid) 63 | if live_detection_task: 64 | live_detection_task.cancel() 65 | except Exception as e: 66 | logging.error("Error stopping live detection task for camera %s: %s", camera_uuid, e) 67 | finally: 68 | live_detection_task = None 69 | await update_camera_state(camera_uuid, {"start_time": None, 70 | "live_detection_running": False, 71 | "live_detection_task": None}) 72 | return {"message": f"Live detection stopped for camera {camera_state.nickname}"} 73 | -------------------------------------------------------------------------------- /printguard/utils/setup_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from ..models import SavedConfig, SavedKey, SiteStartupMode 4 | from .config import SSL_CERT_FILE, get_config, get_key 5 | 6 | 7 | def setup_ngrok_tunnel(close: bool = False) -> bool: 8 | """ 9 | Start a ngrok tunnel at port 8000 using the provided auth key and domain. 10 | Requirements: 11 | - TUNNEL_API_KEY must be set. 12 | - SITE_DOMAIN must be set. 13 | 14 | Args: 15 | close (bool | optional): If True, disconnect the tunnel after starting. Defaults to False. 16 | 17 | Returns: 18 | bool: True if the tunnel was successfully started, False otherwise. 19 | """ 20 | config = get_config() 21 | tunnel_auth_key = get_key(SavedKey.TUNNEL_API_KEY) 22 | tunnel_domain = config.get(SavedConfig.SITE_DOMAIN, None) 23 | if not tunnel_auth_key and not tunnel_domain: 24 | return False 25 | try: 26 | # pylint: disable=import-outside-toplevel 27 | import ngrok 28 | # pylint: disable=E1101 29 | listener = ngrok.forward(8000, authtoken=tunnel_auth_key, domain=tunnel_domain) 30 | if listener: 31 | if close: 32 | ngrok.disconnect() 33 | return True 34 | else: 35 | return False 36 | except Exception as e: 37 | logging.error("Failed to start ngrok tunnel. Error: %s", e) 38 | return False 39 | 40 | def check_ssl_certificates_exist() -> bool: 41 | """ 42 | Check if SSL certificates exist. 43 | 44 | Requirements: 45 | - SSL private key must exist. 46 | - SITE_DOMAIN must be set. 47 | - SSL_CERT_FILE must be set. 48 | 49 | Returns: 50 | bool: True if SSL requirements exist, False otherwise. 51 | """ 52 | config = get_config() 53 | site_domain = config.get(SavedConfig.SITE_DOMAIN, None) 54 | return True if ( 55 | get_key(SavedKey.SSL_PRIVATE_KEY) 56 | and site_domain 57 | and SSL_CERT_FILE 58 | ) else False 59 | 60 | def check_vapid_keys_exist() -> bool: 61 | """ 62 | Check if VAPID keys exist. 63 | 64 | Requirements: 65 | - VAPID private key must exist. 66 | - VAPID public key must exist. 67 | - VAPID claims must be set. 68 | 69 | Returns: 70 | bool: True if VAPID requirements exist, False otherwise. 71 | """ 72 | config = get_config() 73 | vapid_public_key = config.get(SavedConfig.VAPID_PUBLIC_KEY, None) 74 | vapid_subject = config.get(SavedConfig.VAPID_SUBJECT, None) 75 | return True if ( 76 | get_key(SavedKey.VAPID_PRIVATE_KEY) 77 | and vapid_subject 78 | and vapid_public_key 79 | ) else False 80 | 81 | def check_tunnel_requirements_met() -> bool: 82 | """ 83 | Check if the requirements for the tunnel are met. 84 | 85 | Requirements: 86 | - TUNNEL_PROVIDER must be set. 87 | - Tunnel API keys must exist. 88 | 89 | Returns: 90 | bool: True if tunnel requirements are met, False otherwise. 91 | """ 92 | config = get_config() 93 | tunnel_provider = config.get(SavedConfig.TUNNEL_PROVIDER, None) 94 | return True if ( 95 | tunnel_provider 96 | and get_key(SavedKey.TUNNEL_API_KEY) 97 | ) else False 98 | 99 | def startup_mode_requirements_met() -> SiteStartupMode: 100 | """ 101 | Check if the requirements for the current startup mode are met. 102 | 103 | Returns: 104 | SiteStartupMode: The site startup mode if requirements are met, SETUP otherwise. 105 | """ 106 | startup_mode = get_config().get(SavedConfig.STARTUP_MODE, None) 107 | match startup_mode: 108 | case SiteStartupMode.SETUP: 109 | return SiteStartupMode.SETUP 110 | case SiteStartupMode.LOCAL: 111 | if check_ssl_certificates_exist() and check_vapid_keys_exist(): 112 | return SiteStartupMode.LOCAL 113 | case SiteStartupMode.TUNNEL: 114 | if check_vapid_keys_exist() and check_tunnel_requirements_met(): 115 | return SiteStartupMode.TUNNEL 116 | return SiteStartupMode.SETUP 117 | -------------------------------------------------------------------------------- /printguard/routes/printer_routes.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from fastapi import APIRouter, HTTPException 4 | 5 | from ..models import PrinterConfigRequest, AlertAction 6 | from ..utils.printer_services.octoprint import OctoPrintClient 7 | from ..utils.printer_utils import (get_printer_id, remove_printer, 8 | set_printer, suspend_print_job) 9 | from ..utils.camera_utils import get_camera_state 10 | 11 | router = APIRouter() 12 | 13 | @router.post("/printer/add/{camera_uuid}", include_in_schema=False) 14 | async def add_printer_ep(camera_uuid: str, printer_config: PrinterConfigRequest): 15 | """Add a printer configuration to a specific camera. 16 | 17 | Args: 18 | camera_uuid (str): UUID of the camera to associate the printer with. 19 | printer_config (PrinterConfigRequest): Printer configuration including 20 | base URL, API key, and name. 21 | 22 | Returns: 23 | dict: Success status and generated printer ID, or error details. 24 | 25 | Raises: 26 | HTTPException: If printer connection test fails or configuration is invalid. 27 | """ 28 | try: 29 | client = OctoPrintClient(printer_config.base_url, printer_config.api_key) 30 | client.get_job_info() 31 | printer_id = f"{camera_uuid}_{printer_config.name.replace(' ', '_')}" 32 | await set_printer(camera_uuid, printer_id, printer_config.model_dump()) 33 | return {"success": True, "printer_id": printer_id} 34 | except Exception as e: 35 | logging.error("Error adding printer: %s", e) 36 | raise HTTPException(status_code=500, detail=f"Failed to add printer: {str(e)}") 37 | 38 | @router.post("/printer/remove/{camera_uuid}", include_in_schema=False) 39 | async def remove_printer_ep(camera_uuid: str): 40 | """Remove printer configuration from a specific camera. 41 | 42 | Args: 43 | camera_uuid (str): UUID of the camera to remove printer configuration from. 44 | 45 | Returns: 46 | dict: Success status and confirmation message, or error if no printer configured. 47 | 48 | Raises: 49 | HTTPException: If removal fails due to system errors. 50 | """ 51 | try: 52 | printer_id = get_printer_id(camera_uuid) 53 | if printer_id: 54 | await remove_printer(camera_uuid) 55 | camera_state = await get_camera_state(camera_uuid) 56 | camera_nickname = camera_state.nickname if camera_state else camera_uuid 57 | return {"success": True, "message": f"Printer removed from camera {camera_nickname}"} 58 | else: 59 | return {"success": False, "error": "No printer configured for this camera"} 60 | except Exception as e: 61 | logging.error("Error removing printer from camera %s: %s", camera_uuid, e) 62 | raise HTTPException(status_code=500, detail=f"Failed to remove printer: {str(e)}") 63 | 64 | @router.post("/printer/cancel/{camera_uuid}", include_in_schema=False) 65 | async def cancel_print_job_ep(camera_uuid: str): 66 | """Cancel the current print job for a specific camera's printer. 67 | 68 | Args: 69 | camera_uuid (str): UUID of the camera whose printer job should be cancelled. 70 | 71 | Returns: 72 | dict: Success status and confirmation message. 73 | """ 74 | suspend_print_job(camera_uuid, AlertAction.CANCEL_PRINT) 75 | camera_state = await get_camera_state(camera_uuid) 76 | camera_nickname = camera_state.nickname if camera_state else camera_uuid 77 | return {"success": True, "message": f"Print job cancelled for camera {camera_nickname}"} 78 | 79 | @router.post("/printer/pause/{camera_uuid}", include_in_schema=False) 80 | async def pause_print_job_ep(camera_uuid: str): 81 | """Pause the current print job for a specific camera's printer. 82 | 83 | Args: 84 | camera_uuid (str): UUID of the camera whose printer job should be paused. 85 | 86 | Returns: 87 | dict: Success status and confirmation message. 88 | """ 89 | suspend_print_job(camera_uuid, AlertAction.PAUSE_PRINT) 90 | camera_state = await get_camera_state(camera_uuid) 91 | camera_nickname = camera_state.nickname if camera_state else camera_uuid 92 | return {"success": True, "message": f"Print job paused for camera {camera_nickname}"} 93 | -------------------------------------------------------------------------------- /printguard/static/js/cloudflare_setup.js: -------------------------------------------------------------------------------- 1 | document.addEventListener('DOMContentLoaded', async function() { 2 | const enrollmentUrl = `${window.location.protocol}//${window.location.hostname}:${window.location.port || (window.location.protocol === 'https:' ? '443' : '8000')}/setup/cloudflare/add-device`; 3 | document.getElementById('enrollment-url').textContent = enrollmentUrl; 4 | await fetchWarpConfig(); 5 | const screenWidth = window.innerWidth; 6 | let qrSize; 7 | if (screenWidth <= 480) { 8 | qrSize = 150; 9 | } else if (screenWidth <= 768) { 10 | qrSize = 180; 11 | } else { 12 | qrSize = 200; 13 | } 14 | 15 | const qr = new QRious({ 16 | element: document.getElementById('qr-code'), 17 | value: enrollmentUrl, 18 | size: qrSize, 19 | level: 'M', 20 | background: 'white', 21 | foreground: 'black', 22 | padding: 2 23 | }); 24 | try { 25 | qr.value = enrollmentUrl; 26 | } catch (err) { 27 | console.error('QR Code generation failed:', err); 28 | document.getElementById('qr-code').style.display = 'none'; 29 | } 30 | document.getElementById('copy-url-btn').addEventListener('click', function() { 31 | const button = this; 32 | navigator.clipboard.writeText(enrollmentUrl).then(() => { 33 | const originalText = button.textContent; 34 | button.textContent = '✓ Copied!'; 35 | setTimeout(() => { 36 | button.textContent = originalText; 37 | }, 2000); 38 | }).catch((error) => { 39 | console.error('Clipboard API failed, using fallback:', error); 40 | const textArea = document.createElement('textarea'); 41 | textArea.value = enrollmentUrl; 42 | document.body.appendChild(textArea); 43 | textArea.select(); 44 | document.execCommand('copy'); 45 | document.body.removeChild(textArea); 46 | const originalText = button.textContent; 47 | button.textContent = '✓ Copied!'; 48 | setTimeout(() => { 49 | button.textContent = originalText; 50 | }, 2000); 51 | }); 52 | }); 53 | document.getElementById('device-info').textContent = 54 | `${navigator.platform} - ${navigator.userAgent.split(' ')[0]}`; 55 | 56 | async function fetchWarpConfig() { 57 | try { 58 | const response = await fetch('/setup/cloudflare/organisation'); 59 | if (response.ok) { 60 | const data = await response.json(); 61 | document.getElementById('team-name').textContent = data.team_name || 'your-organization'; 62 | const siteDomain = data.site_domain || 'your-cloudflare-domain.com'; 63 | const domainElement = document.getElementById('site-domain'); 64 | if (domainElement) { 65 | domainElement.textContent = `https://${siteDomain}`; 66 | } 67 | } else { 68 | console.warn('Could not fetch WARP config, using defaults'); 69 | document.getElementById('team-name').textContent = 'your-organization'; 70 | const domainElement = document.getElementById('site-domain'); 71 | if (domainElement) { 72 | domainElement.textContent = 'your-cloudflare-domain.com'; 73 | } 74 | } 75 | } catch (error) { 76 | console.error('Error fetching WARP config:', error); 77 | document.getElementById('team-name').textContent = 'your-organization'; 78 | const domainElement = document.getElementById('site-domain'); 79 | if (domainElement) { 80 | domainElement.textContent = 'your-cloudflare-domain.com'; 81 | } 82 | } 83 | } 84 | window.addEventListener('resize', function() { 85 | const screenWidth = window.innerWidth; 86 | let newQrSize; 87 | if (screenWidth <= 480) { 88 | newQrSize = 150; 89 | } else if (screenWidth <= 768) { 90 | newQrSize = 180; 91 | } else { 92 | newQrSize = 200; 93 | } 94 | if (qr && newQrSize !== qr.size) { 95 | qr.size = newQrSize; 96 | } 97 | }); 98 | }); 99 | -------------------------------------------------------------------------------- /.github/workflows/docker-publish.yml: -------------------------------------------------------------------------------- 1 | name: Docker 2 | 3 | # This workflow uses actions that are not certified by GitHub. 4 | # They are provided by a third-party and are governed by 5 | # separate terms of service, privacy policy, and support 6 | # documentation. 7 | 8 | on: 9 | release: 10 | types: [ published ] 11 | workflow_dispatch: 12 | 13 | env: 14 | # Use docker.io for Docker Hub if empty 15 | REGISTRY: ghcr.io 16 | # github.repository as / 17 | IMAGE_NAME: ${{ github.repository }} 18 | 19 | 20 | jobs: 21 | build: 22 | 23 | runs-on: ubuntu-latest 24 | permissions: 25 | contents: read 26 | packages: write 27 | # This is used to complete the identity challenge 28 | # with sigstore/fulcio when running outside of PRs. 29 | id-token: write 30 | 31 | steps: 32 | - name: Checkout repository 33 | uses: actions/checkout@v4 34 | 35 | # Install the cosign tool except on PR 36 | # https://github.com/sigstore/cosign-installer 37 | - name: Install cosign 38 | if: github.event_name != 'pull_request' 39 | uses: sigstore/cosign-installer@59acb6260d9c0ba8f4a2f9d9b48431a222b68e20 #v3.5.0 40 | with: 41 | cosign-release: 'v2.2.4' 42 | 43 | # Set up BuildKit Docker container builder to be able to build 44 | # multi-platform images and export cache 45 | # https://github.com/docker/setup-buildx-action 46 | - name: Set up Docker Buildx 47 | uses: docker/setup-buildx-action@f95db51fddba0c2d1ec667646a06c2ce06100226 # v3.0.0 48 | 49 | # Login against a Docker registry except on PR 50 | # https://github.com/docker/login-action 51 | - name: Log into registry ${{ env.REGISTRY }} 52 | if: github.event_name != 'pull_request' 53 | uses: docker/login-action@343f7c4344506bcbf9b4de18042ae17996df046d # v3.0.0 54 | with: 55 | registry: ${{ env.REGISTRY }} 56 | username: ${{ github.actor }} 57 | password: ${{ secrets.GITHUB_TOKEN }} 58 | 59 | # Extract metadata (tags, labels) for Docker 60 | # https://github.com/docker/metadata-action 61 | - name: Extract Docker metadata 62 | id: meta 63 | uses: docker/metadata-action@96383f45573cb7f253c731d3b3ab81c87ef81934 # v5.0.0 64 | with: 65 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 66 | tags: | 67 | type=semver,pattern={{version}} 68 | type=semver,pattern={{major}}.{{minor}} 69 | type=raw,value=latest 70 | 71 | # Free up space before building 72 | - name: Free Disk Space (Ubuntu) 73 | uses: jlumbroso/free-disk-space@main 74 | with: 75 | tool-cache: false 76 | android: true 77 | dotnet: true 78 | haskell: true 79 | large-packages: true 80 | docker-images: true 81 | 82 | # Build and push Docker image with Buildx (don't push on PR) 83 | # https://github.com/docker/build-push-action 84 | - name: Build and push Docker image 85 | id: build-and-push 86 | uses: docker/build-push-action@0565240e2d4ab88bba5387d719585280857ece09 # v5.0.0 87 | with: 88 | context: . 89 | platforms: linux/amd64,linux/arm64,linux/arm/v7 90 | push: ${{ github.event_name != 'pull_request' }} 91 | tags: ${{ steps.meta.outputs.tags }} 92 | labels: ${{ steps.meta.outputs.labels }} 93 | cache-from: type=gha 94 | cache-to: type=gha,mode=min 95 | 96 | # Sign the resulting Docker image digest except on PRs. 97 | # This will only write to the public Rekor transparency log when the Docker 98 | # repository is public to avoid leaking data. If you would like to publish 99 | # transparency data even for private images, pass --force to cosign below. 100 | # https://github.com/sigstore/cosign 101 | - name: Sign the published Docker image 102 | if: ${{ github.event_name != 'pull_request' }} 103 | env: 104 | # https://docs.github.com/en/actions/security-guides/security-hardening-for-github-actions#using-an-intermediate-environment-variable 105 | TAGS: ${{ steps.meta.outputs.tags }} 106 | DIGEST: ${{ steps.build-and-push.outputs.digest }} 107 | # This step uses the identity token to provision an ephemeral certificate 108 | # against the sigstore community Fulcio instance. 109 | run: echo "${TAGS}" | xargs -I {} cosign sign --yes {}@${DIGEST} 110 | -------------------------------------------------------------------------------- /printguard/static/js/notifications.js: -------------------------------------------------------------------------------- 1 | export { registerPush, unsubscribeFromPush }; 2 | 3 | function urlBase64ToUint8Array(base64String) { 4 | const padding = '='.repeat((4 - base64String.length % 4) % 4); 5 | const base64 = (base64String + padding) 6 | .replace(/-/g, '+') 7 | .replace(/_/g, '/'); 8 | const rawData = atob(base64); 9 | const outputArray = new Uint8Array(rawData.length); 10 | for (let i = 0; i < rawData.length; ++i) { 11 | outputArray[i] = rawData.charCodeAt(i); 12 | } 13 | return outputArray; 14 | } 15 | 16 | async function registerPush() { 17 | try { 18 | if ('Notification' in window && Notification.permission !== 'granted') { 19 | const permission = await Notification.requestPermission(); 20 | if (permission !== 'granted') { 21 | throw new Error('Notification permission denied'); 22 | } 23 | console.debug('Notification permission newly granted'); 24 | } 25 | 26 | if ('serviceWorker' in navigator) { 27 | const registrations = await navigator.serviceWorker.getRegistrations(); 28 | for (let registration of registrations) { 29 | await registration.unregister(); 30 | } 31 | } 32 | 33 | const {publicKey} = await fetch('/notification/public_key').then(r => r.json()); 34 | const registration = await navigator.serviceWorker.getRegistration('/static/js/sw.js'); 35 | const sw = registration || await navigator.serviceWorker.register('/static/js/sw.js'); 36 | 37 | if (sw.active === null) { 38 | await new Promise(resolve => { 39 | if (sw.installing) { 40 | sw.installing.addEventListener('statechange', e => { 41 | if (e.target.state === 'activated') { 42 | resolve(); 43 | } 44 | }); 45 | } else if (sw.waiting) { 46 | sw.waiting.addEventListener('statechange', e => { 47 | if (e.target.state === 'activated') { 48 | resolve(); 49 | } 50 | }); 51 | } else { 52 | resolve(); 53 | } 54 | }); 55 | } 56 | 57 | const sub = await sw.pushManager.subscribe({ 58 | userVisibleOnly: true, 59 | applicationServerKey: urlBase64ToUint8Array(publicKey) 60 | }); 61 | 62 | await fetch('/notification/subscribe', { 63 | method: 'POST', 64 | headers: {'Content-Type':'application/json'}, 65 | body: JSON.stringify(sub) 66 | }); 67 | if (Notification.permission === 'granted') { 68 | return true; 69 | } else { 70 | throw new Error('Permission appears granted but notification state is inconsistent'); 71 | } 72 | } catch (error) { 73 | console.error('Failed to register for push notifications:', error); 74 | alert('Failed to enable notifications: ' + error.message); 75 | return false; 76 | } 77 | } 78 | 79 | async function unsubscribeFromPush() { 80 | try { 81 | if (!('serviceWorker' in navigator)) { 82 | throw new Error('Service workers are not supported in this browser'); 83 | } 84 | const registrations = await navigator.serviceWorker.getRegistrations(); 85 | let hadSubscription = false; 86 | for (let registration of registrations) { 87 | const subscription = await registration.pushManager.getSubscription(); 88 | if (subscription) { 89 | await subscription.unsubscribe(); 90 | hadSubscription = true; 91 | } 92 | await registration.unregister(); 93 | } 94 | return true; 95 | } catch (error) { 96 | console.error('Failed to unsubscribe from push notifications:', error); 97 | alert('Failed to disable notifications: ' + error.message); 98 | return false; 99 | } 100 | } 101 | 102 | if ('serviceWorker' in navigator) { 103 | window.addEventListener('load', async () => { 104 | try { 105 | console.debug('Page loaded, service workers will be managed by notification functions'); 106 | } catch (error) { 107 | console.error('Error managing service workers on page load:', error); 108 | } 109 | }); 110 | } -------------------------------------------------------------------------------- /printguard/routes/notification_routes.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from fastapi import APIRouter, Request 4 | 5 | from ..models import SavedConfig, SavedKey 6 | from ..utils.config import get_config, get_key, update_config 7 | 8 | router = APIRouter() 9 | 10 | @router.get("/notification/public_key") 11 | async def get_public_key(): 12 | """Retrieve the VAPID public key for push notification subscriptions. 13 | 14 | Returns: 15 | dict: VAPID public key for client-side push notification setup, 16 | or error message if key is not configured. 17 | """ 18 | config = get_config() 19 | vapid_public_key = config.get(SavedConfig.VAPID_PUBLIC_KEY, None) 20 | if not vapid_public_key: 21 | logging.error("VAPID public key is not set in the configuration.") 22 | return {"error": "VAPID public key not configured"} 23 | return {"publicKey": vapid_public_key} 24 | 25 | @router.post("/notification/subscribe") 26 | async def subscribe(request: Request): 27 | """Subscribe a client to push notifications. 28 | 29 | Args: 30 | request (Request): The FastAPI request object containing subscription data. 31 | 32 | Returns: 33 | dict: Success status indicating whether subscription was added successfully. 34 | """ 35 | try: 36 | subscription = await request.json() 37 | logging.debug("Received subscription request: %s", subscription.get('endpoint', 'no endpoint')) 38 | if not subscription.get('endpoint') or not subscription.get('keys'): 39 | logging.error("Invalid subscription format - missing endpoint or keys") 40 | return {"success": False, "error": "Invalid subscription format"} 41 | for existing_sub in request.app.state.subscriptions: 42 | if existing_sub.get('endpoint') == subscription.get('endpoint'): 43 | request.app.state.subscriptions.remove(existing_sub) 44 | logging.debug("Removed existing subscription for same endpoint") 45 | break 46 | request.app.state.subscriptions.append(subscription) 47 | config = get_config() or {} 48 | config[SavedConfig.PUSH_SUBSCRIPTIONS] = request.app.state.subscriptions 49 | update_config(config) 50 | logging.debug("Successfully added subscription. Total subscriptions: %d", len(request.app.state.subscriptions)) 51 | return {"success": True} 52 | # pylint: disable=W0718 53 | except Exception as e: 54 | logging.error("Subscription error: %s", str(e)) 55 | return {"success": False, "error": f"Server error: {str(e)}"} 56 | 57 | @router.post("/notification/unsubscribe") 58 | async def unsubscribe(request: Request): 59 | """Unsubscribe all clients from push notifications. 60 | 61 | Args: 62 | request (Request): The FastAPI request object containing app state. 63 | 64 | Returns: 65 | dict: Success status indicating all subscriptions were cleared. 66 | """ 67 | request.app.state.subscriptions.clear() 68 | config = get_config() or {} 69 | config[SavedConfig.PUSH_SUBSCRIPTIONS] = [] 70 | update_config(config) 71 | logging.debug("All push subscriptions cleared and persisted.") 72 | return {"success": True} 73 | 74 | @router.get("/notification/debug") 75 | async def notification_debug(request: Request): 76 | """Get debug information about notification configuration and subscriptions. 77 | 78 | Args: 79 | request (Request): The FastAPI request object containing app state. 80 | 81 | Returns: 82 | dict: Debug information including subscription count, VAPID configuration, 83 | and subscription details for troubleshooting. 84 | """ 85 | config = get_config() 86 | debug_info = { 87 | "subscriptions_count": len(request.app.state.subscriptions), 88 | "subscriptions": [ 89 | { 90 | "endpoint": sub.get('endpoint', 'unknown')[:50] + "..." if len(sub.get('endpoint', '')) > 50 else sub.get('endpoint', 'unknown'), 91 | "has_keys": bool(sub.get('keys')) 92 | } 93 | for sub in request.app.state.subscriptions 94 | ], 95 | "vapid_config": { 96 | "has_public_key": bool(config.get(SavedConfig.VAPID_PUBLIC_KEY)), 97 | "has_subject": bool(config.get(SavedConfig.VAPID_SUBJECT)), 98 | "has_private_key": bool(get_key(SavedKey.VAPID_PRIVATE_KEY)), 99 | "subject": config.get(SavedConfig.VAPID_SUBJECT, "not set") 100 | } 101 | } 102 | return debug_info 103 | -------------------------------------------------------------------------------- /printguard/utils/detection_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import uuid 3 | import logging 4 | import cv2 5 | 6 | from .alert_utils import (dismiss_alert, alert_to_response_json, 7 | get_alert, append_new_alert) 8 | from .sse_utils import append_new_outbound_packet 9 | from .camera_utils import (get_camera_state, get_camera_state_sync, 10 | update_camera_state, update_camera_detection_history) 11 | from .printer_utils import get_printer_config, suspend_print_job 12 | from .notification_utils import send_defect_notification 13 | from ..models import Alert, AlertAction, SSEDataType 14 | 15 | def _passed_majority_vote(camera_state): 16 | """Determine if failures in detection history meet the majority threshold. 17 | 18 | Args: 19 | camera_state (CameraState): The camera state containing detection history, 20 | which includes a list of tuples `(timestamp, label)`. 21 | 22 | Returns: 23 | bool: True if the number of 'failure' labels in the most recent 24 | `majority_vote_window` entries is at least `majority_vote_threshold`. 25 | """ 26 | detection_history = camera_state.detection_history 27 | majority_vote_window = camera_state.majority_vote_window 28 | majority_vote_threshold = camera_state.majority_vote_threshold 29 | results_to_retreive = min(len(detection_history), majority_vote_window) 30 | detection_window_results = detection_history[-results_to_retreive:] 31 | failed_detections = [res for res in detection_window_results if res[1] == 'failure'] 32 | return len(failed_detections) >= majority_vote_threshold 33 | 34 | async def _send_alert(alert): 35 | """Send an alert to clients via Server-Sent Events. 36 | 37 | Args: 38 | alert (Alert): The alert object to send. 39 | """ 40 | await append_new_outbound_packet(alert_to_response_json(alert), SSEDataType.ALERT) 41 | 42 | async def _terminate_alert_after_cooldown(alert): 43 | """Wait for the alert's countdown, then dismiss or act on the print job. 44 | 45 | Args: 46 | alert (Alert): The alert object with `countdown_time` and `countdown_action`. 47 | """ 48 | await asyncio.sleep(alert.countdown_time) 49 | if get_alert(alert.id) is not None: 50 | camera_uuid = alert.camera_uuid 51 | camera_state = await get_camera_state(camera_uuid) 52 | if not camera_state: 53 | return 54 | match camera_state.countdown_action: 55 | case AlertAction.DISMISS: 56 | await dismiss_alert(alert.id) 57 | case AlertAction.CANCEL_PRINT | AlertAction.PAUSE_PRINT: 58 | suspend_print_job(camera_uuid, camera_state.countdown_action) 59 | return await dismiss_alert(alert.id) 60 | 61 | async def _create_alert_and_notify(camera_state_ref, camera_uuid, frame, timestamp_arg): 62 | """Create a new Alert object and notify all subsystems. 63 | 64 | Args: 65 | camera_state_ref (CameraState): The state reference for the camera. 66 | camera_uuid (str): The UUID of the camera. 67 | frame (ndarray): The image frame where a defect was detected. 68 | timestamp_arg (float): The timestamp of detection. 69 | 70 | Returns: 71 | Alert: The newly created alert. 72 | """ 73 | alert_id = f"{camera_uuid}_{str(uuid.uuid4())}" 74 | # pylint: disable=E1101 75 | _, img_buf = cv2.imencode('.jpg', frame) 76 | has_printer = get_printer_config(camera_uuid) is not None 77 | alert = Alert( 78 | id=alert_id, 79 | camera_uuid=camera_uuid, 80 | timestamp=timestamp_arg, 81 | snapshot=img_buf.tobytes(), 82 | title=f"Defect - Camera {camera_state_ref.nickname}", 83 | message=f"Defect detected on camera {camera_state_ref.nickname}", 84 | countdown_time=camera_state_ref.countdown_time, 85 | countdown_action=camera_state_ref.countdown_action, 86 | has_printer=has_printer, 87 | ) 88 | append_new_alert(alert) 89 | asyncio.create_task(_terminate_alert_after_cooldown(alert)) 90 | await update_camera_state(camera_uuid, {"current_alert_id": alert_id}) 91 | await send_defect_notification(alert_id) 92 | return alert 93 | 94 | async def _live_detection_loop(app_state, camera_uuid): 95 | """Continuously run detection on camera frames and generate alerts using shared video stream. 96 | 97 | This loop reads frames from the shared video stream, runs inference, updates state, 98 | and dispatches alerts when defects are detected based on majority vote. 99 | 100 | Args: 101 | app_state: The application state holding model, transforms, and other context. 102 | camera_uuid (str): The UUID of the camera to process. 103 | """ 104 | # pylint: disable=C0415 105 | from .stream_utils import create_optimized_detection_loop 106 | update_functions = { 107 | 'update_camera_state': update_camera_state, 108 | 'update_camera_detection_history': update_camera_detection_history, 109 | } 110 | try: 111 | await create_optimized_detection_loop( 112 | app_state, 113 | camera_uuid, 114 | get_camera_state_sync, 115 | update_functions 116 | ) 117 | except Exception as e: 118 | logging.error("Error in optimized detection loop for camera %s: %s", camera_uuid, e) 119 | await update_camera_state(camera_uuid, { 120 | "error": f"Detection loop error: {str(e)}", 121 | "live_detection_running": False 122 | }) 123 | -------------------------------------------------------------------------------- /printguard/utils/notification_utils.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlparse 2 | import logging 3 | import json 4 | from pywebpush import WebPushException, webpush 5 | 6 | from ..models import Notification, SavedKey, SavedConfig 7 | from ..utils.config import get_key, get_config 8 | from ..utils.alert_utils import get_alert 9 | 10 | def get_subscriptions(): 11 | """Retrieve the list of current push notification subscriptions. 12 | 13 | Returns: 14 | list: A list of subscription dictionaries, each with at least an 'id' and 'endpoint'. 15 | """ 16 | # pylint: disable=C0415 17 | from ..app import app 18 | return app.state.subscriptions 19 | 20 | def remove_subscription(subscription_id = None, subscription = None): 21 | """Remove a subscription by ID or subscription object. 22 | 23 | Args: 24 | subscription_id (str, optional): The ID of the subscription to remove. 25 | subscription (dict, optional): The subscription object to remove. 26 | """ 27 | # pylint: disable=C0415 28 | from ..app import app 29 | if subscription_id is not None: 30 | app.state.subscriptions = [ 31 | sub for sub in app.state.subscriptions if sub.get('id') != subscription_id 32 | ] 33 | elif subscription is not None: 34 | app.state.subscriptions.remove(subscription) 35 | else: 36 | logging.error("No subscription ID or object provided to remove.") 37 | 38 | async def send_defect_notification(alert_id): 39 | """Send a defect notification for a given alert ID to all subscribers. 40 | 41 | Args: 42 | alert_id (str): The ID of the alert for which to send a notification. 43 | """ 44 | logging.debug("Attempting to send defect notification for alert ID: %s", alert_id) 45 | alert = get_alert(alert_id) 46 | if alert: 47 | logging.debug("Alert found for ID %s, preparing notification", alert_id) 48 | # pylint: disable=import-outside-toplevel 49 | from .camera_utils import get_camera_state 50 | camera_state = await get_camera_state(alert.camera_uuid) 51 | camera_nickname = camera_state.nickname if camera_state else alert.camera_uuid 52 | notification = Notification( 53 | title=f"Defect - Camera {camera_nickname}", 54 | body=f"Defect detected on camera {camera_nickname}", 55 | ) 56 | subscriptions = get_subscriptions() or [] 57 | logging.debug("Created notification object without image payload, sending to %d subscriptions", 58 | len(subscriptions)) 59 | send_notification(notification) 60 | else: 61 | logging.error("No alert found for ID: %s", alert_id) 62 | 63 | def send_notification(notification: Notification): 64 | """Send a push notification to all current subscriptions. 65 | 66 | Args: 67 | notification (Notification): The notification object to send. Should have 'title' and 'body' fields at minimum. 68 | 69 | Returns: 70 | bool: True if at least one notification was sent successfully, False otherwise. 71 | """ 72 | logging.debug("Starting notification send process") 73 | config = get_config() 74 | vapid_subject = config.get(SavedConfig.VAPID_SUBJECT, None) 75 | if not vapid_subject: 76 | logging.error("VAPID subject is not set in the configuration.") 77 | return False 78 | vapid_private_key = get_key(SavedKey.VAPID_PRIVATE_KEY) 79 | if not vapid_private_key: 80 | logging.error("VAPID private key is not set in the configuration.") 81 | return False 82 | subscriptions = get_subscriptions() 83 | logging.debug("VAPID configuration found. Subject: %s", vapid_subject) 84 | logging.debug("Number of subscriptions: %d", len(subscriptions)) 85 | vapid_claims = { 86 | "sub": vapid_subject, 87 | "aud": None, 88 | } 89 | success_count = 0 90 | if not subscriptions: 91 | logging.warning("No push subscriptions available to send notifications") 92 | return False 93 | for i, sub in enumerate(subscriptions.copy()): 94 | logging.debug("Sending notification to subscription %d/%d", 95 | i+1, len(subscriptions)) 96 | try: 97 | endpoint = sub.get('endpoint', '') 98 | if not endpoint: 99 | logging.error("Subscription %d has no endpoint", i+1) 100 | continue 101 | parsed_endpoint = urlparse(endpoint) 102 | audience = f"{parsed_endpoint.scheme}://{parsed_endpoint.netloc}" 103 | aud_vapid_claims = dict(vapid_claims) 104 | aud_vapid_claims['aud'] = audience 105 | payload_dict = { 106 | 'title': notification.title, 107 | 'body': notification.body 108 | } 109 | data_payload = json.dumps(payload_dict) 110 | logging.debug("Sending to endpoint: %s", endpoint) 111 | webpush( 112 | subscription_info=sub, 113 | data=data_payload, 114 | vapid_private_key=vapid_private_key, 115 | vapid_claims=aud_vapid_claims 116 | ) 117 | success_count += 1 118 | logging.debug("Successfully sent notification to subscription %d", i+1) 119 | except WebPushException as ex: 120 | logging.error("WebPush failed for subscription %d: %s", i+1, ex) 121 | if ex.response and ex.response.status_code == 410: 122 | remove_subscription(subscription=sub) 123 | logging.info("Subscription expired and removed: %s", sub.get('endpoint', 'unknown')) 124 | else: 125 | logging.error("Push failed: %s", ex) 126 | except Exception as e: 127 | logging.error("Unexpected error sending notification to subscription %d: %s", i+1, e) 128 | 129 | logging.debug("Notification send complete. Success count: %d/%d", success_count, len(subscriptions)) 130 | return success_count > 0 131 | -------------------------------------------------------------------------------- /docs/setup.md: -------------------------------------------------------------------------------- 1 | # Setup Documentation 2 | > This is the technical documentation for PrintGuard's setup process, if you are looking for guides on how to install and use PrintGuard, please refer to the [README](../README.md). 3 | 4 | ## Table of Contents 5 | - [Network Configuration](#network-configuration) 6 | - [Local Network](#local-network) 7 | - [External Access](#external-access) 8 | - [Ngrok](#ngrok) 9 | - [Cloudflare](#cloudflare) 10 | - [VAPID Keys Setup](#vapid-keys-setup) 11 | - [SSL Certificate Setup](#ssl-certificate-setup) 12 | - [Startup Process](#startup-process) 13 | 14 | ## Network Configuration 15 | PrintGuard can run in two modes: 16 | 17 | ### Local Network 18 | - Runs the FastAPI server on `https://localhost:8000`. 19 | - No external access required. 20 | - Recommended for secure LAN-only deployments. 21 | - Requires SSL certificate and VAPID keys, generated during setup. 22 | 23 | ### External Access 24 | External access allows you to expose PrintGuard outside your local network. 25 | 26 | #### Ngrok 27 | [Ngrok](https://ngrok.com) is a reverse proxy tool, enabling secure internet access to your local server with minimal configuration through both free and paid plans. The setup uses your ngrok API to create and configure tunnels via the official [ngrok python package](https://pypi.org/project/ngrok/). 28 | 1. Obtain an Ngrok authtoken from https://dashboard.ngrok.com/get-started/your-authtoken 29 | 2. In the setup UI, select **Ngrok** and enter your authtoken and desired subdomain (e.g., `myprintguard`). 30 | 3. The token is stored securely using the system keyring under the service `printguard` and key `TUNNEL_API_KEY`. 31 | 4. The domain (e.g., `myprintguard.ngrok.io`) is saved in the JSON config file (`~/.config/printguard/config.json`) under `SITE_DOMAIN`. 32 | 5. The setup code uses `ngrok.forward(8000, authtoken=, domain=)` to establish the tunnel. 33 | 34 | #### Cloudflare 35 | [Cloudflare tunnels](https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/) provide a secure way to expose your local web interface to the internet, offering a reliable and secure connection without needing to open ports on your router. Cloudflare tunnels are free to use and offer a simple setup process; however, a domain connected to your Cloudflare account is required. Restricted access to your PrintGuard site can be set up through [Cloudflare Access](https://one.dash.cloudflare.com/) (configurable at setup). During setup, your API key is used to create a tunnel to your local server and insert a DNS record for the tunnel, allowing you to access your PrintGuard instance via your custom domain or subdomain. 36 | 37 | 1. Create a Cloudflare account and obtain an API Token with *Tunnel* and *DNS:Edit* permissions. 38 | 2. In the setup UI, select **Cloudflare**, enter your API token, and check **Use Global API Key** if desired. 39 | 3. The token is stored securely via keyring under key `TUNNEL_API_KEY` and optional `CLOUDFLARE_EMAIL` in config. 40 | 4. The setup contacts Cloudflare API (`CloudflareAPI` in [`printguard/utils/cloudflare_utils.py`](../printguard/utils/cloudflare_utils.py)) to list accounts and zones. 41 | 5. Select an account, zone, and subdomain (e.g., `printguard.example.com`). 42 | 6. A Cloudflare Tunnel is created via `/api/tunnels` endpoint; the tunnel token returned is stored under key `TUNNEL_TOKEN` in the keyring. 43 | 7. A DNS CNAME record is created pointing your subdomain to the tunnel using the Cloudflare API. 44 | 8. The final domain is saved as `SITE_DOMAIN` in `config.json`. 45 | 46 | ## VAPID Keys Setup 47 | PrintGuard uses Web Push notifications. 48 | 1. In the setup UI, generate or import existing VAPID keys. 49 | - If generating, the keys are created using the [py_vapid library](https://pypi.org/project/py_vapid/). 50 | - If importing, the public key must be in base64 format. 51 | 2. Public key, subject (`mailto:`) and private key: 52 | - Public and subject stored in `config.json` under `VAPID_PUBLIC_KEY` and `VAPID_SUBJECT`. 53 | - Private key stored in keyring under key `VAPID_PRIVATE_KEY`. 54 | 55 | ## SSL Certificate Setup 56 | HTTPS is required for secure web push and SSE. 57 | 1. Generate a self-signed certificate using `trustme.CA()` and `issue_cert(domain)`. The [trustme library](https://pypi.org/project/trustme/) is used to give a fake certificate authority (CA) and issue trusted SSL certificates on your local machine. 58 | - Alternatively, import an existing certificate and private key. 59 | 2. In the setup UI, click **Generate Self-Signed Certificate** or import your own. 60 | 3. Certificate is saved to in the app directory as `cert.pem` (`SSL_CERT_FILE`). 61 | 4. Private key is stored in keyring under `SSL_PRIVATE_KEY`. 62 | 5. On startup, FastAPI loads these files for TLS. 63 | 64 | ## Startup Process 65 | When you execute `printguard`, the application follows these steps to determine how to launch the server: 66 | 67 | 1. **Initialize configuration**: `init_config()` creates or loads the JSON config file stored in the application data directory and ensures default values. 68 | 2. **Determine startup mode**: `startup_mode_requirements_met()` inspects `config.json` and keyring entries to select one of these startup modes: 69 | - `SETUP`: missing required keys or certificates → launch the setup UI at `http://localhost:8000/setup`. 70 | - `LOCAL`: all SSL and VAPID requirements met → start FastAPI with HTTPS on port 8000 using `SSL_CERT_FILE` and the key from keyring. 71 | - `TUNNEL`: VAPID keys and tunnel credentials exist → continue to tunnel provider logic. 72 | 3. **Ngrok tunnel** (_if `TUNNEL_PROVIDER` is NGROK_): 73 | - Calls `setup_ngrok_tunnel()` to forward port 8000 to your custom `SITE_DOMAIN` through the ngrok package. 74 | - On success, runs Uvicorn normally; on failure, resets `STARTUP_MODE` to `SETUP` and restarts. 75 | 4. **Cloudflare tunnel** (_if `TUNNEL_PROVIDER` is CLOUDFLARE_): 76 | - Executes `stop_cloudflare_tunnel()` to clear any previous session. 77 | - Uses `start_cloudflare_tunnel()` to invoke `cloudflared` on your OS (brew, curl, or winget commands) using the stored tunnel credentials. 78 | - On failure, resets `STARTUP_MODE` to `SETUP` and restarts. 79 | 5. **Final launch**: Uvicorn serves the app at `0.0.0.0:8000`, secured by HTTPS for LOCAL or routed through the external domain for TUNNEL modes. 80 | 81 | This logic ensures the server automatically falls back to setup if any required credentials, certificates, or tunnels are missing or fail to start. -------------------------------------------------------------------------------- /printguard/routes/camera_routes.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import uuid 4 | 5 | import cv2 # pylint: disable=E0401 6 | from fastapi import APIRouter, Body, HTTPException, Request 7 | from fastapi.responses import StreamingResponse 8 | 9 | from ..utils.camera_utils import (add_camera, find_available_serial_cameras, 10 | get_camera_state) 11 | from ..utils.camera_utils import remove_camera as remove_camera_util 12 | from ..utils.shared_video_stream import get_shared_stream_manager 13 | from ..utils.stream_utils import generate_frames 14 | 15 | router = APIRouter() 16 | 17 | @router.post("/camera/state", include_in_schema=False) 18 | async def get_camera_state_ep(request: Request, camera_uuid: str = Body(..., embed=True)): 19 | """Get the current state of a specific camera. 20 | 21 | Args: 22 | request (Request): The FastAPI request object. 23 | camera_uuid (str): UUID of the camera to retrieve state for. 24 | 25 | Returns: 26 | dict: Dictionary containing comprehensive camera state information including 27 | detection history, settings, error status, and printer configuration. 28 | """ 29 | camera_state = await get_camera_state(camera_uuid) 30 | detection_times = [t for t, _ in camera_state.detection_history] if ( 31 | camera_state.detection_history 32 | ) else [] 33 | response = { 34 | "nickname": camera_state.nickname, 35 | "start_time": camera_state.start_time, 36 | "last_result": camera_state.last_result, 37 | "last_time": camera_state.last_time, 38 | "detection_times": detection_times, 39 | "error": camera_state.error, 40 | "live_detection_running": camera_state.live_detection_running, 41 | "brightness": camera_state.brightness, 42 | "contrast": camera_state.contrast, 43 | "focus": camera_state.focus, 44 | "countdown_time": camera_state.countdown_time, 45 | "majority_vote_threshold": camera_state.majority_vote_threshold, 46 | "majority_vote_window": camera_state.majority_vote_window, 47 | "current_alert_id": camera_state.current_alert_id, 48 | "sensitivity": camera_state.sensitivity, 49 | "printer_id": camera_state.printer_id, 50 | "printer_config": camera_state.printer_config, 51 | "countdown_action": camera_state.countdown_action 52 | } 53 | return response 54 | 55 | @router.get('/camera/feed/{camera_uuid}', include_in_schema=False) 56 | async def camera_feed(camera_uuid: str): 57 | """Stream live camera feed for a specific camera. 58 | 59 | Args: 60 | camera_uuid (str): UUID of the camera to stream from. 61 | 62 | Returns: 63 | StreamingResponse: MJPEG streaming response with camera frames. 64 | """ 65 | return StreamingResponse(generate_frames(camera_uuid), 66 | media_type='multipart/x-mixed-replace; boundary=frame') 67 | 68 | @router.post("/camera/add") 69 | async def add_camera_ep(request: Request): 70 | """Add a new camera.""" 71 | data = await request.json() 72 | nickname = data.get('nickname') 73 | source = data.get('source') 74 | if not nickname or not source: 75 | raise HTTPException(status_code=400, detail="Missing camera nickname or source.") 76 | camera = await add_camera(source=source, nickname=nickname) 77 | return {"camera_uuid": camera['camera_uuid'], "nickname": camera['nickname'], "source": camera['source']} 78 | 79 | @router.post("/camera/remove") 80 | async def remove_camera_ep(request: Request): 81 | """Remove a camera.""" 82 | data = await request.json() 83 | camera_uuid = data.get('camera_uuid') 84 | if not camera_uuid: 85 | raise HTTPException(status_code=400, detail="Missing camera_uuid.") 86 | success = await remove_camera_util(camera_uuid) 87 | if not success: 88 | raise HTTPException(status_code=404, detail="Camera not found.") 89 | return {"message": "Camera removed successfully."} 90 | 91 | @router.get("/camera/serial_devices") 92 | async def get_serial_devices_ep(): 93 | """Get a list of available serial devices.""" 94 | devices = find_available_serial_cameras() 95 | return devices 96 | 97 | def generate_preview_frames(source: str): 98 | """Generate frames for camera preview using shared video stream. 99 | 100 | Args: 101 | source (str): The camera source (device path or RTSP URL). 102 | 103 | Yields: 104 | bytes: Multipart JPEG frame data. 105 | """ 106 | preview_uuid = f"preview_{uuid.uuid4()}" 107 | manager = get_shared_stream_manager() 108 | try: 109 | stream = manager.get_stream(preview_uuid, source) 110 | max_wait = 50 111 | wait_count = 0 112 | while not stream.is_frame_available() and wait_count < max_wait: 113 | time.sleep(0.1) 114 | wait_count += 1 115 | if not stream.is_frame_available(): 116 | logging.error("Failed to get initial frame from source: %s", source) 117 | return 118 | while True: 119 | frame = stream.get_frame() 120 | if frame is None: 121 | logging.warning("Failed to get frame from source: %s", source) 122 | time.sleep(0.1) 123 | continue 124 | _, buffer = cv2.imencode('.jpg', frame) 125 | frame_bytes = buffer.tobytes() 126 | yield (b'--frame\r\nContent-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n') 127 | time.sleep(0.2) 128 | except (cv2.error, OSError, RuntimeError) as e: 129 | logging.error("Error in preview frame generation for source %s: %s", source, e) 130 | finally: 131 | try: 132 | manager.release_stream(preview_uuid) 133 | except (AttributeError, RuntimeError) as cleanup_error: 134 | logging.error("Error cleaning up preview stream %s: %s", preview_uuid, cleanup_error) 135 | 136 | @router.get('/camera/preview', include_in_schema=False) 137 | async def camera_preview(source: str): 138 | """Stream live camera preview for a specific source without registration. 139 | 140 | Args: 141 | source (str): Camera source (device path or RTSP URL). 142 | 143 | Returns: 144 | StreamingResponse: MJPEG streaming response with camera frames. 145 | """ 146 | return StreamingResponse(generate_preview_frames(source), 147 | media_type='multipart/x-mixed-replace; boundary=frame') 148 | -------------------------------------------------------------------------------- /printguard/utils/printer_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | 4 | import requests 5 | 6 | from ..models import PollingTask, SavedConfig, AlertAction 7 | from .camera_utils import get_camera_state_sync, update_camera_state 8 | from .config import PRINTER_STAT_POLLING_RATE_MS, get_config 9 | from .printer_services.octoprint import OctoPrintClient 10 | from .sse_utils import add_polling_task, sse_update_printer_state 11 | 12 | def get_printer_config(camera_uuid): 13 | """Retrieve printer configuration from camera state. 14 | 15 | Args: 16 | camera_uuid (str): The UUID of the camera. 17 | 18 | Returns: 19 | dict or None: The printer_config dictionary if set, otherwise None. 20 | Structure of printer_config example: 21 | { 22 | 'printer_type': str, 23 | 'base_url': str, 24 | 'api_key': str, 25 | 'name': str 26 | } 27 | """ 28 | camera_state = get_camera_state_sync(camera_uuid) 29 | if camera_state and hasattr(camera_state, 'printer_config') and camera_state.printer_config: 30 | return camera_state.printer_config 31 | return None 32 | 33 | def get_printer_id(camera_uuid): 34 | """Retrieve the printer ID associated with a camera. 35 | 36 | Args: 37 | camera_uuid (str): The UUID of the camera. 38 | 39 | Returns: 40 | str or None: The printer_id if set, otherwise None. 41 | """ 42 | camera_state = get_camera_state_sync(camera_uuid) 43 | if camera_state and hasattr(camera_state, 'printer_id') and camera_state.printer_id: 44 | return camera_state.printer_id 45 | return None 46 | 47 | async def set_printer(camera_uuid, printer_id, printer_config): 48 | """Associate a printer with a camera and persist in state. 49 | 50 | Args: 51 | camera_uuid (str): The UUID of the camera. 52 | printer_id (str): The unique identifier for the printer. 53 | printer_config (dict): The configuration details for the printer. 54 | 55 | Returns: 56 | Optional[CameraState]: The updated camera state, or None if failed. 57 | """ 58 | return await update_camera_state(camera_uuid, { 59 | "printer_id": printer_id, 60 | "printer_config": printer_config 61 | }) 62 | 63 | async def remove_printer(camera_uuid): 64 | """Remove the printer association from a camera. 65 | 66 | Args: 67 | camera_uuid (str): The UUID of the camera. 68 | 69 | Returns: 70 | Optional[CameraState]: The updated camera state, or None if failed. 71 | """ 72 | return await update_camera_state(camera_uuid, { 73 | "printer_id": None, 74 | "printer_config": None 75 | }) 76 | 77 | async def poll_printer_state_func(client, interval, stop_event): 78 | """Continuously poll the printer state and send updates via SSE. 79 | 80 | Args: 81 | client (OctoPrintClient): The client to query printer status. 82 | interval (float): Time in seconds between polls. 83 | stop_event (asyncio.Event): An event to signal polling should stop. 84 | """ 85 | while not stop_event.is_set(): 86 | try: 87 | current_printer_state = client.get_printer_state() 88 | await sse_update_printer_state(current_printer_state) 89 | except (requests.exceptions.RequestException, ConnectionError, 90 | TimeoutError, ValueError) as e: 91 | logging.warning("Error polling printer state: %s", str(e)) 92 | except Exception as e: 93 | logging.error("Unexpected error polling printer state: %s", str(e)) 94 | await asyncio.sleep(interval) 95 | 96 | async def start_printer_state_polling(camera_uuid): 97 | """Start background polling of printer state for a camera. 98 | 99 | Args: 100 | camera_uuid (str): The UUID of the camera to poll. 101 | """ 102 | stop_event = asyncio.Event() 103 | camera_printer_config = get_printer_config(camera_uuid) 104 | if not camera_printer_config: 105 | logging.warning("No printer configuration found for camera UUID %s", camera_uuid) 106 | return 107 | config = get_config() 108 | printer_polling_rate = float(config.get( 109 | SavedConfig.PRINTER_STAT_POLLING_RATE_MS, PRINTER_STAT_POLLING_RATE_MS 110 | ) / 1000) 111 | client = OctoPrintClient( 112 | camera_printer_config.get('base_url'), 113 | camera_printer_config.get('api_key') 114 | ) 115 | task = asyncio.create_task(poll_printer_state_func(client, printer_polling_rate, stop_event)) 116 | add_polling_task(camera_uuid, PollingTask(task=task, stop_event=stop_event)) 117 | logging.debug("Started printer state polling for camera UUID %s", camera_uuid) 118 | 119 | def suspend_print_job(camera_uuid, action: AlertAction): 120 | """Pause or cancel an ongoing print job based on an alert action. 121 | 122 | Args: 123 | camera_uuid (str): The UUID of the camera associated with the printer. 124 | action (AlertAction): The action to perform (CANCEL_PRINT or PAUSE_PRINT). 125 | 126 | Returns: 127 | bool: True if the job was suspended successfully or no job was active, False otherwise. 128 | """ 129 | printer_config = get_printer_config(camera_uuid) 130 | if printer_config: 131 | if printer_config['printer_type'] == 'octoprint': 132 | client = OctoPrintClient( 133 | printer_config['base_url'], 134 | printer_config['api_key'] 135 | ) 136 | try: 137 | job_info = client.get_job_info() 138 | if job_info.state != "Printing": 139 | return True 140 | match action: 141 | case AlertAction.CANCEL_PRINT: 142 | client.cancel_job() 143 | logging.debug("Print cancelled for printer %s on camera %s", 144 | printer_config['name'], camera_uuid) 145 | return True 146 | case AlertAction.PAUSE_PRINT: 147 | client.pause_job() 148 | logging.debug("Print paused for printer %s on camera %s", 149 | printer_config['name'], camera_uuid) 150 | return True 151 | case _: 152 | logging.debug("No action taken for printer %s on camera %s as %s", 153 | printer_config['name'], camera_uuid, action) 154 | return True 155 | except Exception as e: 156 | logging.error("Error suspending print job for printer %s on camera %s: %s", 157 | printer_config['name'], camera_uuid, e) 158 | return False 159 | logging.error("No printer configuration found for camera UUID %s", camera_uuid) 160 | return False 161 | -------------------------------------------------------------------------------- /printguard/utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import concurrent.futures 3 | import logging 4 | import uuid 5 | import sys 6 | import glob 7 | 8 | import cv2 9 | 10 | from ..models import CameraState 11 | from .camera_state_manager import get_camera_state_manager 12 | 13 | 14 | async def add_camera(source, nickname): 15 | """ 16 | Adds a new camera, assigns a UUID, and stores it. 17 | 18 | Args: 19 | source (str): The camera source (e.g., device path or RTSP URL). 20 | nickname (str): A user-friendly name for the camera. 21 | 22 | Returns: 23 | dict: A dictionary containing the new camera's UUID, nickname, and source. 24 | """ 25 | manager = get_camera_state_manager() 26 | camera_uuid = str(uuid.uuid4()) 27 | new_camera_state = CameraState( 28 | nickname=nickname, 29 | source=source, 30 | ) 31 | await manager.update_camera_state(camera_uuid, new_camera_state.model_dump()) 32 | return {"camera_uuid": camera_uuid, "nickname": nickname, "source": source} 33 | 34 | async def remove_camera(camera_uuid: str) -> bool: 35 | """ 36 | Removes a camera completely. 37 | 38 | Args: 39 | camera_uuid (str): The UUID of the camera to remove. 40 | 41 | Returns: 42 | bool: True if the camera was removed successfully, False otherwise. 43 | """ 44 | manager = get_camera_state_manager() 45 | return await manager.remove_camera(camera_uuid) 46 | 47 | def find_available_serial_cameras() -> list[str]: 48 | """ 49 | Finds all available camera devices and returns their paths or indices. 50 | 51 | This function is designed to be cross-platform and works on Linux, macOS, 52 | Windows, and within Docker containers where devices are correctly mapped. 53 | 54 | On Linux, it first attempts to find device paths like '/dev/video*'. 55 | If that fails or on other platforms (macOS, Windows), it probes for 56 | camera indices by trying to open them sequentially. 57 | 58 | Returns: 59 | list[str]: A list of strings, where each string is either a device 60 | path (on Linux) or a camera index. An empty list is 61 | returned if no cameras are found. 62 | """ 63 | logging.debug("INFO: Running on platform: %s", sys.platform) 64 | if sys.platform.startswith('linux'): 65 | logging.debug("INFO: Detected Linux platform. Searching for /dev/video* devices.") 66 | device_paths = glob.glob('/dev/video*') 67 | if device_paths: 68 | logging.debug("INFO: Found device paths: %s", device_paths) 69 | return sorted(device_paths) 70 | else: 71 | logging.warning("WARN: No /dev/video* devices found. Falling back to index probing.") 72 | api_preference = cv2.CAP_ANY 73 | if sys.platform == "win32": 74 | api_preference = cv2.CAP_DSHOW 75 | available_indices = [] 76 | index = 0 77 | while len(available_indices) < 10: 78 | cap = cv2.VideoCapture(index, api_preference) 79 | if cap.isOpened(): 80 | logging.debug("INFO: Camera found at index: %s", index) 81 | available_indices.append(str(index)) 82 | cap.release() 83 | else: 84 | logging.debug("INFO: No camera found at index: %s", index) 85 | cap.release() 86 | break 87 | index += 1 88 | return available_indices 89 | 90 | def open_camera(camera_uuid) -> cv2.VideoCapture: 91 | """ 92 | Open the camera and return a VideoCapture object. 93 | 94 | Args: 95 | camera_uuid (str): The UUID of the camera. 96 | 97 | Returns: 98 | cv2.VideoCapture: The VideoCapture object for the camera. 99 | """ 100 | camera_state = get_camera_state_sync(camera_uuid) 101 | if not camera_state or not camera_state.source: 102 | raise ValueError(f"Camera with UUID {camera_uuid} does not have a valid source.") 103 | source = camera_state.source 104 | if isinstance(source, str) and source.isdigit(): 105 | source = int(source) 106 | 107 | cap = cv2.VideoCapture(source, cv2.CAP_ANY) 108 | if not cap.isOpened(): 109 | raise RuntimeError(f"Failed to open camera with UUID {camera_uuid}") 110 | return cap 111 | 112 | async def get_camera_state(camera_uuid, reset=False): 113 | """Get this camera's state, handling async context appropriately. 114 | 115 | Args: 116 | camera_uuid (str): The UUID of the camera. 117 | reset (bool): If True, resets the camera state to its default. 118 | 119 | Returns: 120 | CameraState: The state of the camera. 121 | """ 122 | manager = get_camera_state_manager() 123 | try: 124 | def sync_get_state(): 125 | return asyncio.run(manager.get_camera_state(camera_uuid, reset)) 126 | return await asyncio.to_thread(sync_get_state) 127 | except Exception as e: 128 | logging.error("Error in camera state access for camera %d: %s", camera_uuid, e) 129 | return CameraState() 130 | 131 | def get_camera_state_sync(camera_uuid, reset=False): 132 | """Synchronous wrapper for get_camera_state for contexts that cannot use async/await. 133 | 134 | Args: 135 | camera_uuid (str): The UUID of the camera. 136 | reset (bool): If True, resets the camera state to its default. 137 | 138 | Returns: 139 | CameraState: The state of the camera. 140 | """ 141 | try: 142 | try: 143 | asyncio.get_running_loop() 144 | def run_in_new_loop(): 145 | return asyncio.run(get_camera_state(camera_uuid, reset)) 146 | with concurrent.futures.ThreadPoolExecutor() as executor: 147 | future = executor.submit(run_in_new_loop) 148 | return future.result(timeout=5.0) 149 | except RuntimeError: 150 | return asyncio.run(get_camera_state(camera_uuid, reset)) 151 | except Exception as e: 152 | logging.error("Error in synchronous camera state access for camera %d: %s", camera_uuid, e) 153 | return CameraState() 154 | 155 | async def update_camera_detection_history(camera_uuid, pred, time_val): 156 | """Append a detection to the camera's detection history. 157 | 158 | Args: 159 | camera_uuid (str): The UUID of the camera. 160 | pred (str): The prediction (detection) label. 161 | time_val (float): The timestamp of the detection. 162 | 163 | Returns: 164 | Optional[CameraState]: The updated camera state, or None if not found. 165 | """ 166 | manager = get_camera_state_manager() 167 | return await manager.update_camera_detection_history(camera_uuid, pred, time_val) 168 | 169 | async def update_camera_state(camera_uuid, new_states): 170 | """Update the camera's state with thread safety and persistence. 171 | 172 | Args: 173 | camera_uuid (str): The UUID of the camera. 174 | new_states (dict): A dictionary of states to update. 175 | Example: {"state_key": new_value} 176 | 177 | Returns: 178 | Optional[CameraState]: The updated camera state, or None if not found. 179 | """ 180 | manager = get_camera_state_manager() 181 | return await manager.update_camera_state(camera_uuid, new_states) 182 | -------------------------------------------------------------------------------- /printguard/routes/index_routes.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | 4 | from fastapi import Form, Request, APIRouter 5 | from fastapi.exceptions import HTTPException 6 | from fastapi.responses import RedirectResponse 7 | 8 | from ..utils.config import (STREAM_MAX_FPS, STREAM_TUNNEL_FPS, 9 | STREAM_JPEG_QUALITY, STREAM_MAX_WIDTH, 10 | DETECTION_INTERVAL_MS, PRINTER_STAT_POLLING_RATE_MS, 11 | MIN_SSE_DISPATCH_DELAY_MS, 12 | update_config, get_config) 13 | from ..utils.camera_utils import update_camera_state 14 | from ..utils.camera_state_manager import get_camera_state_manager 15 | from ..utils.stream_utils import stream_optimizer 16 | from ..models import FeedSettings, SavedConfig 17 | 18 | router = APIRouter() 19 | 20 | @router.get("/", include_in_schema=False) 21 | async def serve_index(request: Request): 22 | """Serve the main index page with camera states and configuration. 23 | 24 | Args: 25 | request (Request): The FastAPI request object. 26 | 27 | Returns: 28 | TemplateResponse: Rendered index.html template with camera states and settings. 29 | """ 30 | # pylint: disable=import-outside-toplevel 31 | from ..app import templates 32 | camera_state_manager = get_camera_state_manager() 33 | camera_uuids = await camera_state_manager.get_all_camera_uuids() 34 | if not camera_uuids: 35 | logging.warning("No camera UUIDs found, attempting to initialize cameras...") 36 | camera_uuids = await camera_state_manager.get_all_camera_uuids() 37 | camera_states = {} 38 | for cam_uuid in camera_uuids: 39 | camera_states[cam_uuid] = await camera_state_manager.get_camera_state(cam_uuid) 40 | return templates.TemplateResponse("index.html", { 41 | "camera_states": camera_states, 42 | "request": request, 43 | "current_time": time.time(), 44 | }) 45 | 46 | # pylint: disable=unused-argument 47 | @router.post("/", include_in_schema=False) 48 | async def update_settings(request: Request, 49 | camera_uuid: str = Form(...), 50 | sensitivity: float = Form(...), 51 | brightness: float = Form(...), 52 | contrast: float = Form(...), 53 | focus: float = Form(...), 54 | countdown_time: int = Form(...), 55 | countdown_action: str = Form(...), 56 | majority_vote_threshold: int = Form(...), 57 | majority_vote_window: int = Form(...), 58 | ): 59 | """Update camera settings and detection parameters. 60 | 61 | Args: 62 | request (Request): The FastAPI request object. 63 | camera_uuid (str): UUID of the camera to update settings for. 64 | sensitivity (float): Detection sensitivity level. 65 | brightness (float): Camera brightness setting. 66 | contrast (float): Camera contrast setting. 67 | focus (float): Camera focus setting. 68 | countdown_time (int): Alert countdown duration in seconds. 69 | countdown_action (str): Action to take when countdown expires. 70 | majority_vote_threshold (int): Number of detections needed for majority vote. 71 | majority_vote_window (int): Time window for majority vote calculation. 72 | 73 | Returns: 74 | RedirectResponse: Redirect to the main index page. 75 | """ 76 | await update_camera_state(camera_uuid, { 77 | "sensitivity": sensitivity, 78 | "brightness": brightness, 79 | "contrast": contrast, 80 | "focus": focus, 81 | "countdown_time": countdown_time, 82 | "countdown_action": countdown_action, 83 | "majority_vote_threshold": majority_vote_threshold, 84 | "majority_vote_window": majority_vote_window, 85 | }) 86 | return RedirectResponse("/", status_code=303) 87 | 88 | 89 | @router.post("/save-feed-settings", include_in_schema=False) 90 | async def save_feed_settings(settings: FeedSettings): 91 | """Save camera feed and detection settings to configuration. 92 | 93 | Args: 94 | settings (FeedSettings): Feed configuration settings including FPS, 95 | quality, detection intervals, and polling rates. 96 | 97 | Returns: 98 | dict: Success status and message indicating settings were saved. 99 | 100 | Raises: 101 | HTTPException: If saving settings fails due to validation or storage errors. 102 | """ 103 | try: 104 | config_data = { 105 | SavedConfig.STREAM_MAX_FPS: settings.stream_max_fps, 106 | SavedConfig.STREAM_TUNNEL_FPS: settings.stream_tunnel_fps, 107 | SavedConfig.STREAM_JPEG_QUALITY: settings.stream_jpeg_quality, 108 | SavedConfig.STREAM_MAX_WIDTH: settings.stream_max_width, 109 | SavedConfig.DETECTION_INTERVAL_MS: settings.detection_interval_ms, 110 | SavedConfig.PRINTER_STAT_POLLING_RATE_MS: settings.printer_stat_polling_rate_ms, 111 | SavedConfig.MIN_SSE_DISPATCH_DELAY_MS: settings.min_sse_dispatch_delay_ms 112 | } 113 | update_config(config_data) 114 | stream_optimizer.invalidate_cache() 115 | logging.debug("Feed settings saved successfully.") 116 | return {"success": True, "message": "Feed settings saved successfully."} 117 | except Exception as e: 118 | logging.error("Error saving feed settings: %s", e) 119 | raise HTTPException( 120 | status_code=500, 121 | detail=f"Failed to save feed settings: {str(e)}" 122 | ) 123 | 124 | @router.get("/get-feed-settings", include_in_schema=False) 125 | async def get_feed_settings(): 126 | """Retrieve current camera feed and detection settings. 127 | 128 | Returns: 129 | dict: Current feed settings including FPS, quality, detection intervals, 130 | polling rates, and calculated detections per second. 131 | 132 | Raises: 133 | HTTPException: If loading settings fails due to configuration errors. 134 | """ 135 | try: 136 | config = get_config() 137 | # pylint:disable=import-outside-toplevel 138 | settings = { 139 | "stream_max_fps": config.get(SavedConfig.STREAM_MAX_FPS, STREAM_MAX_FPS), 140 | "stream_tunnel_fps": config.get(SavedConfig.STREAM_TUNNEL_FPS, STREAM_TUNNEL_FPS), 141 | "stream_jpeg_quality": config.get(SavedConfig.STREAM_JPEG_QUALITY, STREAM_JPEG_QUALITY), 142 | "stream_max_width": config.get(SavedConfig.STREAM_MAX_WIDTH, STREAM_MAX_WIDTH), 143 | "detection_interval_ms": config.get(SavedConfig.DETECTION_INTERVAL_MS, DETECTION_INTERVAL_MS), 144 | "printer_stat_polling_rate_ms": config.get(SavedConfig.PRINTER_STAT_POLLING_RATE_MS, PRINTER_STAT_POLLING_RATE_MS), 145 | "min_sse_dispatch_delay_ms": config.get(SavedConfig.MIN_SSE_DISPATCH_DELAY_MS, MIN_SSE_DISPATCH_DELAY_MS) 146 | } 147 | settings["detections_per_second"] = round(1000 / settings["detection_interval_ms"]) 148 | return {"success": True, "settings": settings} 149 | except Exception as e: 150 | logging.error("Error loading feed settings: %s", e) 151 | raise HTTPException( 152 | status_code=500, 153 | detail=f"Failed to load feed settings: {str(e)}" 154 | ) 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PrintGuard - Local 3D Printing Failure Detection and Monitoring 2 | [![PyPI - Version](https://img.shields.io/pypi/v/printguard?style=for-the-badge&logo=pypi&logoColor=white&logoSize=auto&color=yellow)](https://pypi.org/project/printguard/) 3 | [![GitHub Repo stars](https://img.shields.io/github/stars/oliverbravery/printguard?style=for-the-badge&logo=github&logoColor=white&logoSize=auto&color=yellow)](https://github.com/oliverbravery/printguard) 4 | 5 | PrintGuard offers local, **real-time print failure detection** for **3D printing** on edge devices. A **web interface** enables users to **monitor multiple printer-facing cameras**, **connect to printers** through compatible services (i.e. [Octoprint](https://octoprint.org)) and **receive failure notifications** when the **computer vision** fault detection model designed for local edge deployment detects an issue and **automatically suspend or terminate the print job**. 6 | 7 | > _The machine learning model's training code and technical research paper can be found [here](https://github.com/oliverbravery/Edge-FDM-Fault-Detection)._ 8 | 9 | ## Features 10 | - **Web Interface**: A user-friendly web interface to monitor print jobs and camera feeds. 11 | - **Live Print Failure Detection**: Uses a custom computer vision model to detect print failures in real-time on edge devices. 12 | - **Multiple Inference Backends**: Supports PyTorch & ONNX Runtime for optimized performance across different deployment scenarios. 13 | - **Notifications**: Sends notifications subscribable on desktop and mobile devices via web push notifications to notify of detected print failures. 14 | - **Camera Integration**: Supports multiple camera feeds and simultaneous failure detection. 15 | - **Printer Integration**: Integrates with printers through services like Octoprint, allowing users to link cameras to specific printers for automatic print termination or suspension when a failure is detected. 16 | - **Local and Remote Access**: Can be accessed locally or remotely via secure tunnels (e.g. ngrok, Cloudflare Tunnel) or within a local network utilising the setup page for easy configuration. 17 | 18 | ## Table of Contents 19 | - [Features](#features) 20 | - [Installation](#installation) 21 | - [PyPI Installation](#pypi-installation) 22 | - [Docker Installation](#docker-installation) 23 | - [Initial Configuration](#initial-configuration) 24 | - [Usage](#usage) 25 | - [Technical Documentation](/docs/overview.md) 26 | 27 | ## Installation 28 | 29 | ### PyPI Installation 30 | > _The project is currently in pre-release, so the `--pre` flag is required for installation._ 31 | 32 | PrintGuard is installable via [pip](https://pypi.org/project/printguard/). The following command will install the latest version: 33 | ```bash 34 | pip install --pre printguard 35 | ``` 36 | To start the web interface, run: 37 | ```bash 38 | printguard 39 | ``` 40 | 41 | ### Docker Installation 42 | PrintGuard is also available as a Docker image, which can be pulled from GitHub Container Registry (GHCR): 43 | ```bash 44 | docker pull ghcr.io/oliverbravery/printguard:latest 45 | ``` 46 | 47 | Alternatively, you can build the Docker image from the source, specifying the platforms you want to build for: 48 | ```bash 49 | docker buildx build \ 50 | --platform linux/amd64,linux/arm64,linux/arm/v7 \ 51 | -t oliverbravery/printguard:local \ 52 | --load \ 53 | . 54 | ``` 55 | 56 | To run the Docker container, use the following command. Note that the container requires a volume for persistent data storage and an environment variable for the secret key. Use the `--privileged` flag to allow access to the host's camera devices. 57 | 58 | To run the Docker container pulled from GHCR, use the following command: 59 | ```bash 60 | docker run \ 61 | -p 8000:8000 \ 62 | -v "$(pwd)/data:/data" \ 63 | --privileged \ 64 | ghcr.io/oliverbravery/printguard:latest 65 | ``` 66 | 67 | To run the Docker container built from the source, use the following command: 68 | ```bash 69 | docker run \ 70 | -p 8000:8000 \ 71 | -v "$(pwd)/data:/data" \ 72 | --privileged \ 73 | oliverbravery/printguard:local 74 | ``` 75 | 76 | ## Initial Configuration 77 | After installation, you will need to configure PrintGuard. First, visit the setup page at `http://localhost:8000/setup`. The setup page allows users to configure network access to the locally hosted site, including seamless options for exposing it via popular reverse proxies for a streamlined setup. All setups require you to choose to either automatically generate or import self-signed SSL certificates for secure access, alongside VAPID keys which are required for web push notifications. 78 | 79 | > [Cloudflare](https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/) - A secure way to expose your local web interface to the internet via reverse proxies, providing a reliable and secure connection without needing to open ports on your router. Cloudflare tunnels are free to use and offer a simple setup process however, a domain connected to your Cloudflare account is required. Restricted access to your PrintGuard site can be setup through [Cloudflare Access](https://one.dash.cloudflare.com/), configurable in the setup page. During setup, your API key is used to create a tunnel to your local server and insert a DNS record for the tunnel, allowing you to access your PrintGuard instance via your custom domain or subdomain. 80 | 81 | > [Ngrok](https://ngrok.com/) - Reverse proxy tool which allows you to expose the local web interface to the internet for access outside of your local network, offering a secure tunnel to your local server with minimal configuration through both free and paid plans. The setup uses your ngrok API to create a tunnel to your local server and link it to your free static ngrok domain aquired during setup, allowing access to PrintGuard via a custom, static subdomain. 82 | 83 | > Local Network Access - If you prefer not to expose your web interface to the internet, you can configure PrintGuard to be accessible only within your local network. 84 | 85 | ## Usage 86 | | | | 87 | | --- | --- | 88 | | ![PrintGuard Web Interface](docs/media/images/interface-index.png) | The main interface of PrintGuard. All cameras are selectable in the bottom left camera list. The live camera view displayed in the top right shows the feed of the currently selected camera. The current detection status, total detections and frame rate are displayed in the bottom right alongside a button to toggle live detection for the selected camera on or off. | 89 | | ![PrintGuard Camera Settings](docs/media/images/interface-camera-settings.png) | The camera settings page is accessible via the settings button in the bottom right of the main interface. It allows you to configure camera settings, including camera brightness and contrast, detection thresholds, link a printer to the camera via services such as Octoprint, and configure alert and notification settings for that camera. You can also opt into web push notifications for real-time alerts here. | 90 | | ![PrintGuard Setup Settings](docs/media/images/interface-setup-settings.png) | Accessible via the configure setup button in the settings menu, the setup page allows configuration of camera feed streaming settings such as resolution and frame rate, as well as polling intervals and detection rates. | 91 | | ![PrintGuard Alerts and Notifications](docs/media/images/interface-alerts-notifications.png) | When a failure is detected a notification is dispatched to subscribed devices via web push notifications, allowing users to get real-time alerts and updates about their print. On the web interface, an alert modal appears showing a snapshot of the failure and buttons to dismiss the alert or suspend/cancel the print job. If the alert is not addressed within the customisable countdown time, the printer will automatically be suspended, cancelled or resumed based on user settings. | 92 | | | | -------------------------------------------------------------------------------- /printguard/utils/sse_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import time 5 | 6 | from ..models import (SSEDataType, PrinterState, 7 | PollingTask, SavedConfig) 8 | from ..utils.config import get_config, MIN_SSE_DISPATCH_DELAY_MS 9 | 10 | _last_dispatch_times = {} 11 | 12 | async def outbound_packet_fetch(): 13 | """Async generator yielding outbound SSE packets for clients. 14 | 15 | Yields: 16 | str: Serialized JSON packet from application outbound queue. 17 | """ 18 | # pylint: disable=C0415 19 | from ..app import app 20 | while True: 21 | packet = await app.state.outbound_queue.get() 22 | yield packet 23 | 24 | async def append_new_outbound_packet(packet, sse_data_type: SSEDataType): 25 | """Append a new Server-Sent Event packet to the outbound queue. 26 | 27 | Args: 28 | packet (str): The JSON-serialized data payload. 29 | sse_data_type (SSEDataType): The type of SSE event. 30 | """ 31 | config = get_config() 32 | min_sse_dispatch_delay = config.get(SavedConfig.MIN_SSE_DISPATCH_DELAY_MS, MIN_SSE_DISPATCH_DELAY_MS) 33 | current_time = time.time() * 1000 34 | last_dispatch_time = _last_dispatch_times.get(sse_data_type, 0) 35 | time_since_last_dispatch = current_time - last_dispatch_time 36 | if time_since_last_dispatch < min_sse_dispatch_delay: 37 | logging.debug("Throttling SSE dispatch for %s (time since last: %.1fms)", 38 | sse_data_type.value, time_since_last_dispatch) 39 | return 40 | # pylint: disable=C0415 41 | from ..app import app 42 | pkt = {"data": {"event": sse_data_type.value, "data": packet}} 43 | pkt_json = json.dumps(pkt) 44 | await app.state.outbound_queue.put(pkt_json) 45 | _last_dispatch_times[sse_data_type] = current_time 46 | 47 | async def append_new_outbound_packet_force(packet, sse_data_type: SSEDataType): 48 | """Force append a new Server-Sent Event packet to the outbound queue, bypassing throttling. 49 | 50 | Args: 51 | packet (str): The JSON-serialized data payload. 52 | sse_data_type (SSEDataType): The type of SSE event. 53 | """ 54 | # pylint: disable=C0415 55 | from ..app import app 56 | pkt = {"data": {"event": sse_data_type.value, "data": packet}} 57 | pkt_json = json.dumps(pkt) 58 | await app.state.outbound_queue.put(pkt_json) 59 | current_time = time.time() * 1000 60 | _last_dispatch_times[sse_data_type] = current_time 61 | 62 | def reset_throttle_for_data_type(sse_data_type: SSEDataType): 63 | """Reset the throttle timer for a specific SSE data type. 64 | 65 | Args: 66 | sse_data_type (SSEDataType): The type of SSE event to reset throttling for. 67 | """ 68 | if sse_data_type in _last_dispatch_times: 69 | del _last_dispatch_times[sse_data_type] 70 | logging.debug("Reset throttle for SSE data type: %s", sse_data_type.value) 71 | 72 | def _calculate_frame_rate(detection_history): 73 | """Calculate frames per second based on detection timestamps. 74 | 75 | Args: 76 | detection_history (list of tuples): Each tuple is (timestamp, label). 77 | 78 | Returns: 79 | float: The calculated frame rate, or 0.0 if insufficient data. 80 | """ 81 | if len(detection_history) < 2: 82 | return 0.0 83 | times = [t for t, _ in detection_history] 84 | duration = times[-1] - times[0] 85 | return (len(times) - 1) / duration if duration > 0 else 0.0 86 | 87 | async def _sse_update_camera_state_func(camera_uuid): 88 | """Build and send a camera state update SSE packet. 89 | 90 | Args: 91 | camera_uuid (str): The UUID of the camera. 92 | """ 93 | # pylint: disable=import-outside-toplevel 94 | from .camera_utils import get_camera_state 95 | state = await get_camera_state(camera_uuid) 96 | detection_history = state.detection_history 97 | total_detections = len(detection_history) 98 | frame_rate = _calculate_frame_rate(detection_history) 99 | data = { 100 | "start_time": state.start_time, 101 | "last_result": state.last_result, 102 | "last_time": state.last_time, 103 | "total_detections": total_detections, 104 | "frame_rate": frame_rate, 105 | "error": state.error, 106 | "live_detection_running": state.live_detection_running, 107 | "camera_uuid": camera_uuid 108 | } 109 | await append_new_outbound_packet(data, SSEDataType.CAMERA_STATE) 110 | 111 | async def sse_update_printer_state(printer_state: PrinterState): 112 | """Send an SSE update with the current printer state. 113 | 114 | Args: 115 | printer_state (PrinterState): The printer state object. 116 | """ 117 | try: 118 | await asyncio.wait_for( 119 | append_new_outbound_packet(printer_state.model_dump(), SSEDataType.PRINTER_STATE), 120 | timeout=5.0 121 | ) 122 | except asyncio.TimeoutError: 123 | logging.warning("SSE printer state update timed out") 124 | except (ValueError, TypeError, AttributeError) as e: 125 | logging.error("Error in SSE printer state update: %s", e) 126 | except Exception as e: 127 | logging.error("Unexpected error in SSE printer state update: %s", e) 128 | 129 | async def sse_update_camera_state(camera_uuid): 130 | """Send an SSE update with the current camera state. 131 | 132 | Args: 133 | camera_uuid (str): The UUID of the camera. 134 | """ 135 | try: 136 | await asyncio.wait_for(_sse_update_camera_state_func(camera_uuid), timeout=5.0) 137 | except asyncio.TimeoutError: 138 | logging.warning("SSE camera state update timed out for camera %s", camera_uuid) 139 | except (ValueError, TypeError, AttributeError) as e: 140 | logging.error("Error in SSE camera state update for camera %s: %s", camera_uuid, e) 141 | except Exception as e: # pylint: disable=broad-except 142 | logging.error("Unexpected error in SSE camera state update for camera %s: %s", 143 | camera_uuid, e) 144 | 145 | def get_polling_task(camera_uuid): 146 | """Retrieve the current polling task for a camera. 147 | 148 | Args: 149 | camera_uuid (str): The UUID of the camera. 150 | 151 | Returns: 152 | PollingTask or None: The polling task if exists, otherwise None. 153 | """ 154 | # pylint: disable=C0415 155 | from ..app import app 156 | return app.state.polling_tasks.get(camera_uuid) or None 157 | 158 | def stop_and_remove_polling_task(camera_uuid): 159 | """Stop and remove a polling task for a specified camera. 160 | 161 | Args: 162 | camera_uuid (str): The UUID of the camera. 163 | """ 164 | # pylint: disable=C0415 165 | from ..app import app 166 | task = get_polling_task(camera_uuid) 167 | if task: 168 | task.stop_event.set() 169 | if task.task and not task.task.done(): 170 | task.task.cancel() 171 | logging.debug("Stopped polling task for camera UUID %s", camera_uuid) 172 | del app.state.polling_tasks[camera_uuid] 173 | else: 174 | logging.warning("No polling task found for camera UUID %s to stop.", camera_uuid) 175 | 176 | def add_polling_task(camera_uuid, task: PollingTask): 177 | """Add or replace a polling task for a camera. 178 | 179 | Args: 180 | camera_uuid (str): The UUID of the camera. 181 | task (PollingTask): The task object containing the asyncio.Task and stop_event. 182 | """ 183 | # pylint: disable=C0415 184 | from ..app import app 185 | if camera_uuid in app.state.polling_tasks: 186 | stop_and_remove_polling_task(camera_uuid) 187 | app.state.polling_tasks[camera_uuid] = task 188 | logging.debug("Added polling task for camera UUID %s", camera_uuid) 189 | -------------------------------------------------------------------------------- /printguard/utils/shared_video_stream.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | import time 4 | from typing import Dict, Optional, List, Callable 5 | import cv2 6 | import numpy as np 7 | 8 | from .camera_utils import get_camera_state_sync 9 | 10 | 11 | class SharedVideoStream: 12 | """A shared video stream that allows multiple consumers to access the same camera source.""" 13 | 14 | def __init__(self, camera_uuid: str, source: str): 15 | # pylint: disable=E1101 16 | self.camera_uuid = camera_uuid 17 | self.source = source 18 | self.cap: Optional[cv2.VideoCapture] = None 19 | self.latest_frame: Optional[np.ndarray] = None 20 | self.frame_lock = threading.Lock() 21 | self.consumers: List[Callable] = [] 22 | self.is_running = False 23 | self.thread: Optional[threading.Thread] = None 24 | self.last_frame_time = 0 25 | self.frame_count = 0 26 | 27 | def start(self): 28 | """Start the video stream capture thread.""" 29 | if self.is_running: 30 | return 31 | self.is_running = True 32 | self.thread = threading.Thread(target=self._capture_loop, daemon=True) 33 | self.thread.start() 34 | logging.debug("Started shared video stream for camera %s", self.camera_uuid) 35 | 36 | def stop(self): 37 | """Stop the video stream capture thread.""" 38 | self.is_running = False 39 | if self.thread: 40 | self.thread.join(timeout=1.0) 41 | if self.cap and self.cap.isOpened(): 42 | self.cap.release() 43 | logging.debug("Stopped shared video stream for camera %s", self.camera_uuid) 44 | 45 | def _capture_loop(self): 46 | """Main capture loop that runs in a separate thread.""" 47 | # pylint: disable=E1101 48 | try: 49 | source = self.source 50 | if isinstance(source, str) and source.isdigit(): 51 | source = int(source) 52 | self.cap = cv2.VideoCapture(source, cv2.CAP_ANY) 53 | if not self.cap.isOpened(): 54 | logging.error("Failed to open camera source %s for shared stream", source) 55 | return 56 | self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) 57 | if isinstance(source, str) and source.startswith('rtp://'): 58 | self.cap.set(cv2.CAP_PROP_OPEN_TIMEOUT_MSEC, 5000) 59 | consecutive_failures = 0 60 | max_consecutive_failures = 10 61 | while self.is_running: 62 | ret, frame = self.cap.read() 63 | if not ret: 64 | consecutive_failures += 1 65 | logging.warning("Failed to read frame from camera %s (failure %d/%d)", 66 | self.camera_uuid, consecutive_failures, max_consecutive_failures) 67 | if consecutive_failures >= max_consecutive_failures: 68 | logging.error( 69 | "Too many consecutive failures for camera %s, stopping stream", 70 | self.camera_uuid) 71 | break 72 | time.sleep(0.1) 73 | continue 74 | else: 75 | consecutive_failures = 0 76 | with self.frame_lock: 77 | self.latest_frame = frame.copy() 78 | self.last_frame_time = time.time() 79 | self.frame_count += 1 80 | time.sleep(0.001) 81 | except (cv2.error, OSError, ValueError) as e: 82 | logging.error("Error in shared video stream for camera %s: %s", self.camera_uuid, e) 83 | finally: 84 | if self.cap and self.cap.isOpened(): 85 | self.cap.release() 86 | 87 | def get_frame(self) -> Optional[np.ndarray]: 88 | """Get the latest frame from the shared stream.""" 89 | with self.frame_lock: 90 | if self.latest_frame is not None: 91 | return self.latest_frame.copy() 92 | return None 93 | 94 | def is_frame_available(self) -> bool: 95 | """Check if a frame is available.""" 96 | with self.frame_lock: 97 | return self.latest_frame is not None 98 | 99 | def get_frame_info(self) -> Dict: 100 | """Get information about the current frame.""" 101 | with self.frame_lock: 102 | return { 103 | 'frame_count': self.frame_count, 104 | 'last_frame_time': self.last_frame_time, 105 | 'has_frame': self.latest_frame is not None, 106 | 'is_running': self.is_running, 107 | 'is_healthy': self.is_running and time.time() - self.last_frame_time < 5.0 108 | } 109 | 110 | 111 | class SharedVideoStreamManager: 112 | """Manages shared video streams for multiple cameras.""" 113 | 114 | def __init__(self): 115 | self.streams: Dict[str, SharedVideoStream] = {} 116 | self.lock = threading.Lock() 117 | 118 | def get_stream(self, camera_uuid: str, source: str) -> SharedVideoStream: 119 | """Get or create a shared video stream for a camera.""" 120 | with self.lock: 121 | if camera_uuid not in self.streams: 122 | self.streams[camera_uuid] = SharedVideoStream(camera_uuid, source) 123 | else: 124 | existing_stream = self.streams[camera_uuid] 125 | if (not existing_stream.is_running 126 | or not existing_stream.thread 127 | or not existing_stream.thread.is_alive()): 128 | logging.info("Restarting shared video stream for camera %s", camera_uuid) 129 | existing_stream.stop() 130 | self.streams[camera_uuid] = SharedVideoStream(camera_uuid, source) 131 | stream = self.streams[camera_uuid] 132 | if not stream.is_running: 133 | stream.start() 134 | return stream 135 | 136 | def release_stream(self, camera_uuid: str): 137 | """Release a shared video stream.""" 138 | with self.lock: 139 | if camera_uuid in self.streams: 140 | self.streams[camera_uuid].stop() 141 | del self.streams[camera_uuid] 142 | 143 | def cleanup_all(self): 144 | """Clean up all shared video streams.""" 145 | with self.lock: 146 | for stream in self.streams.values(): 147 | stream.stop() 148 | self.streams.clear() 149 | 150 | def get_stream_health(self, camera_uuid: str) -> Dict: 151 | """Get health information for a specific stream.""" 152 | with self.lock: 153 | if camera_uuid in self.streams: 154 | return self.streams[camera_uuid].get_frame_info() 155 | return {'is_running': False, 'is_healthy': False, 'has_frame': False} 156 | 157 | _shared_stream_manager = SharedVideoStreamManager() 158 | 159 | def get_shared_stream_manager() -> SharedVideoStreamManager: 160 | """Get the global shared stream manager.""" 161 | return _shared_stream_manager 162 | 163 | def get_shared_camera_frame(camera_uuid: str) -> Optional[np.ndarray]: 164 | """Get a frame from the shared camera stream.""" 165 | try: 166 | camera_state = get_camera_state_sync(camera_uuid) 167 | if not camera_state or not camera_state.source: 168 | return None 169 | manager = get_shared_stream_manager() 170 | stream = manager.get_stream(camera_uuid, camera_state.source) 171 | max_wait = 50 172 | wait_count = 0 173 | while not stream.is_frame_available() and wait_count < max_wait: 174 | time.sleep(0.1) 175 | wait_count += 1 176 | return stream.get_frame() 177 | except (ImportError, AttributeError) as e: 178 | logging.error("Error getting shared camera frame for %s: %s", camera_uuid, e) 179 | return None 180 | -------------------------------------------------------------------------------- /scripts/convert_pytorch_to_onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import sys 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | import torch.onnx 11 | import onnxruntime as ort 12 | 13 | project_root = Path(__file__).parent.parent 14 | sys.path.insert(0, str(project_root)) 15 | 16 | try: 17 | import printguard.utils.backends.protonets as _pn 18 | sys.modules['protonets'] = _pn 19 | except ImportError: 20 | pass 21 | 22 | 23 | def get_available_devices(): 24 | """Get list of available devices for model conversion.""" 25 | devices = ["cpu"] 26 | if torch.cuda.is_available(): 27 | devices.append("cuda") 28 | if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 29 | devices.append("mps") 30 | return devices 31 | 32 | def validate_device(device: str): 33 | """Validate if the requested device is available.""" 34 | available_devices = get_available_devices() 35 | if device not in available_devices: 36 | if device == "cuda" and not torch.cuda.is_available(): 37 | raise ValueError("CUDA is not available on this system") 38 | elif device == "mps" and not ( 39 | torch.backends.mps.is_available() and torch.backends.mps.is_built()): 40 | raise ValueError( 41 | "MPS is not available on this system. Requires macOS 12.3+ and Apple Silicon") 42 | else: 43 | raise ValueError( 44 | f"Device '{device}' is not available. Available devices: {available_devices}") 45 | return device 46 | 47 | 48 | def convert_pytorch_to_onnx(pytorch_model_path: str, options_path: str, 49 | output_path: str, device: str = "cpu"): 50 | """Convert a PyTorch model to ONNX format. 51 | 52 | Args: 53 | pytorch_model_path: Path to the PyTorch model file (.pt or .pth) 54 | options_path: Path to the model options JSON file 55 | output_path: Path where the ONNX model will be saved 56 | device: Device to use for conversion ('cpu', 'cuda', or 'mps') 57 | """ 58 | device = validate_device(device) 59 | try: 60 | logging.info("Loading PyTorch model from %s", pytorch_model_path) 61 | device_obj = torch.device(device) 62 | if device == "mps": 63 | logging.info("Using MPS (Metal Performance Shaders) for acceleration") 64 | elif device == "cuda": 65 | logging.info("Using CUDA GPU: %s", torch.cuda.get_device_name()) 66 | else: 67 | logging.info("Using CPU for conversion") 68 | full_model = torch.load(pytorch_model_path, map_location=device_obj, weights_only=False) 69 | if hasattr(full_model, 'encoder'): 70 | model = full_model.encoder 71 | logging.info("Extracted encoder from Protonet model") 72 | else: 73 | model = full_model 74 | logging.info("Using full model (no encoder attribute found)") 75 | model.eval() 76 | with open(options_path, 'r', encoding='utf-8') as f: 77 | model_opt = json.load(f) 78 | x_dim = list(map(int, model_opt['model.x_dim'].split(','))) 79 | logging.info("Model input dimensions: %s", x_dim) 80 | dummy_input = torch.randn(1, *x_dim).to(device_obj) 81 | logging.info("Testing model with dummy input...") 82 | with torch.no_grad(): 83 | test_output = model(dummy_input) 84 | logging.info("Model test successful. Output shape: %s", test_output.shape) 85 | export_params = { 86 | "input_names": ["input"], 87 | "output_names": ["output"], 88 | "dynamic_axes": { 89 | "input": {0: "batch_size"}, 90 | "output": {0: "batch_size"} 91 | }, 92 | "opset_version": 11, 93 | "do_constant_folding": True, 94 | "export_params": True, 95 | } 96 | logging.info("Converting to ONNX format...") 97 | torch.onnx.export( 98 | model, 99 | dummy_input, 100 | output_path, 101 | **export_params 102 | ) 103 | logging.info("ONNX model saved to %s", output_path) 104 | try: 105 | logging.info("Verifying ONNX model...") 106 | session = ort.InferenceSession(output_path) 107 | dummy_input_numpy = dummy_input.detach().cpu().numpy() 108 | output = session.run(None, {"input": dummy_input_numpy}) 109 | logging.info("ONNX model verification successful") 110 | logging.info("Output shape: %s", output[0].shape) 111 | pytorch_output = test_output.detach().cpu().numpy() 112 | onnx_output = output[0] 113 | max_diff = np.max(np.abs(pytorch_output - onnx_output)) 114 | logging.info("Maximum difference between PyTorch and ONNX outputs: %.6f", max_diff) 115 | if max_diff < 1e-5: 116 | logging.info("PyTorch and ONNX outputs are very close (diff < 1e-5)") 117 | elif max_diff < 1e-3: 118 | logging.info("PyTorch and ONNX outputs are close (diff < 1e-3)") 119 | else: 120 | logging.warning("PyTorch and ONNX outputs differ significantly (diff = %.6f)", 121 | max_diff) 122 | except ImportError: 123 | logging.warning("ONNX Runtime not available. Skipping verification.") 124 | except Exception as e: 125 | logging.error("ONNX model verification failed: %s", e) 126 | raise 127 | except Exception as e: 128 | logging.error("Failed to convert model: %s", e) 129 | raise 130 | 131 | def main(): 132 | """Main function to handle command line arguments and run conversion.""" 133 | # Get available devices for help text 134 | available_devices = get_available_devices() 135 | device_info = [] 136 | if "cpu" in available_devices: 137 | device_info.append("cpu (always available)") 138 | if "cuda" in available_devices: 139 | device_info.append("cuda (NVIDIA GPU detected)") 140 | if "mps" in available_devices: 141 | device_info.append("mps (Apple Silicon with Metal)") 142 | parser = argparse.ArgumentParser( 143 | description="Convert PyTorch models to ONNX format for PrintGuard", 144 | epilog=f"Available devices on this system: {', '.join(device_info)}" 145 | ) 146 | parser.add_argument( 147 | "pytorch_model", 148 | help="Path to the PyTorch model file (.pt or .pth)" 149 | ) 150 | parser.add_argument( 151 | "options_file", 152 | help="Path to the model options JSON file" 153 | ) 154 | parser.add_argument( 155 | "-o", "--output", 156 | help="Output path for the ONNX model (default: same name with .onnx extension)" 157 | ) 158 | parser.add_argument( 159 | "-d", "--device", 160 | choices=get_available_devices(), 161 | default="cpu", 162 | help="Device to use for conversion. 'cpu' is always available, 'cuda' requires NVIDIA GPU, 'mps' requires Apple Silicon Mac with macOS 12.3+" 163 | ) 164 | parser.add_argument( 165 | "-v", "--verbose", 166 | action="store_true", 167 | help="Enable verbose logging" 168 | ) 169 | args = parser.parse_args() 170 | log_level = logging.DEBUG if args.verbose else logging.INFO 171 | logging.basicConfig( 172 | level=log_level, 173 | format='%(asctime)s - %(levelname)s - %(message)s' 174 | ) 175 | if args.output: 176 | output_path = args.output 177 | else: 178 | pytorch_path = Path(args.pytorch_model) 179 | output_path = str(pytorch_path.with_suffix('.onnx')) 180 | if not os.path.exists(args.pytorch_model): 181 | logging.error("PyTorch model file not found: %s", args.pytorch_model) 182 | sys.exit(1) 183 | if not os.path.exists(args.options_file): 184 | logging.error("Options file not found: %s", args.options_file) 185 | sys.exit(1) 186 | try: 187 | convert_pytorch_to_onnx( 188 | args.pytorch_model, 189 | args.options_file, 190 | output_path, 191 | args.device 192 | ) 193 | logging.info("Conversion completed successfully!") 194 | except Exception as e: 195 | logging.error("Conversion failed: %s", e) 196 | sys.exit(1) 197 | 198 | if __name__ == "__main__": 199 | main() 200 | -------------------------------------------------------------------------------- /printguard/templates/cloudflare_setup.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Cloudflare Tunnel Authentication Setup - PrintGuard 7 | 8 | 9 | 10 | 11 | 12 | 13 |
14 |
15 |
16 |

🔒 Cloudflare Tunnel Authentication Setup

17 |

Setup authentication for your site so it can only be accessed by authorised emails.

18 |
19 |
20 |

Device Information

21 |

Current Device: Detecting...

22 |
23 |
24 |

📋 Setup Instructions

25 |
26 |

Initial Setup (First Time Only)

27 |

If this is your first time setting up Cloudflare WARP with Zero Trust, follow these detailed steps:

28 |
    29 |
  1. All high-level instructions found here just in case: https://developers.cloudflare.com/cloudflare-one/setup/
  2. 30 |
  3. First setup a zero trust account and team name by going dash.cloudflare.com and clicking zero trust. Note that zero trust free plan is all you need.
  4. 31 |
  5. Next, in zero trust (one.dash.cloudflare.com) go settings → authentication and under login methods, click add new and select one-time PIN - if instructions change, use this link https://developers.cloudflare.com/cloudflare-one/identity/one-time-pin/
  6. 32 |
  7. Next define enrollment permissions. Go to the zero trust again, then settings → warp client. in device enrollment section click the manage button for device enrollment permissions.
  8. 33 |
  9. Create new policy, then in the add policy page, in basic information section enter a policy name like WARP-login, action Allow and session duration same as application session timeout. In the add rules section, for the 'include' rule choose selector as emails and add the allowed emails to login (trusted emails) in the value field. Note: If you wish for your application to require users to be authenticated via the WARP client, add a 'require' rule to the policy with the 'Warp' selector. Save.
  10. 34 |
  11. Back on the device enrollment permissions page, press select existing policies and choose the one you just created. Then, in the login methods tab (still in device enrollment permissions page), under login methods, make sure 'Accept all available identity providers' is toggled off, One time PIN is ticked and 'Instant Auth' is toggled on. Click save.
  12. 35 |
  13. Finally, from the zero trust dashboard visit access → Applications → Add an application. Give it an appropriate application name like 'printguard' and session duration of 1 week. Set the public hostname to the same as your cloudflare tunnel domain then, in the policies section, add your existing WARP-login policy that was created earlier.
  14. 36 |
37 |
38 |

Download and Connect WARP Client

39 |

If in the initial setup instructions you chose to require the WARP client to access printguard for added security, you will need to follow these steps for each device you wish to enroll. If your policy did not require WARP, you can skip this section.

40 |
41 |

Step 1: Download WARP Client

42 |

First, download and install the Cloudflare WARP client for your device.

43 |

Click here to download directly from the Cloudflare website.

44 |

Or use the buttons below to download the client for your specific platform (links from Cloudflare):

45 | 62 |
63 |
64 |

Step 2: Connect to Organization

65 |
    66 |
  1. Open the WARP client application
  2. 67 |
  3. Navigate to Settings → Account → Login to Cloudflare Zero Trust
  4. 68 |
  5. Enter your organization's team name: Loading...
  6. 69 |
  7. Authenticate using your organization's login method
  8. 70 |
71 |
72 |
73 |

Step 3: Access the Site

74 |

Once connected to your organization's WARP network, you should be able to access the PrintGuard site automatically through your Cloudflare domain.

75 |
76 |
77 |
78 |

✅ Setup Complete!

79 |

If you've followed the setup steps correctly and are connected to your organization's Zero Trust, you should be able to access the PrintGuard application directly through your Cloudflare domain.

80 |

Your Cloudflare domain: Loading...

81 |

If your setup requires the WARP client, you will need to download and login to the WARP client on each device you wish to visit the site from. Otherwise, you can access the site directly through your Cloudflare domain and by logging in with one of your allowed email addresses.

82 |
83 |
84 |

📱 Share Link

85 |

86 | Scan the QR code or share the link to open this page on your other devices. 87 |

88 |
89 | 90 |
91 |
Loading...
92 | 93 |
94 |
95 |
96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /printguard/utils/backends/pytorch_engine.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pickle 5 | from typing import Any, List, Tuple 6 | 7 | import torch 8 | 9 | from .base_engine import BaseInferenceEngine 10 | 11 | try: 12 | from . import protonets as _pn 13 | import sys 14 | sys.modules['protonets'] = _pn 15 | except ImportError: 16 | pass 17 | 18 | class PyTorchInferenceEngine(BaseInferenceEngine): 19 | """PyTorch-based inference engine implementation.""" 20 | 21 | def load_model(self, model_path: str, options_path: str, device: str) -> Tuple[Any, List[int]]: 22 | """Load a PyTorch model and its configuration options. 23 | 24 | Args: 25 | model_path: Path to the saved model file 26 | options_path: Path to the JSON options file 27 | device: Device to load the model onto 28 | 29 | Returns: 30 | Tuple of (model, input_dimensions) 31 | """ 32 | device_obj = torch.device(device) 33 | model = torch.load(model_path, map_location=device_obj, weights_only=False) 34 | model.eval() 35 | with open(options_path, 'r', encoding='utf-8') as f: 36 | model_opt = json.load(f) 37 | x_dim = list(map(int, model_opt['model.x_dim'].split(','))) 38 | return model, x_dim 39 | 40 | def _compute_prototype_from_embeddings(self, embeddings: Any) -> Any: 41 | """Compute a single prototype from a set of embeddings. 42 | 43 | Args: 44 | embeddings: Embeddings tensor for a single class 45 | 46 | Returns: 47 | Prototype tensor for the class 48 | """ 49 | return embeddings.mean(0) 50 | 51 | def _stack_prototypes(self, prototypes: List[Any]) -> Any: 52 | """Stack individual prototypes into a single structure. 53 | 54 | Args: 55 | prototypes: List of individual prototype tensors 56 | 57 | Returns: 58 | Stacked prototype tensor 59 | """ 60 | return torch.stack(prototypes) 61 | 62 | def _copy_predictions(self, predictions: Any) -> Any: 63 | """Create a copy of predictions tensor.""" 64 | return predictions.clone() 65 | 66 | def _get_prediction_at_index(self, predictions: Any, index: int) -> int: 67 | """Get prediction at a specific index.""" 68 | return int(predictions[index]) 69 | 70 | def _get_min_distance_at_index(self, distances: Any, index: int) -> float: 71 | """Get minimum distance for a specific sample.""" 72 | return float(torch.min(distances[index])) 73 | 74 | def _get_distance_to_class(self, distances: Any, sample_idx: int, class_idx: int) -> float: 75 | """Get distance from sample to specific class.""" 76 | return float(distances[sample_idx, class_idx]) 77 | 78 | def _set_prediction_at_index(self, predictions: Any, index: int, value: int) -> None: 79 | """Set prediction at a specific index.""" 80 | predictions[index] = value 81 | 82 | def _is_empty_batch(self, batch_tensors: Any) -> bool: 83 | """Check if batch is empty (PyTorch-specific).""" 84 | return batch_tensors.shape[0] == 0 85 | 86 | def _compute_embeddings(self, model: Any, processed_images: List[Any], device: str) -> Any: 87 | """Compute embeddings for processed images using PyTorch. 88 | 89 | Args: 90 | model: The PyTorch model 91 | processed_images: List of processed image tensors 92 | device: Device to run computations on 93 | 94 | Returns: 95 | Computed embeddings tensor 96 | """ 97 | device_obj = torch.device(device) 98 | ts = torch.stack(processed_images).to(device_obj) 99 | with torch.no_grad(): 100 | emb = model.encoder(ts) 101 | return emb 102 | 103 | def predict_batch(self, model: Any, batch_tensors: Any, prototypes: Any, 104 | defect_idx: int, sensitivity: float, device: str) -> List[int]: 105 | """Predict classes for a batch of image tensors using prototype matching. 106 | 107 | Args: 108 | model: The encoder model 109 | batch_tensors: Batch of preprocessed image tensors 110 | prototypes: Class prototype tensors 111 | defect_idx: Index of the defect class for sensitivity adjustment 112 | sensitivity: Sensitivity multiplier for defect detection 113 | device: Device to run computations on 114 | 115 | Returns: 116 | List of predicted class indices for each input 117 | """ 118 | if not self._validate_batch_input(batch_tensors): 119 | return [] 120 | device_obj = torch.device(device) 121 | model.eval() 122 | with torch.no_grad(): 123 | batch_x = batch_tensors.to(device_obj) 124 | batch_emb = model.encoder(batch_x) 125 | distances = torch.cdist(batch_emb, prototypes) 126 | _, initial_preds = torch.min(distances, dim=1) 127 | final_preds = self._apply_sensitivity_adjustment(initial_preds, distances, defect_idx, sensitivity) 128 | return final_preds.cpu().tolist() 129 | 130 | def setup_device(self, requested_device: str) -> str: 131 | """Set up the compute device based on availability and request. 132 | 133 | Args: 134 | requested_device: Requested device ('cuda', 'mps', or 'cpu') 135 | 136 | Returns: 137 | The actual device string to use 138 | """ 139 | if requested_device == 'cuda' and torch.cuda.is_available(): 140 | device = 'cuda' 141 | elif requested_device == 'mps' and torch.backends.mps.is_available(): 142 | device = 'mps' 143 | else: 144 | device = 'cpu' 145 | if requested_device != 'cpu': 146 | logging.warning("%s requested but not available. Falling back to CPU.", 147 | requested_device) 148 | logging.debug("Using device: %s", device) 149 | return device 150 | 151 | def _save_prototypes(self, prototypes: torch.Tensor, class_names: List[str], 152 | defect_idx: int, cache_file: str) -> None: 153 | """Save computed prototypes to a cache file. 154 | 155 | Args: 156 | prototypes: The computed prototype tensors 157 | class_names: List of class names 158 | defect_idx: Index of the defect class 159 | cache_file: Path to save the cache file 160 | """ 161 | try: 162 | cache_dir = os.path.dirname(cache_file) 163 | os.makedirs(cache_dir, exist_ok=True) 164 | cache_data = { 165 | 'prototypes': prototypes.cpu(), 166 | 'class_names': class_names, 167 | 'defect_idx': defect_idx 168 | } 169 | with open(cache_file, 'wb') as f: 170 | pickle.dump(cache_data, f) 171 | logging.debug("Prototypes saved to cache: %s", cache_file) 172 | except (OSError, pickle.PickleError) as e: 173 | logging.warning("Failed to save prototypes to cache: %s", e) 174 | 175 | def _load_prototypes(self, cache_file: str, device: str = None) -> Tuple[Any, List[str], int]: 176 | """Load prototypes from a cache file. 177 | 178 | Args: 179 | cache_file: Path to the cache file 180 | device: Device to load tensors onto 181 | 182 | Returns: 183 | Tuple of (prototypes, class_names, defect_idx) or (None, None, -1) if loading fails 184 | """ 185 | try: 186 | if not os.path.exists(cache_file): 187 | return None, None, -1 188 | with open(cache_file, 'rb') as f: 189 | cache_data = pickle.load(f) 190 | if device is not None: 191 | device_obj = torch.device(device) if isinstance(device, str) else device 192 | prototypes = cache_data['prototypes'].to(device_obj) 193 | else: 194 | prototypes = cache_data['prototypes'] 195 | class_names = cache_data['class_names'] 196 | defect_idx = cache_data['defect_idx'] 197 | logging.debug("Prototypes loaded from cache: %s", cache_file) 198 | return prototypes, class_names, defect_idx 199 | except (OSError, pickle.PickleError, KeyError) as e: 200 | logging.warning("Failed to load prototypes from cache: %s", e) 201 | return None, None, -1 202 | -------------------------------------------------------------------------------- /printguard/utils/printer_services/octoprint.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import requests 3 | from ...models import (FileInfo, JobInfoResponse, 4 | TemperatureReadings, TemperatureReading, 5 | PrinterState, PrinterTemperatures) 6 | 7 | 8 | class OctoPrintClient: 9 | """ 10 | A client for interacting with OctoPrint's REST API. 11 | 12 | This class provides methods to control and monitor 3D printers through 13 | OctoPrint's web interface, including job management, temperature monitoring, 14 | and printer state retrieval. 15 | 16 | Attributes: 17 | base_url (str): The base URL of the OctoPrint instance 18 | headers (dict): HTTP headers including API key for authentication 19 | """ 20 | 21 | def __init__(self, base_url: str, api_key: str): 22 | """ 23 | Initialize the OctoPrint client. 24 | 25 | Args: 26 | base_url (str): The base URL of the OctoPrint instance (e.g., 'http://octopi.local') 27 | api_key (str): The API key for authentication with OctoPrint 28 | """ 29 | self.base_url = base_url.rstrip("/") 30 | self.headers = { 31 | "X-Api-Key": api_key, 32 | "Content-Type": "application/json" 33 | } 34 | 35 | def get_job_info(self) -> JobInfoResponse: 36 | """ 37 | Retrieve information about the current print job. 38 | 39 | Returns: 40 | JobInfoResponse: Complete job information including progress, file details, 41 | and print statistics 42 | 43 | Raises: 44 | requests.HTTPError: If the API request fails 45 | requests.Timeout: If the request times out 46 | """ 47 | resp = requests.get(f"{self.base_url}/api/job", 48 | headers=self.headers, 49 | timeout=10) 50 | resp.raise_for_status() 51 | return JobInfoResponse(**resp.json()) 52 | 53 | def cancel_job(self) -> None: 54 | """ 55 | Cancel the currently running print job. 56 | 57 | This will immediately stop the current print job and return the printer 58 | to an idle state. 59 | 60 | Raises: 61 | requests.HTTPError: If the API request fails 62 | requests.Timeout: If the request times out 63 | """ 64 | resp = requests.post( 65 | f"{self.base_url}/api/job", 66 | headers=self.headers, 67 | timeout=10, 68 | json={"command": "cancel"} 69 | ) 70 | if resp.status_code == 204: 71 | return 72 | resp.raise_for_status() 73 | 74 | def pause_job(self) -> None: 75 | """ 76 | Pause the currently running print job. 77 | 78 | This will temporarily halt the current print job, allowing it to be 79 | resumed later. 80 | 81 | Raises: 82 | requests.HTTPError: If the API request fails 83 | requests.Timeout: If the request times out 84 | """ 85 | resp = requests.post( 86 | f"{self.base_url}/api/job", 87 | headers=self.headers, 88 | timeout=10, 89 | json={"command": "pause"} 90 | ) 91 | if resp.status_code == 204: 92 | return 93 | resp.raise_for_status() 94 | 95 | def get_printer_temperatures(self) -> Dict[str, TemperatureReading]: 96 | """ 97 | Retrieve current temperature readings from all printer components. 98 | 99 | Returns: 100 | Dict[str, TemperatureReading]: Dictionary mapping component names 101 | (e.g., 'tool0', 'bed') to their temperature readings. 102 | Returns empty dict if printer is not operational. 103 | 104 | Raises: 105 | requests.HTTPError: If the API request fails (except for 409 conflicts) 106 | requests.Timeout: If the request times out 107 | """ 108 | resp = requests.get(f"{self.base_url}/api/printer", 109 | headers=self.headers, 110 | timeout=10) 111 | if resp.status_code == 409: 112 | return {} 113 | resp.raise_for_status() 114 | state = TemperatureReadings(**resp.json()) 115 | return state.temperature 116 | 117 | def percent_complete(self) -> float: 118 | """ 119 | Get the completion percentage of the current print job. 120 | 121 | Returns: 122 | float: Completion percentage (0.0 to 100.0) 123 | 124 | Raises: 125 | requests.HTTPError: If the API request fails 126 | requests.Timeout: If the request times out 127 | """ 128 | return self.get_job_info().progress.completion * 100 129 | 130 | def current_file(self) -> FileInfo: 131 | """ 132 | Get information about the currently loaded file. 133 | 134 | Returns: 135 | FileInfo: Details about the file being printed, including name, 136 | size, and other metadata 137 | 138 | Raises: 139 | requests.HTTPError: If the API request fails 140 | requests.Timeout: If the request times out 141 | """ 142 | return self.get_job_info().job["file"] 143 | 144 | def nozzle_and_bed_temps(self) -> Dict[str, float]: 145 | """ 146 | Get simplified temperature readings for nozzle and bed. 147 | 148 | This method provides a simplified interface to temperature data, 149 | returning both actual and target temperatures for the primary nozzle 150 | and heated bed. 151 | 152 | Returns: 153 | Dict[str, float]: Dictionary with keys: 154 | - 'nozzle_actual': Current nozzle temperature 155 | - 'nozzle_target': Target nozzle temperature 156 | - 'bed_actual': Current bed temperature 157 | - 'bed_target': Target bed temperature 158 | Returns 0.0 for all values if temperatures are unavailable. 159 | """ 160 | temps = self.get_printer_temperatures() 161 | if not temps: 162 | return { 163 | "nozzle_actual": 0.0, 164 | "nozzle_target": 0.0, 165 | "bed_actual": 0.0, 166 | "bed_target": 0.0, 167 | } 168 | tool0 = temps.get("tool0") 169 | bed = temps.get("bed") 170 | return { 171 | "nozzle_actual": tool0.actual if tool0 else 0.0, 172 | "nozzle_target": tool0.target if tool0 else 0.0, 173 | "bed_actual" : bed.actual if bed else 0.0, 174 | "bed_target" : bed.target if bed else 0.0, 175 | } 176 | 177 | def get_printer_state(self) -> PrinterState: 178 | """ 179 | Get comprehensive printer state information. 180 | 181 | This method combines job information and temperature readings into 182 | a unified printer state object, providing a complete snapshot of 183 | the printer's current status. 184 | 185 | Returns: 186 | PrinterState: Complete printer state including job information 187 | and temperature readings. Job info may be None if 188 | retrieval fails. 189 | 190 | Note: 191 | If job information retrieval fails, the jobInfoResponse field 192 | will be None, but temperature data will still be included if available. 193 | """ 194 | temperature_readings = self.get_printer_temperatures() 195 | tool0_temp = temperature_readings.get("tool0") if temperature_readings else None 196 | bed_temp = temperature_readings.get("bed") if temperature_readings else None 197 | printer_temps: PrinterTemperatures = PrinterTemperatures( 198 | nozzle_actual=tool0_temp.actual if tool0_temp else None, 199 | nozzle_target=tool0_temp.target if tool0_temp else None, 200 | bed_actual=bed_temp.actual if bed_temp else None, 201 | bed_target=bed_temp.target if bed_temp else None 202 | ) 203 | try: 204 | job_info = self.get_job_info() 205 | except Exception: 206 | job_info = None 207 | printer_state = PrinterState( 208 | jobInfoResponse=job_info, 209 | temperatureReading=printer_temps 210 | ) 211 | return printer_state 212 | -------------------------------------------------------------------------------- /printguard/utils/camera_state_manager.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from typing import Dict, Optional 4 | from pydantic import ValidationError 5 | from ..models import CameraState 6 | from .config import get_config, update_config, SavedConfig 7 | 8 | 9 | class CameraStateManager: 10 | """Manages the state of all cameras in the application.""" 11 | def __init__(self): 12 | """Initializes the CameraStateManager, loading states from the configuration.""" 13 | self._states: Dict[str, CameraState] = {} 14 | self._lock: Optional[asyncio.Lock] = None 15 | self._loop: Optional[asyncio.AbstractEventLoop] = None 16 | self._load_states_from_config() 17 | 18 | @property 19 | def lock(self) -> asyncio.Lock: 20 | """Provides a lock for thread-safe operations on camera states. 21 | 22 | Returns: 23 | asyncio.Lock: The lock instance for the current event loop. 24 | """ 25 | loop = asyncio.get_running_loop() 26 | if self._lock is None or self._loop is not loop: 27 | self._lock = asyncio.Lock() 28 | self._loop = loop 29 | return self._lock 30 | 31 | def _load_states_from_config(self): 32 | """Loads camera states from the application's configuration file.""" 33 | config = get_config() or {} 34 | saved_states = config.get(SavedConfig.CAMERA_STATES, {}) 35 | for camera_uuid, state_data in saved_states.items(): 36 | try: 37 | self._states[camera_uuid] = CameraState(**state_data) 38 | except (ValueError, TypeError, ValidationError) as e: 39 | logging.warning("Failed to load camera state for UUID %s: %s", camera_uuid, e) 40 | try: 41 | self._states[camera_uuid] = CameraState() 42 | logging.info("Created fresh camera state for UUID %s", camera_uuid) 43 | except Exception as ex: 44 | logging.error("Failed to create fresh camera state for UUID %s: %s", camera_uuid, ex) 45 | 46 | def _save_states_to_config(self): 47 | """Saves the current camera states to the application's configuration file.""" 48 | try: 49 | states_data = {} 50 | for camera_uuid, state in self._states.items(): 51 | state_dict = state.model_dump(exclude={'live_detection_task'}) 52 | if 'detection_history' in state_dict and len(state_dict['detection_history']) > 1000: 53 | state_dict['detection_history'] = state_dict['detection_history'][-1000:] 54 | states_data[camera_uuid] = state_dict 55 | update_config({SavedConfig.CAMERA_STATES: states_data}) 56 | except Exception as e: 57 | logging.error("Failed to save camera states to config: %s", e) 58 | 59 | async def get_camera_state(self, camera_uuid: str, reset: bool = False) -> CameraState: 60 | """Get camera state for the given UUID, creating if it doesn't exist 61 | 62 | Args: 63 | camera_uuid (str): The UUID of the camera. 64 | reset (bool): If True, resets the camera state to its default. 65 | 66 | Returns: 67 | CameraState: The state of the camera. 68 | """ 69 | async with self.lock: 70 | if camera_uuid not in self._states or reset: 71 | self._states[camera_uuid] = CameraState() 72 | self._save_states_to_config() 73 | return self._states[camera_uuid] 74 | 75 | async def update_camera_state(self, camera_uuid: str, 76 | new_states: Dict) -> Optional[CameraState]: 77 | """Updates the state of a specific camera. 78 | 79 | Args: 80 | camera_uuid (str): The UUID of the camera to update. 81 | new_states (Dict): A dictionary containing the state updates. 82 | Example: 83 | { 84 | "state_key": new_value, 85 | ... 86 | } 87 | 88 | Returns: 89 | Optional[CameraState]: The updated camera state, or None if not found. 90 | """ 91 | async with self.lock: 92 | camera_state_ref = self._states.get(camera_uuid) 93 | if not camera_state_ref: 94 | camera_state_ref = CameraState(**new_states) 95 | self._states[camera_uuid] = camera_state_ref 96 | else: 97 | for key, value in new_states.items(): 98 | if hasattr(camera_state_ref, key): 99 | setattr(camera_state_ref, key, value) 100 | else: 101 | logging.warning("Key '%s' not found in camera state for UUID %s.", 102 | key, camera_uuid) 103 | self._save_states_to_config() 104 | return camera_state_ref 105 | 106 | async def update_camera_detection_history(self, camera_uuid: str, 107 | pred: str, time_val: float) -> Optional[CameraState]: 108 | """Updates the detection history for a camera. 109 | 110 | Args: 111 | camera_uuid (str): The UUID of the camera. 112 | pred (str): The prediction (detection) label. 113 | time_val (float): The timestamp of the detection. 114 | 115 | Returns: 116 | Optional[CameraState]: The updated camera state, or None if not found. 117 | """ 118 | async with self.lock: 119 | if camera_uuid not in self._states: 120 | self._states[camera_uuid] = CameraState() 121 | camera_state_ref = self._states.get(camera_uuid) 122 | if camera_state_ref: 123 | camera_state_ref.detection_history.append((time_val, pred)) 124 | max_history = 10000 125 | if len(camera_state_ref.detection_history) > max_history: 126 | camera_state_ref.detection_history = camera_state_ref.detection_history[-max_history:] 127 | if len(camera_state_ref.detection_history) % 100 == 0: 128 | self._save_states_to_config() 129 | return camera_state_ref 130 | return None 131 | 132 | async def get_all_camera_uuids(self) -> list: 133 | """Retrieves a list of all camera UUIDs. 134 | 135 | Returns: 136 | list: A list of all camera UUIDs. 137 | """ 138 | async with self.lock: 139 | return list(self._states.keys()) 140 | 141 | async def remove_camera(self, camera_uuid: str) -> bool: 142 | """ 143 | Removes a camera and its state. 144 | 145 | Args: 146 | camera_uuid (str): The UUID of the camera to remove. 147 | 148 | Returns: 149 | bool: True if the camera was removed, False otherwise. 150 | """ 151 | async with self.lock: 152 | if camera_uuid in self._states: 153 | await self.cleanup_camera_resources(camera_uuid) 154 | del self._states[camera_uuid] 155 | self._save_states_to_config() 156 | logging.info("Successfully removed camera %s.", camera_uuid) 157 | return True 158 | logging.warning("Attempted to remove non-existent camera %s.", camera_uuid) 159 | return False 160 | 161 | async def cleanup_camera_resources(self, camera_uuid: str): 162 | """ 163 | Clean up resources for a specific camera including shared video streams. 164 | 165 | Args: 166 | camera_uuid (str): The UUID of the camera to clean up. 167 | """ 168 | camera_state_ref = self._states.get(camera_uuid) 169 | if camera_state_ref: 170 | camera_state_ref.live_detection_running = False 171 | camera_state_ref.live_detection_task = None 172 | 173 | try: 174 | from .shared_video_stream import get_shared_stream_manager 175 | manager = get_shared_stream_manager() 176 | manager.release_stream(camera_uuid) 177 | except (ImportError, AttributeError) as e: 178 | logging.warning("Error cleaning up shared video stream for camera %s: %s", camera_uuid, e) 179 | 180 | async def cleanup_all_resources(self): 181 | """Clean up all camera resources including shared video streams.""" 182 | try: 183 | from .shared_video_stream import get_shared_stream_manager 184 | manager = get_shared_stream_manager() 185 | manager.cleanup_all() 186 | except (ImportError, AttributeError) as e: 187 | logging.warning("Error cleaning up all shared video streams: %s", e) 188 | 189 | _camera_state_manager = None 190 | 191 | def get_camera_state_manager() -> CameraStateManager: 192 | """Get the global camera state manager instance""" 193 | global _camera_state_manager 194 | if _camera_state_manager is None: 195 | _camera_state_manager = CameraStateManager() 196 | return _camera_state_manager 197 | -------------------------------------------------------------------------------- /printguard/app.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import os 4 | from contextlib import asynccontextmanager 5 | 6 | from fastapi import FastAPI, Request 7 | from fastapi.middleware.cors import CORSMiddleware 8 | from fastapi.responses import RedirectResponse 9 | from fastapi.staticfiles import StaticFiles 10 | from fastapi.templating import Jinja2Templates 11 | from .models import (SiteStartupMode, 12 | TunnelProvider, SavedConfig) 13 | from .routes.alert_routes import router as alert_router 14 | from .routes.detection_routes import router as detection_router 15 | from .routes.notification_routes import router as notification_router 16 | from .routes.sse_routes import router as sse_router 17 | from .routes.setup_routes import router as setup_router 18 | from .routes.index_routes import router as index_router 19 | from .routes.camera_routes import router as camera_router 20 | from .routes.printer_routes import router as printer_router 21 | from .utils.config import (get_ssl_private_key_temporary_path, 22 | SSL_CERT_FILE, get_prototypes_dir, 23 | get_model_path, get_model_options_path, 24 | DEVICE_TYPE, SUCCESS_LABEL, 25 | get_config, update_config, init_config) 26 | from .utils.inference_lib import get_inference_engine 27 | from .utils.cloudflare_utils import (start_cloudflare_tunnel, stop_cloudflare_tunnel) 28 | 29 | @asynccontextmanager 30 | async def lifespan(app_instance: FastAPI): 31 | """ 32 | Lifespan event handler for FastAPI application. 33 | 34 | Initializes the device and model, sets up camera indices, and handles startup modes. 35 | """ 36 | # pylint: disable=C0415 37 | from .utils.setup_utils import startup_mode_requirements_met 38 | startup_mode = startup_mode_requirements_met() 39 | inference_engine = get_inference_engine() 40 | if startup_mode is SiteStartupMode.SETUP: 41 | logging.warning("Starting in setup mode. Detection model and device will not be initialized.") 42 | yield 43 | return 44 | logging.debug("Setting up device...") 45 | app_instance.state.device = inference_engine.setup_device(DEVICE_TYPE) 46 | logging.debug("Using device: %s", app_instance.state.device) 47 | try: 48 | logging.debug("Loading model...") 49 | app_instance.state.model, _ = inference_engine.load_model(get_model_path(), 50 | get_model_options_path(), 51 | app_instance.state.device) 52 | app_instance.state.transform = inference_engine.get_transform() 53 | logging.debug("Model loaded successfully.") 54 | logging.debug("Building prototypes...") 55 | try: 56 | prototypes, class_names, defect_idx = inference_engine.compute_prototypes( 57 | app_instance.state.model, get_prototypes_dir(), app_instance.state.transform, 58 | app_instance.state.device, SUCCESS_LABEL 59 | ) 60 | app_instance.state.prototypes = prototypes 61 | app_instance.state.class_names = class_names 62 | app_instance.state.defect_idx = defect_idx 63 | logging.debug("Prototypes built successfully.") 64 | except NameError: 65 | logging.warning("Skipping prototype building.") 66 | except ValueError as e: 67 | logging.error("Error building prototypes: %s", e) 68 | except RuntimeError as e: 69 | logging.error("Error during startup: %s", e) 70 | app_instance.state.model = None 71 | raise 72 | logging.debug("Camera indices set up successfully.") 73 | yield 74 | logging.debug("Cleaning up resources on shutdown...") 75 | try: 76 | from .utils.camera_state_manager import get_camera_state_manager 77 | manager = get_camera_state_manager() 78 | await manager.cleanup_all_resources() 79 | logging.debug("Cleaned up camera resources successfully.") 80 | except Exception as e: 81 | logging.error("Error during cleanup: %s", e) 82 | 83 | app = FastAPI( 84 | title="PrintGuard", 85 | description="Real-time Defect Detection on Edge-devices", 86 | version="1.0.0", 87 | lifespan=lifespan, 88 | ) 89 | 90 | app.add_middleware( 91 | CORSMiddleware, 92 | allow_origins=["*"], 93 | allow_credentials=True, 94 | allow_methods=["*"], 95 | allow_headers=["*"], 96 | ) 97 | 98 | app.state.model = None 99 | app.state.transform = None 100 | app.state.device = None 101 | app.state.prototypes = None 102 | app.state.class_names = ['success', 'failure'] 103 | app.state.defect_idx = -1 104 | app.state.alerts = {} 105 | app.state.outbound_queue = asyncio.Queue() 106 | config = get_config() or {} 107 | app.state.subscriptions = config.get(SavedConfig.PUSH_SUBSCRIPTIONS, []) 108 | app.state.polling_tasks = {} 109 | 110 | if app.debug: 111 | logging.basicConfig(level=logging.DEBUG) 112 | 113 | base_dir = os.path.dirname(__file__) 114 | static_dir = os.path.join(base_dir, "static") 115 | templates_dir = os.path.join(base_dir, "templates") 116 | app.mount("/static", StaticFiles(directory=static_dir), name="static") 117 | templates = Jinja2Templates(directory=templates_dir) 118 | 119 | app.include_router(detection_router, tags=["detection"]) 120 | app.include_router(alert_router, tags=["alerts"]) 121 | app.include_router(notification_router, tags=["notifications"]) 122 | app.include_router(sse_router, tags=["sse"]) 123 | app.include_router(setup_router, tags=["setup"]) 124 | app.include_router(index_router, tags=["index"]) 125 | app.include_router(camera_router, tags=["camera"]) 126 | app.include_router(printer_router, tags=["printer"]) 127 | 128 | @app.middleware("http") 129 | async def http_redirect_middleware(request: Request, call_next): 130 | """ 131 | Middleware to handle HTTP requests - redirect to setup page unless accessing setup routes. 132 | Only allows setup routes and static files when using HTTP. 133 | """ 134 | if request.url.scheme == "http": 135 | if (request.url.path.startswith("/setup") or 136 | request.url.path.startswith("/static")): 137 | response = await call_next(request) 138 | return response 139 | else: 140 | return RedirectResponse(url="/setup", status_code=307) 141 | response = await call_next(request) 142 | return response 143 | 144 | def run(): 145 | """ 146 | Run the FastAPI application with uvicorn, handling different startup modes. 147 | """ 148 | # pylint: disable=C0415 149 | import uvicorn 150 | from .utils.setup_utils import (startup_mode_requirements_met, 151 | setup_ngrok_tunnel) 152 | init_config() 153 | startup_mode = startup_mode_requirements_met() 154 | app_config = get_config() 155 | site_domain = app_config.get(SavedConfig.SITE_DOMAIN, "") 156 | tunnel_provider = app_config.get(SavedConfig.TUNNEL_PROVIDER, None) 157 | stop_cloudflare_tunnel() 158 | match startup_mode: 159 | case SiteStartupMode.SETUP: 160 | logging.warning("Starting in setup mode. Available at http://localhost:8000/setup") 161 | uvicorn.run(app, host="0.0.0.0", port=8000) 162 | case SiteStartupMode.LOCAL: 163 | logging.warning("Starting in local mode. Available at %s", site_domain) 164 | ssl_private_key_path = get_ssl_private_key_temporary_path() 165 | uvicorn.run(app, 166 | host="0.0.0.0", 167 | port=8000, 168 | ssl_certfile=SSL_CERT_FILE, 169 | ssl_keyfile=ssl_private_key_path) 170 | case SiteStartupMode.TUNNEL: 171 | match tunnel_provider: 172 | case TunnelProvider.NGROK: 173 | logging.warning( 174 | "Starting in tunnel mode with ngrok. Available at %s", 175 | site_domain) 176 | tunnel_setup = setup_ngrok_tunnel(close=False) 177 | if not tunnel_setup: 178 | logging.error("Failed to establish ngrok tunnel. Starting in SETUP mode.") 179 | update_config({SavedConfig.STARTUP_MODE: SiteStartupMode.SETUP}) 180 | run() 181 | else: 182 | uvicorn.run(app, host="0.0.0.0", port=8000) 183 | case TunnelProvider.CLOUDFLARE: 184 | logging.warning("Starting in tunnel mode with Cloudflare.") 185 | if start_cloudflare_tunnel(): 186 | logging.warning("Cloudflare tunnel started. Available at %s", site_domain) 187 | uvicorn.run(app, host="0.0.0.0", port=8000) 188 | else: 189 | logging.error("Failed to start Cloudflare tunnel. Starting in SETUP mode.") 190 | update_config({SavedConfig.STARTUP_MODE: SiteStartupMode.SETUP}) 191 | run() 192 | 193 | if __name__ == "__main__": 194 | run() 195 | -------------------------------------------------------------------------------- /printguard/models.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from enum import Enum 3 | from typing import List, Optional, Dict, Any 4 | from pydantic import BaseModel, field_validator, Field 5 | 6 | class Alert(BaseModel): 7 | id: str 8 | snapshot: bytes 9 | title: str 10 | message: str 11 | timestamp: float 12 | countdown_time: float 13 | camera_uuid: str 14 | has_printer: bool = False 15 | countdown_action: str = "dismiss" 16 | 17 | class AlertAction(str, Enum): 18 | DISMISS = "dismiss" 19 | CANCEL_PRINT = "cancel_print" 20 | PAUSE_PRINT = "pause_print" 21 | 22 | class SSEDataType(str, Enum): 23 | ALERT = "alert" 24 | CAMERA_STATE = "camera_state" 25 | PRINTER_STATE = "printer_state" 26 | 27 | class NotificationAction(BaseModel): 28 | action: str 29 | title: str 30 | icon: Optional[str] = None 31 | 32 | class Notification(BaseModel): 33 | title: str 34 | body: str 35 | image_url: Optional[str] = None 36 | icon_url: Optional[str] = None 37 | badge_url: Optional[str] = None 38 | actions: List[NotificationAction] = [] 39 | 40 | def _get_config_value(key: str): 41 | # pylint: disable=import-outside-toplevel 42 | from .utils.config import (BRIGHTNESS, CONTRAST, 43 | FOCUS, SENSITIVITY, 44 | COUNTDOWN_TIME, COUNTDOWN_ACTION, 45 | DETECTION_VOTING_THRESHOLD, 46 | DETECTION_VOTING_WINDOW) 47 | config_map = { 48 | 'BRIGHTNESS': BRIGHTNESS, 49 | 'CONTRAST': CONTRAST, 50 | 'FOCUS': FOCUS, 51 | 'SENSITIVITY': SENSITIVITY, 52 | 'COUNTDOWN_TIME': COUNTDOWN_TIME, 53 | 'COUNTDOWN_ACTION': COUNTDOWN_ACTION, 54 | 'DETECTION_VOTING_THRESHOLD': DETECTION_VOTING_THRESHOLD, 55 | 'DETECTION_VOTING_WINDOW': DETECTION_VOTING_WINDOW, 56 | } 57 | return config_map[key] 58 | 59 | class FileInfo(BaseModel): 60 | name: Optional[str] = None 61 | origin: Optional[str] = None 62 | size: Optional[int] = None 63 | date: Optional[int] = None 64 | 65 | 66 | class Progress(BaseModel): 67 | completion: Optional[float] = None 68 | filepos: Optional[int] = None 69 | printTime: Optional[int] = None 70 | printTimeLeft: Optional[int] = None 71 | 72 | 73 | class JobInfoResponse(BaseModel): 74 | job: Dict = Field(default_factory=dict) 75 | progress: Optional[Progress] = None 76 | state: str 77 | error: Optional[str] = None 78 | 79 | model_config = { 80 | "extra": "ignore" 81 | } 82 | 83 | 84 | class TemperatureReading(BaseModel): 85 | actual: float 86 | target: Optional[float] 87 | offset: Optional[float] 88 | 89 | 90 | class TemperatureReadings(BaseModel): 91 | temperature: Dict[str, TemperatureReading] 92 | 93 | class PrinterTemperatures(BaseModel): 94 | nozzle_actual: Optional[float] = None 95 | nozzle_target: Optional[float] = None 96 | bed_actual: Optional[float] = None 97 | bed_target: Optional[float] = None 98 | 99 | class PrinterState(BaseModel): 100 | jobInfoResponse: Optional[JobInfoResponse] = None 101 | temperatureReading: Optional[PrinterTemperatures] = None 102 | 103 | class CurrentPayload(BaseModel): 104 | state: dict 105 | job: Any 106 | progress: Progress 107 | temps: Optional[list] = Field(None, alias="temps") 108 | 109 | class PrinterType(str, Enum): 110 | OCTOPRINT = "octoprint" 111 | 112 | class PrinterConfig(BaseModel): 113 | name: str 114 | printer_type: PrinterType 115 | camera_uuid: str 116 | base_url: str 117 | api_key: str 118 | 119 | class PrinterConfigRequest(BaseModel): 120 | name: str 121 | printer_type: PrinterType 122 | camera_uuid: str 123 | base_url: str 124 | api_key: str 125 | 126 | class CameraState(BaseModel): 127 | nickname: str 128 | source: str 129 | lock: asyncio.Lock = Field(default_factory=asyncio.Lock, exclude=True) 130 | current_alert_id: Optional[str] = None 131 | detection_history: List[tuple] = [] 132 | live_detection_running: bool = False 133 | live_detection_task: Optional[str] = None 134 | last_result: Optional[str] = None 135 | last_time: Optional[float] = None 136 | start_time: Optional[float] = None 137 | error: Optional[str] = None 138 | brightness: float = None 139 | contrast: float = None 140 | focus: float = None 141 | sensitivity: float = None 142 | countdown_time: float = None 143 | countdown_action: str = None 144 | majority_vote_threshold: int = None 145 | majority_vote_window: int = None 146 | printer_id: Optional[str] = None 147 | printer_config: Optional[Dict] = None 148 | 149 | def __init__(self, **data): 150 | if 'brightness' not in data: 151 | data['brightness'] = _get_config_value('BRIGHTNESS') 152 | if 'contrast' not in data: 153 | data['contrast'] = _get_config_value('CONTRAST') 154 | if 'focus' not in data: 155 | data['focus'] = _get_config_value('FOCUS') 156 | if 'sensitivity' not in data: 157 | data['sensitivity'] = _get_config_value('SENSITIVITY') 158 | if 'countdown_time' not in data: 159 | data['countdown_time'] = _get_config_value('COUNTDOWN_TIME') 160 | if 'countdown_action' not in data: 161 | data['countdown_action'] = _get_config_value('COUNTDOWN_ACTION') 162 | if 'majority_vote_threshold' not in data: 163 | data['majority_vote_threshold'] = _get_config_value('DETECTION_VOTING_THRESHOLD') 164 | if 'majority_vote_window' not in data: 165 | data['majority_vote_window'] = _get_config_value('DETECTION_VOTING_WINDOW') 166 | super().__init__(**data) 167 | model_config = { 168 | "arbitrary_types_allowed": True 169 | } 170 | 171 | class VapidSettings(BaseModel): 172 | public_key: str 173 | private_key: str 174 | subject: str 175 | base_url: str 176 | 177 | class SiteStartupMode(str, Enum): 178 | SETUP = "setup" 179 | LOCAL = "local" 180 | TUNNEL = "tunnel" 181 | 182 | class TunnelProvider(str, Enum): 183 | NGROK = "ngrok" 184 | CLOUDFLARE = "cloudflare" 185 | 186 | class OperatingSystem(str, Enum): 187 | MACOS = "macos" 188 | WINDOWS = "windows" 189 | LINUX = "linux" 190 | 191 | class TunnelSettings(BaseModel): 192 | provider: TunnelProvider 193 | token: str 194 | domain: str = "" 195 | email: Optional[str] = None 196 | 197 | @field_validator('domain') 198 | @classmethod 199 | def validate_domain_for_ngrok(cls, v, info): 200 | if info.data.get('provider') == TunnelProvider.NGROK and not v: 201 | raise ValueError('Domain is required for ngrok provider') 202 | return v 203 | 204 | class SetupCompletion(BaseModel): 205 | startup_mode: SiteStartupMode 206 | tunnel_provider: Optional[TunnelProvider] = None 207 | 208 | class SavedKey(str, Enum): 209 | VAPID_PRIVATE_KEY = "vapid_private_key" 210 | SSL_PRIVATE_KEY = "ssl_private_key" 211 | TUNNEL_API_KEY = "tunnel_api_key" 212 | TUNNEL_TOKEN = "tunnel_token" 213 | 214 | class SavedConfig(str, Enum): 215 | VERSION = "version" 216 | VAPID_SUBJECT = "vapid_subject" 217 | VAPID_PUBLIC_KEY = "vapid_public_key" 218 | STARTUP_MODE = "startup_mode" 219 | SITE_DOMAIN = "site_domain" 220 | TUNNEL_PROVIDER = "tunnel_provider" 221 | CLOUDFLARE_EMAIL = "cloudflare_email" 222 | CLOUDFLARE_TEAM_NAME = "cloudflare_team_name" 223 | USER_OPERATING_SYSTEM = "user_operating_system" 224 | STREAM_OPTIMIZE_FOR_TUNNEL = "stream_optimize_for_tunnel" 225 | STREAM_MAX_FPS = "stream_max_fps" 226 | STREAM_TUNNEL_FPS = "stream_tunnel_fps" 227 | STREAM_JPEG_QUALITY = "stream_jpeg_quality" 228 | STREAM_MAX_WIDTH = "stream_max_width" 229 | DETECTION_INTERVAL_MS = "detection_interval_ms" 230 | PRINTER_STAT_POLLING_RATE_MS = "printer_stat_polling_rate_ms" 231 | MIN_SSE_DISPATCH_DELAY_MS = "min_sse_dispatch_delay_ms" 232 | PUSH_SUBSCRIPTIONS = "push_subscriptions" 233 | CAMERA_STATES = "camera_states" 234 | 235 | class CloudflareTunnelConfig(BaseModel): 236 | account_id: str 237 | zone_id: str 238 | subdomain: str 239 | 240 | class CloudflareDownloadConfig(BaseModel): 241 | operating_system: OperatingSystem 242 | 243 | class WarpDeviceConfig(BaseModel): 244 | device_id: Optional[str] = None 245 | user_email: Optional[str] = None 246 | 247 | class CloudflareCommandSet(BaseModel): 248 | operating_system: OperatingSystem 249 | install_command: str 250 | enable_command: str = "" 251 | start_command: str 252 | stop_command: str 253 | restart_command: str = "" 254 | setup_sequence: List[str] 255 | 256 | class WarpDeviceEnrollmentRule(BaseModel): 257 | name: str 258 | precedence: int = 0 259 | require: List[str] = [] 260 | include: List[str] = [] 261 | 262 | class FeedSettings(BaseModel): 263 | stream_max_fps: int 264 | stream_tunnel_fps: int 265 | stream_jpeg_quality: int 266 | stream_max_width: int 267 | detections_per_second: int 268 | detection_interval_ms: int 269 | printer_stat_polling_rate_ms: int 270 | min_sse_dispatch_delay_ms: int 271 | 272 | class PollingTask(BaseModel): 273 | task: Optional[asyncio.Task] = None 274 | stop_event: Optional[asyncio.Event] = None 275 | model_config = { 276 | "arbitrary_types_allowed": True 277 | } 278 | -------------------------------------------------------------------------------- /printguard/static/css/universal.css: -------------------------------------------------------------------------------- 1 | :root { 2 | --background-primary: #f4f7f6; 3 | --background-secondary: #ffffff; 4 | --background-tertiary: #f8f9fa; 5 | --background-accent: #f9f9f9; 6 | 7 | --border-primary: #222; 8 | --border-secondary: #ddd; 9 | --border-tertiary: #ccc; 10 | --border-light: #eaeaea; 11 | 12 | --text-primary: #222; 13 | --text-secondary: #666; 14 | --text-muted: #888; 15 | --text-inverted: #fff; 16 | 17 | --color-white: #ffffff; 18 | --color-danger: #ff4136; 19 | --color-danger-dark: #dc3545; 20 | --color-success: #2ecc40; 21 | 22 | --success-color: #2ecc40; 23 | --error-color: #ff4136; 24 | --warning-color: #ffa500; 25 | --info-color: #3b82f6; 26 | 27 | --success-bg: rgba(46, 204, 64, 0.1); 28 | --success-border: rgba(46, 204, 64, 0.3); 29 | --error-bg: rgba(255, 65, 54, 0.1); 30 | --error-border: rgba(255, 65, 54, 0.3); 31 | --warning-bg: rgba(255, 165, 0, 0.1); 32 | --warning-border: rgba(255, 165, 0, 0.3); 33 | --info-bg: rgba(59, 130, 246, 0.1); 34 | --info-border: rgba(59, 130, 246, 0.3); 35 | 36 | --spacing-xs: 5px; 37 | --spacing-sm: 10px; 38 | --spacing-md: 15px; 39 | --spacing-lg: 20px; 40 | --spacing-xl: 30px; 41 | --spacing-2xl: 40px; 42 | --spacing-3xl: 50px; 43 | --spacing-4xl: 60px; 44 | 45 | --radius-sm: 3px; 46 | --radius-md: 6px; 47 | --radius-lg: 8px; 48 | --radius-xl: 12px; 49 | 50 | --font-family: "Nanum Gothic Coding", monospace; 51 | --font-size-sm: 0.8rem; 52 | --font-size-base: 1rem; 53 | --font-size-lg: 1.1rem; 54 | --font-size-xl: 1.8rem; 55 | 56 | --shadow-sm: 0 3px 8px rgba(0, 0, 0, 0.1); 57 | --shadow-md: 0 4px 12px rgba(0, 0, 0, 0.15); 58 | --shadow-lg: 0 8px 25px rgba(0, 0, 0, 0.2); 59 | 60 | --transition-fast: 0.2s ease; 61 | --transition-normal: 0.3s ease; 62 | --transition-slow: 0.4s ease; 63 | } 64 | 65 | body, html { 66 | margin: 0; 67 | font-family: var(--font-family); 68 | font-weight: 400; 69 | font-style: normal; 70 | background-color: var(--background-primary); 71 | color: var(--text-primary); 72 | } 73 | 74 | .universal-container { 75 | display: flex; 76 | flex-direction: column; 77 | align-items: center; 78 | justify-content: center; 79 | text-align: center; 80 | padding: var(--spacing-lg); 81 | width: 100%; 82 | min-height: 100vh; 83 | box-sizing: border-box; 84 | } 85 | 86 | .universal-card { 87 | max-width: 600px; 88 | margin: 0 auto; 89 | padding: var(--spacing-xl); 90 | background: var(--background-secondary); 91 | border-radius: var(--radius-xl); 92 | box-shadow: var(--shadow-sm); 93 | border: 1px solid var(--border-light); 94 | } 95 | 96 | .universal-section { 97 | background: var(--background-tertiary); 98 | padding: var(--spacing-lg); 99 | border-radius: var(--radius-lg); 100 | margin: var(--spacing-lg) 0; 101 | border: 1px solid var(--border-secondary); 102 | } 103 | 104 | .btn { 105 | padding: var(--spacing-md) var(--spacing-xl); 106 | border: 2px dotted var(--border-primary); 107 | cursor: pointer; 108 | background: transparent; 109 | transition: background var(--transition-fast), color var(--transition-fast); 110 | font-family: inherit; 111 | font-size: var(--font-size-base); 112 | text-decoration: none; 113 | display: inline-block; 114 | text-align: center; 115 | } 116 | 117 | .btn:hover { 118 | background: var(--text-primary); 119 | color: var(--text-inverted); 120 | } 121 | 122 | .btn.primary { 123 | background: var(--text-primary); 124 | color: var(--text-inverted); 125 | } 126 | 127 | .btn.primary:hover { 128 | background: #444; 129 | } 130 | 131 | .btn.secondary { 132 | border-color: var(--border-secondary); 133 | color: var(--text-secondary); 134 | } 135 | 136 | .btn.secondary:hover { 137 | background: var(--text-secondary); 138 | color: var(--text-inverted); 139 | } 140 | 141 | .btn.success { 142 | background: var(--success-color); 143 | color: var(--text-inverted); 144 | border-color: var(--success-color); 145 | } 146 | 147 | .btn.success:hover { 148 | background: #27a73a; 149 | } 150 | 151 | .btn.info { 152 | background: var(--info-color); 153 | color: var(--text-inverted); 154 | border-color: var(--info-color); 155 | } 156 | 157 | .btn.info:hover { 158 | background: #2563eb; 159 | } 160 | 161 | .btn:disabled { 162 | opacity: 0.6; 163 | cursor: not-allowed; 164 | background: var(--background-tertiary); 165 | color: var(--text-muted); 166 | } 167 | 168 | .btn:disabled:hover { 169 | background: var(--background-tertiary); 170 | color: var(--text-muted); 171 | } 172 | 173 | .form-group { 174 | margin-bottom: var(--spacing-lg); 175 | width: 100%; 176 | } 177 | 178 | .form-group label { 179 | display: block; 180 | margin-bottom: var(--spacing-xs); 181 | font-weight: bold; 182 | color: var(--text-primary); 183 | } 184 | 185 | .form-group input, 186 | .form-group select, 187 | .form-group textarea { 188 | width: 100%; 189 | padding: var(--spacing-sm); 190 | border: 2px dotted var(--border-primary); 191 | background: transparent; 192 | font-family: inherit; 193 | font-size: var(--font-size-base); 194 | box-sizing: border-box; 195 | } 196 | 197 | .form-group input:focus, 198 | .form-group select:focus, 199 | .form-group textarea:focus { 200 | outline: none; 201 | border-color: #444; 202 | } 203 | 204 | .message { 205 | padding: var(--spacing-md); 206 | border-radius: var(--radius-lg); 207 | margin: var(--spacing-md) 0; 208 | border: 1px solid; 209 | } 210 | 211 | .message.success { 212 | background: var(--success-bg); 213 | border-color: var(--success-border); 214 | color: var(--success-color); 215 | } 216 | 217 | .message.error { 218 | background: var(--error-bg); 219 | border-color: var(--error-border); 220 | color: var(--error-color); 221 | } 222 | 223 | .message.warning { 224 | background: var(--warning-bg); 225 | border-color: var(--warning-border); 226 | color: var(--warning-color); 227 | } 228 | 229 | .message.info { 230 | background: var(--info-bg); 231 | border-color: var(--info-border); 232 | color: var(--info-color); 233 | } 234 | 235 | .status-indicator { 236 | display: flex; 237 | align-items: center; 238 | justify-content: center; 239 | gap: var(--spacing-xs); 240 | margin: var(--spacing-md) 0; 241 | } 242 | 243 | .status-icon { 244 | width: 20px; 245 | height: 20px; 246 | border-radius: 50%; 247 | } 248 | 249 | .status-pending { background-color: var(--warning-color); } 250 | .status-success { background-color: var(--success-color); } 251 | .status-error { background-color: var(--error-color); } 252 | .status-info { background-color: var(--info-color); } 253 | 254 | .text-center { text-align: center; } 255 | .text-left { text-align: left; } 256 | .text-right { text-align: right; } 257 | 258 | .mb-xs { margin-bottom: var(--spacing-xs); } 259 | .mb-sm { margin-bottom: var(--spacing-sm); } 260 | .mb-md { margin-bottom: var(--spacing-md); } 261 | .mb-lg { margin-bottom: var(--spacing-lg); } 262 | .mb-xl { margin-bottom: var(--spacing-xl); } 263 | 264 | .mt-xs { margin-top: var(--spacing-xs); } 265 | .mt-sm { margin-top: var(--spacing-sm); } 266 | .mt-md { margin-top: var(--spacing-md); } 267 | .mt-lg { margin-top: var(--spacing-lg); } 268 | .mt-xl { margin-top: var(--spacing-xl); } 269 | 270 | .p-xs { padding: var(--spacing-xs); } 271 | .p-sm { padding: var(--spacing-sm); } 272 | .p-md { padding: var(--spacing-md); } 273 | .p-lg { padding: var(--spacing-lg); } 274 | .p-xl { padding: var(--spacing-xl); } 275 | 276 | .hover-lift { 277 | transition: transform var(--transition-fast), box-shadow var(--transition-fast); 278 | } 279 | 280 | .hover-lift:hover { 281 | transform: translateY(-2px); 282 | box-shadow: var(--shadow-md); 283 | } 284 | 285 | @media (max-width: 768px) { 286 | .universal-card { 287 | margin: var(--spacing-sm); 288 | padding: var(--spacing-lg); 289 | } 290 | 291 | .btn { 292 | width: 100%; 293 | margin-bottom: var(--spacing-sm); 294 | } 295 | } 296 | 297 | .code-block { 298 | background: var(--background-tertiary); 299 | padding: var(--spacing-sm); 300 | border-radius: var(--radius-md); 301 | font-family: monospace; 302 | font-size: var(--font-size-sm); 303 | word-break: break-all; 304 | border: 1px solid var(--border-secondary); 305 | } 306 | 307 | .info-badge { 308 | display: inline-block; 309 | padding: var(--spacing-xs) var(--spacing-sm); 310 | background: var(--info-bg); 311 | color: var(--info-color); 312 | border-radius: var(--radius-sm); 313 | font-size: var(--font-size-sm); 314 | border: 1px solid var(--info-border); 315 | } 316 | 317 | .divider { 318 | margin: var(--spacing-xl) 0; 319 | padding-top: var(--spacing-md); 320 | border-top: 1px solid var(--border-secondary); 321 | } 322 | 323 | @keyframes fadeIn { 324 | from { opacity: 0; } 325 | to { opacity: 1; } 326 | } 327 | 328 | @keyframes slideIn { 329 | from { 330 | opacity: 0; 331 | transform: translateY(20px); 332 | } 333 | to { 334 | opacity: 1; 335 | transform: translateY(0); 336 | } 337 | } 338 | 339 | .fade-in { 340 | animation: fadeIn var(--transition-normal) ease-in-out; 341 | } 342 | 343 | .slide-in { 344 | animation: slideIn var(--transition-normal) ease-out; 345 | } 346 | -------------------------------------------------------------------------------- /printguard/utils/inference_engine.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | from typing import Any, Dict, List, Tuple 4 | from enum import Enum 5 | 6 | import cv2 7 | 8 | class InferenceBackend(Enum): 9 | """Supported inference backends.""" 10 | PYTORCH = "pytorch" 11 | ONNXRUNTIME = "onnxruntime" 12 | 13 | 14 | class InferenceEngine(ABC): 15 | """Abstract base class for inference engines.""" 16 | 17 | @abstractmethod 18 | def load_model(self, model_path: str, options_path: str, device: str) -> Tuple[Any, List[int]]: 19 | """Load a model and its configuration. 20 | 21 | Args: 22 | model_path: Path to the model file 23 | options_path: Path to the model options/config file 24 | device: Device to load the model on 25 | 26 | Returns: 27 | Tuple of (model, input_dimensions) 28 | """ 29 | 30 | @abstractmethod 31 | def get_transform(self) -> Any: 32 | """Get the image preprocessing transform pipeline. 33 | 34 | Returns: 35 | Transform pipeline for preprocessing images 36 | """ 37 | 38 | @abstractmethod 39 | def compute_prototypes(self, model: Any, support_dir: str, transform: Any, 40 | device: str, success_label: str = "success", 41 | use_cache: bool = True) -> Tuple[Any, List[str], int]: 42 | """Compute class prototypes from support images. 43 | 44 | Args: 45 | model: The loaded model 46 | support_dir: Directory containing class subdirectories with support images 47 | transform: Image preprocessing transform 48 | device: Device to run computations on 49 | success_label: Label for the non-defective class 50 | use_cache: Whether to use cached prototypes if available 51 | 52 | Returns: 53 | Tuple of (prototypes, class_names, defect_idx) 54 | """ 55 | 56 | @abstractmethod 57 | def predict_batch(self, model: Any, batch_tensors: Any, prototypes: Any, 58 | defect_idx: int, sensitivity: float, device: str) -> List[int]: 59 | """Predict classes for a batch of image tensors. 60 | 61 | Args: 62 | model: The loaded model 63 | batch_tensors: Batch of preprocessed image tensors 64 | prototypes: Class prototype tensors 65 | defect_idx: Index of the defect class for sensitivity adjustment 66 | sensitivity: Sensitivity multiplier for defect detection 67 | device: Device to run computations on 68 | 69 | Returns: 70 | List of predicted class indices 71 | """ 72 | 73 | @abstractmethod 74 | def setup_device(self, requested_device: str) -> str: 75 | """Set up the compute device based on availability and request. 76 | 77 | Args: 78 | requested_device: Requested device ('cuda', 'mps', or 'cpu') 79 | 80 | Returns: 81 | The actual device string to use 82 | """ 83 | 84 | @abstractmethod 85 | def clear_prototype_cache(self, support_dir: str) -> None: 86 | """Clear the prototype cache for a support directory. 87 | 88 | Args: 89 | support_dir: Path to the support directory whose cache should be cleared 90 | """ 91 | 92 | class UniversalInferenceEngine: 93 | """Universal inference engine that delegates to backend-specific engines.""" 94 | def __init__(self, backend: InferenceBackend = InferenceBackend.PYTORCH): 95 | """Initialize the universal inference engine. 96 | 97 | Args: 98 | backend: The inference backend to use 99 | """ 100 | self.backend = backend 101 | self._engine = self._create_engine(backend) 102 | 103 | def _create_engine(self, backend: InferenceBackend) -> InferenceEngine: 104 | """Create the appropriate backend engine. 105 | 106 | Args: 107 | backend: The inference backend to create 108 | 109 | Returns: 110 | The backend-specific inference engine 111 | """ 112 | # pylint: disable=import-outside-toplevel 113 | if backend == InferenceBackend.PYTORCH: 114 | from .backends.pytorch_engine import PyTorchInferenceEngine 115 | return PyTorchInferenceEngine() 116 | elif backend == InferenceBackend.ONNXRUNTIME: 117 | from .backends.onnxruntime_engine import ONNXRuntimeInferenceEngine 118 | return ONNXRuntimeInferenceEngine() 119 | else: 120 | raise ValueError(f"Unsupported backend: {backend}") 121 | 122 | def load_model(self, model_path: str, options_path: str, device: str) -> Tuple[Any, List[int]]: 123 | """Load a model and its configuration.""" 124 | return self._engine.load_model(model_path, options_path, device) 125 | 126 | def get_transform(self) -> Any: 127 | """Get the image preprocessing transform pipeline.""" 128 | return self._engine.get_transform() 129 | 130 | def compute_prototypes(self, model: Any, support_dir: str, transform: Any, 131 | device: str, success_label: str = "success", 132 | use_cache: bool = True) -> Tuple[Any, List[str], int]: 133 | """Compute class prototypes from support images.""" 134 | return self._engine.compute_prototypes( 135 | model, support_dir, transform, device, success_label, use_cache 136 | ) 137 | 138 | def predict_batch(self, model: Any, batch_tensors: Any, prototypes: Any, 139 | defect_idx: int, sensitivity: float, device: str) -> List[int]: 140 | """Predict classes for a batch of image tensors.""" 141 | return self._engine.predict_batch( 142 | model, batch_tensors, prototypes, defect_idx, sensitivity, device 143 | ) 144 | 145 | def setup_device(self, requested_device: str) -> str: 146 | """Set up the compute device based on availability and request.""" 147 | return self._engine.setup_device(requested_device) 148 | 149 | def clear_prototype_cache(self, support_dir: str) -> None: 150 | """Clear the prototype cache for a support directory.""" 151 | self._engine.clear_prototype_cache(support_dir) 152 | 153 | def draw_label(self, frame: Any, label: str, color: Tuple[int, int, int], 154 | success_label: str = "success") -> Any: 155 | """Draw a detection label on an image frame. 156 | 157 | This is a common utility function that doesn't depend on the backend. 158 | 159 | Args: 160 | frame: The image frame to draw on 161 | label: The prediction label to display 162 | color: RGB color tuple for the label background 163 | success_label: Label considered as "success" (non-defective) 164 | 165 | Returns: 166 | The frame with the label drawn on it 167 | """ 168 | # pylint: disable=E1101 169 | text = "non-defective" if label == success_label else "defect" 170 | font = cv2.FONT_HERSHEY_SIMPLEX 171 | font_scale = 2 172 | thickness = 3 173 | try: 174 | text_size, _ = cv2.getTextSize(text, font, font_scale, thickness) 175 | text_w, text_h = text_size 176 | h, w, _ = frame.shape 177 | rect_start = (w - text_w - 40, h - text_h - 40) 178 | rect_end = (w - 20, h - 20) 179 | text_pos = (w - text_w - 30, h - 30) 180 | cv2.rectangle(frame, rect_start, rect_end, color, -1) 181 | cv2.putText(frame, text, text_pos, font, font_scale, 182 | (255, 255, 255), thickness, cv2.LINE_AA) 183 | except Exception as e: 184 | logging.error("Error drawing label: %s. Frame shape: %s, Label: %s", 185 | e, frame.shape, label) 186 | return frame 187 | 188 | def get_backend_info(self) -> Dict[str, Any]: 189 | """Get information about the current backend. 190 | 191 | Returns: 192 | Dictionary containing backend information 193 | """ 194 | return { 195 | "backend": self.backend.value, 196 | "engine_class": self._engine.__class__.__name__, 197 | "available_devices": self._get_available_devices() 198 | } 199 | 200 | def _get_available_devices(self) -> List[str]: 201 | """Get list of available devices for the current backend.""" 202 | devices = ["cpu"] 203 | if self.backend == InferenceBackend.PYTORCH: 204 | # pylint: disable=import-outside-toplevel 205 | import torch 206 | if torch.cuda.is_available(): 207 | devices.append("cuda") 208 | if torch.backends.mps.is_available(): 209 | devices.append("mps") 210 | elif self.backend == InferenceBackend.ONNXRUNTIME: 211 | # pylint: disable=import-outside-toplevel 212 | try: 213 | import onnxruntime as ort 214 | available_providers = ort.get_available_providers() 215 | if 'CUDAExecutionProvider' in available_providers: 216 | devices.append("cuda") 217 | if 'CoreMLExecutionProvider' in available_providers: 218 | devices.append("mps") 219 | except ImportError: 220 | pass 221 | return devices 222 | -------------------------------------------------------------------------------- /printguard/utils/model_downloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Optional, Dict, Any 4 | from pathlib import Path 5 | 6 | from huggingface_hub import hf_hub_download 7 | 8 | from .inference_lib import _detect_backend, InferenceBackend 9 | 10 | class ModelDownloader: 11 | """Downloads models from Hugging Face Hub based on detected backend.""" 12 | def __init__(self, model_repo: str = "oliverbravery/printguard"): 13 | """Initialize the model downloader. 14 | 15 | Args: 16 | model_repo: Hugging Face repository containing models 17 | """ 18 | self.model_repo = model_repo 19 | self.base_dir = Path(__file__).parent.parent / "model" 20 | self.base_dir.mkdir(exist_ok=True) 21 | self.backend_files = { 22 | InferenceBackend.PYTORCH: { 23 | "model": "model.pt", 24 | "options": "opt.json", 25 | "prototypes": "prototypes.pkl" 26 | }, 27 | InferenceBackend.ONNXRUNTIME: { 28 | "model": "model.onnx", 29 | "options": "opt.json", 30 | "prototypes": "prototypes.pkl" 31 | } 32 | } 33 | 34 | def get_model_path(self, backend: Optional[InferenceBackend] = None) -> str: 35 | """Get the local path to the model file for the given backend. 36 | 37 | Args: 38 | backend: Backend to get model path for (auto-detected if None) 39 | 40 | Returns: 41 | Path to the model file 42 | """ 43 | if backend is None: 44 | backend = _detect_backend() 45 | model_file = self.backend_files[backend]["model"] 46 | return str(self.base_dir / model_file) 47 | 48 | def get_options_path(self) -> str: 49 | """Get the local path to the model options file. 50 | 51 | Returns: 52 | Path to the options JSON file 53 | """ 54 | return str(self.base_dir / "opt.json") 55 | 56 | def get_prototypes_path(self) -> str: 57 | """Get the local path to the prototypes cache directory. 58 | 59 | Returns: 60 | Path to the prototypes directory 61 | """ 62 | return str(self.base_dir / "prototypes") 63 | 64 | def get_prototypes_cache_file(self) -> str: 65 | """Get the local path to the downloaded prototypes cache file. 66 | 67 | Returns: 68 | Path to the prototypes.pkl file 69 | """ 70 | return str(self.base_dir / "prototypes" / "cache" / "prototypes.pkl") 71 | 72 | def _is_file_cached(self, file_path: str) -> bool: 73 | """Check if a file is already cached locally. 74 | 75 | Args: 76 | file_path: Path to the file to check 77 | 78 | Returns: 79 | True if file exists and is non-empty 80 | """ 81 | return os.path.exists(file_path) and os.path.getsize(file_path) > 0 82 | 83 | def _download_file(self, filename: str, local_path: str) -> bool: 84 | """Download a file from Hugging Face Hub. 85 | 86 | Args: 87 | filename: Name of file in the repository 88 | local_path: Local path to save the file 89 | 90 | Returns: 91 | True if download was successful 92 | """ 93 | try: 94 | logging.info("Downloading %s from %s", filename, self.model_repo) 95 | downloaded_path = hf_hub_download( 96 | repo_id=self.model_repo, 97 | filename=filename, 98 | local_dir=self.base_dir, 99 | local_dir_use_symlinks=False 100 | ) 101 | if downloaded_path != local_path: 102 | os.rename(downloaded_path, local_path) 103 | logging.info("Successfully downloaded %s to %s", filename, local_path) 104 | return True 105 | except (OSError, ValueError, RuntimeError) as e: 106 | logging.error("Failed to download %s: %s", filename, e) 107 | return False 108 | 109 | def download_model(self, 110 | backend: Optional[InferenceBackend] = None, 111 | force: bool = False) -> bool: 112 | """Download the model file for the specified backend. 113 | 114 | Args: 115 | backend: Backend to download model for (auto-detected if None) 116 | force: Force download even if file exists 117 | 118 | Returns: 119 | True if model is available (cached or downloaded) 120 | """ 121 | if backend is None: 122 | backend = _detect_backend() 123 | model_file = self.backend_files[backend]["model"] 124 | local_path = self.get_model_path(backend) 125 | if not force and self._is_file_cached(local_path): 126 | logging.info("Model %s already cached at %s", model_file, local_path) 127 | return True 128 | return self._download_file(model_file, local_path) 129 | 130 | def download_options(self, force: bool = False) -> bool: 131 | """Download the model options file. 132 | 133 | Args: 134 | force: Force download even if file exists 135 | 136 | Returns: 137 | True if options file is available (cached or downloaded) 138 | """ 139 | local_path = self.get_options_path() 140 | if not force and self._is_file_cached(local_path): 141 | logging.info("Options file already cached at %s", local_path) 142 | return True 143 | return self._download_file("opt.json", local_path) 144 | 145 | def download_prototypes(self, force: bool = False) -> bool: 146 | """Download the cached prototypes file. 147 | 148 | Args: 149 | force: Force download even if file exists 150 | 151 | Returns: 152 | True if prototypes are available (cached or downloaded) 153 | """ 154 | prototypes_dir = Path(self.get_prototypes_path()) 155 | prototypes_dir.mkdir(parents=True, exist_ok=True) 156 | cache_dir = prototypes_dir / "cache" 157 | cache_dir.mkdir(exist_ok=True) 158 | local_path = self.get_prototypes_cache_file() 159 | if not force and self._is_file_cached(local_path): 160 | logging.info("Prototypes already cached at %s", local_path) 161 | return True 162 | return self._download_file("prototypes.pkl", local_path) 163 | 164 | def download_all(self, 165 | backend: Optional[InferenceBackend] = None, 166 | force: bool = False) -> bool: 167 | """Download all required files for the specified backend. 168 | 169 | Args: 170 | backend: Backend to download files for (auto-detected if None) 171 | force: Force download even if files exist 172 | 173 | Returns: 174 | True if all files are available 175 | """ 176 | if backend is None: 177 | backend = _detect_backend() 178 | logging.info("Downloading all model files for %s backend", backend.value) 179 | success = True 180 | success &= self.download_model(backend, force) 181 | success &= self.download_options(force) 182 | success &= self.download_prototypes(force) 183 | if success: 184 | logging.info( 185 | "All model files successfully downloaded/cached for %s backend", 186 | backend.value) 187 | else: 188 | logging.error( 189 | "Failed to download some model files for %s backend", 190 | backend.value) 191 | return success 192 | 193 | def get_backend_info(self) -> Dict[str, Any]: 194 | """Get information about the detected backend and file availability. 195 | 196 | Returns: 197 | Dictionary with backend info and file status 198 | """ 199 | backend = _detect_backend() 200 | info = { 201 | "detected_backend": backend.value, 202 | "model_repo": self.model_repo, 203 | "files": {} 204 | } 205 | model_path = self.get_model_path(backend) 206 | info["files"]["model"] = { 207 | "path": model_path, 208 | "exists": self._is_file_cached(model_path), 209 | "filename": self.backend_files[backend]["model"] 210 | } 211 | options_path = self.get_options_path() 212 | info["files"]["options"] = { 213 | "path": options_path, 214 | "exists": self._is_file_cached(options_path), 215 | "filename": "opt.json" 216 | } 217 | prototypes_file = self.get_prototypes_cache_file() 218 | info["files"]["prototypes"] = { 219 | "path": prototypes_file, 220 | "exists": self._is_file_cached(prototypes_file), 221 | "filename": "prototypes.pkl" 222 | } 223 | return info 224 | 225 | _model_downloader: Optional[ModelDownloader] = None 226 | 227 | def get_model_downloader() -> ModelDownloader: 228 | """Get the global model downloader instance.""" 229 | # pylint: disable=global-statement 230 | global _model_downloader 231 | if _model_downloader is None: 232 | _model_downloader = ModelDownloader() 233 | return _model_downloader 234 | 235 | def ensure_model_files(backend: Optional[InferenceBackend] = None) -> bool: 236 | """Ensure all required model files are available for the specified backend. 237 | 238 | Args: 239 | backend: Backend to ensure files for (auto-detected if None) 240 | 241 | Returns: 242 | True if all files are available 243 | """ 244 | downloader = get_model_downloader() 245 | return downloader.download_all(backend) 246 | -------------------------------------------------------------------------------- /printguard/static/js/sse.js: -------------------------------------------------------------------------------- 1 | const evtSource = new EventSource('/sse'); 2 | const notificationPopup = document.getElementById('notificationPopup'); 3 | const notificationMessage = document.getElementById('notificationMessage'); 4 | const notificationImage = document.getElementById('notificationImage'); 5 | const notificationCountdownTimer = document.getElementById('notificationCountdownTimer'); 6 | const dismissNotificationBtn = document.getElementById('dismissNotificationBtn'); 7 | const cancelPrintBtn = document.getElementById('cancelPrintBtn'); 8 | const pausePrintBtn = document.getElementById('pausePrintBtn'); 9 | 10 | let currentAlertId = null; 11 | 12 | document.addEventListener('DOMContentLoaded', loadPendingAlerts); 13 | 14 | function getLocalActiveAlerts() { 15 | try { 16 | return JSON.parse(localStorage.getItem('activeAlerts')) || {}; 17 | } catch (e) { 18 | console.error("Error parsing activeAlerts from localStorage:", e); 19 | return {}; 20 | } 21 | } 22 | 23 | function getRemoteActiveAlerts() { 24 | return fetch('/alert/active', { 25 | method: 'GET', 26 | }) 27 | .then(response => response.json()) 28 | .then(data => data.active_alerts || []) 29 | .catch(error => { 30 | console.error("Error fetching remote active alerts:", error); 31 | return []; 32 | }); 33 | } 34 | 35 | function saveActiveAlert(alert) { 36 | const activeAlerts = getLocalActiveAlerts(); 37 | const expirationTime = Date.now() + (alert.countdown_time || 10) * 1000; 38 | activeAlerts[alert.id] = { 39 | data: alert, 40 | expirationTime: expirationTime 41 | }; 42 | localStorage.setItem('activeAlerts', JSON.stringify(activeAlerts)); 43 | } 44 | 45 | function removeActiveAlert(alertId) { 46 | const activeAlerts = getLocalActiveAlerts(); 47 | if (activeAlerts[alertId]) { 48 | delete activeAlerts[alertId]; 49 | localStorage.setItem('activeAlerts', JSON.stringify(activeAlerts)); 50 | } 51 | } 52 | 53 | async function loadPendingAlerts() { 54 | const activeAlerts = getLocalActiveAlerts(); 55 | const now = Date.now(); 56 | const remoteAlerts = await getRemoteActiveAlerts(); 57 | const remoteAlertIds = remoteAlerts.map(alert => alert.id); 58 | 59 | Object.keys(activeAlerts).forEach(alertId => { 60 | if (activeAlerts[alertId].expirationTime < now || !remoteAlertIds.includes(alertId)) { 61 | delete activeAlerts[alertId]; 62 | } 63 | }); 64 | 65 | remoteAlerts.forEach(remoteAlert => { 66 | if (!activeAlerts[remoteAlert.id]) { 67 | const alert_start_time = remoteAlert.timestamp * 1000; 68 | const expirationTime = alert_start_time + (remoteAlert.countdown_time * 1000); 69 | activeAlerts[remoteAlert.id] = { 70 | data: remoteAlert, 71 | expirationTime: expirationTime 72 | }; 73 | } 74 | }); 75 | 76 | localStorage.setItem('activeAlerts', JSON.stringify(activeAlerts)); 77 | const alertIds = Object.keys(activeAlerts); 78 | 79 | alertIds.forEach(alertId => { 80 | const alert = activeAlerts[alertId].data; 81 | alert.countdown_time = Math.max(1, Math.floor((activeAlerts[alertId].expirationTime - now) / 1000)); 82 | displayAlert(alert); 83 | }); 84 | 85 | return alertIds.length > 0; 86 | } 87 | 88 | function displayAlert(alert_data) { 89 | const parsedData = parseAlertData(alert_data); 90 | updateAlertUI(parsedData); 91 | startAlertCountdown(parsedData); 92 | saveActiveAlert(parsedData); 93 | } 94 | 95 | function parseAlertData(alert_data) { 96 | return typeof alert_data === 'string' ? JSON.parse(alert_data) : alert_data; 97 | } 98 | 99 | function updateAlertUI(data) { 100 | currentAlertId = data.id; 101 | const notificationsContainer = document.getElementById('notificationsContainer'); 102 | 103 | if (document.getElementById(`alert-${data.id}`)) { 104 | return; 105 | } 106 | 107 | const alertElement = document.createElement('div'); 108 | alertElement.id = `alert-${data.id}`; 109 | alertElement.className = 'alert-item'; 110 | alertElement.style.padding = '10px'; 111 | alertElement.style.marginBottom = '10px'; 112 | alertElement.style.borderBottom = '1px solid #dee2e6'; 113 | let alertContent = `

${data.message}

`; 114 | alertContent += `

`; 115 | 116 | if (data.snapshot) { 117 | alertContent = `` + alertContent; 119 | } 120 | const hasPrinter = data.has_printer === true; 121 | alertContent += `
122 | 123 | 126 | 129 |
`; 130 | 131 | alertElement.innerHTML = alertContent; 132 | notificationsContainer.prepend(alertElement); 133 | 134 | alertElement.querySelector('.dismiss-btn').addEventListener('click', () => { 135 | dismissAlert('dismiss', data.id); 136 | }); 137 | 138 | const cancelBtns = alertElement.querySelectorAll('.suspend-print-btn'); 139 | if (hasPrinter && cancelBtns.length >= 1) { 140 | cancelBtns[0].addEventListener('click', () => { 141 | dismissAlert('cancel_print', data.id); 142 | }); 143 | } 144 | 145 | if (hasPrinter && cancelBtns.length >= 2) { 146 | cancelBtns[1].addEventListener('click', () => { 147 | dismissAlert('pause_print', data.id); 148 | }); 149 | } 150 | 151 | notificationPopup.style.display = 'block'; 152 | } 153 | 154 | function startAlertCountdown(data) { 155 | if (!data.id) return; 156 | 157 | const countdownElement = document.getElementById(`countdown-${data.id}`); 158 | if (!countdownElement) return; 159 | 160 | const countdownTimerId = `countdown-timer-${data.id}`; 161 | if (window[countdownTimerId]) { 162 | clearInterval(window[countdownTimerId]); 163 | } 164 | 165 | const startTime = Date.now(); 166 | const countdownTime = data.countdown_time || 0; 167 | const endTime = startTime + countdownTime * 1000; 168 | 169 | function updateCountdown() { 170 | const now = Date.now(); 171 | let secondsLeft = Math.max(0, Math.round((endTime - now) / 1000)); 172 | countdownElement.textContent = `${secondsLeft}s remaining`; 173 | 174 | const activeAlerts = getLocalActiveAlerts(); 175 | if (activeAlerts[data.id]) { 176 | activeAlerts[data.id].expirationTime = endTime; 177 | localStorage.setItem('activeAlerts', JSON.stringify(activeAlerts)); 178 | } 179 | if (secondsLeft <= 0) { 180 | clearInterval(window[countdownTimerId]); 181 | const action = data.countdown_action || 'pause_print'; 182 | if (action === 'cancel_print' && data.has_printer) { 183 | executeAlertAction('cancel_print', data.id); 184 | } else if (action === 'pause_print' && data.has_printer) { 185 | executeAlertAction('pause_print', data.id); 186 | } else { 187 | executeAlertAction('dismiss', data.id); 188 | } 189 | } 190 | } 191 | 192 | updateCountdown(); 193 | window[countdownTimerId] = setInterval(updateCountdown, 1000); 194 | } 195 | 196 | evtSource.onmessage = (e) => { 197 | try { 198 | let packet_data = JSON.parse(e.data); 199 | packet_data = packet_data.data; 200 | if (packet_data) { 201 | if (packet_data.event == "alert") { 202 | displayAlert(packet_data.data); 203 | } 204 | else if (packet_data.event == "camera_state") { 205 | const cameraData = packet_data.data; 206 | if (!cameraData.camera_uuid) { 207 | console.warn("Camera data missing camera_uuid", cameraData); 208 | } 209 | if (typeof cameraData.live_detection_running !== 'boolean') { 210 | cameraData.live_detection_running = !!cameraData.live_detection_running; 211 | } 212 | document.dispatchEvent(new CustomEvent('cameraStateUpdated', { 213 | detail: cameraData 214 | })); 215 | } 216 | else if (packet_data.event == "printer_state") { 217 | const printerData = packet_data.data; 218 | document.dispatchEvent(new CustomEvent('printerStateUpdated', { 219 | detail: printerData 220 | })); 221 | } 222 | } 223 | } catch (error) { 224 | console.error("Error processing SSE message:", error); 225 | } 226 | }; 227 | 228 | evtSource.onerror = (err) => { 229 | console.error("SSE error", err); 230 | }; 231 | 232 | function executeAlertAction(action_type, alertId) { 233 | fetch(`/alert/dismiss`, { 234 | method: 'POST', 235 | headers: { 236 | 'Content-Type': 'application/json' 237 | }, 238 | body: JSON.stringify({ alert_id: alertId, action: action_type }) 239 | }) 240 | .then(response => { 241 | if (response.ok) { 242 | const alertElement = document.getElementById(`alert-${alertId}`); 243 | if (alertElement) alertElement.remove(); 244 | removeActiveAlert(alertId); 245 | 246 | if (document.getElementById('notificationsContainer').children.length === 0) { 247 | notificationPopup.style.display = 'none'; 248 | } 249 | } else { 250 | console.error('Failed to execute alert action'); 251 | } 252 | }) 253 | .catch(error => console.error('Error:', error)); 254 | } 255 | 256 | function dismissAlert(action_type, alertId) { 257 | if (!alertId) alertId = currentAlertId; 258 | executeAlertAction(action_type, alertId); 259 | } 260 | 261 | document.addEventListener('DOMContentLoaded', () => { 262 | const dismissBtn = document.getElementById('dismissNotificationBtn'); 263 | const cancelBtn = document.getElementById('cancelPrintBtn'); 264 | const pauseBtn = document.getElementById('pausePrintBtn'); 265 | if (dismissBtn) dismissBtn.remove(); 266 | if (cancelBtn) cancelBtn.remove(); 267 | if (pauseBtn) pauseBtn.remove(); 268 | }); --------------------------------------------------------------------------------