├── .gitattributes ├── .github └── workflows │ ├── docker-publish.yml │ ├── publish.yml │ └── python-app.yml ├── .gitignore ├── CLAUDE.md ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── PACKAGING.md ├── README.md ├── __init__.py ├── example ├── blender_view.png ├── depth_map.png ├── external_config.png ├── hacking.gif ├── inpainting-thumb.jpg ├── input.png ├── input_plus_depth.png ├── output.gif ├── thumb.png ├── unreal-thumb.jpg ├── webui.jpg ├── webui_3d.jpg └── workflow.json ├── package.json ├── parallax_maker ├── __init__.py ├── assets │ ├── .pydiscard │ ├── css │ │ ├── features.css │ │ └── tailwind_compiled.css │ └── scripts │ │ └── utility.js ├── automatic1111.py ├── camera.py ├── clientside.py ├── comfyui.py ├── components.py ├── constants.py ├── controller.py ├── depth.py ├── falai.py ├── gltf.py ├── gltf_cli.py ├── inpainting.py ├── instance.py ├── segmentation.py ├── setup_dev.py ├── slice.py ├── stabilityai.py ├── test_automatic1111.py ├── test_components.py ├── test_constants.py ├── test_controller.py ├── test_gltf.py ├── test_inpainting.py ├── test_instance.py ├── test_segmentation.py ├── test_stabilityai.py ├── test_upscaler.py ├── test_utils.py ├── test_webui.py ├── upscaler.py ├── utils.py └── webui.py ├── poetry.lock ├── postcss.config.js ├── pyproject.toml ├── requirements.txt ├── setup_dev.py ├── tailwind.config.js ├── tailwind.css └── test_installation.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.github/workflows/docker-publish.yml: -------------------------------------------------------------------------------- 1 | name: Build and Push Docker Images 2 | 3 | on: 4 | release: 5 | types: [published] 6 | # Optionally allow manual trigger 7 | workflow_dispatch: 8 | 9 | jobs: 10 | docker: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up QEMU 17 | uses: docker/setup-qemu-action@v3 18 | 19 | - name: Set up Docker Buildx 20 | uses: docker/setup-buildx-action@v3 21 | 22 | - name: Login to DockerHub 23 | uses: docker/login-action@v3 24 | with: 25 | username: nielsprovos 26 | password: ${{ secrets.DOCKERHUB_TOKEN }} 27 | 28 | - name: Build and push 29 | uses: docker/build-push-action@v5 30 | with: 31 | context: . 32 | platforms: linux/amd64,linux/arm64 33 | push: true 34 | tags: | 35 | nielsprovos/parallax-maker:latest 36 | nielsprovos/parallax-maker:${{ github.event.release.tag_name }} 37 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: # Allows manual triggering 7 | inputs: 8 | upload_to_pypi: 9 | description: 'Upload to PyPI (be careful!)' 10 | required: false 11 | default: false 12 | type: boolean 13 | 14 | jobs: 15 | build: 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | with: 21 | fetch-depth: 0 22 | 23 | - name: Set up Python 24 | uses: actions/setup-python@v4 25 | with: 26 | python-version: '3.10' 27 | 28 | - name: Install build dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build twine 32 | 33 | - name: Build package 34 | run: python -m build 35 | 36 | - name: Check distribution 37 | run: twine check dist/* 38 | 39 | - name: Upload to PyPI 40 | if: (github.event_name == 'release' && github.event.action == 'published') || (github.event_name == 'workflow_dispatch' && inputs.upload_to_pypi == true) 41 | env: 42 | TWINE_USERNAME: __token__ 43 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 44 | run: twine upload dist/* 45 | 46 | - name: Upload artifacts 47 | uses: actions/upload-artifact@v4 48 | with: 49 | name: dist 50 | path: dist/ 51 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python 3.10 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: "3.10" 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install -e .[dev] 30 | - name: Lint with flake8 31 | run: | 32 | # stop the build if there are Python syntax errors or undefined names 33 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 34 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 35 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 36 | - name: Test with pytest 37 | run: | 38 | pytest 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | #.idea/ 153 | test.py 154 | tmp-images/ 155 | .DS_Store 156 | 157 | # Node modules 158 | node_modules/ 159 | package-lock.json 160 | 161 | # Parallax Maker 162 | appstate-*/ -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. 4 | 5 | ## Project Overview 6 | 7 | Parallax Maker is a Python-based computer vision and AI tool that converts 2D images into 2.5D animations using depth estimation, segmentation, and inpainting. It features a web UI built with Dash and exports to glTF for 3D applications. 8 | 9 | ## Development Commands 10 | 11 | ### Setup and Dependencies 12 | ```bash 13 | # Create virtual environment 14 | python3.12 -m venv .venv 15 | source .venv/bin/activate 16 | 17 | # Install in development mode 18 | pip install -e .[dev] 19 | 20 | # Install TailwindCSS dependencies (only if modifying styles) 21 | npm install -D tailwindcss 22 | ``` 23 | 24 | ### Running the Application 25 | ```bash 26 | # Start the web UI 27 | parallax-maker 28 | 29 | # Or run the module directly 30 | python -m parallax_maker.webui 31 | 32 | # Prefetch models to avoid download delays 33 | parallax-maker --prefetch-models=default 34 | ``` 35 | 36 | ### Development Workflow 37 | ```bash 38 | # Run tests with coverage 39 | pytest 40 | 41 | # Lint code 42 | flake8 . --max-line-length=127 43 | 44 | # Format code 45 | black . 46 | 47 | # Type checking 48 | mypy parallax_maker 49 | 50 | # Build TailwindCSS (when modifying styles) 51 | npm run build 52 | 53 | # Watch TailwindCSS changes during development 54 | npm run watch 55 | ``` 56 | 57 | ### Running Individual Tests 58 | ```bash 59 | # Run specific test file 60 | pytest parallax_maker/test_filename.py 61 | 62 | # Run specific test 63 | pytest parallax_maker/test_filename.py::test_function_name 64 | ``` 65 | 66 | ## Architecture Overview 67 | 68 | ### Core Components 69 | 70 | 1. **Web UI (webui.py)**: Main Dash application that orchestrates the workflow 71 | - Handles image loading, depth generation, segmentation, inpainting, and export 72 | - State management via AppState class 73 | 74 | 2. **Depth Estimation (depth.py)**: Supports multiple models 75 | - MiDaS, DINOv2, ZoeDepth for depth map generation 76 | - Configurable model selection and parameters 77 | 78 | 3. **Segmentation (segmentation.py)**: Interactive image segmentation 79 | - Segment Anything Model (SAM) with point-based selection 80 | - Manual card creation and depth manipulation 81 | 82 | 4. **Inpainting (inpainting.py)**: Multiple backends for mask inpainting 83 | - Local: Stable Diffusion XL, SD3 Medium 84 | - Remote: Automatic1111 (automatic1111.py), ComfyUI (comfyui.py) 85 | - API: StabilityAI 86 | 87 | 5. **3D Export (gltf.py)**: glTF 2.0 scene generation 88 | - Creates depth-based card arrangements 89 | - Supports depth displacement for realistic geometry 90 | - Command-line tool: parallax-gltf-cli 91 | 92 | ### Key Design Patterns 93 | 94 | - **AppState**: Central state management for the entire workflow 95 | - **Component Registry**: UI components registered in components.py 96 | - **Model Management**: Lazy loading of AI models to optimize memory 97 | - **Callback Architecture**: Dash callbacks handle all UI interactions 98 | 99 | ### Data Flow 100 | 101 | 1. Image loaded → Depth model generates depth map 102 | 2. Depth map → Segmentation creates cards/slices 103 | 3. Cards → Inpainting fills masked regions 104 | 4. Processed cards → glTF export or GIF animation 105 | 106 | ### File Organization 107 | 108 | - `parallax_maker/`: Main package directory 109 | - `assets/`: CSS and JavaScript files 110 | - `components.py`: All Dash UI components 111 | - `controller.py`: Main workflow orchestration 112 | - `constants.py`: Configuration and defaults 113 | - Model-specific files: depth.py, segmentation.py, inpainting.py 114 | - Integration files: automatic1111.py, comfyui.py 115 | - Export: gltf.py, gltf_cli.py 116 | 117 | ### State Persistence 118 | 119 | - Application state saved to `appstate-*` directories 120 | - Contains processed images, depth maps, masks, and metadata 121 | - JSON serialization for state data 122 | 123 | ## Important Considerations 124 | 125 | - First-time model downloads can take several minutes 126 | - GPU recommended for reasonable performance 127 | - Memory usage scales with image size and model complexity 128 | - TailwindCSS compilation required only when modifying styles 129 | - SD3 Medium requires latest diffusers from GitHub -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=$BUILDPLATFORM python:3.10-slim 2 | 3 | RUN apt-get update && apt-get install -y git ffmpeg 4 | COPY requirements.txt /app/ 5 | WORKDIR /app 6 | RUN pip install --no-cache-dir -r requirements.txt 7 | COPY . /app 8 | 9 | EXPOSE 8050 10 | 11 | # The .cache directory will be optionally bound to /root/.cache/ 12 | VOLUME ["/root/.cache/"] 13 | 14 | # The working directory can be mounted to /app/workdir 15 | VOLUME ["/app/workdir"] 16 | 17 | WORKDIR /app/workdir 18 | 19 | CMD ["python", "/app/webui.py"] -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements.txt 4 | include package.json 5 | include postcss.config.js 6 | include tailwind.config.js 7 | include tailwind.css 8 | recursive-include parallax_maker/assets * 9 | global-exclude __pycache__ 10 | global-exclude *.py[co] 11 | global-exclude .DS_Store 12 | global-exclude *.egg-info 13 | -------------------------------------------------------------------------------- /PACKAGING.md: -------------------------------------------------------------------------------- 1 | # Building and Publishing Parallax Maker to PyPI 2 | 3 | This document provides instructions for building wheel packages and publishing to PyPI. 4 | 5 | ## Prerequisites 6 | 7 | 1. Install build tools: 8 | ```bash 9 | pip install build twine 10 | ``` 11 | 12 | 2. Configure PyPI credentials (one-time setup): 13 | ```bash 14 | pip install keyring 15 | # For PyPI 16 | python -m keyring set https://upload.pypi.org/legacy/ your-username 17 | # For TestPyPI 18 | python -m keyring set https://test.pypi.org/legacy/ your-username 19 | ``` 20 | 21 | ## Building the Package 22 | 23 | 1. Clean any previous builds: 24 | ```bash 25 | rm -rf dist/ build/ *.egg-info 26 | ``` 27 | 28 | 2. Build Tailwind CSS assets (if changed): 29 | ```bash 30 | npm run build 31 | ``` 32 | 33 | 3. Build the wheel and source distribution: 34 | ```bash 35 | python -m build 36 | ``` 37 | 38 | This will create: 39 | - `dist/parallax_maker-*.whl` (wheel package) 40 | - `dist/parallax-maker-*.tar.gz` (source distribution) 41 | 42 | ## Testing the Package Locally 43 | 44 | 1. Create a test environment: 45 | ```bash 46 | python -m venv test_env 47 | source test_env/bin/activate # On Windows: test_env\Scripts\activate 48 | ``` 49 | 50 | 2. Install the wheel: 51 | ```bash 52 | pip install dist/parallax_maker-*.whl 53 | ``` 54 | 55 | 3. Test the installation: 56 | ```bash 57 | parallax-maker --help 58 | parallax-gltf-cli --help 59 | ``` 60 | 61 | ## Publishing to PyPI 62 | 63 | ### Test on TestPyPI first (recommended) 64 | 65 | 1. Upload to TestPyPI: 66 | ```bash 67 | python -m twine upload --repository testpypi dist/* 68 | ``` 69 | 70 | 2. Test installation from TestPyPI: 71 | ```bash 72 | pip install --index-url https://test.pypi.org/simple/ parallax-maker 73 | ``` 74 | 75 | ### Upload to PyPI 76 | 77 | 1. Upload to PyPI: 78 | ```bash 79 | python -m twine upload dist/* 80 | ``` 81 | 82 | 2. Verify the upload at https://pypi.org/project/parallax-maker/ 83 | 84 | ## Version Management 85 | 86 | Update the version in `pyproject.toml` before building: 87 | 88 | ```toml 89 | [project] 90 | version = "1.0.2" # Update this 91 | ``` 92 | 93 | ## Automated Publishing with GitHub Actions 94 | 95 | Consider setting up GitHub Actions for automated publishing. Create `.github/workflows/publish.yml`: 96 | 97 | ```yaml 98 | name: Publish to PyPI 99 | 100 | on: 101 | release: 102 | types: [published] 103 | 104 | jobs: 105 | deploy: 106 | runs-on: ubuntu-latest 107 | steps: 108 | - uses: actions/checkout@v3 109 | - name: Set up Python 110 | uses: actions/setup-python@v4 111 | with: 112 | python-version: '3.10' 113 | - name: Install dependencies 114 | run: | 115 | python -m pip install --upgrade pip 116 | pip install build twine 117 | - name: Build package 118 | run: python -m build 119 | - name: Publish to PyPI 120 | env: 121 | TWINE_USERNAME: __token__ 122 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 123 | run: twine upload dist/* 124 | ``` 125 | 126 | ## Notes 127 | 128 | - The package name is `parallax-maker` (with hyphen) but the Python module is `parallax_maker` (with underscore) 129 | - Entry points are configured for both the web UI (`parallax-maker`) and CLI tool (`parallax-gltf-cli`) 130 | - All assets (CSS, JS) are included via `MANIFEST.in` 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Hacking On Parallax Maker](https://raw.githubusercontent.com/provos/parallax-maker/main/example/hacking.gif) 2 | 3 | # Installation and First Usage 4 | 5 | ## Prerequisites 6 | - Python 3.10, 3.11, or 3.12 7 | - pip (for package management) 8 | 9 | ## Installation Methods 10 | 11 | Create a new environment with Python 3.10+ and install the project in development mode: 12 | 13 | ```bash 14 | # Create and activate a virtual environment 15 | python3.12 -m venv .venv 16 | source .venv/bin/activate # On macOS/Linux 17 | # .venv\Scripts\activate # On Windows 18 | 19 | # Install the project and dependencies 20 | pip install -e . 21 | ``` 22 | 23 | ## Running the Application 24 | After installation, you can start the application using the entry point: 25 | 26 | ```bash 27 | # Using the installed entry point 28 | parallax-maker 29 | 30 | # Or run the module directly 31 | python -m parallax_maker.webui 32 | ``` 33 | 34 | You can then reach the web ui via [http://127.0.0.1:8050/](http://127.0.0.1:8050/). Be prepared that the first time, any new functionality is used, the corresponding models need to be downloaded first. This can take a few minutes based on your connection speed. If you want to prefetch the default models, you can start the application with: 35 | 36 | ```bash 37 | parallax-maker --prefetch-models=default 38 | ``` 39 | 40 | > [!NOTE] 41 | > If you want to make changes to the styles, you need to set up `node` and run `npm run build` to rebuild the tailwind css file. This requires installing `tailwindcss` via `npm install -D tailwindcss`. 42 | 43 | # Parallax-Maker 44 | 45 | Provides a workflow for turning images into 2.5D animation like the one seen above. 46 | 47 | ## Features 48 | - Segmentation of images 49 | - Using depth models like Midas or ZeoDepth 50 | - Using instance segmentation via Segment Anything with multiple positive and negative point selection 51 | - Adding and removing of cards, direct manipulation of depth values 52 | - Inpainting 53 | - Inpainting of masks that can be padded and blurred 54 | - Replacing the masked regions with new images via image generation models like Stable Diffusion 1.0 XL, Stable Diffusion 3 Medium, Automatic1111 or ComyfUI endpoints as well as the StabilityAI API. 55 | - 3D Export 56 | - Generation of glTF scenes that can be imported into Blender or Unreal Engine 57 | - Support for depth displacement of cards to generate more realistic 3D geometry 58 | - In browser 3D preview of the generated glTF scene. 59 | 60 | ## Basic Examples 61 | 62 | Using an input image, the tool runs a depth model like **Midas** or **DINOv2** to generate a depth map 63 | 64 | ![Input Image](https://raw.githubusercontent.com/provos/parallax-maker/main/example/input_plus_depth.png) 65 | 66 | and then creates cards that can be used for 2.5 parallax animation. 67 | 68 | ![Animation](https://raw.githubusercontent.com/provos/parallax-maker/main/example/output.gif) 69 | 70 | This animation was created using the following command: 71 | 72 | ~~~ 73 | ffmpeg -framerate 24 -i rendered_image_%03d.png -filter_complex "fps=5,scale=480:-1:flags=lanczos,split[s0][s1];[s0]palettegen=max_colors=32[p];[s1][p]paletteuse=dither=bayer" output.gif 74 | ~~~ 75 | 76 | 77 | # 3D Export 78 | 79 | The tool also supports generating a glTF2.0 scene file that an be easily imported into 3D apps like Blender or Unreal Engine. 80 | 81 | > [!TIP] 82 | > To utilize depth of field camera effects for the Blender scene, the material needs to be changed to **ALPHA HASHED**. 83 | 84 | > [!TIP] 85 | > To utilize depth of field camera effects for Unreal Engine, the material needs to be changed to **Translucent Masked**. 86 | 87 | 88 | ![Blender Scene View](https://raw.githubusercontent.com/provos/parallax-maker/main/example/blender_view.png) 89 | 90 | 91 | # Web UI 92 | 93 | ![Web UI](https://raw.githubusercontent.com/provos/parallax-maker/main/example/webui.jpg) 94 | 95 | A Dash based Web UI provides a browser assisted workflow to generated slices from images, inpaint the slices and then export them as a glTF scene to Blender or Unreal Engine. The resulting glTF scene can also be visualized within the app or manipulated via a command line tool and the state file saved by the app. 96 | 97 | ![Web UI 3D Example](https://raw.githubusercontent.com/provos/parallax-maker/main/example/webui_3d.jpg) 98 | 99 | # Advanced Use Cases 100 | Parallax Maker also supports the Automatic1111 and ComfyUI API endpoints. This allows the tool to utilize GPUs remotely and potentially achieve much higher performance compared to the local GPU. It also means that it's possible to use more specialized inpainting models and workflows. Here is [an example](https://raw.githubusercontent.com/provos/parallax-maker/main/example/workflow.json) ComfyUI inpainting workflow that makes use the offset lora published by Stability AI. 101 | 102 | ![Example configuration for ComfyUI](https://raw.githubusercontent.com/provos/parallax-maker/main/example/external_config.png) 103 | 104 | # Watch the Video 105 | [![Watch the video](https://raw.githubusercontent.com/provos/parallax-maker/main/example/thumb.png)](https://www.youtube.com/watch?v=4JBQCz-wWYQ) 106 | 107 | # Tutorials 108 | ## Segmentation and Inpainting Tutorial 109 | [![Segmentation and Inpainting Tutorial](https://raw.githubusercontent.com/provos/parallax-maker/main/example/inpainting-thumb.jpg)](https://youtu.be/hb_x8z4WIeI) 110 | ## Unreal Engine Import and Rendering Tutorial 111 | [![Unreal Import and Rendering Tutorial](https://raw.githubusercontent.com/provos/parallax-maker/main/example/unreal-thumb.jpg)](https://www.youtube.com/watch?v=fLSCCS53h_U) -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | """Parallax Maker - A workflow for turning images into 2.5D animations. 2 | 3 | This package provides tools for: 4 | - Depth estimation from images using models like MiDaS and ZoeDepth 5 | - Image segmentation using Segment Anything with point selection 6 | - Inpainting with various models including Stable Diffusion 7 | - 3D export to glTF format for Blender and Unreal Engine 8 | - Web-based user interface for interactive workflow 9 | """ 10 | 11 | __version__ = "1.0.2" 12 | __author__ = "Niels Provos" 13 | __email__ = "niels@provos.org" 14 | __license__ = "AGPL-3.0-or-later" 15 | -------------------------------------------------------------------------------- /example/blender_view.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/provos/parallax-maker/HEAD/example/blender_view.png -------------------------------------------------------------------------------- /example/depth_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/provos/parallax-maker/HEAD/example/depth_map.png -------------------------------------------------------------------------------- /example/external_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/provos/parallax-maker/HEAD/example/external_config.png -------------------------------------------------------------------------------- /example/hacking.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/provos/parallax-maker/HEAD/example/hacking.gif -------------------------------------------------------------------------------- /example/inpainting-thumb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/provos/parallax-maker/HEAD/example/inpainting-thumb.jpg -------------------------------------------------------------------------------- /example/input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/provos/parallax-maker/HEAD/example/input.png -------------------------------------------------------------------------------- /example/input_plus_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/provos/parallax-maker/HEAD/example/input_plus_depth.png -------------------------------------------------------------------------------- /example/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/provos/parallax-maker/HEAD/example/output.gif -------------------------------------------------------------------------------- /example/thumb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/provos/parallax-maker/HEAD/example/thumb.png -------------------------------------------------------------------------------- /example/unreal-thumb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/provos/parallax-maker/HEAD/example/unreal-thumb.jpg -------------------------------------------------------------------------------- /example/webui.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/provos/parallax-maker/HEAD/example/webui.jpg -------------------------------------------------------------------------------- /example/webui_3d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/provos/parallax-maker/HEAD/example/webui_3d.jpg -------------------------------------------------------------------------------- /example/workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": { 3 | "inputs": { 4 | "ckpt_name": "sd_xl_base_1.0_0.9vae.safetensors" 5 | }, 6 | "class_type": "CheckpointLoaderSimple", 7 | "_meta": { 8 | "title": "Load Checkpoint" 9 | } 10 | }, 11 | "2": { 12 | "inputs": { 13 | "seed": 545041323232936, 14 | "steps": 35, 15 | "cfg": 5, 16 | "sampler_name": "dpmpp_2m_sde_gpu", 17 | "scheduler": "karras", 18 | "denoise": 0.85, 19 | "model": [ 20 | "20", 21 | 0 22 | ], 23 | "positive": [ 24 | "3", 25 | 0 26 | ], 27 | "negative": [ 28 | "4", 29 | 0 30 | ], 31 | "latent_image": [ 32 | "40", 33 | 0 34 | ] 35 | }, 36 | "class_type": "KSampler", 37 | "_meta": { 38 | "title": "KSampler" 39 | } 40 | }, 41 | "3": { 42 | "inputs": { 43 | "text": "photo, fruit bowl on table in apartment living room,\n\nsoft lighting, hard shadows, shot on kodak,", 44 | "clip": [ 45 | "20", 46 | 1 47 | ] 48 | }, 49 | "class_type": "CLIPTextEncode", 50 | "_meta": { 51 | "title": "CLIP Text Encode (Prompt)" 52 | } 53 | }, 54 | "4": { 55 | "inputs": { 56 | "text": "ugly", 57 | "clip": [ 58 | "20", 59 | 1 60 | ] 61 | }, 62 | "class_type": "CLIPTextEncode", 63 | "_meta": { 64 | "title": "CLIP Text Encode (Prompt)" 65 | } 66 | }, 67 | "20": { 68 | "inputs": { 69 | "lora_name": "sd_xl_offset_example-lora_1.0.safetensors", 70 | "strength_model": 0.3, 71 | "strength_clip": 0.3, 72 | "model": [ 73 | "1", 74 | 0 75 | ], 76 | "clip": [ 77 | "1", 78 | 1 79 | ] 80 | }, 81 | "class_type": "LoraLoader", 82 | "_meta": { 83 | "title": "Load LoRA" 84 | } 85 | }, 86 | "24": { 87 | "inputs": { 88 | "samples": [ 89 | "2", 90 | 0 91 | ], 92 | "vae": [ 93 | "1", 94 | 2 95 | ] 96 | }, 97 | "class_type": "VAEDecode", 98 | "_meta": { 99 | "title": "VAE Decode" 100 | } 101 | }, 102 | "26": { 103 | "inputs": { 104 | "image": "image_slice_0_v8.png", 105 | "upload": "image" 106 | }, 107 | "class_type": "LoadImage", 108 | "_meta": { 109 | "title": "Load Image" 110 | } 111 | }, 112 | "33": { 113 | "inputs": { 114 | "filename_prefix": "ComfyUI", 115 | "images": [ 116 | "24", 117 | 0 118 | ] 119 | }, 120 | "class_type": "SaveImage", 121 | "_meta": { 122 | "title": "Save Image" 123 | } 124 | }, 125 | "38": { 126 | "inputs": { 127 | "image": "image_slice_0_v8_mask.png", 128 | "channel": "red", 129 | "upload": "image" 130 | }, 131 | "class_type": "LoadImageMask", 132 | "_meta": { 133 | "title": "Load Image (as Mask)" 134 | } 135 | }, 136 | "39": { 137 | "inputs": { 138 | "pixels": [ 139 | "26", 140 | 0 141 | ], 142 | "vae": [ 143 | "1", 144 | 2 145 | ] 146 | }, 147 | "class_type": "VAEEncode", 148 | "_meta": { 149 | "title": "VAE Encode" 150 | } 151 | }, 152 | "40": { 153 | "inputs": { 154 | "samples": [ 155 | "39", 156 | 0 157 | ], 158 | "mask": [ 159 | "38", 160 | 0 161 | ] 162 | }, 163 | "class_type": "SetLatentNoiseMask", 164 | "_meta": { 165 | "title": "Set Latent Noise Mask" 166 | } 167 | } 168 | } -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "parallax-maker", 3 | "version": "1.0.0", 4 | "description": "This is not a node package. I just want to compile tailwindcss.", 5 | "main": "index.js", 6 | "scripts": { 7 | "build": "tailwindcss -i ./tailwind.css -o ./parallax_maker/assets/css/tailwind_compiled.css", 8 | "watch": "tailwindcss -i ./tailwind.css -o ./parallax_maker/assets/css/tailwind_compiled.css --watch", 9 | "test": "echo \"Error: no test specified\" && exit 1" 10 | }, 11 | "author": "Niels Provos", 12 | "license": "AGPL-3.0-only", 13 | "dependencies": { 14 | "autoprefixer": "^10.4.19", 15 | "postcss": "^8.4.38" 16 | }, 17 | "devDependencies": { 18 | "tailwindcss": "^3.4.4" 19 | } 20 | } -------------------------------------------------------------------------------- /parallax_maker/__init__.py: -------------------------------------------------------------------------------- 1 | """Parallax Maker - A workflow for turning images into 2.5D animations. 2 | 3 | This package provides tools for: 4 | - Depth estimation from images using models like MiDaS and ZoeDepth 5 | - Image segmentation using Segment Anything with point selection 6 | - Inpainting with various models including Stable Diffusion 7 | - 3D export to glTF format for Blender and Unreal Engine 8 | - Web-based user interface for interactive workflow 9 | """ 10 | 11 | __version__ = "1.0.1" 12 | __author__ = "Niels Provos" 13 | __email__ = "niels@provos.org" 14 | __license__ = "AGPL-3.0-or-later" 15 | 16 | from .controller import AppState, CompositeMode 17 | from .camera import Camera 18 | from .slice import ImageSlice 19 | 20 | __all__ = [ 21 | "AppState", 22 | "CompositeMode", 23 | "Camera", 24 | "ImageSlice", 25 | ] 26 | -------------------------------------------------------------------------------- /parallax_maker/assets/.pydiscard: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /parallax_maker/assets/css/features.css: -------------------------------------------------------------------------------- 1 | .text-overlay { 2 | position: absolute; 3 | /* Overlap the image */ 4 | bottom: 0; 5 | /* Position at bottom */ 6 | right: 0; 7 | /* Position at right corner */ 8 | padding: 3px; 9 | /* Add some padding */ 10 | background-color: rgba(0, 0, 0, 0.5); 11 | /* Semi-transparent black background */ 12 | color: white; 13 | /* White text color */ 14 | } 15 | 16 | .checkerboard { 17 | background-image: 18 | linear-gradient(45deg, #ccc 25%, transparent 25%), 19 | linear-gradient(135deg, #ccc 25%, transparent 25%), 20 | linear-gradient(45deg, transparent 75%, #ccc 75%), 21 | linear-gradient(135deg, transparent 75%, #ccc 75%); 22 | background-size: 25px 25px; 23 | /* Must be a square */ 24 | background-position: 0 0, 12.5px 0, 12.5px -12.5px, 0px 12.5px; 25 | /* Must be half of one side of the square */ 26 | } 27 | 28 | #preview-canvas { 29 | position: absolute; 30 | pointer-events: none; 31 | /* Allow clicks and events to pass through */ 32 | opacity: 0.5; 33 | } 34 | 35 | button:disabled { 36 | background-color: #BBD; 37 | } -------------------------------------------------------------------------------- /parallax_maker/automatic1111.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import requests 3 | from PIL import Image 4 | from io import BytesIO 5 | import json 6 | import argparse 7 | 8 | from .utils import to_image_url 9 | 10 | MAX_DIMENSION = 2048 11 | 12 | 13 | def create_img2img_payload( 14 | input_image, 15 | positive_prompt, 16 | negative_prompt, 17 | mask_image=None, 18 | strength=0.8, 19 | steps=20, 20 | cfg_scale=8.0, 21 | ): 22 | """ 23 | Create a payload for the img2img API. 24 | 25 | Args: 26 | input_image (PIL.Image.Image): The input image. 27 | positive_prompt (str): The positive prompt for generating the output image. 28 | negative_prompt (str): The negative prompt for generating the output image. 29 | mask_image (PIL.Image.Image, optional): The mask for the input image. Defaults to None. 30 | strength (float, optional): The denoising strength. Defaults to 0.8. 31 | steps (int, optional): The number of steps for generating the output image. Defaults to 20. 32 | cfg_scale (float, optional): The scale factor for the configuration. Defaults to 8.0. 33 | 34 | Returns: 35 | dict: The payload for the img2img API. 36 | """ 37 | width, height = input_image.size 38 | if width > MAX_DIMENSION or height > MAX_DIMENSION: 39 | # resize the image so the largest dimension is MAX_DIMENSION 40 | if width > height: 41 | new_width = MAX_DIMENSION 42 | new_height = int(MAX_DIMENSION * height / width) 43 | else: 44 | new_height = MAX_DIMENSION 45 | new_width = int(MAX_DIMENSION * width / height) 46 | input_image = input_image.resize((new_width, new_height)) 47 | mask_image = mask_image.resize((new_width, new_height)) 48 | 49 | payload = { 50 | "init_images": [to_image_url(input_image)], 51 | "prompt": positive_prompt, 52 | "negative_prompt": negative_prompt, 53 | "denoising_strength": strength, 54 | "width": width, 55 | "height": height, 56 | "steps": steps, 57 | "cfg_scale": cfg_scale, 58 | "mask_blur_x": 0, # if we provide a mask it's already blurred 59 | "mask_blur_y": 0, 60 | "mask_blur": 0, 61 | } 62 | 63 | if mask_image is not None: 64 | payload["mask"] = to_image_url(mask_image) 65 | 66 | return payload 67 | 68 | 69 | def make_img2img_request(server_address, payload): 70 | """ 71 | Sends an img2img request to the automatic1111 API endpoint. 72 | 73 | Args: 74 | server_address (str): The address:port of the server to send the request to. 75 | payload (dict): The payload generated by create_img2img_payload. 76 | 77 | Returns: 78 | list: A list of images returned in the response. 79 | 80 | Raises: 81 | requests.exceptions.Timeout: If the request times out. 82 | 83 | """ 84 | server_url = f"http://{server_address}/sdapi/v1/img2img" 85 | payload = json.dumps(payload) 86 | response = requests.post(url=server_url, data=payload, timeout=500).json() 87 | 88 | return response.get("images") 89 | 90 | 91 | def make_models_request(server_address): 92 | """ 93 | Sends a request to the automatic1111 API endpoint to get the available models. 94 | 95 | Args: 96 | server_address (str): The address:port of the server to send the request to. 97 | 98 | Returns: 99 | list: A list of available models. 100 | 101 | Raises: 102 | requests.exceptions.Timeout: If the request times out. 103 | 104 | """ 105 | server_url = f"http://{server_address}/sdapi/v1/sd-models" 106 | response = requests.get(url=server_url, timeout=3).json() 107 | 108 | if not isinstance(response, list): 109 | return None 110 | 111 | return [entry["model_name"] for entry in response] 112 | 113 | 114 | def main(): 115 | # Parse arguments - input image, positive prompt, negative prompt, server address 116 | parser = argparse.ArgumentParser(description="Process some integers.") 117 | parser.add_argument("-i", "--input-image", type=str, help="Path to the input image") 118 | parser.add_argument( 119 | "-m", "--mask-image", type=str, default=None, help="Path to the mask image" 120 | ) 121 | parser.add_argument("-p", "--prompt", type=str, help="Positive prompt") 122 | parser.add_argument( 123 | "-n", "--negative-prompt", type=str, default="", help="Negative prompt" 124 | ) 125 | parser.add_argument("-e", "--steps", type=int, default=30, help="Number of steps") 126 | parser.add_argument("-s", "--server", type=str, help="Server address") 127 | args = parser.parse_args() 128 | 129 | image = Image.open(args.input_image) 130 | mask_image = None 131 | if args.mask_image is not None: 132 | mask_image = Image.open(args.mask_image) 133 | 134 | payload = create_img2img_payload( 135 | image, 136 | args.prompt, 137 | args.negative_prompt, 138 | mask_image=mask_image, 139 | steps=args.steps, 140 | ) 141 | images = make_img2img_request(args.server, payload) 142 | if images is None: 143 | print("No response from server") 144 | else: 145 | for image in images: 146 | image = Image.open(BytesIO(base64.b64decode(image))) 147 | image.show() 148 | input("Press Enter to continue...") 149 | 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /parallax_maker/camera.py: -------------------------------------------------------------------------------- 1 | # (c) 2024 Niels Provos 2 | 3 | import numpy as np 4 | 5 | 6 | class Camera: 7 | __slots__ = ( 8 | "_camera_position", 9 | "_camera_distance", 10 | "_max_distance", 11 | "_focal_length", 12 | "_sensor_width", 13 | ) 14 | 15 | def __init__( 16 | self, distance=100, max_distance=500, focal_length=100, sensor_width=35.0 17 | ): 18 | """ 19 | Initializes a Camera object. 20 | 21 | It isn't really a camera but encapsulates related properties. 22 | 23 | Args: 24 | distance (float): The distance of the camera from the scene. 25 | max_distance (float): The maximum distance/depth of the scene. 26 | focal_length (float): The focal length of the camera. 27 | sensor_width (float, optional): The width of the camera sensor. Defaults to 35.0. 28 | 29 | Returns: 30 | None 31 | """ 32 | self.camera_position = np.array([0, 0, -distance], dtype=np.float32) 33 | self.camera_distance = distance 34 | self.max_distance = max_distance 35 | self.focal_length = focal_length 36 | self.sensor_width = sensor_width 37 | 38 | def focal_length_px(self, image_width): 39 | """ 40 | Calculate the focal length in pixels. 41 | 42 | Args: 43 | image_width (int): The width of the image. 44 | 45 | Returns: 46 | float: The focal length in pixels. 47 | """ 48 | return (image_width * self.focal_length) / self.sensor_width 49 | 50 | def camera_matrix(self, image_width, image_height): 51 | """ 52 | Calculate the camera matrix. 53 | 54 | Args: 55 | image_width (int): The width of the image. 56 | image_height (int): The height of the image. 57 | 58 | Returns: 59 | np.ndarray: The camera matrix. 60 | """ 61 | 62 | fl_px = self.focal_length_px(image_width) 63 | return np.array( 64 | [[fl_px, 0, image_width / 2], [0, fl_px, image_height / 2], [0, 0, 1]], 65 | dtype=np.float32, 66 | ) 67 | 68 | def to_json(self): 69 | return { 70 | "position": self._camera_position.tolist(), 71 | "camera_distance": self._camera_distance, 72 | "max_distance": self._max_distance, 73 | "focal_length": self._focal_length, 74 | } 75 | 76 | @staticmethod 77 | def from_json(data): 78 | camera = Camera() 79 | if "position" in data: 80 | camera.camera_position = np.array(data["position"], dtype=np.float32) 81 | if "camera_distance" in data: 82 | camera.camera_distance = data["camera_distance"] 83 | if "max_distance" in data: 84 | camera.max_distance = data["max_distance"] 85 | if "focal_length" in data: 86 | camera.focal_length = data["focal_length"] 87 | 88 | return camera 89 | 90 | def __str__(self): 91 | return ( 92 | f"Camera(distance={self._camera_distance}, max_distance={self._max_distance}, " 93 | f"focal_length={self._focal_length}, position={self._camera_position})" 94 | ) 95 | 96 | def __repr__(self): 97 | return str(self) 98 | 99 | def __eq__(self, other): 100 | if not isinstance(other, Camera): 101 | return False 102 | return ( 103 | self.camera_distance == other.camera_distance 104 | and self.max_distance == other.max_distance 105 | and self.focal_length == other.focal_length 106 | ) 107 | 108 | @property 109 | def camera_position(self): 110 | return self._camera_position 111 | 112 | @camera_position.setter 113 | def camera_position(self, value): 114 | if ( 115 | not isinstance(value, np.ndarray) 116 | or value.dtype != np.float32 117 | or value.shape != (3,) 118 | ): 119 | raise ValueError( 120 | "camera_position must be a numpy array of dtype 'float32' with shape (3,)" 121 | ) 122 | self._camera_position = value 123 | 124 | @property 125 | def camera_distance(self): 126 | return self._camera_distance 127 | 128 | @camera_distance.setter 129 | def camera_distance(self, value): 130 | if not isinstance(value, (float, int)) or value < 0: 131 | raise ValueError("camera_distance must be a non-negative number") 132 | self._camera_distance = value 133 | 134 | @property 135 | def max_distance(self): 136 | return self._max_distance 137 | 138 | @max_distance.setter 139 | def max_distance(self, value): 140 | if not isinstance(value, (float, int)) or value < 0: 141 | raise ValueError("max_distance must be a non-negative number") 142 | self._max_distance = value 143 | 144 | @property 145 | def focal_length(self): 146 | return self._focal_length 147 | 148 | @focal_length.setter 149 | def focal_length(self, value): 150 | if not isinstance(value, (float, int)) or value <= 0: 151 | raise ValueError("focal_length must be a positive number") 152 | self._focal_length = value 153 | 154 | @property 155 | def sensor_width(self): 156 | return self._sensor_width 157 | 158 | @sensor_width.setter 159 | def sensor_width(self, value): 160 | if not isinstance(value, (float, int)) or value <= 0: 161 | raise ValueError("sensor_width must be a positive number") 162 | self._sensor_width = value 163 | -------------------------------------------------------------------------------- /parallax_maker/clientside.py: -------------------------------------------------------------------------------- 1 | # (c) 2024 Niels Provos 2 | 3 | from dash.dependencies import Input, Output, State, ClientsideFunction 4 | 5 | from . import constants as C 6 | 7 | 8 | def make_clientside_callbacks(app): 9 | app.clientside_callback( 10 | ClientsideFunction(namespace="clientside", function_name="store_rect_coords"), 11 | Output(C.STORE_RECT_DATA, "data"), 12 | Input(C.IMAGE, "src"), 13 | Input("evScroll", "n_events"), 14 | ) 15 | 16 | app.clientside_callback( 17 | ClientsideFunction( 18 | namespace="clientside", function_name="suppress_contextmenu" 19 | ), 20 | Output(C.CTR_INPUT_IMAGE, "id"), 21 | Input(C.CTR_INPUT_IMAGE, "id"), 22 | ) 23 | 24 | app.clientside_callback( 25 | ClientsideFunction(namespace="clientside", function_name="store_current_tab"), 26 | Output(C.STORE_CURRENT_TAB, "data"), 27 | Input(C.STORE_CURRENT_TAB, "data"), 28 | ) 29 | 30 | app.clientside_callback( 31 | ClientsideFunction( 32 | namespace="clientside", function_name="record_selected_slice" 33 | ), 34 | Output(C.STORE_IGNORE, "data", allow_duplicate=True), 35 | Input(C.STORE_SELECTED_SLICE, "data"), 36 | prevent_initial_call=True, 37 | ) 38 | 39 | app.clientside_callback( 40 | ClientsideFunction(namespace="clientside", function_name="visualize_point"), 41 | Output(C.STORE_IGNORE, "data", allow_duplicate=True), 42 | Input(C.STORE_CLICKED_POINT, "data"), 43 | prevent_initial_call=True, 44 | ) 45 | 46 | app.clientside_callback( 47 | ClientsideFunction( 48 | namespace="clientside", function_name="preview_canvas_clear" 49 | ), 50 | Output(C.STORE_IGNORE, "data", allow_duplicate=True), 51 | Input(C.STORE_CLEAR_PREVIEW, "data"), 52 | prevent_initial_call=True, 53 | ) 54 | 55 | app.clientside_callback( 56 | ClientsideFunction(namespace="clientside", function_name="show_bounding_box"), 57 | Output(C.STORE_IGNORE, "data", allow_duplicate=True), 58 | Input(C.STORE_BOUNDING_BOX, "data"), 59 | prevent_initial_call=True, 60 | ) 61 | -------------------------------------------------------------------------------- /parallax_maker/comfyui.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import io 4 | import os 5 | from PIL import Image 6 | from urllib import request, parse 7 | import random 8 | import requests 9 | 10 | import uuid 11 | import websocket 12 | 13 | 14 | def load_workflow(workflow_path): 15 | try: 16 | with open(workflow_path, "r") as file: 17 | workflow = json.load(file) 18 | return json.dumps(workflow) 19 | except Exception as e: 20 | print(f"Failed to load workflow: {e}") 21 | return None 22 | 23 | 24 | def patch_inpainting_workflow( 25 | workflow, 26 | image, 27 | mask, 28 | prompt, 29 | negative_prompt, 30 | strength=0.8, 31 | steps=1, 32 | cfg_scale=5.0, 33 | seed=-1, 34 | ): 35 | workflow = json.loads(workflow) 36 | 37 | # find the sampler note in workflow 38 | sampler_id = [ 39 | key for key, value in workflow.items() if value["class_type"] == "KSampler" 40 | ][0] 41 | 42 | sampler = workflow.get(sampler_id) 43 | sampler["inputs"]["denoise"] = strength 44 | sampler["inputs"]["steps"] = steps 45 | sampler["inputs"]["cfg"] = cfg_scale 46 | 47 | if seed == -1: 48 | seed = random.randint(0, 2**32 - 1) 49 | sampler["inputs"]["seed"] = seed 50 | 51 | positive_id = sampler["inputs"]["positive"][0] 52 | positive_node = workflow.get(positive_id) 53 | positive_node["inputs"]["text"] = prompt 54 | 55 | negative_id = sampler["inputs"]["negative"][0] 56 | negative_node = workflow.get(negative_id) 57 | negative_node["inputs"]["text"] = negative_prompt 58 | 59 | latent_id = sampler["inputs"]["latent_image"][0] 60 | latent_node = workflow.get(latent_id) 61 | 62 | # this is very fragile and is not going to work in general 63 | mask_id = None 64 | if latent_node["class_type"] == "SetLatentNoiseMask": 65 | mask_id = latent_node["inputs"]["mask"][0] 66 | latent_id = latent_node["inputs"]["samples"][0] 67 | latent_node = workflow.get(latent_id) 68 | 69 | if latent_node["class_type"] == "VAEEncode": 70 | image_id = latent_node["inputs"]["pixels"][0] 71 | if mask_id is None: 72 | mask_id = latent_node["inputs"]["mask"][0] 73 | else: 74 | raise ValueError(f"Unknown node type: {latent_node['class_type']}") 75 | 76 | image_node = workflow.get(image_id) 77 | mask_node = workflow.get(mask_id) 78 | 79 | assert "image" in image_node["inputs"], "Image node does not have an image input" 80 | assert "image" in mask_node["inputs"], "Mask node does not have an image input" 81 | 82 | image_node["inputs"]["image"] = image 83 | mask_node["inputs"]["image"] = mask 84 | 85 | return workflow 86 | 87 | 88 | def queue_prompt(server_address, client_id, workflow): 89 | # send the workflow to the server 90 | p = {"prompt": workflow, "client_id": client_id} 91 | data = json.dumps(p).encode("utf-8") 92 | req = request.Request(f"http://{server_address}/prompt", data=data) 93 | return json.loads(request.urlopen(req).read()) 94 | 95 | 96 | def upload_image(server_address, file, overwrite=False): 97 | try: 98 | files = { 99 | "image": open(file, "rb"), 100 | } 101 | data = {"type": "input", "overwrite": str(overwrite).lower()} 102 | 103 | # Making the POST request 104 | response = requests.post( 105 | f"http://{server_address}/upload/image", files=files, data=data, timeout=500 106 | ) 107 | 108 | if response.status_code == 200: 109 | data = response.json() 110 | path = data["name"] 111 | return path 112 | else: 113 | print(f"HTTP Error: {response.status_code} - {response.reason}") 114 | return None 115 | 116 | except Exception as error: 117 | print(f"An error occurred: {error}") 118 | return None 119 | 120 | 121 | def get_image(server_address, filename, subfolder, folder_type): 122 | data = {"filename": filename, "subfolder": subfolder, "type": folder_type} 123 | url_values = parse.urlencode(data) 124 | with request.urlopen(f"http://{server_address}/view?{url_values}") as response: 125 | return response.read() 126 | 127 | 128 | def get_images(server_address, client_id, prompt_id): 129 | ws = websocket.WebSocket() 130 | ws.connect(f"ws://{server_address}/ws?clientId={client_id}") 131 | 132 | output_images = {} 133 | while True: 134 | out = ws.recv() 135 | if isinstance(out, str): 136 | message = json.loads(out) 137 | if message["type"] == "executing": 138 | data = message["data"] 139 | if data["node"] is None and data["prompt_id"] == prompt_id: 140 | break # Execution is done 141 | else: 142 | continue # previews are binary data 143 | 144 | history = get_history(server_address, prompt_id)[prompt_id] 145 | for o in history["outputs"]: 146 | for node_id in history["outputs"]: 147 | node_output = history["outputs"][node_id] 148 | if "images" in node_output: 149 | images_output = [] 150 | for image in node_output["images"]: 151 | image_data = get_image( 152 | server_address, 153 | image["filename"], 154 | image["subfolder"], 155 | image["type"], 156 | ) 157 | images_output.append(image_data) 158 | output_images[node_id] = images_output 159 | 160 | return output_images 161 | 162 | 163 | def get_history(server_address, prompt_id): 164 | with request.urlopen(f"http://{server_address}/history/{prompt_id}") as response: 165 | return json.loads(response.read()) 166 | 167 | 168 | def temporary_filename(prefix="tmp", suffix=".png"): 169 | return f"{prefix}-{uuid.uuid4()}{suffix}" 170 | 171 | 172 | def inpainting_comfyui( 173 | server_address, 174 | workflow_path, 175 | image, 176 | mask, 177 | prompt, 178 | negative_prompt, 179 | strength=0.5, 180 | steps=40, 181 | cfg_scale=5.0, 182 | seed=-1, 183 | ): 184 | """ 185 | Perform inpainting using an external ComfyUI server. 186 | 187 | Args: 188 | server_address (str): The address of the server. 189 | workflow_path (str): The path to the workflow file. 190 | image (PIL.Image): The input image. 191 | mask (PIL.Image): The mask indicating the areas to be inpainted. 192 | prompt (str): The prompt for the inpainting process. 193 | negative_prompt (str): The negative prompt for the inpainting process. 194 | strength (float, optional): The strength of the inpainting. Defaults to 0.5. 195 | steps (int, optional): The number of steps for the inpainting process. Defaults to 40. 196 | cfg_scale (float, optional): The scale factor for the configuration. Defaults to 5.0. 197 | seed (int, optional): The seed for the inpainting process. Defaults to -1. 198 | 199 | Returns: 200 | PIL.Image: The inpainted image. 201 | """ 202 | workflow = load_workflow(workflow_path) 203 | if workflow is None: 204 | return None 205 | 206 | client_id = str(uuid.uuid4()) 207 | 208 | # we assume PIL images 209 | image_path = temporary_filename() 210 | image.save(image_path) 211 | mask_path = temporary_filename() 212 | mask.save(mask_path) 213 | 214 | image = None 215 | try: 216 | image = upload_image(server_address, image_path, overwrite=True) 217 | mask = upload_image(server_address, mask_path, overwrite=True) 218 | 219 | workflow = patch_inpainting_workflow( 220 | workflow, 221 | image, 222 | mask, 223 | prompt, 224 | negative_prompt, 225 | strength=strength, 226 | steps=steps, 227 | cfg_scale=cfg_scale, 228 | seed=seed, 229 | ) 230 | 231 | status = queue_prompt(server_address, client_id, workflow) 232 | prompt_id = status["prompt_id"] 233 | 234 | images = get_images(server_address, client_id, prompt_id) 235 | 236 | # XXX - only one image for now 237 | id = list(images.keys())[0] 238 | image_data = images[id][0] 239 | image = Image.open(io.BytesIO(image_data)) 240 | finally: 241 | # remove temporary files 242 | os.remove(image_path) 243 | os.remove(mask_path) 244 | 245 | return image 246 | 247 | 248 | def main(): 249 | argsparse = argparse.ArgumentParser() 250 | argsparse.add_argument( 251 | "-w", "--workflow", type=str, help="Path to the workflow file" 252 | ) 253 | argsparse.add_argument("-i", "--image", type=str, help="Path to the image file") 254 | argsparse.add_argument("-m", "--mask", type=str, help="Path to the mask file") 255 | argsparse.add_argument( 256 | "-p", 257 | "--prompt", 258 | type=str, 259 | default="A beautiful landscape with a mountain in the background", 260 | help="Positive prompt", 261 | ) 262 | argsparse.add_argument( 263 | "-n", 264 | "--negative-prompt", 265 | type=str, 266 | default="out of focus", 267 | help="Negative prompt", 268 | ) 269 | argsparse.add_argument("-s", "--strength", default=0.8, type=float, help="Strength") 270 | argsparse.add_argument( 271 | "-e", "--steps", type=int, default=40, help="Number of steps for the sampler" 272 | ) 273 | args = argsparse.parse_args() 274 | 275 | server_address = "localhost:8188" 276 | 277 | image = Image.open(args.image) 278 | mask = Image.open(args.mask) 279 | 280 | image = inpainting_comfyui( 281 | server_address, 282 | args.workflow, 283 | image, 284 | mask, 285 | args.prompt, 286 | args.negative_prompt, 287 | args.strength, 288 | ) 289 | image.show() 290 | 291 | 292 | if __name__ == "__main__": 293 | main() 294 | -------------------------------------------------------------------------------- /parallax_maker/constants.py: -------------------------------------------------------------------------------- 1 | LOGS_DATA = "logs-data" 2 | 3 | ICON_DARK_MODE = "dark-mode-icon" 4 | BTN_DARK_MODE = "dark-mode-toggle" 5 | 6 | LOADING_UPLOAD = "loading-upload" 7 | LOADING_DEPTHMAP = "loading-depthmap" 8 | LOADING_GENERATE_SLICE = "loading-generate-slice" 9 | LOADING_GENERATE_INPAINTING = "generate-inpainting" 10 | LOADING_GLTF = "gltf-loading" 11 | LOADING_ANIMATION = "animation-loading" 12 | 13 | CTR_DEPTH_MAP = "depth-map-container" 14 | BTN_GENERATE_DEPTHMAP = "generate-depthmap-button" 15 | DEPTHMAP_OUTPUT = "gen-depthmap-output" 16 | ANIMATION_OUTPUT = "gen-animation-output" 17 | 18 | CANVAS_DATA = "canvas-data" 19 | CANVAS_MASK_DATA = "canvas-mask-data" 20 | CANVAS_PAINT = "canvas-paint" 21 | 22 | STORE_SELECTED_SLICE = "selected-slice" 23 | STORE_IGNORE = "store-ignore" 24 | STORE_GENERATE_SLICE = "generate-slice-request" 25 | STORE_UPDATE_SLICE = "update-slice-request" 26 | STORE_INPAINTING = "inpainting-request" 27 | STORE_APPSTATE_FILENAME = "application-state-filename" 28 | STORE_RESTORE_STATE = "restore-state" 29 | STORE_RECT_DATA = "rect-data" 30 | STORE_BOUNDING_BOX = "bounding-box" 31 | STORE_TRIGGER_GEN_DEPTHMAP = "trigger-generate-depthmap" 32 | STORE_TRIGGER_UPDATE_DEPTHMAP = "trigger-update-depthmap" 33 | STORE_UPDATE_THRESHOLD_CONTAINER = "update-thresholds-container" 34 | STORE_CLICKED_POINT = "clicked-point" 35 | STORE_CLEAR_PREVIEW = "clear-preview" 36 | 37 | TEXT_POSITIVE_PROMPT = "positive-prompt" 38 | TEXT_NEGATIVE_PROMPT = "negative-prompt" 39 | 40 | CTR_INPUT_IMAGE = "input-image-container" 41 | IMAGE = "image" 42 | CANVAS = "canvas" 43 | PREVIEW_CANVAS = "preview-canvas" 44 | UPLOAD_IMAGE = "upload-image" 45 | UPLOAD_STATE = "upload-state" 46 | UPLOAD_SLICE = "slice-upload" 47 | BTN_SAVE_STATE = "save-state" 48 | 49 | SLIDER_INPAINT_STRENGTH = "inpaint-strength" 50 | SLIDER_INPAINT_GUIDANCE = "inpaint-guidance" 51 | SLIDER_NUM_SLICES = "num-slices-slider" 52 | SLIDER_MASK_PADDING = "mask-padding-slider" 53 | SLIDER_MASK_BLUR = "mask-blur-slider" 54 | 55 | SLIDER_CAMERA_DISTANCE = "camera-distance-slider" 56 | SLIDER_MAX_DISTANCE = "max-distance-slider" 57 | SLIDER_FOCAL_LENGTH = "focal-length-slider" 58 | SLIDER_DISPLACEMENT = "displacement-slider" 59 | SLIDER_NUM_FRAMES = "number-of-frames-slider" 60 | 61 | DROPDOWN_DEPTH_MODEL = "depth-model-dropdown" 62 | DROPDOWN_INPAINT_MODEL = "inpainting-model-dropdown" 63 | DROPDOWN_MODE_SELECTOR = "mode-selector" 64 | 65 | INPUT_EXTERNAL_SERVER = "external-server-address" 66 | BTN_EXTERNAL_TEST_CONNECTION = "external-test-connection-button" 67 | 68 | UPLOAD_COMFYUI_WORKFLOW = "comfyui-workflow-upload" 69 | CTR_COMFYUI_WORKFLOW = "comfyui-workflow-container" 70 | CTR_AUTOMATIC_CONFIG = "automatic-config-container" 71 | 72 | INPUT_API_KEY = "api-key" 73 | BTN_VALIDATE_API_KEY = "validate-api-key" 74 | CTR_API_KEY = "api-key-container" 75 | 76 | DOWNLOAD_IMAGE = "download-image" 77 | DOWNLOAD_GLTF = "download-gltf" 78 | DOWNLOAD_ANIMATION = "download-animation" 79 | 80 | BTN_EXPORT_ANIMATION = "animation-export" 81 | 82 | CTR_THRESHOLDS = "thresholds-container" 83 | CTR_SLICE_IMAGES = "slice-img-container" 84 | 85 | BTN_GENERATE_INPAINTING = "generate-inpainting-button" 86 | BTN_FILL_INPAINTING = "fill-inpainting-button" 87 | BTN_ENHANCE = "enhance-button" 88 | BTN_ERASE_INPAINTING = "erase-inpainting-button" 89 | BTN_APPLY_INPAINTING = "apply-inpainting-button" 90 | CTR_INPAINTING_OUTPUT = "inpainting-output" 91 | CTR_INPAINTING_IMAGES = "inpainting-img-container" 92 | CTR_INPAINTING_DISPLAY = "inpainting-image-display" 93 | CTR_INPAINTING_COLUNM = "inpainting-column" 94 | 95 | ID_INPAINTING_IMAGE = "inpainting-image" 96 | 97 | ID_SLICE_DEPTH_DISPLAY = "depth-display" 98 | ID_SLICE_OVERLAY = "slicer-overlay" 99 | INPUT_SLICE_DEPTH = "depth-input" 100 | 101 | BTN_GENERATE_SLICE = "generate-slice-button" 102 | BTN_BALANCE_SLICE = "balance-slice-button" 103 | BTN_CREATE_SLICE = "create-slice-button" 104 | BTN_DELETE_SLICE = "delete-slice-button" 105 | BTN_ADD_SLICE = "add-to-slice-button" 106 | BTN_REMOVE_SLICE = "remove-from-slice-button" 107 | BTN_COPY_SLICE = "copy-button" 108 | BTN_PASTE_SLICE = "paste-button" 109 | 110 | BTN_CLEAR_CANVAS = "clear-canvas" 111 | BTN_ERASE_MODE = "erase-mode-canvas" 112 | BTN_LOAD_CANVAS = "load-canvas" 113 | 114 | CTR_CANVAS_BUTTONS = "canvas-buttons" 115 | NAV_ZOOM_OUT = "nav-zoom-out" 116 | NAV_UP = "nav-up" 117 | NAV_DOWN = "nav-down" 118 | NAV_LEFT = "nav-left" 119 | NAV_RIGHT = "nav-right" 120 | NAV_ZOOM_IN = "nav-zoom-in" 121 | NAV_RESET = "nav-reset" 122 | 123 | CTR_SEG_BUTTONS = "segmentation-buttons" 124 | SEG_TOGGLE_CHECKERBOARD = "toggle-checkerboard" 125 | SEG_INVERT_MASK = "invert-mask" 126 | SEG_FEATHER_MASK = "feather-mask" 127 | SEG_MULTI_POINT = "multi-point" 128 | SEG_MULTI_COMMIT = "multi-commit" 129 | 130 | CTR_PROGRESS_BAR = "progress-bar-container" 131 | PROGRESS_INTERVAL = "progress-interval" 132 | 133 | CTR_GLTF_OUTPUT = "gen-gltf-output" 134 | BTN_GLTF_CREATE = "gltf-create" 135 | BTN_GLTF_EXPORT = "gltf-export" 136 | BTN_UPSCALE_TEXTURES = "upscale-textures" 137 | 138 | CTR_MODEL_VIEWER = "model-viewer-container" 139 | IFRAME_MODEL_VIEWER = "model-viewer" 140 | 141 | CTR_HELP_WINDOW = "help-window" 142 | STORE_CURRENT_TAB = "current-tab" 143 | 144 | CHECKLIST_DOF = "toggle-dof-support" 145 | CHECKLIST_REGION_OF_INTEREST = "toggle-region-of-interest" 146 | -------------------------------------------------------------------------------- /parallax_maker/depth.py: -------------------------------------------------------------------------------- 1 | # (c) 2024 Niels Provos 2 | # 3 | """ 4 | Create Depth Maps from Images 5 | 6 | This module provides functionality to create depth maps from images using pre-trained deep learning models. 7 | The depth maps can be used to create parallax effects in images and videos. 8 | 9 | TODO: 10 | - Investigate https://huggingface.co/LiheYoung/depth_anything_vitl14 - also at https://huggingface.co/docs/transformers/main/model_doc/depth_anything 11 | """ 12 | 13 | 14 | from transformers import AutoImageProcessor, DPTForDepthEstimation 15 | import torch 16 | import cv2 17 | import numpy as np 18 | from PIL import Image 19 | 20 | from .utils import torch_get_device 21 | 22 | 23 | class DepthEstimationModel: 24 | MODELS = ["midas", "zoedepth", "dinov2"] 25 | 26 | def __init__(self, model="midas"): 27 | assert model in self.MODELS, f"Model {model} must be one of {self.MODELS}" 28 | self._model_name = model 29 | self.model = None 30 | self.transforms = None 31 | self.image_processor = None 32 | 33 | def __eq__(self, other): 34 | if not isinstance(other, DepthEstimationModel): 35 | return False 36 | return self._model_name == other._model_name 37 | 38 | @property 39 | def model_name(self): 40 | return self._model_name 41 | 42 | def load_model(self, progress_callback=None): 43 | load_pipeline = { 44 | "midas": create_medias_pipeline, 45 | "zoedepth": create_zoedepth_pipeline, 46 | "dinov2": create_dinov2_pipeline, 47 | } 48 | 49 | result = load_pipeline[self._model_name](progress_callback=progress_callback) 50 | 51 | if self._model_name == "midas": 52 | self.model, self.transforms = result 53 | elif self._model_name == "zoedepth": 54 | self.model = result 55 | elif self._model_name == "dinov2": 56 | self.model, self.image_processor = result 57 | 58 | def depth_map(self, image, progress_callback=None): 59 | if self.model is None: 60 | self.load_model() 61 | 62 | run_pipeline = { 63 | "midas": lambda img, cb: run_medias_pipeline( 64 | img, self.model, self.transforms, progress_callback=cb 65 | ), 66 | "zoedepth": lambda img, cb: run_zoedepth_pipeline( 67 | img, self.model, progress_callback=cb 68 | ), 69 | "dinov2": lambda img, cb: run_dinov2_pipeline( 70 | img, self.model, self.image_processor, progress_callback=cb 71 | ), 72 | } 73 | 74 | return run_pipeline[self._model_name](image, progress_callback) 75 | 76 | 77 | def create_dinov2_pipeline(progress_callback=None): 78 | image_processor = AutoImageProcessor.from_pretrained( 79 | "facebook/dpt-dinov2-large-nyu" 80 | ) 81 | model = DPTForDepthEstimation.from_pretrained("facebook/dpt-dinov2-large-nyu") 82 | model.to(torch_get_device()) 83 | return model, image_processor 84 | 85 | 86 | def run_dinov2_pipeline(image, model, image_processor, progress_callback=None): 87 | image = Image.fromarray(image) 88 | 89 | new_size = image.size 90 | if image.width > image.height: 91 | if image.width > 1024: 92 | new_size = (1024, int(image.height * 1024 / image.width)) 93 | else: 94 | if image.height > 1024: 95 | new_size = (int(image.width * 1024 / image.height), 1024) 96 | 97 | resized_image = image.convert("RGB").resize(new_size, Image.BICUBIC) 98 | inputs = image_processor(images=resized_image, return_tensors="pt") 99 | inputs = {k: v.to(torch_get_device()) for k, v in inputs.items()} 100 | with torch.no_grad(): 101 | outputs = model(**inputs) 102 | predicted_depth = outputs.predicted_depth 103 | 104 | # interpolate to original size 105 | prediction = torch.nn.functional.interpolate( 106 | predicted_depth.unsqueeze(1), 107 | size=image.size[::-1], 108 | mode="bicubic", 109 | align_corners=False, 110 | ) 111 | 112 | # visualize the prediction 113 | output = prediction.squeeze().cpu().numpy() 114 | formatted = (output * 255 / np.max(output)).astype("uint8") 115 | formatted[:, :] = 255 - formatted[:, :] # invert the depth map 116 | 117 | # resize to original size 118 | formatted = cv2.resize( 119 | formatted, (image.width, image.height), interpolation=cv2.INTER_CUBIC 120 | ) 121 | 122 | return formatted 123 | 124 | 125 | def create_medias_pipeline(progress_callback=None): 126 | """ 127 | Creates a media pipeline using the MiDaS model for depth estimation. 128 | 129 | Args: 130 | progress_callback (callable, optional): A callback function to report progress. Defaults to None. 131 | 132 | Returns: 133 | tuple: A tuple containing the MiDaS model and the transformation pipeline. 134 | 135 | """ 136 | # Load the MiDaS v2.1 model 137 | model_type = "DPT_Large" 138 | midas = torch.hub.load("intel-isl/MiDaS", model_type, skip_validation=True) 139 | 140 | if progress_callback: 141 | progress_callback(30, 100) 142 | 143 | # Set the model to evaluation mode 144 | midas.eval() 145 | 146 | # Define the transformation pipeline 147 | midas_transforms = torch.hub.load( 148 | "intel-isl/MiDaS", "transforms", skip_validation=True 149 | ) 150 | if model_type == "DPT_Large" or model_type == "DPT_Hybrid": 151 | transforms = midas_transforms.dpt_transform 152 | else: 153 | transforms = midas_transforms.small_transform 154 | 155 | if progress_callback: 156 | progress_callback(50, 100) 157 | 158 | # Set the device (CPU or GPU) 159 | midas.to(torch_get_device()) 160 | 161 | return midas, transforms 162 | 163 | 164 | def run_medias_pipeline(image, midas, transforms, progress_callback=None): 165 | """ 166 | Runs the media pipeline for segmentation. 167 | 168 | Args: 169 | image (numpy.ndarray): The input image. 170 | midas (torch.nn.Module): The MIDAS model. 171 | transforms (torchvision.transforms.Compose): The image transforms. 172 | progress_callback (callable, optional): A callback function to report progress. 173 | 174 | Returns: 175 | numpy.ndarray: The predicted segmentation mask. 176 | """ 177 | input_batch = transforms(image).to(torch_get_device()) 178 | with torch.no_grad(): 179 | prediction = midas(input_batch) 180 | 181 | prediction = torch.nn.functional.interpolate( 182 | prediction.unsqueeze(1), 183 | size=image.shape[:2], 184 | mode="bicubic", 185 | align_corners=False, 186 | ).squeeze() 187 | 188 | if progress_callback: 189 | progress_callback(90, 100) 190 | 191 | return prediction.cpu().numpy() 192 | 193 | 194 | def midas_depth_map(image, progress_callback=None): 195 | if progress_callback: 196 | progress_callback(0, 100) 197 | 198 | midas, transforms = create_medias_pipeline(progress_callback=progress_callback) 199 | 200 | depth_map = run_medias_pipeline( 201 | image, midas, transforms, progress_callback=progress_callback 202 | ) 203 | 204 | if progress_callback: 205 | progress_callback(100, 100) 206 | 207 | return depth_map 208 | 209 | 210 | def create_zoedepth_pipeline(progress_callback=None): 211 | # Triggers fresh download of MiDaS repo 212 | torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True) 213 | 214 | # Zoe_NK 215 | model_zoe_nk = torch.hub.load( 216 | "isl-org/ZoeDepth", "ZoeD_NK", pretrained=True, skip_validation=True 217 | ) 218 | 219 | # Set the device (CPU or GPU) 220 | device = torch_get_device() 221 | model_zoe_nk.to(device) 222 | 223 | if progress_callback: 224 | progress_callback(50, 100) 225 | 226 | return model_zoe_nk 227 | 228 | 229 | def run_zoedepth_pipeline(image, model_zoe_nk, progress_callback=None): 230 | depth_map = model_zoe_nk.infer_pil(image) # as numpy 231 | 232 | # invert the depth map since we are expecting the farthest objects to be black 233 | depth_map = 255 - depth_map 234 | 235 | if progress_callback: 236 | progress_callback(100, 100) 237 | 238 | return depth_map 239 | 240 | 241 | def zoedepth_depth_map(image, progress_callback=None): 242 | model_zoe_nk = create_zoedepth_pipeline(progress_callback=progress_callback) 243 | 244 | return run_zoedepth_pipeline( 245 | image, model_zoe_nk, progress_callback=progress_callback 246 | ) 247 | -------------------------------------------------------------------------------- /parallax_maker/falai.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # (c) 2024 Niels Provos 3 | 4 | import os 5 | from PIL import Image 6 | from io import BytesIO 7 | import fal_client 8 | import requests 9 | import argparse 10 | from typing import Optional, Tuple, Dict, Any 11 | 12 | 13 | class FalAI: 14 | """ 15 | Implements the fal.ai API for inpainting images using various models. 16 | 17 | Supports: 18 | - Foocus inpainting 19 | - Flux general inpainting 20 | - SD/SDXL inpainting 21 | """ 22 | 23 | # Model endpoints on fal.ai 24 | MODELS = { 25 | "foocus": "fal-ai/fooocus/inpaint", 26 | "flux-general": "fal-ai/flux-general/inpainting", 27 | "sd": "fal-ai/inpaint", # Supports both SD and SDXL 28 | "sdxl": "fal-ai/inpaint" 29 | } 30 | 31 | def __init__(self, api_key: str): 32 | """ 33 | Initialize the FalAI client with an API key. 34 | 35 | Args: 36 | api_key (str): The fal.ai API key 37 | """ 38 | self.api_key = api_key 39 | # Set the API key as environment variable for fal_client 40 | os.environ["FAL_KEY"] = api_key 41 | 42 | def validate_key(self) -> Tuple[bool, Optional[str]]: 43 | """ 44 | Validate the API key by making a simple request. 45 | 46 | Returns: 47 | Tuple[bool, Optional[str]]: (success, error_message) 48 | """ 49 | try: 50 | # Try a simple model call with minimal parameters 51 | result = fal_client.run( 52 | "fal-ai/fast-sdxl", 53 | arguments={ 54 | "prompt": "test", 55 | "image_size": "square_hd", 56 | "num_inference_steps": 1, 57 | "num_images": 1 58 | } 59 | ) 60 | return True, None 61 | except Exception as e: 62 | error_msg = str(e) 63 | if "401" in error_msg or "unauthorized" in error_msg.lower(): 64 | return False, "Invalid API key" 65 | return False, f"Connection error: {error_msg}" 66 | 67 | def _download_image(self, url: str) -> Image.Image: 68 | """Download an image from a URL and return as PIL Image.""" 69 | response = requests.get(url) 70 | response.raise_for_status() 71 | return Image.open(BytesIO(response.content)) 72 | 73 | def _prepare_image_for_upload(self, image: Image.Image) -> str: 74 | """ 75 | Prepare image for upload to fal.media CDN. 76 | 77 | Args: 78 | image: PIL Image to upload 79 | 80 | Returns: 81 | str: URL of uploaded image 82 | """ 83 | # Save image to bytes 84 | img_bytes = BytesIO() 85 | image.save(img_bytes, format="PNG") 86 | img_bytes.seek(0) 87 | 88 | # Upload to fal.media 89 | url = fal_client.upload(img_bytes.getvalue(), "image/png") 90 | return url 91 | 92 | def inpaint_foocus( 93 | self, 94 | image: Image.Image, 95 | mask: Image.Image, 96 | prompt: str, 97 | negative_prompt: str = "", 98 | strength: float = 0.8, 99 | guidance_scale: float = 7.5, 100 | num_inference_steps: int = 25, 101 | seed: int = -1 102 | ) -> Image.Image: 103 | """ 104 | Inpaint using Foocus model. 105 | 106 | Args: 107 | image: Input image 108 | mask: Mask indicating areas to inpaint 109 | prompt: Text prompt for generation 110 | negative_prompt: Negative prompt 111 | strength: Inpainting strength (0-1) 112 | guidance_scale: Guidance scale 113 | num_inference_steps: Number of inference steps 114 | seed: Random seed (-1 for random) 115 | 116 | Returns: 117 | Inpainted image 118 | """ 119 | # Upload images 120 | image_url = self._prepare_image_for_upload(image) 121 | mask_url = self._prepare_image_for_upload(mask) 122 | 123 | arguments = { 124 | "inpaint_image_url": image_url, 125 | "mask_image_url": mask_url, 126 | "prompt": prompt, 127 | "negative_prompt": negative_prompt, 128 | "guidance_scale": guidance_scale, 129 | "inpaint_strength": strength, 130 | "inpaint_respective_field": 0.618, # Default value from API docs 131 | "inpaint_engine": "v2.6", # Latest version 132 | "seed": seed if seed != -1 else None, 133 | "num_images": 1, 134 | "enable_safety_checker": False 135 | } 136 | 137 | result = fal_client.run( 138 | self.MODELS["foocus"], 139 | arguments=arguments 140 | ) 141 | 142 | # Download and return the first generated image 143 | image_url = result["images"][0]["url"] 144 | return self._download_image(image_url) 145 | 146 | def inpaint_flux_general( 147 | self, 148 | image: Image.Image, 149 | mask: Image.Image, 150 | prompt: str, 151 | negative_prompt: str = "", 152 | strength: float = 0.8, 153 | guidance_scale: float = 3.5, 154 | num_inference_steps: int = 28, 155 | seed: int = -1 156 | ) -> Image.Image: 157 | """ 158 | Inpaint using Flux general inpainting model. 159 | 160 | Args: 161 | image: Input image 162 | mask: Mask indicating areas to inpaint 163 | prompt: Text prompt for generation 164 | negative_prompt: Negative prompt (not used by Flux) 165 | strength: Inpainting strength (0-1) 166 | guidance_scale: Guidance scale 167 | num_inference_steps: Number of inference steps 168 | seed: Random seed (-1 for random) 169 | 170 | Returns: 171 | Inpainted image 172 | """ 173 | # Upload images 174 | image_url = self._prepare_image_for_upload(image) 175 | mask_url = self._prepare_image_for_upload(mask) 176 | 177 | arguments = { 178 | "image_url": image_url, 179 | "mask_url": mask_url, 180 | "prompt": prompt, 181 | "guidance_scale": guidance_scale, 182 | "num_inference_steps": num_inference_steps, 183 | "strength": strength, 184 | "seed": seed if seed != -1 else None, 185 | "num_images": 1, 186 | "enable_safety_checker": False 187 | } 188 | 189 | result = fal_client.run( 190 | self.MODELS["flux-general"], 191 | arguments=arguments 192 | ) 193 | 194 | # Download and return the first generated image 195 | image_url = result["images"][0]["url"] 196 | return self._download_image(image_url) 197 | 198 | def inpaint_sd( 199 | self, 200 | image: Image.Image, 201 | mask: Image.Image, 202 | prompt: str, 203 | negative_prompt: str = "", 204 | strength: float = 0.8, 205 | guidance_scale: float = 7.5, 206 | num_inference_steps: int = 50, 207 | model_type: str = "sdxl", 208 | seed: int = -1 209 | ) -> Image.Image: 210 | """ 211 | Inpaint using Stable Diffusion or SDXL model. 212 | 213 | Args: 214 | image: Input image 215 | mask: Mask indicating areas to inpaint 216 | prompt: Text prompt for generation 217 | negative_prompt: Negative prompt 218 | strength: Inpainting strength (0-1) 219 | guidance_scale: Guidance scale 220 | num_inference_steps: Number of inference steps 221 | model_type: "sd" or "sdxl" 222 | seed: Random seed (-1 for random) 223 | 224 | Returns: 225 | Inpainted image 226 | """ 227 | # Upload images 228 | image_url = self._prepare_image_for_upload(image) 229 | mask_url = self._prepare_image_for_upload(mask) 230 | 231 | arguments = { 232 | "image_url": image_url, 233 | "mask_url": mask_url, 234 | "prompt": prompt, 235 | "negative_prompt": negative_prompt, 236 | "guidance_scale": guidance_scale, 237 | "num_inference_steps": num_inference_steps, 238 | "strength": strength, 239 | "model": model_type, # "sd" or "sdxl" 240 | "seed": seed if seed != -1 else None, 241 | "num_images": 1, 242 | "enable_safety_checker": False, 243 | "sync_mode": True 244 | } 245 | 246 | result = fal_client.run( 247 | self.MODELS[model_type], 248 | arguments=arguments 249 | ) 250 | 251 | # Download and return the first generated image 252 | image_url = result["images"][0]["url"] 253 | return self._download_image(image_url) 254 | 255 | 256 | def main(): # pragma: no cover 257 | parser = argparse.ArgumentParser(description="Test fal.ai inpainting") 258 | parser.add_argument("--api-key", required=True, help="fal.ai API key") 259 | parser.add_argument("-i", "--image", required=True, help="Input image path") 260 | parser.add_argument("-m", "--mask", required=True, help="Mask image path") 261 | parser.add_argument( 262 | "-p", "--prompt", 263 | default="a beautiful landscape", 264 | help="Prompt for inpainting" 265 | ) 266 | parser.add_argument( 267 | "-n", "--negative-prompt", 268 | default="", 269 | help="Negative prompt" 270 | ) 271 | parser.add_argument( 272 | "--model", 273 | choices=["foocus", "flux-general", "sd", "sdxl"], 274 | default="sdxl", 275 | help="Model to use" 276 | ) 277 | parser.add_argument( 278 | "-s", "--strength", 279 | type=float, 280 | default=0.8, 281 | help="Inpainting strength" 282 | ) 283 | parser.add_argument( 284 | "-o", "--output", 285 | default="output_falai.png", 286 | help="Output filename" 287 | ) 288 | 289 | args = parser.parse_args() 290 | 291 | # Initialize client 292 | client = FalAI(args.api_key) 293 | 294 | # Validate API key 295 | success, error = client.validate_key() 296 | if not success: 297 | print(f"API key validation failed: {error}") 298 | return 299 | 300 | print("API key validated successfully") 301 | 302 | # Load images 303 | image = Image.open(args.image).convert("RGB") 304 | mask = Image.open(args.mask).convert("L") 305 | 306 | # Perform inpainting 307 | print(f"Inpainting with {args.model} model...") 308 | 309 | if args.model == "foocus": 310 | result = client.inpaint_foocus( 311 | image, mask, args.prompt, args.negative_prompt, 312 | strength=args.strength 313 | ) 314 | elif args.model == "flux-general": 315 | result = client.inpaint_flux_general( 316 | image, mask, args.prompt, args.negative_prompt, 317 | strength=args.strength 318 | ) 319 | else: # sd or sdxl 320 | result = client.inpaint_sd( 321 | image, mask, args.prompt, args.negative_prompt, 322 | strength=args.strength, model_type=args.model 323 | ) 324 | 325 | # Save result 326 | result.save(args.output) 327 | print(f"Saved result to {args.output}") 328 | 329 | 330 | if __name__ == "__main__": 331 | main() -------------------------------------------------------------------------------- /parallax_maker/gltf.py: -------------------------------------------------------------------------------- 1 | # (c) 2024 Niels Provos 2 | # 3 | # This file contains functions for creating and exporting glTF files. 4 | # We generate a glTF file representing a 3D scene with a camera, cards, and image slices. 5 | # The resulting file can be opened in a 3D application like Blender, Houdini or Unreal. 6 | # 7 | import base64 8 | 9 | import numpy as np 10 | import pygltflib as gltf 11 | from PIL import Image 12 | 13 | 14 | def rotation_quaternion_y(y_rot_degrees): 15 | """Calculates the rotation quaternion for a rotation around the y-axis. 16 | 17 | Args: 18 | y_rot_degrees: The rotation angle in degrees. 19 | 20 | Returns: 21 | A NumPy array representing the rotation quaternion (x, y, z, w). 22 | """ 23 | 24 | # Convert to radians and half the angle 25 | theta = np.radians(y_rot_degrees) / 2 26 | axis = np.array([0, 1, 0]) # Rotation around the y-axis 27 | 28 | quaternion = np.zeros(4) 29 | quaternion[:3] = axis * np.sin(theta) 30 | quaternion[3] = np.cos(theta) 31 | 32 | return quaternion.tolist() 33 | 34 | 35 | def create_camera( 36 | gltf_obj, focal_length, aspect_ratio, translation, rotation_quarternion 37 | ): 38 | """ 39 | Creates a camera in the glTF object with the specified parameters. 40 | 41 | Args: 42 | gltf_obj (gltf.Gltf): The glTF object to add the camera to. 43 | focal_length (float): The focal length of the camera. 44 | aspect_ratio (float): The aspect ratio of the camera. 45 | translation (List[float]): The translation of the camera node. 46 | rotation_quarternion (List[float]): The rotation of the camera node as a quaternion. 47 | 48 | Returns: 49 | int: The index of the created camera. 50 | 51 | """ 52 | camera_index = len(gltf_obj.cameras) 53 | 54 | sensor_width = 35.0 # Sensor width in mm 55 | sensor_height = sensor_width / aspect_ratio 56 | 57 | # Create the camera object 58 | camera = gltf.Camera( 59 | type="perspective", 60 | name=f"Camera_{camera_index}", 61 | perspective=gltf.Perspective( 62 | aspectRatio=aspect_ratio, 63 | yfov=2 * np.arctan(sensor_height / focal_length), 64 | znear=0.01, 65 | zfar=10000, 66 | ), 67 | ) 68 | gltf_obj.cameras.append(camera) 69 | 70 | # Create the camera node 71 | camera_node = gltf.Node( 72 | translation=translation, rotation=rotation_quarternion, camera=camera_index 73 | ) 74 | gltf_obj.nodes.append(camera_node) 75 | 76 | return camera_index 77 | 78 | 79 | def create_buffer_and_view(gltf_obj, data, target=gltf.ARRAY_BUFFER): 80 | """ 81 | Creates a buffer and buffer view in a glTF object. 82 | 83 | Args: 84 | gltf_obj (gltf.Gltf): The glTF object to add the buffer and buffer view to. 85 | data (numpy.ndarray): The data to be stored in the buffer. 86 | target (int): The target usage of the buffer view (default: gltf.ARRAY_BUFFER). 87 | 88 | Returns: 89 | int: the index of the created buffer view. 90 | """ 91 | tmp_buffer = gltf.Buffer( 92 | byteLength=data.nbytes, 93 | uri=f"data:application/octet-stream;base64,{base64.b64encode(data.tobytes()).decode()}", 94 | ) 95 | gltf_obj.buffers.append(tmp_buffer) 96 | tmp_buffer_index = len(gltf_obj.buffers) - 1 97 | 98 | tmp_buffer_view = gltf.BufferView( 99 | buffer=tmp_buffer_index, byteOffset=0, byteLength=data.nbytes, target=target 100 | ) 101 | gltf_obj.bufferViews.append(tmp_buffer_view) 102 | tmp_buffer_view_index = len(gltf_obj.bufferViews) - 1 103 | 104 | return tmp_buffer_view_index 105 | 106 | 107 | def subdivide_geometry(coords, subdivisions, dimension): 108 | """ 109 | Subdivides a plane into a grid with the specified number of subdivisions. Can handle both 3D and 2D geometries. 110 | 111 | Args: 112 | coords (numpy.ndarray): The corner coordinates of the geometry (3D for spatial coordinates, 2D for texture coordinates). 113 | subdivisions (int): The number of subdivisions to create. 114 | dimension (int): The dimension of the target points (3 for spatial coordinates, 2 for texture coordinates). 115 | 116 | Returns: 117 | numpy.ndarray: The corner coordinates of the subdivided geometry. 118 | """ 119 | x = np.linspace(coords[0, 0], coords[1, 0], subdivisions + 1, dtype=np.float32) 120 | y = np.linspace(coords[0, 1], coords[3, 1], subdivisions + 1, dtype=np.float32) 121 | x, y = np.meshgrid(x, y) 122 | 123 | if dimension == 3: 124 | z = np.zeros_like(x) 125 | points = np.stack([x, y, z], axis=-1) 126 | elif dimension == 2: 127 | points = np.stack([x, y], axis=-1) 128 | 129 | return points.reshape(-1, dimension) 130 | 131 | 132 | def triangle_indices_from_grid(vertices): 133 | """ 134 | Generates triangle indices for a grid of vertices. 135 | 136 | Args: 137 | vertices (numpy.ndarray): The 3D corner coordinates of the grid. 138 | 139 | Returns: 140 | numpy.ndarray: The triangle indices for the grid. 141 | """ 142 | # Calculate the number of vertices in each row 143 | row_length = int(np.sqrt(len(vertices))) 144 | 145 | # Create the indices for the triangles 146 | indices = [] 147 | for i in range(row_length - 1): 148 | for j in range(row_length - 1): 149 | # Calculate the indices for the current quad 150 | tl = i * row_length + j 151 | tr = tl + 1 152 | bl = (i + 1) * row_length + j 153 | br = bl + 1 154 | 155 | # Create the two triangles for the quad 156 | indices.append([tl, tr, bl]) 157 | indices.append([bl, tr, br]) 158 | 159 | return np.array(indices, dtype=np.uint32) 160 | 161 | 162 | def displace_vertices(vertices, depth_map, displacement_scale=10.0): 163 | """ 164 | Displaces the vertices of a plane based on a depth map. 165 | 166 | Args: 167 | vertices (numpy.ndarray): The 3D corner coordinates of the plane. 168 | depth_map (numpy.ndarray): The depth map to displace the vertices with. Normalized to [0, 1]. 169 | 170 | Returns: 171 | numpy.ndarray: The displaced vertices. 172 | """ 173 | # Get the dimensions of the depth map 174 | depth_map_width, depth_map_height = depth_map.shape 175 | 176 | # Calculate the texture coordinates for the vertices 177 | tex_coords = vertices[:, :2].copy() 178 | tex_coords[:, 1] = -tex_coords[:, 1] # flip Y-axis 179 | 180 | # normalize the texture coordinates to [0, 1] 181 | tex_min_x, tex_min_y = tex_coords.min(axis=0) 182 | tex_max_x, tex_max_y = tex_coords.max(axis=0) 183 | tex_coords -= [tex_min_x, tex_min_y] 184 | tex_coords /= [tex_max_x - tex_min_x, tex_max_y - tex_min_y] 185 | 186 | # Calculate the pixel coordinates for the texture coordinates 187 | pixel_coords = (tex_coords * [depth_map_height - 1, depth_map_width - 1]).astype( 188 | int 189 | ) 190 | 191 | # Get the depth values for the pixel coordinates 192 | depths = depth_map[pixel_coords[:, 1], pixel_coords[:, 0]] * displacement_scale 193 | 194 | # Displace the vertices based on the depth values 195 | vertices[:, 2] = depths 196 | 197 | return vertices 198 | 199 | 200 | def create_card( 201 | gltf_obj, i, corners_3d, subdivisions=300, depth_map=None, displacement_scale=0.0 202 | ): 203 | """ 204 | Creates a card (plane) in the glTF object with the specified parameters. 205 | 206 | Args: 207 | gltf_obj (gltf.Gltf): The glTF object to add the card to. 208 | i (int): The index of the card. 209 | corners_3d (numpy.ndarray): The 3D corner coordinates for the card. 210 | subdivisions (int, optional): The number of subdivisions for the card. Defaults to 300. 211 | depth_map (numpy.ndarray, optional): The depth map for the card. Defaults to None. 212 | displacement_scale (float, optional): The scale of the displacement. Defaults to 0.0. 213 | 214 | Returns: 215 | int: The index of the created mesh. 216 | """ 217 | # Set the vertices and indices for the plane 218 | 219 | # negate the y coordinates of corners_3d 220 | vertices = np.array(corners_3d, dtype=np.float32) 221 | vertices[:, 1] = -vertices[:, 1] 222 | 223 | # reorder the vertices of the 4 point plane 224 | tl = vertices[0] 225 | tr = vertices[1] 226 | bl = vertices[3] 227 | br = vertices[2] 228 | 229 | vertices = np.array([tl, tr, bl, br], dtype=np.float32) 230 | 231 | tex_coords = np.array([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=np.float32) 232 | 233 | if displacement_scale > 0.0 and depth_map is not None: 234 | vertices = subdivide_geometry(vertices, subdivisions, 3) 235 | vertices = displace_vertices( 236 | vertices, depth_map, displacement_scale=displacement_scale 237 | ) 238 | tex_coords = subdivide_geometry(tex_coords, subdivisions, 2) 239 | 240 | indices = triangle_indices_from_grid(vertices) 241 | 242 | # Create the buffer and buffer view for vertices 243 | vertex_bufferview_index = create_buffer_and_view( 244 | gltf_obj, vertices, target=gltf.ARRAY_BUFFER 245 | ) 246 | 247 | # Create the buffer and buffer view for texture coordinates 248 | tex_coord_bufferview_index = create_buffer_and_view( 249 | gltf_obj, tex_coords, target=gltf.ARRAY_BUFFER 250 | ) 251 | 252 | # Create the buffer and buffer view for indices 253 | index_bufferview_index = create_buffer_and_view( 254 | gltf_obj, indices, target=gltf.ELEMENT_ARRAY_BUFFER 255 | ) 256 | 257 | # Create the accessor for texture coordinates 258 | tex_coord_accessor = gltf.Accessor( 259 | bufferView=tex_coord_bufferview_index, 260 | componentType=gltf.FLOAT, 261 | count=len(tex_coords), 262 | type=gltf.VEC2, 263 | max=tex_coords.max(axis=0).tolist(), 264 | min=tex_coords.min(axis=0).tolist(), 265 | ) 266 | gltf_obj.accessors.append(tex_coord_accessor) 267 | tex_coord_accessor_index = len(gltf_obj.accessors) - 1 268 | 269 | # Create the accessor for vertices 270 | vertex_accessor = gltf.Accessor( 271 | bufferView=vertex_bufferview_index, 272 | componentType=gltf.FLOAT, 273 | count=len(vertices), 274 | type=gltf.VEC3, 275 | max=vertices.max(axis=0).tolist(), 276 | min=vertices.min(axis=0).tolist(), 277 | ) 278 | gltf_obj.accessors.append(vertex_accessor) 279 | vertex_accessor_index = len(gltf_obj.accessors) - 1 280 | 281 | # Create the accessor for indices 282 | index_accessor = gltf.Accessor( 283 | bufferView=index_bufferview_index, 284 | componentType=gltf.UNSIGNED_INT, 285 | count=indices.size, 286 | type=gltf.SCALAR, 287 | ) 288 | gltf_obj.accessors.append(index_accessor) 289 | index_accessor_index = len(gltf_obj.accessors) - 1 290 | 291 | card_name = f"Card_{i}" 292 | 293 | # Create the mesh for the plane 294 | mesh = gltf.Mesh( 295 | name=card_name, 296 | primitives=[ 297 | gltf.Primitive( 298 | attributes=gltf.Attributes( 299 | POSITION=vertex_accessor_index, 300 | TEXCOORD_0=tex_coord_accessor_index, 301 | ), 302 | indices=index_accessor_index, 303 | material=i, 304 | ) 305 | ], 306 | ) 307 | 308 | return mesh 309 | 310 | 311 | def export_gltf( 312 | output_path, 313 | cam, 314 | image_slices, 315 | image_paths, 316 | depth_paths=[], 317 | displacement_scale=0.0, 318 | inline_images=True, 319 | support_dof=False, 320 | ): 321 | """ 322 | Export the camera, cards, and image slices to a glTF file. 323 | 324 | Args: 325 | output_path (str): The path to save the glTF file. 326 | aspect_ratio (float): The aspect ratio of the camera. 327 | focal_length (float): The focal length of the camera. 328 | camera_distance (float): The distance of the camera from the origin. 329 | cam (Camera): The camera object for the scene. 330 | image_slices (list): List of 3D corner coordinates for each card. 331 | image_paths (list): List of file paths for each image slice. 332 | depth_paths (list, optional): List of file paths for each depth map. Defaults to []. 333 | displacement_scale (float, optional): The scale of the displacement. Defaults to 0.0. 334 | inline_images (bool, optional): Whether to inline the images in the glTF file. Defaults to True. 335 | """ 336 | 337 | # compute pre-requisites 338 | image_height, image_width = image_slices[0].image.shape[:2] 339 | camera_matrix = cam.camera_matrix(image_width, image_height) 340 | aspect_ratio = float(camera_matrix[0, 2]) / camera_matrix[1, 2] 341 | focal_length = cam.focal_length 342 | camera_distance = cam.camera_distance 343 | 344 | # Create a new glTF object 345 | gltf_obj = gltf.GLTF2(scene=0) 346 | 347 | # Create the scene 348 | scene = gltf.Scene() 349 | gltf_obj.scenes.append(scene) 350 | 351 | camera_index = create_camera( 352 | gltf_obj, 353 | focal_length, 354 | aspect_ratio, 355 | [0, 0, -camera_distance], 356 | rotation_quaternion_y(180), 357 | ) 358 | # Add the camera node to the scene 359 | scene.nodes.append(camera_index) 360 | 361 | subdivisions = 500 362 | 363 | alpha_mode = "MASK" if support_dof else "BLEND" 364 | 365 | # Create the card objects (planes) 366 | for i, image_slice in enumerate(image_slices): 367 | corners_3d = image_slice.create_card(image_height, image_width, cam) 368 | # Translaton hack so that we can put the depth on the node 369 | z_transform = corners_3d[0][2] 370 | corners_3d[:, 2] -= z_transform 371 | 372 | depth_map = None 373 | if len(depth_paths) > i: 374 | depth_map = Image.open(depth_paths[i]) 375 | width, height = depth_map.size 376 | depth_map = depth_map.resize( 377 | (subdivisions + 1, subdivisions + 1), Image.BICUBIC 378 | ) 379 | depth_map = depth_map.resize((width, height), Image.BICUBIC) 380 | depth_map = np.array(depth_map) 381 | depth_map = depth_map.astype(np.float32) / 255.0 382 | 383 | mesh = create_card( 384 | gltf_obj, 385 | i, 386 | corners_3d, 387 | subdivisions, 388 | depth_map, 389 | displacement_scale=displacement_scale, 390 | ) 391 | gltf_obj.meshes.append(mesh) 392 | 393 | # Create the material and assign the texture 394 | material = gltf.Material( 395 | name=f"Material_{i}", 396 | pbrMetallicRoughness=gltf.PbrMetallicRoughness( 397 | baseColorTexture=gltf.TextureInfo(index=i) 398 | ), 399 | # Set the emissive color (RGB values) 400 | emissiveFactor=[1.0, 1.0, 1.0], 401 | emissiveTexture=gltf.TextureInfo(index=i), 402 | alphaMode=alpha_mode, 403 | alphaCutoff=0.5 if alpha_mode == "MASK" else None, 404 | doubleSided=True, 405 | ) 406 | 407 | image = gltf.Image(uri=str(image_paths[i])) 408 | gltf_obj.images.append(image) 409 | 410 | texture = gltf.Texture( 411 | source=i, 412 | ) 413 | gltf_obj.textures.append(texture) 414 | 415 | gltf_obj.materials.append(material) 416 | 417 | # Create the card node and add it to the scene 418 | card_node = gltf.Node( 419 | mesh=i, 420 | translation=[0, 0, int(z_transform)], 421 | rotation=rotation_quaternion_y(180), 422 | ) 423 | gltf_obj.nodes.append(card_node) 424 | scene.nodes.append(len(gltf_obj.nodes) - 1) 425 | 426 | # Save the glTF file 427 | if inline_images: 428 | gltf_obj.convert_images(gltf.ImageFormat.DATAURI) 429 | else: 430 | gltf_obj.convert_images(gltf.ImageFormat.FILE) 431 | 432 | gltf_obj.save(str(output_path)) 433 | 434 | return str(output_path) 435 | -------------------------------------------------------------------------------- /parallax_maker/gltf_cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # (c) 2024 Niels Provos 3 | # 4 | 5 | import argparse 6 | import os 7 | from pathlib import Path 8 | 9 | from PIL import Image 10 | 11 | from .controller import AppState 12 | from .depth import DepthEstimationModel 13 | from .segmentation import generate_depth_map 14 | from .utils import postprocess_depth_map 15 | from .webui import export_state_as_gltf 16 | 17 | 18 | def compute_depth_map_for_slices(state: AppState, postprocess: bool = True): 19 | depth_maps = [] 20 | 21 | model_name = state.depth_model_name if state.depth_model_name else "midas" 22 | model = DepthEstimationModel(model=model_name) 23 | for i, image_slice in enumerate(state.image_slices): 24 | filename = image_slice.filename 25 | print(f"Processing {filename}") 26 | 27 | image = image_slice.image 28 | 29 | depth_map = generate_depth_map(image[:, :, :3], model) 30 | 31 | tmp_filename = state._make_filename(i, "depth_tmp") 32 | depth_image = Image.fromarray(depth_map) 33 | depth_image.save(tmp_filename, compress_level=9) 34 | 35 | if postprocess: 36 | image_alpha = image[:, :, 3] 37 | depth_map = postprocess_depth_map(depth_map, image_alpha, final_blur=50) 38 | 39 | depth_image = Image.fromarray(depth_map) 40 | 41 | output_filename = Path(state.filename) / (Path(filename).stem + "_depth.png") 42 | 43 | depth_image.save(output_filename, compress_level=1) 44 | print(f"Saved depth map to {output_filename}") 45 | 46 | depth_maps.append(output_filename) 47 | return depth_maps 48 | 49 | 50 | def main(): 51 | os.environ["DISABLE_TELEMETRY"] = "YES" 52 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 53 | 54 | # get arguments from the command line 55 | # -i name of the state file 56 | # -o output for the gltf file 57 | parser = argparse.ArgumentParser( 58 | description="Create a glTF file from the state file" 59 | ) 60 | parser.add_argument("-i", "--state_file", type=str, help="Path to the state file") 61 | parser.add_argument( 62 | "-o", 63 | "--output_path", 64 | type=str, 65 | default="output", 66 | help="Path to save the glTF file", 67 | ) 68 | parser.add_argument( 69 | "-n", 70 | "--no-inline", 71 | action="store_true", 72 | help="Do not inline images in the glTF file", 73 | ) 74 | parser.add_argument( 75 | "-d", "--depth", action="store_true", help="Compute depth maps for slices" 76 | ) 77 | parser.add_argument( 78 | "-s", "--scale", type=float, default=0.0, help="Displacement scale factor" 79 | ) 80 | args = parser.parse_args() 81 | 82 | state = AppState.from_file(args.state_file) 83 | 84 | output_path = Path(args.output_path) 85 | if not output_path.exists(): 86 | output_path.mkdir(parents=True) 87 | 88 | if args.depth: 89 | compute_depth_map_for_slices(state) 90 | 91 | gltf_path = export_state_as_gltf( 92 | state, 93 | args.output_path, 94 | state.camera, 95 | displacement_scale=args.scale, 96 | inline_images=not args.no_inline, 97 | support_dof=True, 98 | ) 99 | print(f"Exported glTF to {gltf_path}") 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /parallax_maker/instance.py: -------------------------------------------------------------------------------- 1 | # (c) 2024 Niels Provos 2 | # 3 | """ 4 | Provides instance segmentation functionality using pre-trained deep learning models. 5 | Segment Anything seems superior to Mask2Former. 6 | 7 | Supports two models: 8 | - Mask2Former: https://huggingface.co/facebook/mask2former-swin-large-coco-instance 9 | - SAM: https://huggingface.co/facebook/sam-vit-huge 10 | 11 | Todo: 12 | - Create support for HQ-SAM: https://github.com/SysCV/sam-hq?tab=readme-ov-file 13 | """ 14 | 15 | from PIL import Image 16 | import torch 17 | from transformers import ( 18 | AutoImageProcessor, 19 | Mask2FormerForUniversalSegmentation, 20 | SamModel, 21 | SamProcessor, 22 | ) 23 | import numpy as np 24 | 25 | from .utils import torch_get_device, image_overlay, draw_circle 26 | 27 | 28 | class SegmentationModel: 29 | MODELS = ["mask2former", "sam"] 30 | 31 | def __init__(self, model="sam"): 32 | assert model in self.MODELS 33 | self.model_name = model 34 | self.model = None 35 | self.image_processor = None 36 | self.image = None 37 | self.mask = None 38 | 39 | def __eq__(self, other): 40 | if not isinstance(other, SegmentationModel): 41 | return False 42 | return self.model_name == other.model_name 43 | 44 | def load_model(self): 45 | load_model = { 46 | "mask2former": self.load_mask2former_model, 47 | "sam": self.load_sam_model, 48 | } 49 | 50 | result = load_model[self.model_name]() 51 | 52 | if self.model_name == "mask2former": 53 | self.model, self.image_processor = result 54 | elif self.model_name == "sam": 55 | self.model, self.image_processor = result 56 | 57 | @staticmethod 58 | def load_mask2former_model(): 59 | image_processor = AutoImageProcessor.from_pretrained( 60 | "facebook/mask2former-swin-large-coco-instance" 61 | ) 62 | model = Mask2FormerForUniversalSegmentation.from_pretrained( 63 | "facebook/mask2former-swin-large-coco-instance" 64 | ) 65 | model.to(torch_get_device()) 66 | return model, image_processor 67 | 68 | @staticmethod 69 | def load_sam_model(): 70 | model = SamModel.from_pretrained("facebook/sam-vit-huge").to(torch_get_device()) 71 | processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") 72 | return model, processor 73 | 74 | def segment_image(self, image): 75 | self.image = image 76 | if isinstance(image, np.ndarray): 77 | self.image = Image.fromarray(image) 78 | self.image = self.image.convert("RGB") 79 | 80 | run_pipeline = { 81 | "mask2former": self.segment_image_mask2former, 82 | "sam": self.segment_image_sam, 83 | } 84 | 85 | self.mask = run_pipeline[self.model_name]() 86 | return self.mask 87 | 88 | def _get_mask_at_point_function(self): 89 | needs_mask = {"mask2former": True, "sam": False} 90 | if needs_mask[self.model_name] and self.mask is None: 91 | return None 92 | 93 | run_pipeline = { 94 | "mask2former": self.mask_at_point_mask2former, 95 | "sam": self.mask_at_point_sam, 96 | } 97 | return run_pipeline[self.model_name] 98 | 99 | def mask_at_point(self, point_xy): 100 | executor = self._get_mask_at_point_function() 101 | if executor is None: 102 | return None 103 | return executor(point_xy) 104 | 105 | def mask_at_point_blended(self, point_xy): 106 | executor = self._get_mask_at_point_function() 107 | 108 | # Apply segmentation function to each transformed image 109 | transforms = [ 110 | "identity", 111 | "rotate_90", 112 | "rotate_180", 113 | "rotate_270", 114 | "flip_h", 115 | "flip_v", 116 | ] 117 | computed_masks = [] 118 | image = self.image.copy() 119 | for transformation in transforms: 120 | transformed_image = SegmentationModel._transform_image( 121 | image, transformation 122 | ) 123 | transformed_point = SegmentationModel._transform_point( 124 | point_xy, transformation, image.size 125 | ) 126 | self.segment_image(transformed_image) 127 | mask = executor(transformed_point) 128 | if mask is not None: 129 | mask = SegmentationModel._inverse_transform(mask, transformation) 130 | mask_size = (mask.shape[1], mask.shape[0]) 131 | assert ( 132 | mask_size == image.size 133 | ), f"Mask size {mask_size} does not match image size {image.size} for transformation {transformation}" 134 | computed_masks.append(mask) 135 | 136 | self.image = image 137 | 138 | computed_masks = SegmentationModel._filter_mask(computed_masks) 139 | blended_mask = np.array(computed_masks[0]).astype(np.float16) 140 | for mask in computed_masks[1:]: 141 | # print the max mask value in the console 142 | blended_mask += np.array(mask).astype(np.float16) 143 | 144 | blended_mask /= len(computed_masks) 145 | blended_mask = blended_mask.astype(np.uint8) 146 | 147 | return blended_mask 148 | 149 | def segment_image_mask2former(self): 150 | if self.model is None: 151 | self.load_model() 152 | 153 | inputs = self.image_processor(self.image, size=1280, return_tensors="pt").to( 154 | self.model.device 155 | ) 156 | with torch.no_grad(): 157 | outputs = self.model(**inputs) 158 | 159 | pred_map = self.image_processor.post_process_instance_segmentation( 160 | outputs, target_sizes=[self.image.size[::-1]] 161 | )[0] 162 | self.mask = pred_map["segmentation"] 163 | return self.mask 164 | 165 | def mask_at_point_mask2former(self, point_xy): 166 | width, height = self.image.size 167 | assert point_xy[0] < width and point_xy[1] < height 168 | 169 | label = int(self.mask[point_xy[1], point_xy[0]]) 170 | if label == -1: 171 | return None 172 | 173 | mask = (self.mask == label).numpy().astype(np.uint8) 174 | mask *= 255 175 | 176 | return mask 177 | 178 | def segment_image_sam(self): 179 | # this is a no-op for sam as the model needs to be guided by an input point 180 | return None 181 | 182 | def mask_at_point_sam(self, point_input): 183 | """ 184 | Masks the given points using the SAM algorithm. 185 | 186 | Args: 187 | point_input (tuple, list, dict): The input points to be masked. It can be one of the following: 188 | - tuple: A single point. 189 | - list of tuples: Multiple points. 190 | - dict with 'positive_points' and 'negative_points' keys: Lists of positive and negative points. 191 | 192 | Returns: 193 | list: The masked points. 194 | 195 | Raises: 196 | ValueError: If the input is invalid. 197 | 198 | """ 199 | positive_points = [] 200 | negative_points = [] 201 | if isinstance(point_input, tuple): 202 | # Single point case 203 | positive_points = [point_input] 204 | elif isinstance(point_input, list) and all( 205 | isinstance(item, tuple) for item in point_input 206 | ): 207 | # List of points case 208 | positive_points = point_input 209 | elif ( 210 | isinstance(point_input, dict) 211 | and "positive_points" in point_input 212 | and "negative_points" in point_input 213 | ): 214 | # Lists of positive and negative points case 215 | positive_points = point_input["positive_points"] 216 | negative_points = point_input["negative_points"] 217 | else: 218 | raise ValueError("Invalid input for mask_at_point function.") 219 | return self._mask_at_point_sam(positive_points, negative_points) 220 | 221 | def _mask_at_point_sam(self, positive_points, negative_points): 222 | if self.model is None: 223 | self.load_model() 224 | 225 | input_points = positive_points + negative_points 226 | input_labels = [1] * len(positive_points) + [0] * len(negative_points) 227 | inputs = self.image_processor( 228 | self.image, 229 | input_points=[input_points], 230 | input_labels=[input_labels], 231 | return_tensors="pt", 232 | ) 233 | # convert inputs to dtype torch.float32 234 | inputs = inputs.to(torch.float32).to(self.model.device) 235 | 236 | with torch.no_grad(): 237 | outputs = self.model(**inputs) 238 | 239 | # the model is capabale of returning multiple masks, but we only return one for now 240 | masks = self.image_processor.image_processor.post_process_masks( 241 | outputs.pred_masks.cpu(), 242 | inputs["original_sizes"].cpu(), 243 | inputs["reshaped_input_sizes"].cpu(), 244 | ) 245 | scores = outputs.iou_scores.cpu() 246 | 247 | mask = masks[0].numpy() 248 | 249 | # format the mask for visualization 250 | if len(mask.shape) == 4: 251 | mask = mask.squeeze() 252 | if scores.shape[0] == 1: 253 | scores = scores.squeeze() 254 | 255 | # 3 predictions at this point? 256 | # find the index where scores has the highest value 257 | index = int(torch.argmax(scores)) 258 | mask = mask[index] 259 | 260 | h, w = mask.shape[-2:] 261 | mask_image = mask.reshape(h, w) 262 | mask_image = (mask_image * 255).astype(np.uint8) 263 | 264 | return mask_image 265 | 266 | # Function to rotate point around the image center for specific angles 267 | @staticmethod 268 | def _rotate_point(point, angle, image_size): 269 | assert angle in [0, 90, 180, 270] 270 | ox, oy = image_size[0] // 2, image_size[1] // 2 271 | x, y = point[0] - ox, point[1] - oy 272 | 273 | if angle == 90: 274 | new_x, new_y = oy - y, x + ox 275 | elif angle == 180: 276 | new_x, new_y = ox - x, oy - y 277 | elif angle == 270: 278 | new_x, new_y = y + oy, ox - x 279 | else: 280 | new_x, new_y = point 281 | 282 | return int(new_x), int(new_y) 283 | 284 | @staticmethod 285 | def _transform_image(image, transformation): 286 | if transformation == "rotate_90": 287 | return image.rotate(-90, expand=True) 288 | elif transformation == "rotate_180": 289 | return image.rotate(-180, expand=True) 290 | elif transformation == "rotate_270": 291 | return image.rotate(-270, expand=True) 292 | elif transformation == "flip_h": 293 | return image.transpose(Image.FLIP_LEFT_RIGHT) 294 | elif transformation == "flip_v": 295 | return image.transpose(Image.FLIP_TOP_BOTTOM) 296 | else: 297 | return image 298 | 299 | @staticmethod 300 | def _transform_point(point, transformation, image_size): 301 | if isinstance(point, tuple): 302 | return SegmentationModel._transform_point_single( 303 | point, transformation, image_size 304 | ) 305 | elif isinstance(point, list) and all(isinstance(item, tuple) for item in point): 306 | transformed_points = [] 307 | for p in point: 308 | transformed_points.append( 309 | SegmentationModel._transform_point_single( 310 | p, transformation, image_size 311 | ) 312 | ) 313 | return transformed_points 314 | elif ( 315 | isinstance(point, dict) 316 | and "positive_points" in point 317 | and "negative_points" in point 318 | ): 319 | positive_points = point["positive_points"] 320 | negative_points = point["negative_points"] 321 | transformed_points = {"positive_points": [], "negative_points": []} 322 | for p in positive_points: 323 | transformed_points["positive_points"].append( 324 | SegmentationModel._transform_point_single( 325 | p, transformation, image_size 326 | ) 327 | ) 328 | for n in negative_points: 329 | transformed_points["negative_points"].append( 330 | SegmentationModel._transform_point_single( 331 | n, transformation, image_size 332 | ) 333 | ) 334 | return transformed_points 335 | else: 336 | raise ValueError("Invalid input for _transform_point function.") 337 | 338 | @staticmethod 339 | def _transform_point_single(point, transformation, image_size): 340 | if transformation == "rotate_90": 341 | return SegmentationModel._rotate_point(point, 90, image_size) 342 | elif transformation == "rotate_180": 343 | return SegmentationModel._rotate_point(point, 180, image_size) 344 | elif transformation == "rotate_270": 345 | return SegmentationModel._rotate_point(point, 270, image_size) 346 | elif transformation == "flip_h": 347 | return (image_size[0] - point[0], point[1]) 348 | elif transformation == "flip_v": 349 | return (point[0], image_size[1] - point[1]) 350 | else: 351 | return point 352 | 353 | @staticmethod 354 | def _inverse_transform(mask, transformation): 355 | mask = Image.fromarray(mask) 356 | if transformation == "rotate_90": 357 | result = mask.rotate(90, expand=True) 358 | elif transformation == "rotate_180": 359 | result = mask.rotate(180, expand=True) 360 | elif transformation == "rotate_270": 361 | result = mask.rotate(270, expand=True) 362 | elif transformation == "flip_h": 363 | result = mask.transpose(Image.FLIP_LEFT_RIGHT) 364 | elif transformation == "flip_v": 365 | result = mask.transpose(Image.FLIP_TOP_BOTTOM) 366 | else: 367 | result = mask 368 | 369 | result = np.array(result) 370 | 371 | return result 372 | 373 | @staticmethod 374 | def _filter_mask(masks): 375 | selected_pixels = [] 376 | for mask in masks: 377 | # compute the number of pixels in the mask 378 | num_pixels = np.sum(mask) 379 | selected_pixels.append(num_pixels) 380 | 381 | # compute the median and standard deviation of the number of pixels 382 | median_pixels = np.median(selected_pixels) 383 | std_pixels = np.std(selected_pixels) 384 | 385 | print(f"Median pixels: {median_pixels}, Std pixels: {std_pixels}") 386 | 387 | # filter out masks that have less than the median number of pixels 388 | filtered_masks = [] 389 | for mask in masks: 390 | num_pixels = np.sum(mask) 391 | if ( 392 | num_pixels >= median_pixels - std_pixels 393 | and num_pixels <= median_pixels + std_pixels 394 | ): 395 | filtered_masks.append(mask) 396 | else: 397 | print(f"Filtering mask with {num_pixels} pixels") 398 | 399 | return filtered_masks 400 | 401 | 402 | if __name__ == "__main__": 403 | filename = "input.jpg" 404 | 405 | model = SegmentationModel() 406 | image = Image.open(filename) 407 | 408 | # Apply segmentation function to original image 409 | point_xy = (500, 600) 410 | negative_point_xy = (200, 600) 411 | mask = model.segment_image(image) 412 | 413 | point_input = { 414 | "positive_points": [point_xy], 415 | "negative_points": [negative_point_xy], 416 | } 417 | 418 | mask_at_point = model.mask_at_point_blended(point_input) 419 | mask = Image.fromarray(mask_at_point).convert("RGB") 420 | 421 | result = image_overlay(image, mask) 422 | 423 | # draw a circle at the point on image 424 | draw_circle(result, point_xy, 20, fill_color=(0, 255, 0)) 425 | draw_circle(result, negative_point_xy, 20, fill_color=(255, 0, 0)) 426 | 427 | result.save("segmented_image.png") 428 | -------------------------------------------------------------------------------- /parallax_maker/setup_dev.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Setup script for Parallax Maker development and packaging. 4 | """ 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | import shutil 10 | from pathlib import Path 11 | 12 | 13 | def run_command(cmd, description): 14 | """Run a command and handle errors.""" 15 | print(f"Running: {description}") 16 | print(f"Command: {cmd}") 17 | result = subprocess.run(cmd, shell=True, capture_output=True, text=True) 18 | if result.returncode != 0: 19 | print(f"Error: {result.stderr}") 20 | return False 21 | print(f"Success: {result.stdout}") 22 | return True 23 | 24 | 25 | def clean_build(): 26 | """Clean previous build artifacts.""" 27 | print("Cleaning previous build artifacts...") 28 | dirs_to_remove = ["dist", "build", "*.egg-info"] 29 | for pattern in dirs_to_remove: 30 | for path in Path(".").glob(pattern): 31 | if path.is_dir(): 32 | shutil.rmtree(path) 33 | print(f"Removed directory: {path}") 34 | else: 35 | path.unlink() 36 | print(f"Removed file: {path}") 37 | 38 | 39 | def build_package(): 40 | """Build the package.""" 41 | print("Building package...") 42 | return run_command("python -m build", "Building wheel and source distribution") 43 | 44 | 45 | def check_package(): 46 | """Check the built package.""" 47 | print("Checking package...") 48 | return run_command("python -m twine check dist/*", "Checking package integrity") 49 | 50 | 51 | def install_dev_tools(): 52 | """Install development tools.""" 53 | print("Installing development tools...") 54 | tools = ["build", "twine", "pytest", "black", "flake8"] 55 | for tool in tools: 56 | if not run_command(f"pip install {tool}", f"Installing {tool}"): 57 | return False 58 | return True 59 | 60 | 61 | def run_tests(): 62 | """Run tests.""" 63 | print("Running tests...") 64 | return run_command("python -m pytest", "Running test suite") 65 | 66 | 67 | def format_code(): 68 | """Format code with black.""" 69 | print("Formatting code...") 70 | return run_command("python -m black .", "Formatting code with black") 71 | 72 | 73 | def main(): 74 | """Main setup function.""" 75 | if len(sys.argv) < 2: 76 | print("Usage: python setup_dev.py ") 77 | print("Commands:") 78 | print(" install-tools - Install development tools") 79 | print(" clean - Clean build artifacts") 80 | print(" build - Build the package") 81 | print(" check - Check the package") 82 | print(" test - Run tests") 83 | print(" format - Format code") 84 | print(" full-build - Clean, build, and check") 85 | print(" dev-setup - Full development setup") 86 | sys.exit(1) 87 | 88 | command = sys.argv[1] 89 | 90 | if command == "install-tools": 91 | install_dev_tools() 92 | elif command == "clean": 93 | clean_build() 94 | elif command == "build": 95 | build_package() 96 | elif command == "check": 97 | check_package() 98 | elif command == "test": 99 | run_tests() 100 | elif command == "format": 101 | format_code() 102 | elif command == "full-build": 103 | clean_build() 104 | if build_package(): 105 | check_package() 106 | elif command == "dev-setup": 107 | install_dev_tools() 108 | format_code() 109 | run_tests() 110 | clean_build() 111 | if build_package(): 112 | check_package() 113 | print("\n✅ Development setup complete!") 114 | print("📦 Package built successfully!") 115 | print("\nNext steps:") 116 | print("1. Test install: pip install dist/*.whl") 117 | print( 118 | "2. Upload to TestPyPI: python -m twine upload --repository testpypi dist/*" 119 | ) 120 | print("3. Upload to PyPI: python -m twine upload dist/*") 121 | else: 122 | print(f"Unknown command: {command}") 123 | sys.exit(1) 124 | 125 | 126 | if __name__ == "__main__": 127 | main() 128 | -------------------------------------------------------------------------------- /parallax_maker/slice.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from pathlib import Path 3 | from PIL import Image 4 | import numpy as np 5 | 6 | from .utils import filename_add_version, filename_previous_version 7 | from .camera import Camera 8 | 9 | 10 | class ImageSlice: 11 | __slots__ = ( 12 | "_image", 13 | "_depth", 14 | "_filename", 15 | "_is_ground_plane", 16 | "positive_prompt", 17 | "negative_prompt", 18 | ) 19 | 20 | def __init__( 21 | self, 22 | image=None, 23 | depth=-1, 24 | filename=None, 25 | positive_prompt="", 26 | negative_prompt="", 27 | ): 28 | self.image = image 29 | self._depth = depth 30 | self._filename = filename 31 | self._is_ground_plane = False 32 | 33 | self.positive_prompt = positive_prompt 34 | self.negative_prompt = negative_prompt 35 | 36 | @property 37 | def image(self): 38 | return self._image 39 | 40 | @image.setter 41 | def image(self, value): 42 | if ( 43 | not isinstance(value, np.ndarray) 44 | and not isinstance(value, Image.Image) 45 | and value is not None 46 | ): 47 | raise ValueError("image must be a np.ndarray, Image.Image object or None") 48 | if isinstance(value, Image.Image): 49 | print( 50 | "WARNING: ImageSlice.image is being set with a PIL Image object. This is not recommended." 51 | ) 52 | value = np.array(value) 53 | self._image = value 54 | 55 | @property 56 | def is_ground_plane(self): 57 | return self._is_ground_plane 58 | 59 | @is_ground_plane.setter 60 | def is_ground_plane(self, value): 61 | if not isinstance(value, bool): 62 | raise ValueError("is_ground_plane must be a boolean value") 63 | self._is_ground_plane = value 64 | 65 | @property 66 | def depth(self): 67 | return self._depth 68 | 69 | @depth.setter 70 | def depth(self, value): 71 | if not isinstance(value, (float, int)) or value < -1: 72 | raise ValueError( 73 | "depth must be a non-negative number or -1 for unknown depth" 74 | ) 75 | self._depth = value 76 | 77 | @property 78 | def filename(self): 79 | return self._filename 80 | 81 | @filename.setter 82 | def filename(self, value): 83 | if ( 84 | not isinstance(value, Path) 85 | and not isinstance(value, str) 86 | and value is not None 87 | ): 88 | raise ValueError("filename must be a Path, str object or None") 89 | self._filename = value 90 | 91 | def __eq__(self, other): 92 | if not isinstance(other, ImageSlice): 93 | return False 94 | return ( 95 | np.array_equal(self.image, other.image) 96 | and self.depth == other.depth 97 | and self.filename == other.filename 98 | ) 99 | 100 | @staticmethod 101 | def _dimension_at_depth(z: float, image_height: int, image_width: int, cam: Camera): 102 | fl_px = cam.focal_length_px(image_width) 103 | card_width = (image_width * (z + cam.camera_distance)) / fl_px 104 | card_height = (image_height * (z + cam.camera_distance)) / fl_px 105 | return card_width, card_height 106 | 107 | def _depth_to_z(self, depth: float, cam: Camera): 108 | return cam.max_distance * ((255 - depth) / 255.0) 109 | 110 | def create_card(self, image_height: int, image_width: int, cam: Camera): 111 | if self.is_ground_plane: 112 | near_z = 0 113 | far_z = cam.max_distance 114 | 115 | near_width, near_height = self._dimension_at_depth( 116 | near_z, image_height, image_width, cam 117 | ) 118 | far_width, far_height = self._dimension_at_depth( 119 | far_z, image_height, image_width, cam 120 | ) 121 | 122 | card_corners_3d = np.array( 123 | [ 124 | [-far_width / 2, -far_height / 2, far_z], 125 | [far_width / 2, -far_height / 2, far_z], 126 | [near_width / 2, near_height / 2, near_z + 2 * far_z], # NO IDEA 127 | [-near_width / 2, near_height / 2, near_z + 2 * far_z], 128 | ], 129 | dtype=np.float32, 130 | ) 131 | else: 132 | z = self._depth_to_z(self.depth, cam) 133 | 134 | # Calculate the 3D points of the card corners 135 | card_width, card_height = self._dimension_at_depth( 136 | z, image_height, image_width, cam 137 | ) 138 | 139 | card_corners_3d = np.array( 140 | [ 141 | [-card_width / 2, -card_height / 2, z], 142 | [card_width / 2, -card_height / 2, z], 143 | [card_width / 2, card_height / 2, z], 144 | [-card_width / 2, card_height / 2, z], 145 | ], 146 | dtype=np.float32, 147 | ) 148 | 149 | return card_corners_3d 150 | 151 | def save_image(self): 152 | slice_image = self.image 153 | if not isinstance(slice_image, Image.Image): 154 | slice_image = Image.fromarray(slice_image, mode="RGBA") 155 | output_image_path = self.filename 156 | print(f"Saving image slice: {output_image_path}") 157 | slice_image.save(str(output_image_path)) 158 | 159 | def read_image(self): 160 | img = cv2.imread(str(self.filename), cv2.IMREAD_UNCHANGED) 161 | self.image = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA) 162 | 163 | def new_version(self, image=None, save=True): 164 | """ 165 | Creates a new version of the image and saves it if specified. 166 | 167 | Args: 168 | image (np.ndarray, optional): The new image to be set. Defaults to None. 169 | save (bool, optional): Whether to save the new image. Defaults to True. 170 | 171 | Returns: 172 | str: The filename of the new version of the image. 173 | """ 174 | if image is not None: 175 | assert isinstance(image, np.ndarray) 176 | self.image = image 177 | self.filename = filename_add_version(self.filename) 178 | if save: 179 | self.save_image() 180 | return self.filename 181 | 182 | def can_undo(self, forward=False): 183 | """ 184 | Check if it is possible to undo the specified image slice. 185 | 186 | Args: 187 | forward (bool, optional): If True, check for the next version of the slice. 188 | If False, check for the previous version. Defaults to False. 189 | 190 | Returns: 191 | bool: True if the specified slice version exists, False otherwise. 192 | """ 193 | if forward: 194 | filename = filename_add_version(self.filename) 195 | else: 196 | filename = filename_previous_version(self.filename) 197 | 198 | if filename is None: 199 | return False 200 | 201 | return Path(filename).exists() 202 | 203 | def undo(self, forward=False): 204 | """ 205 | Undo the specified image slice. 206 | 207 | Args: 208 | forward (bool, optional): If True, undo the next version of the slice. 209 | If False, undo the previous version. Defaults to False. 210 | 211 | Returns: 212 | bool: True if the undo operation is successful, False otherwise. 213 | """ 214 | if not self.can_undo(forward): 215 | return False 216 | 217 | if forward: 218 | filename = filename_add_version(self.filename) 219 | else: 220 | filename = filename_previous_version(self.filename) 221 | 222 | if filename is None: 223 | return False 224 | 225 | self.filename = filename 226 | self.read_image() 227 | return True 228 | -------------------------------------------------------------------------------- /parallax_maker/stabilityai.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import requests 4 | import argparse 5 | 6 | 7 | class StabilityAI: 8 | """ 9 | Implements the Stability AI API for generating and inpainting images. 10 | 11 | This API has several limitations: 12 | - Inpainting masks cannot be feathered or have soft edges. Consequently, the inpainting does not work with low strength. 13 | """ 14 | 15 | MAX_PIXELS = 9437184 16 | 17 | def __init__(self, api_key): 18 | self.api_key = api_key 19 | 20 | def validate_key(self): 21 | response = requests.get( 22 | "https://api.stability.ai/v1/user/balance", 23 | headers={"Authorization": f"Bearer {self.api_key}"}, 24 | ) 25 | 26 | if response.status_code != 200: 27 | return False, None 28 | 29 | # Do something with the payload... 30 | payload = response.json() 31 | assert "credits" in payload, f"Invalid API response: {payload}." 32 | return True, payload["credits"] 33 | 34 | def generate_image( 35 | self, prompt, negative_prompt="", aspect_ratio="16:9", output_format="png" 36 | ): 37 | """ 38 | Generates an image using the Stability AI API. 39 | 40 | Args: 41 | prompt (str): The main text prompt for generating the image. 42 | negative_prompt (str, optional): The negative text prompt for generating the image. Defaults to ''. 43 | aspect_ratio (str, optional): The aspect ratio of the generated image. Defaults to "16:9". 44 | output_format (str, optional): The output format of the generated image. Defaults to "png". 45 | 46 | Returns: 47 | PIL.Image.Image: The generated image as a PIL Image object. 48 | 49 | Raises: 50 | Exception: If the API response status code is not 200. 51 | """ 52 | response = requests.post( 53 | f"https://api.stability.ai/v2beta/stable-image/generate/core", 54 | headers={"authorization": f"Bearer {self.api_key}", "accept": "image/*"}, 55 | files={"none": ""}, 56 | data={ 57 | "mode": "text-to-image", 58 | "prompt": prompt, 59 | "negative_prompt": negative_prompt, 60 | "aspect_ratio": aspect_ratio, 61 | "output_format": output_format, 62 | }, 63 | ) 64 | 65 | if response.status_code != 200: 66 | raise Exception(str(response.json())) 67 | 68 | return Image.open(BytesIO(response.content)) 69 | 70 | def image_to_image( 71 | self, image, prompt, negative_prompt="", strength=0.8, output_format="png" 72 | ): 73 | """ 74 | Generates an image-to-image request using the Stability AI API. 75 | 76 | Args: 77 | image (PIL.Image.Image): The input image to be transformed. 78 | prompt (str): The prompt to guide the inpainting process. 79 | negative_prompt (str, optional): The negative prompt to guide the inpainting process. Defaults to ''. 80 | strength (float, optional): The strength of the transformation effect. Defaults to 0.8. 81 | output_format (str, optional): The format of the output image. Defaults to "png". 82 | 83 | Returns: 84 | PIL.Image.Image: The inpainted image. 85 | 86 | Raises: 87 | ValueError: If the image is smaller than 64x64 pixels. 88 | Exception: If the API request fails. 89 | 90 | """ 91 | # check that the image it at least 64x64 92 | if image.width < 64 or image.height < 64: 93 | raise ValueError("Image must be at least 64x64") 94 | 95 | orig_width, orig_height = image.width, image.height 96 | 97 | image = self._resize_image(image) 98 | 99 | # prepare image and mask data 100 | image_data = BytesIO() 101 | image.save(image_data, format="PNG") 102 | image_data = image_data.getvalue() 103 | 104 | response = requests.post( 105 | f"https://api.stability.ai/v2beta/stable-image/generate/sd3", 106 | headers={"authorization": f"Bearer {self.api_key}", "accept": "image/*"}, 107 | files={"image": image_data}, 108 | data={ 109 | "prompt": prompt, 110 | "negative_prompt": negative_prompt, 111 | "mode": "image-to-image", 112 | "strength": strength, 113 | "output_format": output_format, 114 | }, 115 | ) 116 | 117 | if response.status_code != 200: 118 | raise Exception(str(response.json())) 119 | 120 | image = Image.open(BytesIO(response.content)) 121 | if orig_width != image.width or orig_height != image.height: 122 | image = image.resize((orig_width, orig_height), Image.LANCZOS) 123 | return image 124 | 125 | def _resize_image(self, image, max_pixels=MAX_PIXELS): 126 | if image.width * image.height > max_pixels: 127 | # square root as with multiply the ratio twice below 128 | ratio = (max_pixels / (image.width * image.height)) ** 0.5 129 | new_width = int(image.width * ratio) 130 | new_height = int(image.height * ratio) 131 | image = image.resize((new_width, new_height), Image.LANCZOS) 132 | return image 133 | 134 | def inpaint_image( 135 | self, image, mask, prompt, negative_prompt="", strength=1.0, output_format="png" 136 | ): 137 | """ 138 | Inpaints an image using the Stability AI API. 139 | 140 | Args: 141 | image (PIL.Image.Image): The input image to be inpainted. 142 | mask (PIL.Image.Image): The mask indicating the areas to be inpainted. 143 | prompt (str): The prompt to guide the inpainting process. 144 | negative_prompt (str, optional): The negative prompt to guide the inpainting process. Defaults to ''. 145 | strength (float, optional): The strength of the inpainting effect. Defaults to 1.0. 146 | output_format (str, optional): The format of the output image. Defaults to "png". 147 | 148 | Returns: 149 | PIL.Image.Image: The inpainted image. 150 | 151 | Raises: 152 | ValueError: If the image is smaller than 64x64 pixels. 153 | Exception: If the API request fails. 154 | 155 | """ 156 | # check that the image it at least 64x64 157 | if image.width < 64 or image.height < 64: 158 | raise ValueError("Image must be at least 64x64") 159 | 160 | orig_width, orig_height = image.width, image.height 161 | 162 | image = self._resize_image(image) 163 | mask = self._resize_image(mask) 164 | 165 | # scale the mask by strength - the API does not work well with low strength 166 | if strength < 1.0: 167 | mask = mask.point(lambda p: p * strength) 168 | 169 | # prepare image and mask data 170 | image_data = BytesIO() 171 | image.save(image_data, format="PNG") 172 | image_data = image_data.getvalue() 173 | mask_data = BytesIO() 174 | mask.save(mask_data, format="PNG") 175 | mask_data = mask_data.getvalue() 176 | 177 | response = requests.post( 178 | f"https://api.stability.ai/v2beta/stable-image/edit/inpaint", 179 | headers={"authorization": f"Bearer {self.api_key}", "accept": "image/*"}, 180 | files={"image": image_data, "mask": mask_data}, 181 | data={ 182 | "prompt": prompt, 183 | "negative_prompt": negative_prompt, 184 | "output_format": output_format, 185 | }, 186 | ) 187 | 188 | if response.status_code != 200: 189 | raise Exception(str(response.json())) 190 | 191 | image = Image.open(BytesIO(response.content)) 192 | if orig_width != image.width or orig_height != image.height: 193 | image = image.resize((orig_width, orig_height), Image.LANCZOS) 194 | return image 195 | 196 | def upscale_image(self, image, prompt, negative_prompt="", output_format="png"): 197 | # check that the image it at least 64x64 198 | if image.width < 64 or image.height < 64: 199 | raise ValueError("Image must be at least 64x64") 200 | 201 | if image.width * image.height > 1024 * 1024: 202 | raise ValueError( 203 | f"Image must be at most 1024x1024: Got {image.width}x{image.height} = {image.width * image.height} pixels." 204 | ) 205 | 206 | # prepare image and mask data 207 | image_data = BytesIO() 208 | image.save(image_data, format="PNG") 209 | image_data = image_data.getvalue() 210 | 211 | response = requests.post( 212 | f"https://api.stability.ai/v2beta/stable-image/upscale/conservative", 213 | headers={"authorization": f"Bearer {self.api_key}", "accept": "image/*"}, 214 | files={"image": image_data}, 215 | data={ 216 | "prompt": prompt, 217 | "negative_prompt": negative_prompt, 218 | "output_format": output_format, 219 | }, 220 | ) 221 | 222 | if response.status_code != 200: 223 | raise Exception(str(response.json())) 224 | 225 | upscaled_image = Image.open(BytesIO(response.content)) 226 | return upscaled_image 227 | 228 | 229 | def main(): # pragma: no cover 230 | parser = argparse.ArgumentParser() 231 | parser.add_argument("--api-key", required=True) 232 | parser.add_argument( 233 | "-p", 234 | "--prompt", 235 | default="an intelligent AI robot on the surface of the moon; intelligent and cute glowing eyes; " 236 | + "surrounded by canyons and rocks; dystopian atmosphere; scfi vibe; bladerunner and cyberpunk; " 237 | + "distant stars visible in the sky behind the mountains; photorealistic; everything in sharp focus", 238 | ) 239 | parser.add_argument("--aspect-ratio", default="16:9") 240 | parser.add_argument("--output-format", default="png") 241 | parser.add_argument("-i", "--image", type=str, default=None) 242 | parser.add_argument("-m", "--mask", type=str, default=None) 243 | parser.add_argument("-s", "--strength", type=float, default=0.8) 244 | parser.add_argument("-u", "--upscale", action="store_true") 245 | args = parser.parse_args() 246 | 247 | ai = StabilityAI(args.api_key) 248 | success, _ = ai.validate_key() 249 | if not success: 250 | print("Invalid API key.") 251 | return 252 | if args.image and args.upscale: 253 | image = Image.open(args.image) 254 | image = ai._resize_image(image, 1024 * 1024) 255 | image = ai.upscale_image(image, args.prompt) 256 | elif args.image and args.mask: 257 | image = Image.open(args.image) 258 | mask = Image.open(args.mask) 259 | 260 | image = ai.inpaint_image(image, mask, args.prompt, strength=1.0) 261 | elif args.image: 262 | print("Running image to image") 263 | image = Image.open(args.image) 264 | image = ai.image_to_image(image, args.prompt, strength=args.strength) 265 | else: 266 | image = ai.generate_image( 267 | args.prompt, 268 | aspect_ratio=args.aspect_ratio, 269 | output_format=args.output_format, 270 | ) 271 | image.save("output.png") 272 | 273 | 274 | if __name__ == "__main__": 275 | main() 276 | -------------------------------------------------------------------------------- /parallax_maker/test_automatic1111.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch 3 | from PIL import Image 4 | from .automatic1111 import create_img2img_payload, make_models_request 5 | 6 | 7 | class TestCreateImg2ImgPayload(unittest.TestCase): 8 | def setUp(self): 9 | self.input_image = Image.new("RGB", (100, 100)) 10 | self.positive_prompt = "Generate a beautiful image" 11 | self.negative_prompt = "Avoid using bright colors" 12 | self.mask_image = Image.new("L", (100, 100)) 13 | self.strength = 0.8 14 | self.steps = 20 15 | self.cfg_scale = 8.0 16 | 17 | def test_create_img2img_payload_without_mask(self): 18 | with patch( 19 | "parallax_maker.automatic1111.to_image_url", 20 | return_value="data:image/png;base64,", 21 | ) as mock_to_image_url: 22 | payload = create_img2img_payload( 23 | self.input_image, 24 | self.positive_prompt, 25 | self.negative_prompt, 26 | mask_image=None, 27 | strength=self.strength, 28 | steps=self.steps, 29 | cfg_scale=self.cfg_scale, 30 | ) 31 | 32 | self.assertEqual(payload["init_images"], ["data:image/png;base64,"]) 33 | self.assertEqual(payload["prompt"], self.positive_prompt) 34 | self.assertEqual(payload["negative_prompt"], self.negative_prompt) 35 | self.assertEqual(payload["denoising_strength"], self.strength) 36 | self.assertEqual(payload["width"], 100) 37 | self.assertEqual(payload["height"], 100) 38 | self.assertEqual(payload["steps"], self.steps) 39 | self.assertEqual(payload["cfg_scale"], self.cfg_scale) 40 | self.assertNotIn("mask", payload) 41 | 42 | def test_create_img2img_payload_with_mask(self): 43 | with patch( 44 | "parallax_maker.automatic1111.to_image_url", 45 | return_value="data:image/png;base64,", 46 | ) as mock_to_image_url: 47 | payload = create_img2img_payload( 48 | self.input_image, 49 | self.positive_prompt, 50 | self.negative_prompt, 51 | mask_image=self.mask_image, 52 | strength=self.strength, 53 | steps=self.steps, 54 | cfg_scale=self.cfg_scale, 55 | ) 56 | 57 | self.assertEqual(payload["init_images"], ["data:image/png;base64,"]) 58 | self.assertEqual(payload["prompt"], self.positive_prompt) 59 | self.assertEqual(payload["negative_prompt"], self.negative_prompt) 60 | self.assertEqual(payload["denoising_strength"], self.strength) 61 | self.assertEqual(payload["width"], 100) 62 | self.assertEqual(payload["height"], 100) 63 | self.assertEqual(payload["steps"], self.steps) 64 | self.assertEqual(payload["cfg_scale"], self.cfg_scale) 65 | self.assertEqual(payload["mask"], "data:image/png;base64,") 66 | 67 | 68 | class TestMakeModelsRequest(unittest.TestCase): 69 | @patch("parallax_maker.automatic1111.requests.get") 70 | def test_make_models_request(self, mock_get): 71 | mock_response = [ 72 | { 73 | "title": "sd_xl_base_1.0.safetensors [31e35c80fc]", 74 | "model_name": "sd_xl_base_1.0", 75 | "hash": "31e35c80fc", 76 | "sha256": "31e35c80fc4829d14f90153f4c74cd59c90b779f6afe05a74cd6120b893f7e5b", 77 | "filename": "models/Stable-diffusion/sd_xl_base_1.0.safetensors", 78 | "config": None, 79 | }, 80 | { 81 | "title": "sd_xl_refiner_1.0.safetensors [7440042bbd]", 82 | "model_name": "sd_xl_refiner_1.0", 83 | "hash": "7440042bbd", 84 | "sha256": "7440042bbdc8a24813002c09b6b69b64dc90fded4472613437b7f55f9b7d9c5f", 85 | "filename": "models/Stable-diffusion/sd_xl_refiner_1.0.safetensors", 86 | "config": None, 87 | }, 88 | { 89 | "title": "v2-1_768-ema-pruned.ckpt [ad2a33c361]", 90 | "model_name": "v2-1_768-ema-pruned", 91 | "hash": "ad2a33c361", 92 | "sha256": "ad2a33c361c1f593c4a1fb32ea81afce2b5bb7d1983c6b94793a26a3b54b08a0", 93 | "filename": "models/Stable-diffusion/v2-1_768-ema-pruned.ckpt", 94 | "config": None, 95 | }, 96 | ] 97 | 98 | mock_get.return_value.json.return_value = mock_response 99 | 100 | server_address = "localhost:7860" 101 | models = make_models_request(server_address) 102 | 103 | self.assertEqual( 104 | models, ["sd_xl_base_1.0", "sd_xl_refiner_1.0", "v2-1_768-ema-pruned"] 105 | ) 106 | mock_get.assert_called_once_with( 107 | url="http://localhost:7860/sdapi/v1/sd-models", timeout=3 108 | ) 109 | 110 | 111 | if __name__ == "__main__": 112 | unittest.main() 113 | -------------------------------------------------------------------------------- /parallax_maker/test_components.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock 3 | 4 | import io 5 | import base64 6 | from PIL import Image 7 | 8 | import dash 9 | from dash import html 10 | 11 | from . import constants as C 12 | import numpy as np 13 | from .components import ( 14 | make_inpainting_container_callbacks, 15 | make_inpainting_container, 16 | make_configuration_container, 17 | ) 18 | from .slice import ImageSlice 19 | 20 | 21 | class TestUpdateInpaintingImageDisplay(unittest.TestCase): 22 | def setUp(self): 23 | # Setup a minimal Dash app 24 | self.app = dash.Dash(__name__) 25 | self.app.layout = html.Div( 26 | [ 27 | make_inpainting_container(), 28 | make_configuration_container(), 29 | ] 30 | ) 31 | 32 | # Save the callback 33 | self._update_inpainting_image_display = make_inpainting_container_callbacks( 34 | self.app 35 | ) 36 | 37 | @patch("parallax_maker.components.ctx", new_callable=MagicMock) 38 | @patch("parallax_maker.components.Image.open") 39 | @patch("parallax_maker.components.Image.fromarray") 40 | @patch("parallax_maker.components.Path") 41 | @patch("parallax_maker.controller.AppState.from_cache") 42 | @patch("parallax_maker.inpainting.InpaintingModel") 43 | def test_callback_triggered_comfyui( 44 | self, 45 | mock_model, 46 | mock_from_cache, 47 | mock_path, 48 | mock_image_open, 49 | mock_image_fromarray, 50 | mock_callback_context, 51 | ): # Mock callback context 52 | mock_callback_context.triggered_id = C.BTN_GENERATE_INPAINTING 53 | 54 | # Setup test data 55 | n_clicks = 1 56 | filename = "test_filename" 57 | model = "comfyui" 58 | server_address = "http://localhost:8000" 59 | workflow = "data:filetype;base64,workflow_data" 60 | positive_prompt = "sky is clear" 61 | negative_prompt = "cloudy sky" 62 | strength = 0.7 63 | guidance_scale = 5.0 64 | padding = 10 65 | blur = 5 66 | 67 | # Workflow path mock 68 | workflow_path = MagicMock() 69 | workflow_path.exists.return_value = True 70 | 71 | # Set up the AppState mock 72 | state = MagicMock() 73 | state.selected_slice = ( 74 | 0 # Ensure this matches the state expected by the function 75 | ) 76 | state.mask_filename.return_value = "mask_0.png" 77 | state.image_slices = [ImageSlice(np.zeros((100, 100, 4)))] 78 | state.workflow_path.return_value = workflow_path 79 | mock_from_cache.return_value = state 80 | 81 | # Mock mask Path exists 82 | mock_path.return_value.exists.return_value = True 83 | 84 | # Return a mock Image when open succeeds 85 | mock_img = Image.new("L", (100, 100)) 86 | mock_image_open.return_value = mock_img 87 | 88 | # Setup InpaintingModel mock 89 | pipeline = MagicMock() 90 | mock_model.return_value = pipeline 91 | pipeline.inpaint.return_value = "painted_image" # Mock the result of inpainting 92 | 93 | # Return the final image 94 | mock_image_fromarray.return_value = Image.new("RGBA", (100, 100)) 95 | 96 | # Call the function under test 97 | results = self._update_inpainting_image_display( 98 | n_clicks, 99 | None, 100 | None, 101 | filename, 102 | model, 103 | workflow, 104 | positive_prompt, 105 | negative_prompt, 106 | strength, 107 | guidance_scale, 108 | padding, 109 | blur, 110 | ) 111 | 112 | # Verify outputs and behaviors 113 | pipeline.load_model.assert_called_once() # Check that model was loaded 114 | self.assertEqual(len(results[0]), 3) # Assuming 3 images are expected 115 | for i, img in enumerate(results[0]): 116 | id = img.id 117 | self.assertDictEqual(id, {"type": C.ID_INPAINTING_IMAGE, "index": i}) 118 | img_data = img.src 119 | # Assuming you have some format to represent the image 120 | img = Image.open(io.BytesIO(base64.b64decode(img_data.split(",")[1]))) 121 | self.assertEqual(img.size, (100, 100)) 122 | 123 | # make sure the second return is a list 124 | self.assertIsInstance(results[1], list) 125 | # If this is expected to be an empty list 126 | self.assertEqual(len(results[1]), 0) 127 | 128 | @patch("parallax_maker.components.ctx", new_callable=MagicMock) 129 | @patch("parallax_maker.components.Image.open") 130 | @patch("parallax_maker.components.Image.fromarray") 131 | @patch("parallax_maker.components.Path") 132 | @patch("parallax_maker.controller.AppState.from_cache") 133 | @patch("parallax_maker.inpainting.InpaintingModel") 134 | def test_callback_triggered_normal( 135 | self, 136 | mock_model, 137 | mock_from_cache, 138 | mock_path, 139 | mock_image_open, 140 | mock_image_fromarray, 141 | mock_callback_context, 142 | ): 143 | # Mock callback context 144 | mock_callback_context.triggered_id = C.BTN_GENERATE_INPAINTING 145 | 146 | # Setup test data 147 | n_clicks = 1 148 | filename = "test_filename" 149 | model = "other_model" 150 | server_address = None 151 | workflow = None 152 | positive_prompt = "sky is clear" 153 | negative_prompt = "cloudy sky" 154 | strength = 0.7 155 | guidance_scale = 5.0 156 | padding = 10 157 | blur = 5 158 | 159 | # Workflow path mock 160 | workflow_path = MagicMock() 161 | workflow_path.exists.return_value = False 162 | 163 | # Set up the AppState mock 164 | state = MagicMock() 165 | state.selected_slice = ( 166 | 0 # Ensure this matches the state expected by the function 167 | ) 168 | state.mask_filename.return_value = "mask_0.png" 169 | state.image_slices = [ImageSlice(np.zeros((100, 100, 4)))] 170 | state.workflow_path.return_value = workflow_path 171 | mock_from_cache.return_value = state 172 | 173 | # Mock mask Path exists 174 | mock_path.return_value.exists.return_value = True 175 | 176 | # Return a mock Image when open succeeds 177 | mock_img = Image.new("L", (100, 100)) 178 | mock_image_open.return_value = mock_img 179 | 180 | # Setup InpaintingModel mock 181 | pipeline = MagicMock() 182 | mock_model.return_value = pipeline 183 | pipeline.inpaint.return_value = "painted_image" # Mock the result of inpainting 184 | 185 | # Return the final image 186 | mock_image_fromarray.return_value = Image.new("RGBA", (100, 100)) 187 | 188 | # Call the function under test 189 | results = self._update_inpainting_image_display( 190 | n_clicks, 191 | None, 192 | None, 193 | filename, 194 | model, 195 | workflow, 196 | positive_prompt, 197 | negative_prompt, 198 | strength, 199 | guidance_scale, 200 | padding, 201 | blur, 202 | ) 203 | 204 | # Verify outputs and behaviors 205 | pipeline.load_model.assert_called_once() # Check that model was loaded 206 | self.assertEqual(len(results[0]), 3) # Assuming 3 images are expected 207 | for i, img in enumerate(results[0]): 208 | id = img.id 209 | self.assertDictEqual(id, {"type": C.ID_INPAINTING_IMAGE, "index": i}) 210 | img_data = img.src 211 | # Assuming you have some format to represent the image 212 | img = Image.open(io.BytesIO(base64.b64decode(img_data.split(",")[1]))) 213 | self.assertEqual(img.size, (100, 100)) 214 | 215 | # make sure the second return is a list 216 | self.assertIsInstance(results[1], list) 217 | # If this is expected to be an empty list 218 | self.assertEqual(len(results[1]), 0) 219 | 220 | 221 | if __name__ == "__main__": 222 | unittest.main() 223 | -------------------------------------------------------------------------------- /parallax_maker/test_constants.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from . import constants as C 3 | 4 | 5 | class TestConstants(unittest.TestCase): 6 | 7 | def test_unique_constants(self): 8 | constants_dict = { 9 | name: value for name, value in vars(C).items() if isinstance(value, str) 10 | } 11 | values = list(constants_dict.values()) 12 | unique_values = set(values) 13 | 14 | if len(values) != len(unique_values): 15 | # Find non-unique values 16 | duplicates = set([value for value in values if values.count(value) > 1]) 17 | non_unique_constants = { 18 | name: value 19 | for name, value in constants_dict.items() 20 | if value in duplicates 21 | } 22 | self.fail(f"String constants are not unique: {non_unique_constants}") 23 | 24 | 25 | if __name__ == "__main__": 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /parallax_maker/test_gltf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import unittest 3 | from .gltf import subdivide_geometry, triangle_indices_from_grid 4 | 5 | 6 | class TestSubdivideGeometry(unittest.TestCase): 7 | def test_subdivide_2d(self): 8 | coords = np.array([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=np.float32) 9 | subdivisions = 2 10 | dimension = 2 11 | expected = np.array( 12 | [ 13 | [0.0, 0.0], 14 | [0.5, 0.0], 15 | [1.0, 0.0], 16 | [0.0, 0.5], 17 | [0.5, 0.5], 18 | [1.0, 0.5], 19 | [0.0, 1.0], 20 | [0.5, 1.0], 21 | [1.0, 1.0], 22 | ], 23 | dtype=np.float32, 24 | ) 25 | result = subdivide_geometry(coords, subdivisions, dimension) 26 | np.testing.assert_array_equal(result, expected) 27 | 28 | def test_subdivide_3d(self): 29 | coords = np.array( 30 | [[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], dtype=np.float32 31 | ) 32 | subdivisions = 2 33 | dimension = 3 34 | expected = np.array( 35 | [ 36 | [0.0, 0.0, 0.0], 37 | [0.5, 0.0, 0.0], 38 | [1.0, 0.0, 0.0], 39 | [0.0, 0.5, 0.0], 40 | [0.5, 0.5, 0.0], 41 | [1.0, 0.5, 0.0], 42 | [0.0, 1.0, 0.0], 43 | [0.5, 1.0, 0.0], 44 | [1.0, 1.0, 0.0], 45 | ], 46 | dtype=np.float32, 47 | ) 48 | result = subdivide_geometry(coords, subdivisions, dimension) 49 | np.testing.assert_array_equal(result, expected) 50 | 51 | 52 | class TestTriangleIndicesFromGrid(unittest.TestCase): 53 | def test_triangle_indices_from_grid(self): 54 | vertices = np.array( 55 | [[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], dtype=np.float32 56 | ) 57 | expected = np.array([[0, 1, 2], [2, 1, 3]], dtype=np.uint32) 58 | result = triangle_indices_from_grid(vertices) 59 | np.testing.assert_array_equal(result, expected) 60 | 61 | 62 | if __name__ == "__main__": 63 | unittest.main() 64 | -------------------------------------------------------------------------------- /parallax_maker/test_inpainting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import unittest 3 | 4 | from .inpainting import find_nearest_alpha 5 | 6 | 7 | class TestFindNearestAlpha(unittest.TestCase): 8 | def test_nearest_alpha_with_full_alpha(self): 9 | alpha = np.ones((3, 3)) * 255 10 | nearest_alpha = np.full((3, 3, 4, 2), -1, dtype=int) 11 | expected = np.zeros((3, 3, 4, 2)) 12 | for i in range(3): 13 | for j in range(3): 14 | expected[i, j, :, :] = [[i, j], [i, j], [i, j], [i, j]] 15 | find_nearest_alpha(alpha, nearest_alpha) 16 | np.testing.assert_array_equal(nearest_alpha, expected) 17 | 18 | def test_nearest_alpha_with_partial_alpha(self): 19 | alpha = np.array([[255, 0, 255], [255, 255, 0], [0, 255, 255]]) 20 | nearest_alpha = np.full((3, 3, 4, 2), -1, dtype=int) 21 | expected = np.zeros((3, 3, 4, 2)) 22 | expected[0, 0, :, :] = [[0, 0], [0, 0], [0, 0], [0, 0]] 23 | expected[0, 1, :, :] = [[-1, -1], [0, 0], [1, 1], [0, 2]] 24 | expected[0, 2, :, :] = [[0, 2], [0, 2], [0, 2], [0, 2]] 25 | expected[1, 0, :, :] = [[1, 0], [1, 0], [1, 0], [1, 0]] 26 | expected[1, 1, :, :] = [[1, 1], [1, 1], [1, 1], [1, 1]] 27 | expected[1, 2, :, :] = [[0, 2], [1, 1], [2, 2], [-1, -1]] 28 | expected[2, 0, :, :] = [[1, 0], [-1, -1], [-1, -1], [2, 1]] 29 | expected[2, 1, :, :] = [[2, 1], [2, 1], [2, 1], [2, 1]] 30 | expected[2, 2, :, :] = [[2, 2], [2, 2], [2, 2], [2, 2]] 31 | find_nearest_alpha(alpha, nearest_alpha) 32 | np.testing.assert_array_equal(nearest_alpha, expected) 33 | 34 | 35 | if __name__ == "__main__": 36 | unittest.main() 37 | -------------------------------------------------------------------------------- /parallax_maker/test_instance.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock 3 | import numpy as np 4 | from PIL import Image 5 | 6 | from .instance import SegmentationModel 7 | 8 | 9 | class TestSegmentationModel(unittest.TestCase): 10 | 11 | def setUp(self): 12 | self.model = SegmentationModel("sam") 13 | self.mock_image = Image.new("RGB", (120, 90), color="white") 14 | self.mock_point = (50, 50) 15 | self.mock_points = [(50, 50), (60, 60), (70, 70)] 16 | self.mock_mask = np.zeros((90, 120), np.uint8) 17 | self.mock_mask[40:60, 40:60] = 255 18 | 19 | patcher = patch.object( 20 | SegmentationModel, "segment_image_mask2former", return_value=self.mock_mask 21 | ) 22 | self.mock_segment_image = patcher.start() 23 | self.addCleanup(patcher.stop) 24 | 25 | patcher2 = patch.object( 26 | SegmentationModel, 27 | "_get_mask_at_point_function", 28 | return_value=lambda point: self.mock_mask, 29 | ) 30 | self.mock_get_mask_at_point_function = patcher2.start() 31 | self.addCleanup(patcher2.stop) 32 | 33 | def test_transform(self): 34 | transformations = [ 35 | "identity", 36 | "rotate_90", 37 | "rotate_180", 38 | "rotate_270", 39 | "flip_h", 40 | "flip_v", 41 | ] 42 | for transformation in transformations: 43 | transformed_image = self.model._transform_image( 44 | self.mock_image, transformation 45 | ) 46 | transformed_point = self.model._transform_point( 47 | self.mock_point, transformation, self.mock_image.size 48 | ) 49 | self.assertIsInstance(transformed_image, Image.Image) 50 | self.assertIsInstance(transformed_point, tuple) 51 | self.assertEqual(len(transformed_point), 2) 52 | 53 | def test_rotate_point_single(self): 54 | angles = [0, 90, 180, 270] 55 | image_size = (120, 90) 56 | for angle in angles: 57 | rotated_point = self.model._rotate_point(self.mock_point, angle, image_size) 58 | self.assertIsInstance(rotated_point, tuple) 59 | self.assertEqual(len(rotated_point), 2) 60 | 61 | def test_transform_point_multiple(self): 62 | transformations = [ 63 | "identity", 64 | "rotate_90", 65 | "rotate_180", 66 | "rotate_270", 67 | "flip_h", 68 | "flip_v", 69 | ] 70 | image_size = (120, 90) 71 | for transformation in transformations: 72 | rotated_point = self.model._transform_point( 73 | self.mock_points, transformation, image_size 74 | ) 75 | self.assertIsInstance(rotated_point, list) 76 | for point in rotated_point: 77 | self.assertIsInstance(point, tuple) 78 | self.assertEqual(len(point), 2) 79 | 80 | def test_inverse_transform(self): 81 | transformations = [ 82 | "identity", 83 | "rotate_90", 84 | "rotate_180", 85 | "rotate_270", 86 | "flip_h", 87 | "flip_v", 88 | ] 89 | for transformation in transformations: 90 | transformed_image = self.model._transform_image( 91 | self.mock_image, transformation 92 | ) 93 | inverse_transformed_mask = self.model._inverse_transform( 94 | np.array(transformed_image), transformation 95 | ) 96 | self.assertIsInstance(inverse_transformed_mask, np.ndarray) 97 | it_mask = Image.fromarray(inverse_transformed_mask) 98 | self.assertEqual(self.mock_image.size, it_mask.size) 99 | 100 | def test_filter_mask(self): 101 | masks = [] 102 | for _ in range(5): 103 | m = np.zeros((90, 120), np.uint8) 104 | m[40:60, 40:60] = 255 105 | masks.append(Image.fromarray(m)) 106 | 107 | filtered_masks = self.model._filter_mask(masks) 108 | self.assertIsInstance(filtered_masks, list) 109 | self.assertGreater(len(filtered_masks), 0) 110 | 111 | @patch.object(SegmentationModel, "_get_mask_at_point_function") 112 | def test_mask_at_point_blended(self, mock_get_mask_at_point_function): 113 | def mock_mask_function(point): 114 | width, height = self.model.image.size 115 | return np.zeros((height, width), np.uint8) 116 | 117 | mock_get_mask_at_point_function.return_value = mock_mask_function 118 | 119 | self.model.image = self.mock_image 120 | blended_mask = self.model.mask_at_point_blended(self.mock_point) 121 | self.assertIsInstance(blended_mask, np.ndarray) 122 | 123 | 124 | if __name__ == "__main__": 125 | unittest.main() 126 | -------------------------------------------------------------------------------- /parallax_maker/test_segmentation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from .segmentation import ( 3 | analyze_depth_histogram, 4 | blend_with_alpha, 5 | generate_simple_thresholds, 6 | ) 7 | import numpy as np 8 | 9 | 10 | class TestAnalyzeDepthHistogram(unittest.TestCase): 11 | def test_analyze_depth_histogram(self): 12 | # Create a dummy depth map 13 | depth_map = np.array([[0, 20, 30], [40, 50, 60], [70, 80, 255]]) 14 | 15 | # Define the number of slices 16 | num_slices = 5 17 | 18 | # Call the function 19 | result = analyze_depth_histogram(depth_map, num_slices=num_slices) 20 | 21 | # Define the expected result 22 | expected_result = [0, 30, 40, 60, 70, 255] 23 | 24 | # Assert that the result is as expected 25 | self.assertEqual(result, expected_result) 26 | 27 | def test_analyze_depth_histogram_with_different_slices(self): 28 | # Create a dummy depth map 29 | depth_map = np.array([[0, 20, 30], [40, 50, 60], [70, 80, 255]]) 30 | 31 | # Define the number of slices 32 | num_slices = 3 33 | 34 | # Call the function 35 | result = analyze_depth_histogram(depth_map, num_slices=num_slices) 36 | 37 | # Define the expected result 38 | expected_result = [0, 40, 70, 255] 39 | 40 | # Assert that the result is as expected 41 | self.assertEqual(result, expected_result) 42 | 43 | def test_analyze_depth_histogram_with_one_slice(self): 44 | # Create a dummy depth map 45 | depth_map = np.array([[0, 20, 30], [40, 50, 60], [70, 80, 255]]) 46 | 47 | # Define the number of slices 48 | num_slices = 1 49 | 50 | # Call the function 51 | result = analyze_depth_histogram(depth_map, num_slices=num_slices) 52 | 53 | # Define the expected result 54 | expected_result = [0, 255] 55 | 56 | # Assert that the result is as expected 57 | self.assertEqual(result, expected_result) 58 | 59 | 60 | class TestGenerateSimpleThresholds(unittest.TestCase): 61 | def test_generate_simple_thresholds(self): 62 | # Create a dummy image 63 | image = np.array([[0, 20, 30], [40, 50, 60], [70, 80, 255]]) 64 | 65 | # Define the number of slices 66 | num_slices = 5 67 | 68 | # Call the function 69 | result = generate_simple_thresholds(image, num_slices=num_slices) 70 | 71 | # Define the expected result 72 | expected_result = np.array([0, 51, 102, 153, 204, 255]) 73 | 74 | # Assert that the result is as expected 75 | np.testing.assert_array_equal(result, expected_result) 76 | 77 | 78 | class TestBlendWithAlpha(unittest.TestCase): 79 | def test_blend_with_alpha(self): 80 | # Create target image 81 | target_image = np.array( 82 | [ 83 | [[255, 0, 0, 255], [0, 255, 0, 255], [0, 0, 255, 255]], 84 | [[255, 255, 0, 255], [255, 0, 255, 255], [0, 255, 255, 255]], 85 | [[255, 255, 255, 255], [0, 0, 0, 255], [128, 128, 128, 255]], 86 | ] 87 | ) 88 | 89 | # Create merge image 90 | merge_image = np.array( 91 | [ 92 | [[0, 0, 0, 128], [128, 128, 128, 128], [255, 255, 255, 128]], 93 | [[128, 0, 0, 128], [0, 128, 0, 128], [0, 0, 128, 128]], 94 | [[128, 128, 0, 128], [128, 0, 128, 128], [0, 128, 128, 128]], 95 | ] 96 | ) 97 | 98 | # Call the function 99 | blend_with_alpha(target_image, merge_image) 100 | 101 | # Define the expected result 102 | expected_result = np.array( 103 | [ 104 | [[127, 0, 0, 255], [64, 191, 64, 255], [128, 128, 255, 255]], 105 | [[191, 127, 0, 255], [127, 64, 127, 255], [0, 127, 191, 255]], 106 | [[191, 191, 127, 255], [64, 0, 64, 255], [63, 128, 128, 255]], 107 | ] 108 | ) 109 | 110 | # Assert that the result is as expected 111 | np.testing.assert_array_equal(target_image, expected_result) 112 | 113 | 114 | if __name__ == "__main__": 115 | unittest.main() 116 | -------------------------------------------------------------------------------- /parallax_maker/test_stabilityai.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, Mock 3 | from PIL import Image 4 | from io import BytesIO 5 | import requests 6 | from .stabilityai import StabilityAI 7 | 8 | 9 | class TestStabilityAI(unittest.TestCase): 10 | 11 | @patch("parallax_maker.stabilityai.requests.get") 12 | def test_validate_key_success(self, mock_get): 13 | mock_response = Mock() 14 | mock_response.status_code = 200 15 | mock_response.json.return_value = {"credits": 100} 16 | mock_get.return_value = mock_response 17 | 18 | ai = StabilityAI(api_key="fake_api_key") 19 | success, credits = ai.validate_key() 20 | 21 | self.assertTrue(success) 22 | self.assertEqual(credits, 100) 23 | 24 | @patch("parallax_maker.stabilityai.requests.get") 25 | def test_validate_key_failure(self, mock_get): 26 | mock_response = Mock() 27 | mock_response.status_code = 401 28 | mock_get.return_value = mock_response 29 | 30 | ai = StabilityAI(api_key="fake_api_key") 31 | success, _ = ai.validate_key() 32 | 33 | self.assertFalse(success) 34 | 35 | @patch("parallax_maker.stabilityai.requests.post") 36 | def test_generate_image_success(self, mock_post): 37 | # Prepare mocked response 38 | image = Image.new("RGB", (100, 100)) 39 | image_data = BytesIO() 40 | image.save(image_data, format="PNG") 41 | image_data.seek(0) 42 | 43 | mock_response = Mock() 44 | mock_response.status_code = 200 45 | mock_response.content = image_data.read() 46 | mock_post.return_value = mock_response 47 | 48 | ai = StabilityAI(api_key="fake_api_key") 49 | result_image = ai.generate_image(prompt="A dog running in the park") 50 | 51 | self.assertIsInstance(result_image, Image.Image) 52 | self.assertEqual(result_image.size, (100, 100)) 53 | 54 | @patch("parallax_maker.stabilityai.requests.post") 55 | def test_generate_image_failure(self, mock_post): 56 | mock_response = Mock() 57 | mock_response.status_code = 400 58 | mock_response.json.return_value = {"error": "Bad request"} 59 | mock_post.return_value = mock_response 60 | 61 | ai = StabilityAI(api_key="fake_api_key") 62 | with self.assertRaises(Exception): 63 | ai.generate_image(prompt="A dog running in the park") 64 | 65 | @patch("parallax_maker.stabilityai.requests.post") 66 | def test_image_to_image_success(self, mock_post): 67 | image = Image.new("RGB", (100, 100)) 68 | input_image = Image.new("RGB", (100, 100)) 69 | image_data = BytesIO() 70 | image.save(image_data, format="PNG") 71 | image_data.seek(0) 72 | 73 | mock_response = Mock() 74 | mock_response.status_code = 200 75 | mock_response.content = image_data.read() 76 | mock_post.return_value = mock_response 77 | 78 | ai = StabilityAI(api_key="fake_api_key") 79 | result_image = ai.image_to_image(input_image, prompt="A cat on a sofa") 80 | 81 | self.assertIsInstance(result_image, Image.Image) 82 | self.assertEqual(result_image.size, (100, 100)) 83 | 84 | @patch("parallax_maker.stabilityai.requests.post") 85 | def test_image_to_image_invalid_size(self, mock_post): 86 | small_image = Image.new("RGB", (50, 50)) 87 | ai = StabilityAI(api_key="fake_api_key") 88 | 89 | with self.assertRaises(ValueError): 90 | ai.image_to_image(small_image, prompt="A cat on a sofa") 91 | 92 | @patch("parallax_maker.stabilityai.requests.post") 93 | def test_inpaint_image_success(self, mock_post): 94 | image = Image.new("RGB", (100, 100)) 95 | mask = Image.new("1", (100, 100)) 96 | output_image = Image.new("RGB", (100, 100)) 97 | image_data = BytesIO() 98 | output_image.save(image_data, format="PNG") 99 | image_data.seek(0) 100 | 101 | mock_response = Mock() 102 | mock_response.status_code = 200 103 | mock_response.content = image_data.read() 104 | mock_post.return_value = mock_response 105 | 106 | ai = StabilityAI(api_key="fake_api_key") 107 | result_image = ai.inpaint_image( 108 | image, mask, prompt="A beach sunset", strength=0.5 109 | ) 110 | 111 | self.assertIsInstance(result_image, Image.Image) 112 | self.assertEqual(result_image.size, (100, 100)) 113 | 114 | @patch("parallax_maker.stabilityai.requests.post") 115 | def test_inpaint_image_invalid_size(self, mock_post): 116 | small_image = Image.new("RGB", (50, 50)) 117 | mask = Image.new("1", (50, 50)) 118 | ai = StabilityAI(api_key="fake_api_key") 119 | 120 | with self.assertRaises(ValueError): 121 | ai.inpaint_image(small_image, mask, prompt="A beach sunset") 122 | 123 | @patch("parallax_maker.stabilityai.requests.post") 124 | def test_upscale_image_success(self, mock_post): 125 | image = Image.new("RGB", (100, 100)) 126 | output_image = Image.new("RGB", (200, 200)) 127 | image_data = BytesIO() 128 | output_image.save(image_data, format="PNG") 129 | image_data.seek(0) 130 | 131 | mock_response = Mock() 132 | mock_response.status_code = 200 133 | mock_response.content = image_data.read() 134 | mock_post.return_value = mock_response 135 | 136 | ai = StabilityAI(api_key="fake_api_key") 137 | result_image = ai.upscale_image(image, prompt="Upscale this image") 138 | 139 | self.assertIsInstance(result_image, Image.Image) 140 | self.assertEqual(result_image.size, (200, 200)) 141 | 142 | @patch("parallax_maker.stabilityai.requests.post") 143 | def test_upscale_image_invalid_size(self, mock_post): 144 | large_image = Image.new("RGB", (2000, 2000)) 145 | ai = StabilityAI(api_key="fake_api_key") 146 | 147 | with self.assertRaises(ValueError): 148 | ai.upscale_image(large_image, prompt="Upscale this image") 149 | 150 | 151 | if __name__ == "__main__": 152 | unittest.main() 153 | -------------------------------------------------------------------------------- /parallax_maker/test_upscaler.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock 3 | from PIL import Image 4 | import numpy as np 5 | 6 | from .upscaler import Upscaler 7 | 8 | 9 | class TestUpscaler(unittest.TestCase): 10 | def setUp(self): 11 | # Initialize Upscaler object without creating a model. 12 | self.upscaler = Upscaler(model_name="simple") 13 | 14 | # Create a dummy input image 15 | self.input_image = Image.new("RGBA", (800, 800)) 16 | # Fill the alpha with 1s - so that the bounding box is the full image 17 | self.input_image.putalpha(255) 18 | 19 | # Create a dummy upscaled tile 20 | self.upscaled_tile_dummy = Image.new("RGB", (1, 1)) 21 | self.upscaled_tile_dummy_real = Image.new( 22 | "RGB", (self.upscaler.tile_size * 2, self.upscaler.tile_size * 2) 23 | ) 24 | 25 | def test_upscale_image_tiled_overlap(self): 26 | # Mocking the creation of the model to skip loading model and processor 27 | with patch.object(self.upscaler, "create_model"): 28 | # Mock the upscale_tile function to return our dummy tile 29 | with patch.object( 30 | self.upscaler, "upscale_tile", return_value=self.upscaled_tile_dummy 31 | ) as mock_upscale_tile: 32 | # Mock the integrate_tile to ensure integral processing of tiles 33 | with patch.object(self.upscaler, "integrate_tile") as mock_integrate: 34 | # Call the function to test 35 | upscaled_image = self.upscaler.upscale_image_tiled( 36 | self.input_image, overlap=64 37 | ) 38 | 39 | # Check that integrate_tile was called the correct number of times (expected 3x3 grid based on sizes). 40 | self.assertEqual(mock_integrate.call_count, 4) 41 | mock_upscale_tile.assert_called() 42 | 43 | def test_upscale_image_tiled_no_overlap(self): 44 | # Mocking the creation of the model to skip loading model and processor 45 | with patch.object(self.upscaler, "create_model"): 46 | # Mock the upscale_tile function to return our dummy tile 47 | with patch.object( 48 | self.upscaler, "upscale_tile", return_value=self.upscaled_tile_dummy 49 | ): 50 | # Mock the integrate_tile to ensure integral processing of tiles 51 | with patch.object(self.upscaler, "integrate_tile") as mock_integrate: 52 | # Call the function to test 53 | upscaled_image = self.upscaler.upscale_image_tiled( 54 | self.input_image, overlap=0 55 | ) 56 | 57 | # Check that integrate_tile was called the correct number of times (expected 2x2 grid based on sizes). 58 | self.assertEqual(mock_integrate.call_count, 4) 59 | 60 | def test_upscale_image_tiled_no_mock(self): 61 | # Mocking the creation of the model to skip loading model and processor 62 | with patch.object(self.upscaler, "create_model"): 63 | # Mock the upscale_tile function to return our dummy tile 64 | with patch.object( 65 | self.upscaler, 66 | "upscale_tile", 67 | return_value=self.upscaled_tile_dummy_real, 68 | ): 69 | # Call the function to test 70 | upscaled_image = self.upscaler.upscale_image_tiled( 71 | self.input_image, overlap=64 72 | ) 73 | 74 | # Check that integrate_tile was called the correct number of times (expected 3x3 grid based on sizes). 75 | self.assertEqual(upscaled_image.size, (2 * 800, 2 * 800)) 76 | 77 | 78 | class TestIntegrateTile(unittest.TestCase): 79 | 80 | def setUp(self): 81 | # Create a sample tile and image for testing 82 | self.tile = np.ones((100, 100, 3), dtype=np.uint8) * 255 83 | self.image = np.zeros((200, 200, 3), dtype=np.uint8) 84 | 85 | def test_integrate_tile_no_overlap(self): 86 | # Test integrating a tile with no overlap 87 | Upscaler.integrate_tile(self.tile, self.image, 50, 50, 150, 150, 0, 0, 0) 88 | self.assertTrue(np.array_equal(self.image[50:150, 50:150], self.tile)) 89 | 90 | def test_integrate_tile_left_overlap(self): 91 | # Test integrating a tile with left overlap 92 | Upscaler.integrate_tile(self.tile, self.image, 0, 50, 100, 150, 1, 0, 20) 93 | self.assertTrue(np.all(self.image[50:150, 1:20] > 0)) 94 | self.assertTrue(np.all(self.image[50:150, 20:100] == 255)) 95 | 96 | def test_integrate_tile_top_overlap(self): 97 | # Test integrating a tile with top overlap 98 | Upscaler.integrate_tile(self.tile, self.image, 50, 0, 150, 100, 0, 1, 20) 99 | self.assertTrue(np.all(self.image[1:20, 50:150] > 0)) 100 | self.assertTrue(np.all(self.image[20:100, 50:150] == 255)) 101 | 102 | def test_integrate_tile_corner_overlap(self): 103 | # Test integrating a tile with corner overlap 104 | Upscaler.integrate_tile(self.tile, self.image, 0, 0, 100, 100, 1, 1, 20) 105 | self.assertTrue(np.all(self.image[2:20, 2:20] > 0)) 106 | self.assertTrue(np.all(self.image[1:20, 20:100] > 0)) 107 | self.assertTrue(np.all(self.image[20:100, 1:20] > 0)) 108 | self.assertTrue(np.all(self.image[20:100, 20:100] == 255)) 109 | 110 | def test_integrate_tile_gradient(self): 111 | # Test if there is a gradient in the overlapping regions 112 | overlap = 20 113 | Upscaler.integrate_tile(self.tile, self.image, 0, 0, 100, 100, 1, 1, overlap) 114 | 115 | # Check left overlap gradient 116 | left_overlap = self.image[:, :overlap] 117 | self.assertFalse(np.all(left_overlap == left_overlap[0, 0])) 118 | self.assertTrue(np.all(np.diff(left_overlap, axis=1) >= 0)) 119 | 120 | # Check top overlap gradient 121 | top_overlap = self.image[:overlap, :] 122 | self.assertFalse(np.all(top_overlap == top_overlap[0, 0])) 123 | self.assertTrue(np.all(np.diff(top_overlap, axis=0) >= 0)) 124 | 125 | # Check corner overlap gradient 126 | corner_overlap = self.image[:overlap, :overlap] 127 | self.assertFalse(np.all(corner_overlap == corner_overlap[0, 0])) 128 | self.assertTrue(np.all(np.diff(corner_overlap, axis=0) >= 0)) 129 | self.assertTrue(np.all(np.diff(corner_overlap, axis=1) >= 0)) 130 | 131 | 132 | if __name__ == "__main__": 133 | unittest.main() 134 | -------------------------------------------------------------------------------- /parallax_maker/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from PIL import Image 3 | from .utils import ( 4 | find_bounding_box, 5 | find_square_from_bounding_box, 6 | find_square_bounding_box, 7 | filename_add_version, 8 | filename_previous_version, 9 | highlight_selected_element, 10 | encode_string_with_nonce, 11 | decode_string_with_nonce, 12 | ) 13 | 14 | 15 | class TestFindBoundingBox(unittest.TestCase): 16 | def test_find_bounding_box(self): 17 | # Create a mask image with a bounding box 18 | mask_image = Image.new("L", (100, 100)) 19 | mask_image.paste(255, (20, 30, 80, 70)) 20 | 21 | # Call the function 22 | result = find_bounding_box(mask_image, padding=0) 23 | 24 | # Define the expected result 25 | expected_result = (20, 30, 79, 69) 26 | 27 | # Assert that the result is as expected 28 | self.assertEqual(result, expected_result) 29 | 30 | def test_find_bounding_box_with_padding(self): 31 | # Create a mask image with a bounding box 32 | mask_image = Image.new("L", (100, 100)) 33 | mask_image.paste(255, (20, 30, 80, 70)) 34 | 35 | # Call the function with padding 36 | result = find_bounding_box(mask_image, padding=10) 37 | 38 | # Define the expected result 39 | expected_result = (10, 20, 89, 79) 40 | 41 | # Assert that the result is as expected 42 | self.assertEqual(result, expected_result) 43 | 44 | def test_find_bounding_box_with_empty_mask(self): 45 | # Create an empty mask image 46 | mask_image = Image.new("L", (100, 100)) 47 | 48 | # Call the function 49 | result = find_bounding_box(mask_image) 50 | 51 | # Define the expected result 52 | expected_result = (0, 0, 0, 0) 53 | 54 | # Assert that the result is as expected 55 | self.assertEqual(result, expected_result) 56 | 57 | def test_find_square_from_bounding_box(self): 58 | # Define the bounding box coordinates 59 | xmin, ymin, xmax, ymax = 20, 30, 80, 70 60 | 61 | # Call the function 62 | result = find_square_from_bounding_box(xmin, ymin, xmax, ymax) 63 | 64 | # Define the expected result 65 | expected_result = (20, 20, 80, 80) 66 | 67 | # Assert that the result is as expected 68 | self.assertEqual(result, expected_result) 69 | 70 | def test_find_square_bounding_box(self): 71 | # Create an empty mask image 72 | mask_image = Image.new("L", (100, 100)) 73 | mask_image.paste(255, (20, 30, 80, 70)) 74 | 75 | result = find_square_bounding_box(mask_image, padding=10) 76 | 77 | expected_result = (11, 11, 89, 89) 78 | 79 | self.assertEqual(result, expected_result) 80 | 81 | def test_find_square_bounding_box_outside(self): 82 | # Create an empty mask image 83 | mask_image = Image.new("L", (100, 100)) 84 | mask_image.paste(255, (20, 30, 80, 70)) 85 | 86 | result = find_square_bounding_box(mask_image, padding=30) 87 | 88 | expected_result = (0, 0, 100, 100) 89 | 90 | self.assertEqual(result, expected_result) 91 | 92 | 93 | class TestFilenameAddVersion(unittest.TestCase): 94 | def test_filename_add_version(self): 95 | # Define the input filename 96 | filename = "/path/to/image.png" 97 | 98 | # Call the function 99 | result = filename_add_version(filename) 100 | 101 | # Define the expected result 102 | expected_result = "/path/to/image_v2.png" 103 | 104 | # Assert that the result is as expected 105 | self.assertEqual(result, expected_result) 106 | 107 | # Call the function 108 | result = filename_add_version(result) 109 | 110 | # Define the expected result 111 | expected_result = "/path/to/image_v3.png" 112 | 113 | # Assert that the result is as expected 114 | self.assertEqual(result, expected_result) 115 | 116 | def test_filename_previous_version(self): 117 | # Define the input filename 118 | filename = "/path/to/image_v3.png" 119 | 120 | # Call the function 121 | result = filename_previous_version(filename) 122 | 123 | # Define the expected result 124 | expected_result = "/path/to/image_v2.png" 125 | 126 | # Assert that the result is as expected 127 | self.assertEqual(result, expected_result) 128 | 129 | # Call the function 130 | result = filename_previous_version(result) 131 | 132 | # Define the expected result 133 | expected_result = "/path/to/image.png" 134 | 135 | # Assert that the result is as expected 136 | self.assertEqual(result, expected_result) 137 | 138 | # Call the function with a filename without version 139 | filename = "/path/to/image.png" 140 | result = filename_previous_version(filename) 141 | 142 | # Define the expected result 143 | expected_result = None 144 | 145 | # Assert that the result is as expected 146 | self.assertEqual(result, expected_result) 147 | 148 | 149 | class TestHighLightSelectedElement(unittest.TestCase): 150 | def setUp(self) -> None: 151 | super().setUp() 152 | 153 | self.classnames = [ 154 | "text-white", 155 | "text-brown", 156 | "text-black", 157 | ] 158 | self.highlight_classname = "color-is-selected" 159 | 160 | def test_highlight_selected_element(self): 161 | result = highlight_selected_element( 162 | self.classnames, 0, self.highlight_classname 163 | ) 164 | 165 | expected_result = self.classnames.copy() 166 | expected_result[0] += f" {self.highlight_classname}" 167 | 168 | # Assert that the result is as expected 169 | self.assertEqual(result, expected_result) 170 | 171 | def test_highlight_selected_element_remove(self): 172 | classnames = self.classnames.copy() 173 | expected_result = self.classnames 174 | classnames[1] += f" {self.highlight_classname}" 175 | 176 | result = highlight_selected_element(classnames, None, self.highlight_classname) 177 | 178 | # Assert that the result is as expected 179 | self.assertEqual(result, expected_result) 180 | 181 | 182 | class TestEncodeStringWithNonce(unittest.TestCase): 183 | def test_encode_string_with_nonce(self): 184 | # Define the plaintext and nonce 185 | plaintext = "Hello, World!" 186 | nonce = "unique-filename" 187 | 188 | # Call the function 189 | result = encode_string_with_nonce(plaintext, nonce) 190 | 191 | self.assertNotEqual(result, plaintext) 192 | 193 | decoded = decode_string_with_nonce(result, nonce) 194 | 195 | self.assertEqual(decoded, plaintext) 196 | 197 | decoded = decode_string_with_nonce(result, "wrong-nonce") 198 | 199 | self.assertNotEqual(decoded, plaintext) 200 | 201 | def test_decode_with_string_bad_data(self): 202 | nonce = "unique-filename" 203 | input = "bad-data" 204 | 205 | decoded = decode_string_with_nonce(input, nonce) 206 | 207 | self.assertEqual(decoded, None) 208 | 209 | 210 | if __name__ == "__main__": 211 | unittest.main() 212 | -------------------------------------------------------------------------------- /parallax_maker/upscaler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # (c) 2024 Niels Provos 3 | # 4 | """ 5 | Upscale Image Tiling 6 | 7 | This module provides functionality to upscale an image using a pre-trained deep learning model, 8 | specifically Swin2SR for Super Resolution. The upscaling process breaks the input image into smaller 9 | tiles. Each tile is individually upscaled using the model, and then these upscaled tiles are reassembled 10 | into the final image. The tile-based approach helps manage memory usage and can handle larger 11 | images by processing small parts one at a time. 12 | 13 | For better integration of the tiles into the resultant upscaled image, tiles are overlapped and 14 | blended to ensure smoother transitions between tiles. 15 | """ 16 | 17 | from PIL import Image 18 | import torch 19 | from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution 20 | 21 | import numpy as np 22 | from scipy.ndimage import zoom 23 | 24 | from .utils import torch_get_device, premultiply_alpha_numpy, find_bounding_box 25 | import argparse 26 | 27 | 28 | class Upscaler: 29 | def __init__(self, model_name="swin2sr", external_model=None): 30 | assert model_name in ["swin2sr", "simple", "inpainting", "stabilityai"] 31 | self.model_name = model_name 32 | self.model = None 33 | self.image_processor = None 34 | self.external_model = external_model 35 | self.tile_size = 512 36 | self.scale_factor = 2 37 | 38 | def create_model(self): 39 | if self.model_name == "swin2sr": 40 | self.model, self.image_processor = self.load_swin2sr_model() 41 | elif self.model_name == "simple": 42 | self.model, self.image_processor = None, None 43 | self.tile_size = 512 44 | elif self.model_name == "inpainting": 45 | assert self.external_model is not None 46 | assert self.external_model.get_dimension() is not None 47 | self.model, self.image_processor = self.external_model, None 48 | self.tile_size = self.external_model.get_dimension() // 2 49 | elif self.model_name == "stabilityai": 50 | self.model, self.image_processor = self.external_model, None 51 | self.tile_size = 1024 52 | self.scale_factor = 4 53 | 54 | def __eq__(self, other): 55 | if not isinstance(other, Upscaler): 56 | return False 57 | return self.model_name == other.model_name 58 | 59 | def load_swin2sr_model(self): 60 | # Initialize the image processor and model 61 | image_processor = AutoImageProcessor.from_pretrained( 62 | "caidas/swin2SR-classical-sr-x2-64" 63 | ) 64 | model = Swin2SRForImageSuperResolution.from_pretrained( 65 | "caidas/swin2SR-classical-sr-x2-64" 66 | ) 67 | model.to(torch_get_device()) 68 | 69 | return model, image_processor 70 | 71 | def upscale_image_tiled(self, image, overlap=64, prompt=None, negative_prompt=None): 72 | """ 73 | Upscales an image using a tiled approach. 74 | 75 | Args: 76 | image (PIL.Image or np.ndarray): The input image array. 77 | overlap (int, optional): The overlap between adjacent tiles. Defaults to 64. 78 | 79 | Returns: 80 | np.ndarray: The upscaled image array. 81 | """ 82 | if isinstance(image, Image.Image): 83 | image = np.array(image) 84 | 85 | # initializes parameters we need below 86 | if self.model is None: 87 | self.create_model() 88 | 89 | alpha = None 90 | bounding_box = None 91 | if image.shape[2] == 4: # RGBA image 92 | alpha = image[:, :, 3] 93 | bounding_box = find_bounding_box(alpha, padding=0) 94 | 95 | # scale the size of the alpha channel using bicubic interpolation 96 | alpha = zoom(alpha, self.scale_factor, order=3) 97 | image = image[:, :, :3] # Remove alpha channel 98 | 99 | # crop the image to the bounding box 100 | orig_image = image 101 | image = image[ 102 | bounding_box[1] : bounding_box[3], bounding_box[0] : bounding_box[2] 103 | ] 104 | 105 | # Ensure the overlap can be divided by the scale factor 106 | if overlap % self.scale_factor != 0: 107 | overlap = (overlap // self.scale_factor + 1) * self.scale_factor 108 | 109 | # Calculate the number of tiles 110 | height, width, _ = image.shape 111 | step_size = self.tile_size - overlap // self.scale_factor 112 | num_tiles_x = (width + step_size - 1) // step_size 113 | num_tiles_y = (height + step_size - 1) // step_size 114 | 115 | # Create a new array to store the upscaled result 116 | upscaled_height = height * self.scale_factor 117 | upscaled_width = width * self.scale_factor 118 | upscaled_image = np.zeros((upscaled_height, upscaled_width, 3), dtype=np.uint8) 119 | 120 | # Iterate over the tiles 121 | for y in range(num_tiles_y): 122 | for x in range(num_tiles_x): 123 | # Calculate the coordinates of the current tile 124 | left = x * step_size 125 | top = y * step_size 126 | right = min(left + self.tile_size, width) 127 | bottom = min(top + self.tile_size, height) 128 | 129 | # make sure we process a full tile 130 | if x > 0 and right - left < self.tile_size: 131 | left = right - self.tile_size 132 | assert left >= 0 133 | if y > 0 and bottom - top < self.tile_size: 134 | top = bottom - self.tile_size 135 | assert top >= 0 136 | 137 | print( 138 | f"Processing tile ({y}, {x}) with coordinates ({left}, {top}, {right}, {bottom})" 139 | ) 140 | 141 | # Extract the current tile from the image 142 | tile = image[top:bottom, left:right] 143 | # XXX - revisit whether we can keep this as a numpy array 144 | tile = Image.fromarray(tile) 145 | 146 | cur_width, cur_height = tile.size 147 | if cur_width % 64 != 0 or cur_height % 64 != 0: 148 | new_width = ( 149 | cur_width + (64 - cur_width % 64) 150 | if cur_width % 64 != 0 151 | else cur_width 152 | ) 153 | new_height = ( 154 | cur_height + (64 - cur_height % 64) 155 | if cur_height % 64 != 0 156 | else cur_height 157 | ) 158 | tile = tile.resize((new_width, new_height)) 159 | upscaled_tile = self.upscale_tile(tile, prompt, negative_prompt) 160 | if tile.size != (cur_width, cur_height): 161 | upscaled_tile = upscaled_tile.resize( 162 | (cur_width * self.scale_factor, cur_height * self.scale_factor) 163 | ) 164 | upscaled_tile = np.array(upscaled_tile) 165 | 166 | # Calculate the coordinates to paste the upscaled tile 167 | place_left = left * self.scale_factor 168 | place_top = top * self.scale_factor 169 | place_right = place_left + upscaled_tile.shape[1] 170 | place_bottom = place_top + upscaled_tile.shape[0] 171 | 172 | self.integrate_tile( 173 | upscaled_tile, 174 | upscaled_image, 175 | place_left, 176 | place_top, 177 | place_right, 178 | place_bottom, 179 | x, 180 | y, 181 | overlap, 182 | ) 183 | 184 | # Combine the upscaled image with the alpha channel if present 185 | if alpha is not None: 186 | image = zoom(orig_image, (self.scale_factor, self.scale_factor, 1), order=3) 187 | bounding_box = [coord * self.scale_factor for coord in bounding_box] 188 | image[ 189 | bounding_box[1] : bounding_box[3], bounding_box[0] : bounding_box[2] 190 | ] = upscaled_image 191 | upscaled_image = np.dstack((image, alpha)) 192 | upscaled_image = premultiply_alpha_numpy(upscaled_image) 193 | else: 194 | upscaled_image = Image.fromarray(upscaled_image) 195 | 196 | return upscaled_image 197 | 198 | @staticmethod 199 | def integrate_tile(tile, image, left, top, right, bottom, tile_x, tile_y, overlap): 200 | height, width, _ = tile.shape 201 | 202 | # xxx - should be move before upscaling 203 | if overlap >= height or overlap >= width: 204 | return 205 | 206 | # Create an alpha channel for the tile 207 | alpha = np.ones((height, width), dtype=np.float32) 208 | if tile_x > 0 and tile_y > 0: 209 | alpha[:overlap, :overlap] = np.outer( 210 | np.linspace(0, 1, overlap), np.linspace(0, 1, overlap) 211 | ) 212 | if tile_x > 0: 213 | new_alpha = np.tile(np.linspace(0, 1, overlap), (height, 1)) 214 | alpha[:, :overlap] = np.minimum(alpha[:, :overlap], new_alpha) 215 | 216 | if tile_y > 0: 217 | new_alpha = np.tile(np.linspace(0, 1, overlap).reshape(-1, 1), (1, width)) 218 | alpha[:overlap, :] = np.minimum(alpha[:overlap, :], new_alpha) 219 | 220 | # Reshape the alpha channel to match the tile shape 221 | alpha = alpha.reshape(height, width, 1) 222 | 223 | # Apply the tile with alpha blending to the image 224 | image[top:bottom, left:right] = ( 225 | alpha * tile + (1 - alpha) * image[top:bottom, left:right] 226 | ).astype(np.uint8) 227 | 228 | def upscale_tile(self, tile, prompt=None, negative_prompt=None): 229 | if self.model_name == "swin2sr": 230 | return self._upscale_tile_swin2sr(tile) 231 | elif self.model_name == "simple": 232 | return self._upscale_tile_simple(tile) 233 | elif self.model_name == "inpainting": 234 | return self._upscale_tile_inpainting(tile, prompt, negative_prompt) 235 | elif self.model_name == "stabilityai": 236 | return self._upscale_tile_stabilityai(tile, prompt, negative_prompt) 237 | 238 | def _upscale_tile_stabilityai(self, tile, prompt, negative_prompt): 239 | upscaled_image = self.external_model.upscale_image( 240 | tile, prompt, negative_prompt=negative_prompt 241 | ) 242 | width, height = tile.size 243 | upscaled_image = upscaled_image.resize( 244 | (width * self.scale_factor, height * self.scale_factor), Image.LANCZOS 245 | ) 246 | return upscaled_image 247 | 248 | def _upscale_tile_inpainting(self, tile, prompt, negative_prompt): 249 | rescaled_tile = tile.resize((tile.size[0] * 2, tile.size[1] * 2), Image.LANCZOS) 250 | rescaled_tile = np.array(rescaled_tile) 251 | mask_image = np.ones(rescaled_tile.shape[:2], dtype=np.uint8) 252 | mask_image *= 255 253 | scale = 2 if len(prompt) or len(negative_prompt) else 0 254 | tile = self.external_model.inpaint( 255 | prompt, 256 | negative_prompt, 257 | rescaled_tile, 258 | mask_image, 259 | strength=0.25, 260 | guidance_scale=scale, 261 | num_inference_steps=75, 262 | padding=0, 263 | blur_radius=0, 264 | ) 265 | tile = tile.convert("RGB") 266 | return tile 267 | 268 | def _upscale_tile_simple(self, tile): 269 | """ 270 | Upscales a tile using a simple bicubic interpolation. 271 | 272 | Args: 273 | tile: The input tile to be upscaled. 274 | 275 | Returns: 276 | An Image object representing the upscaled tile. 277 | """ 278 | return tile.resize((tile.width * 2, tile.height * 2), Image.BICUBIC) 279 | 280 | def _upscale_tile_swin2sr(self, tile): 281 | """ 282 | Upscales a tile using a given model and image processor. 283 | 284 | Args: 285 | model: The model used for upscaling the tile. 286 | image_processor: The image processor used to preprocess the tile. 287 | tile: The input tile to be upscaled. 288 | 289 | Returns: 290 | An Image object representing the upscaled tile. 291 | """ 292 | inputs = self.image_processor(tile, return_tensors="pt") 293 | inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()} 294 | 295 | with torch.no_grad(): 296 | outputs = self.model(**inputs) 297 | 298 | output = outputs.reconstruction.squeeze().float().cpu().clamp_(0, 1).numpy() 299 | output = np.moveaxis(output, source=0, destination=-1) 300 | output = (output * 255.0).round().astype(np.uint8) 301 | 302 | return Image.fromarray(output) 303 | 304 | 305 | if __name__ == "__main__": 306 | from inpainting import InpaintingModel 307 | from stabilityai import StabilityAI 308 | 309 | parser = argparse.ArgumentParser(description="Image Upscaling") 310 | parser.add_argument( 311 | "-i", "--input", type=str, default="input.jpg", help="Input image path" 312 | ) 313 | parser.add_argument( 314 | "-p", "--prompt", type=str, default="a sci-fi robot in a futuristic laboratory" 315 | ) 316 | parser.add_argument("--api-key", type=str) 317 | args = parser.parse_args() 318 | 319 | if args.api_key: 320 | model_name = "stabilityai" 321 | model = StabilityAI(args.api_key) 322 | else: 323 | model_name = "inpainting" 324 | model = InpaintingModel() 325 | model.load_model() 326 | upscaler = Upscaler(model_name=model_name, external_model=model) 327 | image = Image.open(args.input) 328 | upscaled_image = upscaler.upscale_image_tiled( 329 | image, 330 | overlap=64, 331 | prompt=args.prompt, 332 | ) 333 | upscaled_image.save("upscaled_image.png") 334 | -------------------------------------------------------------------------------- /postcss.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | plugins: [ 3 | require('tailwindcss'), 4 | require('autoprefixer'), 5 | ] 6 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=65.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "parallax-maker" 7 | version = "1.0.3" 8 | description = "A workflow for turning images into 2.5D animations with depth-based segmentation and inpainting" 9 | authors = [ 10 | {name = "Niels Provos", email = "niels@provos.org"} 11 | ] 12 | readme = "README.md" 13 | license = "AGPL-3.0-or-later" 14 | keywords = ["computer-vision", "depth-estimation", "image-segmentation", "3d-animation", "gltf", "dash", "stable-diffusion"] 15 | classifiers = [ 16 | "Development Status :: 4 - Beta", 17 | "Intended Audience :: Developers", 18 | "Intended Audience :: End Users/Desktop", 19 | "Operating System :: OS Independent", 20 | "Programming Language :: Python :: 3", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: 3.12", 24 | "Topic :: Multimedia :: Graphics :: 3D Modeling", 25 | "Topic :: Multimedia :: Graphics :: Graphics Conversion", 26 | "Topic :: Scientific/Engineering :: Image Processing", 27 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 28 | ] 29 | requires-python = ">=3.10,<3.14" 30 | dependencies = [ 31 | "dash==2.17.0", 32 | "dash_extensions==1.0.14", 33 | "numpy==1.26.4", 34 | "Pillow>=11.3.0", 35 | "pygltflib==1.16.2", 36 | "torch (>=2.8.0,<3.0.0)", 37 | "torchvision (>=0.23.0,<0.24.0)", 38 | "opencv-python>=4.9.0", 39 | "timm==0.6.12", 40 | "diffusers[torch] (>=0.34.0,<0.35.0)", 41 | "transformers>=4.53.0", 42 | "orjson>=3.10.1", 43 | "numba>=0.59.1", 44 | "websocket-client>=1.8.0", 45 | "scipy>=1.13.1", 46 | "protobuf (>=6.31.1,<7.0.0)", 47 | "sentencepiece (>=0.2.0,<0.3.0)", 48 | "accelerate (>=1.8.1,<2.0.0)", 49 | "fal-client (>=0.7.0,<0.8.0)", 50 | ] 51 | 52 | [project.optional-dependencies] 53 | dev = [ 54 | "pytest", 55 | "pytest-cov", 56 | "black", 57 | "flake8", 58 | "mypy", 59 | "pre-commit", 60 | ] 61 | 62 | [project.urls] 63 | Homepage = "https://github.com/provos/parallax-maker" 64 | Documentation = "https://github.com/provos/parallax-maker#readme" 65 | Repository = "https://github.com/provos/parallax-maker" 66 | Issues = "https://github.com/provos/parallax-maker/issues" 67 | 68 | [project.scripts] 69 | parallax-maker = "parallax_maker.webui:main" 70 | parallax-gltf-cli = "parallax_maker.gltf_cli:main" 71 | 72 | [tool.setuptools] 73 | include-package-data = false 74 | 75 | [tool.setuptools.packages.find] 76 | include = ["parallax_maker"] 77 | 78 | [tool.setuptools.package-data] 79 | parallax_maker = [ 80 | "assets/css/*", 81 | "assets/scripts/*", 82 | "example/**/*", 83 | "*.js", 84 | "*.css", 85 | "*.json", 86 | ] 87 | 88 | [tool.black] 89 | line-length = 88 90 | target-version = ['py310'] 91 | 92 | [tool.pytest.ini_options] 93 | testpaths = ["parallax_maker"] 94 | python_files = ["test_*.py"] 95 | addopts = "--cov=parallax_maker --cov-report=html --cov-report=term-missing" 96 | pythonpath = ["."] 97 | 98 | [tool.mypy] 99 | python_version = "3.10" 100 | warn_return_any = true 101 | warn_unused_configs = true 102 | disallow_untyped_defs = true 103 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dash==2.17.0 2 | dash_extensions==1.0.14 3 | numpy==1.26.4 4 | Pillow>=10.3.0 5 | pygltflib==1.16.2 6 | torch>=2.7.1,<3.0.0 7 | torchvision>=0.22.1,<0.23.0 8 | opencv-python>=4.9.0 9 | timm==0.6.12 10 | diffusers[torch]>=0.34.0,<0.35.0 11 | transformers>=4.40 12 | orjson>=3.10.1 13 | numba>=0.59.1 14 | websocket-client>=1.8.0 15 | scipy>=1.13.1 16 | protobuf>=6.31.1,<7.0.0 17 | sentencepiece>=0.2.0,<0.3.0 18 | accelerate>=1.8.1,<2.0.0 19 | fal-client>=0.7.0,<0.8.0 -------------------------------------------------------------------------------- /setup_dev.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Setup script for Parallax Maker development and packaging. 4 | """ 5 | 6 | import sys 7 | import subprocess 8 | import shutil 9 | from pathlib import Path 10 | 11 | 12 | def run_command(cmd, description): 13 | """Run a command and handle errors.""" 14 | print(f"Running: {description}") 15 | print(f"Command: {cmd}") 16 | result = subprocess.run(cmd, shell=True, capture_output=True, text=True) 17 | if result.returncode != 0: 18 | print(f"Error: {result.stderr}") 19 | return False 20 | print(f"Success: {result.stdout}") 21 | return True 22 | 23 | 24 | def clean_build(): 25 | """Clean previous build artifacts.""" 26 | print("Cleaning previous build artifacts...") 27 | dirs_to_remove = ["dist", "build", "*.egg-info"] 28 | for pattern in dirs_to_remove: 29 | for path in Path(".").glob(pattern): 30 | if path.is_dir(): 31 | shutil.rmtree(path) 32 | print(f"Removed directory: {path}") 33 | else: 34 | path.unlink() 35 | print(f"Removed file: {path}") 36 | 37 | 38 | def build_package(): 39 | """Build the package.""" 40 | print("Building package...") 41 | return run_command("python -m build", "Building wheel and source distribution") 42 | 43 | 44 | def check_package(): 45 | """Check the built package.""" 46 | print("Checking package...") 47 | return run_command("python -m twine check dist/*", "Checking package integrity") 48 | 49 | 50 | def install_dev_tools(): 51 | """Install development tools.""" 52 | print("Installing development tools...") 53 | tools = ["build", "twine", "pytest", "black", "flake8"] 54 | for tool in tools: 55 | if not run_command(f"pip install {tool}", f"Installing {tool}"): 56 | return False 57 | return True 58 | 59 | 60 | def run_tests(): 61 | """Run tests.""" 62 | print("Running tests...") 63 | return run_command("python -m pytest", "Running test suite") 64 | 65 | 66 | def format_code(): 67 | """Format code with black.""" 68 | print("Formatting code...") 69 | return run_command("python -m black .", "Formatting code with black") 70 | 71 | 72 | def main(): 73 | """Main setup function.""" 74 | if len(sys.argv) < 2: 75 | print("Usage: python setup_dev.py ") 76 | print("Commands:") 77 | print(" install-tools - Install development tools") 78 | print(" clean - Clean build artifacts") 79 | print(" build - Build the package") 80 | print(" check - Check the package") 81 | print(" test - Run tests") 82 | print(" format - Format code") 83 | print(" full-build - Clean, build, and check") 84 | print(" dev-setup - Full development setup") 85 | sys.exit(1) 86 | 87 | command = sys.argv[1] 88 | 89 | if command == "install-tools": 90 | install_dev_tools() 91 | elif command == "clean": 92 | clean_build() 93 | elif command == "build": 94 | build_package() 95 | elif command == "check": 96 | check_package() 97 | elif command == "test": 98 | run_tests() 99 | elif command == "format": 100 | format_code() 101 | elif command == "full-build": 102 | clean_build() 103 | if build_package(): 104 | check_package() 105 | elif command == "dev-setup": 106 | install_dev_tools() 107 | format_code() 108 | run_tests() 109 | clean_build() 110 | if build_package(): 111 | check_package() 112 | print("\n✅ Development setup complete!") 113 | print("📦 Package built successfully!") 114 | print("\nNext steps:") 115 | print("1. Test install: pip install dist/*.whl") 116 | print( 117 | "2. Upload to TestPyPI: python -m twine upload --repository testpypi dist/*" 118 | ) 119 | print("3. Upload to PyPI: python -m twine upload dist/*") 120 | else: 121 | print(f"Unknown command: {command}") 122 | sys.exit(1) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /tailwind.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | content: [ 3 | './parallax_maker/assets/**/*.js', './parallax_maker/*.py', 4 | '!./parallax_maker/test_*.py', '!./parallax_maker/stabilityai.py', '!./parallax_maker/htmlcov/**', 5 | ], 6 | darkMode: 'selector', 7 | theme: { 8 | extend: {}, 9 | }, 10 | variants: { 11 | extend: {}, 12 | }, 13 | plugins: [], 14 | } 15 | -------------------------------------------------------------------------------- /tailwind.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; 4 | 5 | #app-container { 6 | @apply bg-white text-gray-800; 7 | @apply dark:bg-zinc-900 dark:text-slate-300; 8 | } 9 | 10 | .title-header { 11 | @apply flex justify-between items-center bg-blue-800 text-white p-4 mb-4; 12 | @apply dark:text-black dark:bg-blue-400; 13 | } 14 | 15 | .header-left { 16 | @apply flex items-center; 17 | } 18 | 19 | .header-right { 20 | @apply w-[120px]; 21 | } 22 | 23 | .title-text { 24 | @apply text-2xl font-bold text-center flex-grow; 25 | } 26 | 27 | .dark-mode-toggle { 28 | @apply px-4 py-2 bg-blue-600 text-white rounded-md; 29 | @apply hover:bg-blue-700; 30 | @apply dark:bg-amber-600 dark:hover:bg-amber-500; 31 | } 32 | 33 | .footer { 34 | @apply text-center text-gray-500 p-2; 35 | @apply dark:text-slate-400; 36 | } 37 | 38 | .help-box { 39 | @apply bg-yellow-100 border border-black; 40 | @apply dark:bg-yellow-200 dark:border-gray-700 dark:text-slate-700; 41 | } 42 | 43 | input, textarea { 44 | @apply text-gray-800 bg-white; 45 | @apply dark:text-slate-300 dark:bg-zinc-800; 46 | } 47 | 48 | button:disabled { 49 | @apply bg-gray-300 text-gray-500; 50 | @apply dark:bg-zinc-600 dark:text-slate-300; 51 | } 52 | 53 | button.color-not-selected { 54 | @apply bg-blue-500; 55 | @apply dark:bg-blue-400 dark:text-slate-200; 56 | @apply disabled:bg-gray-300 disabled:text-gray-500; 57 | @apply dark:disabled:bg-zinc-600 dark:disabled:text-slate-300; 58 | } 59 | 60 | button.color-is-selected { 61 | @apply bg-green-500 text-white; 62 | @apply dark:bg-green-400 dark:text-slate-200; 63 | @apply dark:disabled:bg-zinc-600 dark:disabled:text-slate-300; 64 | } 65 | 66 | .rc-slider-disabled { 67 | @apply bg-gray-300; 68 | @apply dark:bg-zinc-600; 69 | } 70 | 71 | .failure-color { 72 | @apply bg-red-300 text-white; 73 | @apply dark:bg-red-500 dark:text-slate-200; 74 | } 75 | 76 | .erase-color { 77 | @apply bg-red-500 text-white; 78 | } 79 | 80 | .color-is-selected-light { 81 | @apply bg-green-200; 82 | @apply dark:bg-green-600 dark:text-slate-200; 83 | } 84 | 85 | .has-history-color { 86 | @apply text-emerald-400; 87 | } 88 | 89 | .no-history-color { 90 | @apply text-orange-800; 91 | } 92 | 93 | .progress-bar { 94 | @apply h-3 w-full bg-gray-200 rounded-lg; 95 | @apply dark:bg-zinc-800; 96 | } 97 | 98 | .progress-bar-fill { 99 | @apply color-is-selected w-0 h-full rounded-lg transition-all; 100 | } 101 | 102 | .general-element { 103 | @apply bg-blue-500 text-white p-2 rounded-md; 104 | @apply dark:bg-blue-400 dark:text-slate-200; 105 | } 106 | 107 | .nav-button { 108 | @apply bg-blue-500 text-white p-1 rounded-full; 109 | @apply dark:bg-blue-400 dark:text-slate-200; 110 | } 111 | 112 | .tools-container { 113 | @apply flex justify-between; 114 | } 115 | 116 | .tools-backdrop { 117 | @apply flex justify-center p-1 bg-gray-200 rounded-md mt-1; 118 | @apply dark:bg-zinc-800; 119 | } 120 | 121 | .general-border { 122 | @apply border-dashed border-2 border-blue-500 rounded-md p-2; 123 | @apply dark:border-blue-400; 124 | } 125 | 126 | .image_border { 127 | @apply w-full h-full object-contain border-solid border-2 border-slate-500; 128 | } 129 | 130 | .general-container { 131 | @apply w-full general-border items-center justify-center; 132 | } 133 | 134 | .light-border { 135 | @apply p-2 border border-gray-300 rounded-md mb-2; 136 | @apply dark:border-gray-700; 137 | } 138 | 139 | .gltf-container { 140 | @apply w-full h-full bg-gray-400 p-2; 141 | @apply dark:bg-zinc-800 dark:text-gray-400; 142 | } 143 | 144 | .depth-number-text { 145 | @apply text-amber-800 text-opacity-50 text-center; 146 | @apply dark:text-amber-300 dark:text-opacity-50; 147 | } 148 | 149 | .depth-number-position { 150 | @apply absolute top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2; 151 | } 152 | 153 | .depth-number-display { 154 | @apply depth-number-text depth-number-position text-6xl; 155 | } 156 | 157 | .depth-number-input { 158 | @apply depth-number-text depth-number-position z-10 w-full h-full text-3xl sm:text-4xl md:text-5xl lg:text-6xl; 159 | } 160 | 161 | .general-dropdown .Select-control { 162 | @apply bg-gray-100 dark:bg-zinc-800 text-gray-800 dark:text-slate-200 rounded-md; 163 | } 164 | 165 | .general-dropdown .Select-menu-outer { 166 | @apply bg-gray-100 dark:bg-zinc-800 dark:text-slate-200; 167 | } 168 | 169 | /* React is overriding the colors here and so we need to add !important */ 170 | .general-dropdown .Select-option { 171 | @apply text-gray-800 dark:text-slate-200 hover:bg-gray-200 dark:hover:bg-gray-600 !important; 172 | } 173 | 174 | .general-dropdown .Select-value-label { 175 | @apply text-gray-800 dark:text-slate-200 hover:bg-gray-200 dark:hover:bg-gray-600 !important; 176 | } 177 | 178 | /* Allows us to highlight images as an overlay */ 179 | .overlay { 180 | @apply absolute inset-0 bg-green-200 bg-opacity-50 flex items-center justify-center; 181 | @apply dark:bg-green-600 dark:bg-opacity-50; 182 | } 183 | -------------------------------------------------------------------------------- /test_installation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Test script to verify the parallax-maker package installation. 4 | """ 5 | 6 | import subprocess 7 | import sys 8 | import tempfile 9 | from pathlib import Path 10 | 11 | 12 | def run_command(cmd, description, capture_output=True, timeout=120): 13 | """Run a command and return success status.""" 14 | print(f"Testing: {description}") 15 | print(f"Command: {cmd}") 16 | 17 | try: 18 | result = subprocess.run( 19 | cmd, 20 | shell=True, 21 | capture_output=capture_output, 22 | text=True, 23 | timeout=timeout, 24 | check=False, 25 | ) 26 | if result.returncode == 0: 27 | print("✅ Success") 28 | if capture_output and result.stdout.strip(): 29 | print(f"Output: {result.stdout.strip()}") 30 | return True 31 | else: 32 | print("❌ Failed") 33 | if capture_output and result.stderr.strip(): 34 | print(f"Error: {result.stderr.strip()}") 35 | return False 36 | except subprocess.TimeoutExpired: 37 | print("⏱️ Timeout") 38 | return False 39 | except Exception as e: 40 | print(f"❌ Exception: {e}") 41 | return False 42 | 43 | 44 | def main(): 45 | """Test the package installation.""" 46 | print("🧪 Testing Parallax Maker Package Installation") 47 | print("=" * 50) 48 | 49 | # Find the wheel file 50 | dist_dir = Path("dist") 51 | if not dist_dir.exists(): 52 | print("❌ No dist/ directory found. Run 'python -m build' first.") 53 | sys.exit(1) 54 | 55 | wheel_files = list(dist_dir.glob("*.whl")) 56 | if not wheel_files: 57 | print("❌ No wheel files found in dist/") 58 | sys.exit(1) 59 | 60 | wheel_file = wheel_files[0] 61 | print(f"📦 Found wheel: {wheel_file}") 62 | 63 | # Create a temporary virtual environment 64 | with tempfile.TemporaryDirectory() as temp_dir: 65 | venv_dir = Path(temp_dir) / "test_env" 66 | 67 | print(f"\n🔧 Creating test environment in {venv_dir}") 68 | 69 | # Create virtual environment 70 | if not run_command( 71 | f"python -m venv {venv_dir}", "Creating virtual environment" 72 | ): 73 | return False 74 | 75 | # Determine activation script 76 | if sys.platform == "win32": 77 | activate_script = venv_dir / "Scripts" / "activate" 78 | python_exe = venv_dir / "Scripts" / "python.exe" 79 | else: 80 | activate_script = venv_dir / "bin" / "activate" 81 | python_exe = venv_dir / "bin" / "python" 82 | 83 | # Upgrade pip first 84 | upgrade_cmd = f"{python_exe} -m pip install --upgrade pip" 85 | if not run_command(upgrade_cmd, "Upgrading pip"): 86 | return False 87 | 88 | # Install the wheel with dependencies 89 | install_cmd = f"{python_exe} -m pip install {wheel_file.absolute()}" 90 | if not run_command( 91 | install_cmd, "Installing parallax-maker wheel", capture_output=False 92 | ): 93 | return False 94 | 95 | # Test importing the package 96 | import_cmd = f'{python_exe} -c "import parallax_maker; print(f\\"Parallax Maker v{{parallax_maker.__version__}} imported successfully\\")"' 97 | if not run_command(import_cmd, "Testing package import"): 98 | return False 99 | 100 | # Test CLI commands 101 | help_cmd = f"{python_exe} -c \"import sys; sys.argv=['parallax-maker', '--help']; from parallax_maker.webui import main; main()\"" 102 | if not run_command(help_cmd, "Testing parallax-maker --help"): 103 | return False 104 | 105 | gltf_help_cmd = f"{python_exe} -c \"import sys; sys.argv=['parallax-gltf-cli', '--help']; from parallax_maker.gltf_cli import main; main()\"" 106 | if not run_command(gltf_help_cmd, "Testing parallax-gltf-cli --help"): 107 | return False 108 | 109 | print("\n🎉 All tests passed!") 110 | print("📦 Package is ready for PyPI upload!") 111 | print("\nNext steps:") 112 | print( 113 | "1. Upload to TestPyPI: python -m twine upload --repository testpypi dist/*" 114 | ) 115 | print( 116 | "2. Test from TestPyPI: pip install --index-url https://test.pypi.org/simple/ parallax-maker" 117 | ) 118 | print("3. Upload to PyPI: python -m twine upload dist/*") 119 | 120 | return True 121 | 122 | 123 | if __name__ == "__main__": 124 | success = main() 125 | sys.exit(0 if success else 1) 126 | --------------------------------------------------------------------------------