├── src └── napari_stable_diffusion │ ├── _tests │ ├── __init__.py │ └── test_widget.py │ ├── utils.py │ ├── __init__.py │ ├── napari.yaml │ ├── _widget.py │ ├── _widget_img2img.py │ └── _widget_inpaint.py ├── MANIFEST.in ├── napari_stable_diffusion_demo.png ├── pyproject.toml ├── .napari-hub ├── DESCRIPTION.md └── config.yml ├── tox.ini ├── .pre-commit-config.yaml ├── .gitignore ├── LICENSE ├── setup.cfg ├── .github └── workflows │ └── test_and_deploy.yml └── README.md /src/napari_stable_diffusion/_tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/napari_stable_diffusion/utils.py: -------------------------------------------------------------------------------- 1 | def get_stable_diffusion_model(): 2 | return "runwayml/stable-diffusion-v1-5" 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | 4 | recursive-exclude * __pycache__ 5 | recursive-exclude * *.py[co] 6 | -------------------------------------------------------------------------------- /napari_stable_diffusion_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kephale/napari-stable-diffusion/HEAD/napari_stable_diffusion_demo.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42.0.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | 7 | [tool.black] 8 | line-length = 79 9 | 10 | [tool.isort] 11 | profile = "black" 12 | line_length = 79 13 | -------------------------------------------------------------------------------- /src/napari_stable_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.1" 2 | from ._widget import StableDiffusionWidget 3 | from ._widget_img2img import StableDiffusionImg2ImgWidget 4 | from ._widget_inpaint import StableDiffusionInpaintWidget 5 | 6 | __all__ = ( 7 | "StableDiffusionWidget", 8 | "StableDiffusionImg2ImgWidget", 9 | "StableDiffusionInpaintWidget", 10 | ) 11 | -------------------------------------------------------------------------------- /.napari-hub/DESCRIPTION.md: -------------------------------------------------------------------------------- 1 | 8 | 9 | The developer has not yet provided a napari-hub specific description. 10 | -------------------------------------------------------------------------------- /.napari-hub/config.yml: -------------------------------------------------------------------------------- 1 | # You may use this file to customize how your plugin page appears 2 | # on the napari hub: https://www.napari-hub.org/ 3 | # See their wiki for details https://github.com/chanzuckerberg/napari-hub/wiki 4 | 5 | # Please note that this file should only be used IN ADDITION to entering 6 | # metadata fields (such as summary, description, authors, and various URLS) 7 | # in your standard python package metadata (e.g. setup.cfg, setup.py, or 8 | # pyproject.toml), when you would like those fields to be displayed 9 | # differently on the hub than in the napari application. 10 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # For more information about tox, see https://tox.readthedocs.io/en/latest/ 2 | [tox] 3 | envlist = py{38,39,310}-{linux,macos,windows} 4 | isolated_build=true 5 | 6 | [gh-actions] 7 | python = 8 | 3.8: py38 9 | 3.9: py39 10 | 3.10: py310 11 | 12 | [gh-actions:env] 13 | PLATFORM = 14 | ubuntu-latest: linux 15 | macos-latest: macos 16 | windows-latest: windows 17 | 18 | [testenv] 19 | platform = 20 | macos: darwin 21 | linux: linux 22 | windows: win32 23 | passenv = 24 | CI 25 | GITHUB_ACTIONS 26 | DISPLAY XAUTHORITY 27 | NUMPY_EXPERIMENTAL_ARRAY_FUNCTION 28 | PYVISTA_OFF_SCREEN 29 | extras = 30 | testing 31 | commands = pytest -v --color=yes --cov=napari_stable_diffusion --cov-report=xml 32 | -------------------------------------------------------------------------------- /src/napari_stable_diffusion/napari.yaml: -------------------------------------------------------------------------------- 1 | name: napari-stable-diffusion 2 | display_name: Stable Diffusion 3 | contributions: 4 | commands: 5 | - id: napari-stable-diffusion.make_qwidget 6 | python_name: napari_stable_diffusion._widget:StableDiffusionWidget 7 | title: "Stable Diffusion: Text to Image" 8 | - id: napari-stable-diffusion.make_img2img_qwidget 9 | python_name: napari_stable_diffusion._widget_img2img:StableDiffusionImg2ImgWidget 10 | title: "Stable Diffusion: Image to Image" 11 | - id: napari-stable-diffusion.make_inpaint_qwidget 12 | python_name: napari_stable_diffusion._widget_inpaint:StableDiffusionInpaintWidget 13 | title: "Stable Diffusion: Inpaint" 14 | widgets: 15 | - command: napari-stable-diffusion.make_qwidget 16 | display_name: "Text to Image" 17 | - command: napari-stable-diffusion.make_img2img_qwidget 18 | display_name: "Image to Image" 19 | - command: napari-stable-diffusion.make_inpaint_qwidget 20 | display_name: "Inpainting" 21 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.3.0 4 | hooks: 5 | - id: check-docstring-first 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | exclude: ^.napari-hub/* 9 | - repo: https://github.com/PyCQA/isort 10 | rev: 5.10.1 11 | hooks: 12 | - id: isort 13 | - repo: https://github.com/asottile/pyupgrade 14 | rev: v2.38.0 15 | hooks: 16 | - id: pyupgrade 17 | args: [--py38-plus, --keep-runtime-typing] 18 | - repo: https://github.com/myint/autoflake 19 | rev: v1.5.3 20 | hooks: 21 | - id: autoflake 22 | args: ["--in-place", "--remove-all-unused-imports"] 23 | - repo: https://github.com/psf/black 24 | rev: 22.8.0 25 | hooks: 26 | - id: black 27 | - repo: https://github.com/PyCQA/flake8 28 | rev: 5.0.4 29 | hooks: 30 | - id: flake8 31 | additional_dependencies: [flake8-typing-imports>=1.9.0] 32 | - repo: https://github.com/tlambert03/napari-plugin-checks 33 | rev: v0.3.0 34 | hooks: 35 | - id: napari-plugin-checks 36 | # https://mypy.readthedocs.io/en/stable/introduction.html 37 | # you may wish to add this as well! 38 | # - repo: https://github.com/pre-commit/mirrors-mypy 39 | # rev: v0.910-1 40 | # hooks: 41 | # - id: mypy 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | .napari_cache 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask instance folder 58 | instance/ 59 | 60 | # Sphinx documentation 61 | docs/_build/ 62 | 63 | # MkDocs documentation 64 | /site/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Pycharm and VSCode 70 | .idea/ 71 | venv/ 72 | .vscode/ 73 | 74 | # IPython Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # OS 81 | .DS_Store 82 | 83 | # written by setuptools_scm 84 | **/_version.py 85 | -------------------------------------------------------------------------------- /src/napari_stable_diffusion/_tests/test_widget.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from napari_stable_diffusion import ( 4 | StableDiffusionWidget, 5 | StableDiffusionImg2ImgWidget, 6 | StableDiffusionInpaintWidget, 7 | ) 8 | 9 | 10 | # make_napari_viewer is a pytest fixture that returns a napari viewer object 11 | # capsys is a pytest fixture that captures stdout and stderr output streams 12 | def test_example_text2img(make_napari_viewer, capsys): 13 | # make viewer and add an image layer using our fixture 14 | viewer = make_napari_viewer() 15 | viewer.add_image(np.random.random((100, 100))) 16 | 17 | # create our widget, passing in the viewer 18 | my_widget = StableDiffusionWidget(viewer) 19 | 20 | assert my_widget is not None 21 | 22 | 23 | def test_example_img2img(make_napari_viewer, capsys): 24 | # make viewer and add an image layer using our fixture 25 | viewer = make_napari_viewer() 26 | viewer.add_image(np.random.random((100, 100))) 27 | 28 | # create our widget, passing in the viewer 29 | my_widget = StableDiffusionImg2ImgWidget(viewer) 30 | 31 | assert my_widget is not None 32 | 33 | 34 | def test_example_inpaint(make_napari_viewer, capsys): 35 | # make viewer and add an image layer using our fixture 36 | viewer = make_napari_viewer() 37 | viewer.add_image(np.random.random((100, 100))) 38 | 39 | # create our widget, passing in the viewer 40 | my_widget = StableDiffusionInpaintWidget(viewer) 41 | 42 | assert my_widget is not None 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) 2022, Kyle Harrington 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | * Neither the name of copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = napari-stable-diffusion 3 | version = 0.1.1 4 | description = A demo of stable diffusion in napari 5 | long_description = file: README.md 6 | long_description_content_type = text/markdown 7 | url = https://github.com/kephale/napari-stable-diffusion 8 | author = Kyle Harrington 9 | author_email = napari@kyleharrington.com 10 | license = BSD-3-Clause 11 | license_files = LICENSE 12 | classifiers = 13 | Development Status :: 2 - Pre-Alpha 14 | Framework :: napari 15 | Intended Audience :: Developers 16 | License :: OSI Approved :: BSD License 17 | Operating System :: OS Independent 18 | Programming Language :: Python 19 | Programming Language :: Python :: 3 20 | Programming Language :: Python :: 3 :: Only 21 | Programming Language :: Python :: 3.8 22 | Programming Language :: Python :: 3.9 23 | Programming Language :: Python :: 3.10 24 | Topic :: Scientific/Engineering :: Image Processing 25 | project_urls = 26 | Bug Tracker = https://github.com/kephale/napari-stable-diffusion/issues 27 | Documentation = https://github.com/kephale/napari-stable-diffusion#README.md 28 | Source Code = https://github.com/kephale/napari-stable-diffusion 29 | User Support = https://github.com/kephale/napari-stable-diffusion/issues 30 | 31 | [options] 32 | packages = find: 33 | install_requires = 34 | napari 35 | napari-plugin-engine>=0.1.4 36 | numpy 37 | magicgui 38 | qtpy 39 | diffusers 40 | transformers 41 | torch 42 | 43 | python_requires = >=3.8 44 | include_package_data = True 45 | package_dir = 46 | =src 47 | 48 | # add your package requirements here 49 | 50 | [options.packages.find] 51 | where = src 52 | 53 | [options.entry_points] 54 | napari.manifest = 55 | napari-stable-diffusion = napari_stable_diffusion:napari.yaml 56 | 57 | [options.extras_require] 58 | testing = 59 | tox 60 | pytest # https://docs.pytest.org/en/latest/contents.html 61 | pytest-cov # https://pytest-cov.readthedocs.io/en/latest/ 62 | pytest-qt # https://pytest-qt.readthedocs.io/en/latest/ 63 | napari 64 | pyqt5 65 | 66 | 67 | [options.package_data] 68 | * = *.yaml 69 | -------------------------------------------------------------------------------- /.github/workflows/test_and_deploy.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: tests 5 | 6 | on: 7 | push: 8 | branches: 9 | - main 10 | - npe2 11 | tags: 12 | - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 13 | pull_request: 14 | branches: 15 | - main 16 | - npe2 17 | workflow_dispatch: 18 | 19 | jobs: 20 | test: 21 | name: ${{ matrix.platform }} py${{ matrix.python-version }} 22 | runs-on: ${{ matrix.platform }} 23 | strategy: 24 | matrix: 25 | platform: [ubuntu-latest, windows-latest, macos-latest] 26 | python-version: ['3.8', '3.9', '3.10'] 27 | 28 | steps: 29 | - uses: actions/checkout@v3 30 | 31 | - name: Set up Python ${{ matrix.python-version }} 32 | uses: actions/setup-python@v4 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | 36 | # these libraries enable testing on Qt on linux 37 | - uses: tlambert03/setup-qt-libs@v1 38 | 39 | # strategy borrowed from vispy for installing opengl libs on windows 40 | - name: Install Windows OpenGL 41 | if: runner.os == 'Windows' 42 | run: | 43 | git clone --depth 1 https://github.com/pyvista/gl-ci-helpers.git 44 | powershell gl-ci-helpers/appveyor/install_opengl.ps1 45 | 46 | # note: if you need dependencies from conda, considering using 47 | # setup-miniconda: https://github.com/conda-incubator/setup-miniconda 48 | # and 49 | # tox-conda: https://github.com/tox-dev/tox-conda 50 | - name: Install dependencies 51 | run: | 52 | python -m pip install --upgrade pip 53 | python -m pip install setuptools tox tox-gh-actions 54 | 55 | # this runs the platform-specific tests declared in tox.ini 56 | - name: Test with tox 57 | uses: GabrielBB/xvfb-action@v1 58 | with: 59 | run: python -m tox 60 | env: 61 | PLATFORM: ${{ matrix.platform }} 62 | 63 | - name: Coverage 64 | uses: codecov/codecov-action@v2 65 | 66 | deploy: 67 | # this will run when you have tagged a commit, starting with "v*" 68 | # and requires that you have put your twine API key in your 69 | # github secrets (see readme for details) 70 | needs: [test] 71 | runs-on: ubuntu-latest 72 | if: contains(github.ref, 'tags') 73 | steps: 74 | - uses: actions/checkout@v2 75 | - name: Set up Python 76 | uses: actions/setup-python@v2 77 | with: 78 | python-version: "3.x" 79 | - name: Install dependencies 80 | run: | 81 | python -m pip install --upgrade pip 82 | pip install -U setuptools setuptools_scm wheel twine build 83 | - name: Build and publish 84 | env: 85 | TWINE_USERNAME: __token__ 86 | TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} 87 | run: | 88 | git tag 89 | python -m build . 90 | twine upload dist/* 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # napari-stable-diffusion 2 | 3 | [![License BSD-3](https://img.shields.io/pypi/l/napari-stable-diffusion.svg?color=green)](https://github.com/kephale/napari-stable-diffusion/raw/main/LICENSE) 4 | [![PyPI](https://img.shields.io/pypi/v/napari-stable-diffusion.svg?color=green)](https://pypi.org/project/napari-stable-diffusion) 5 | [![Python Version](https://img.shields.io/pypi/pyversions/napari-stable-diffusion.svg?color=green)](https://python.org) 6 | [![tests](https://github.com/kephale/napari-stable-diffusion/workflows/tests/badge.svg)](https://github.com/kephale/napari-stable-diffusion/actions) 7 | [![codecov](https://codecov.io/gh/kephale/napari-stable-diffusion/branch/main/graph/badge.svg)](https://codecov.io/gh/kephale/napari-stable-diffusion) 8 | [![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/napari-stable-diffusion)](https://napari-hub.org/plugins/napari-stable-diffusion) 9 | 10 | A demo of stable diffusion in napari. 11 | 12 | ---------------------------------- 13 | 14 | This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template. 15 | 16 | ![demo image of napari-stable-diffusion of the prompt "a unicorn and a dinosaur eating cookies and drinking tea"](https://github.com/kephale/napari-stable-diffusion/raw/main/napari_stable_diffusion_demo.png) 17 | 18 | 25 | 26 | ## Installation 27 | 28 | You can install `napari-stable-diffusion` via [pip]: 29 | 30 | pip install napari-stable-diffusion 31 | 32 | To install latest development version : 33 | 34 | pip install git+https://github.com/kephale/napari-stable-diffusion.git 35 | 36 | You will also need to sign up with HuggingFace and [generate an access 37 | token](https://huggingface.co/docs/hub/security-tokens) to get access to the 38 | Stable Diffusion model we use. 39 | 40 | When you have generated your access token you can either permanently 41 | set the `HF_TOKEN_SD` environment variable in your `.bashrc` or whichever file 42 | your OS uses, or you can include it on the command line 43 | 44 | ``` 45 | HF_TOKEN_SD="hf_aaaAaaaasdadsadsaoaoaoasoidijo" napari 46 | ``` 47 | 48 | For more information on the Stable Diffusion model itself, please see https://huggingface.co/CompVis/stable-diffusion-v1-4. 49 | 50 | ### Apple M1 specific instructions 51 | 52 | To utilize the M1 GPU, the nightly version of PyTorch needs to be 53 | installed. Consider using `conda` or `mamba` like this: 54 | 55 | ``` 56 | mamba create -c pytorch-nightly -n napari-stable-diffusion python=3.9 pip pyqt pytorch torchvision 57 | pip install git+https://github.com/kephale/napari-stable-diffusion.git 58 | ``` 59 | 60 | ## Next steps 61 | 62 | - Image 2 Image support 63 | - Inpainting support 64 | 65 | ## Contributing 66 | 67 | Contributions are very welcome. Tests can be run with [tox], please ensure 68 | the coverage at least stays the same before you submit a pull request. 69 | 70 | ## License 71 | 72 | Distributed under the terms of the [BSD-3] license, 73 | "napari-stable-diffusion" is free and open source software 74 | 75 | ## Issues 76 | 77 | If you encounter any problems, please [file an issue] along with a detailed description. 78 | 79 | [napari]: https://github.com/napari/napari 80 | [Cookiecutter]: https://github.com/audreyr/cookiecutter 81 | [@napari]: https://github.com/napari 82 | [MIT]: http://opensource.org/licenses/MIT 83 | [BSD-3]: http://opensource.org/licenses/BSD-3-Clause 84 | [GNU GPL v3.0]: http://www.gnu.org/licenses/gpl-3.0.txt 85 | [GNU LGPL v3.0]: http://www.gnu.org/licenses/lgpl-3.0.txt 86 | [Apache Software License 2.0]: http://www.apache.org/licenses/LICENSE-2.0 87 | [Mozilla Public License 2.0]: https://www.mozilla.org/media/MPL/2.0/index.txt 88 | [cookiecutter-napari-plugin]: https://github.com/napari/cookiecutter-napari-plugin 89 | 90 | [file an issue]: https://github.com/kephale/napari-stable-diffusion/issues 91 | 92 | [napari]: https://github.com/napari/napari 93 | [tox]: https://tox.readthedocs.io/en/latest/ 94 | [pip]: https://pypi.org/project/pip/ 95 | [PyPI]: https://pypi.org/ 96 | -------------------------------------------------------------------------------- /src/napari_stable_diffusion/_widget.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module is an example of a barebones QWidget plugin for napari 3 | 4 | It implements the Widget specification. 5 | see: https://napari.org/stable/plugins/guides.html?#widgets 6 | 7 | Replace code below according to your needs. 8 | """ 9 | from typing import TYPE_CHECKING 10 | 11 | from qtpy.QtWidgets import ( 12 | QPushButton, 13 | QWidget, 14 | QComboBox, 15 | QSpinBox, 16 | QCheckBox, 17 | QVBoxLayout, 18 | QLabel, 19 | QPlainTextEdit, 20 | ) 21 | 22 | import numpy as np 23 | 24 | import os 25 | 26 | import torch 27 | from diffusers import StableDiffusionPipeline 28 | 29 | if TYPE_CHECKING: 30 | import napari 31 | 32 | from napari.qt.threading import thread_worker 33 | 34 | from napari_stable_diffusion.utils import get_stable_diffusion_model 35 | 36 | 37 | class StableDiffusionWidget(QWidget): 38 | def __init__(self, napari_viewer): 39 | super().__init__() 40 | self.viewer = napari_viewer 41 | 42 | # Textbox for entering prompt 43 | self.prompt_textbox = QPlainTextEdit(self) 44 | 45 | # Number of output images 46 | self.gallery_size = QSpinBox(self) 47 | self.gallery_size.setMinimum(1) 48 | self.gallery_size.setValue(9) 49 | 50 | # Width and height 51 | self.width_input = QSpinBox(self) 52 | self.width_input.setMinimum(1) 53 | self.width_input.setMaximum(2**31 - 1) 54 | # Overflows if larger than this maximum 55 | self.width_input.setValue(512) 56 | 57 | self.height_input = QSpinBox(self) 58 | self.height_input.setMinimum(1) 59 | self.height_input.setMaximum(2**31 - 1) 60 | self.height_input.setValue(512) 61 | 62 | # Select devices: 63 | # CPU is always available 64 | available_devices = ["cpu"] 65 | # Add 'mps' for M1 66 | if ( 67 | hasattr(torch.backends, "mps") 68 | and torch.backends.mps.is_available() 69 | ): 70 | available_devices += ["mps"] 71 | # Add 'cuda' for nvidia cards 72 | if torch.cuda.is_available(): 73 | available_devices += [ 74 | f"cuda:{id}" for id in range(torch.cuda.device_count()) 75 | ] 76 | 77 | self.device_list = QComboBox(self) 78 | self.device_list.addItems(available_devices) 79 | 80 | self.num_inference_steps = QSpinBox(self) 81 | self.num_inference_steps.setMinimum(1) 82 | self.num_inference_steps.setValue(50) 83 | 84 | # Not Safe For Work button 85 | self.nsfw_button = QCheckBox(self) 86 | self.nsfw_button.setCheckState(True) 87 | 88 | btn = QPushButton("Run") 89 | btn.clicked.connect(self._on_click) 90 | 91 | # Layout and labels 92 | self.setLayout(QVBoxLayout()) 93 | 94 | label = QLabel(self) 95 | label.setText("Prompt") 96 | self.layout().addWidget(label) 97 | self.layout().addWidget(self.prompt_textbox) 98 | 99 | # negative prompt: ugly, disfigured, low quality, blurry, nsfw 100 | 101 | label = QLabel(self) 102 | label.setText("Number of images") 103 | self.layout().addWidget(label) 104 | self.layout().addWidget(self.gallery_size) 105 | 106 | label = QLabel(self) 107 | label.setText("Number of inference steps") 108 | self.layout().addWidget(label) 109 | self.layout().addWidget(self.num_inference_steps) 110 | 111 | label = QLabel(self) 112 | label.setText("Image width") 113 | self.layout().addWidget(label) 114 | self.layout().addWidget(self.width_input) 115 | 116 | label = QLabel(self) 117 | label.setText("Image height") 118 | self.layout().addWidget(label) 119 | self.layout().addWidget(self.height_input) 120 | 121 | label = QLabel(self) 122 | label.setText("Enable Not Safe For Work mode") 123 | self.layout().addWidget(label) 124 | self.layout().addWidget(self.nsfw_button) 125 | 126 | label = QLabel(self) 127 | label.setText("Compute device") 128 | self.layout().addWidget(label) 129 | self.layout().addWidget(self.device_list) 130 | 131 | self.layout().addWidget(btn) 132 | 133 | def _on_click(self): 134 | # Has issues on mps and small GPUs 135 | # self.generate_images_batch() 136 | 137 | # worker = create_worker(self.generate_images_sequential) 138 | # worker.start() 139 | 140 | # TODO: Notify the user that things are processing 141 | 142 | worker = self.generate_images_sequential() 143 | 144 | def yield_catcher(payload): 145 | array, title = payload 146 | 147 | self.viewer.add_image(array, name=title, rgb=True) 148 | 149 | # Show gallery as grid 150 | self.viewer.grid.enabled = True 151 | 152 | worker.yielded.connect(yield_catcher) 153 | 154 | worker.start() 155 | 156 | @thread_worker 157 | def generate_images_sequential(self): 158 | prompt = self.prompt_textbox.document().toPlainText() 159 | print(f"Prompt is {prompt}") 160 | 161 | # Get the device: cpu or gpu 162 | device = self.device_list.currentText() 163 | 164 | # Get huggingface token from environment variable. Generate at HF 165 | MY_SECRET_TOKEN = ( 166 | os.environ.get("HF_TOKEN_SD") 167 | if "HF_TOKEN_SD" in os.environ 168 | else None 169 | ) 170 | 171 | # Pre-generate the latents to ensure correct dtype 172 | batch_size = len(prompt) 173 | in_channels = 3 174 | height = self.height_input.value() 175 | width = self.width_input.value() 176 | latents_shape = (batch_size, in_channels, height // 8, width // 8) 177 | 178 | # Load the pipeline 179 | pipe = StableDiffusionPipeline.from_pretrained( 180 | get_stable_diffusion_model(), 181 | use_auth_token=MY_SECRET_TOKEN, 182 | ) 183 | pipe.to(device) 184 | 185 | # Run the pipeline 186 | num_images = self.gallery_size.value() 187 | 188 | # Populate the gallery 189 | for gallery_id in range(num_images): 190 | # Generate our random latent space uniquely per image 191 | latents = torch.randn( 192 | latents_shape, 193 | generator=None, 194 | device=("cpu" if device == "mps" else device), 195 | # dtype=torch.float16, 196 | ) 197 | pipe.latents = latents 198 | pipe.to(device) 199 | 200 | image_list = pipe([prompt], 201 | height=height, 202 | width=width, 203 | num_inference_steps=self.num_inference_steps.value() 204 | ) 205 | 206 | array = np.array(image_list.images[0]) 207 | 208 | # If NSFW, then zero over image 209 | if image_list["nsfw_content_detected"][0]: 210 | array = np.zeros_like(array) 211 | 212 | # Empty GPU cache as we generate images 213 | if torch.cuda.is_available(): 214 | torch.cuda.empty_cache() 215 | 216 | yield (array, f"nsd_{prompt}-{gallery_id}") 217 | -------------------------------------------------------------------------------- /src/napari_stable_diffusion/_widget_img2img.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module is an example of a barebones QWidget plugin for napari 3 | 4 | It implements the Widget specification. 5 | see: https://napari.org/stable/plugins/guides.html?#widgets 6 | 7 | Replace code below according to your needs. 8 | """ 9 | from qtpy.QtWidgets import ( 10 | QPushButton, 11 | QWidget, 12 | QComboBox, 13 | QSpinBox, 14 | QCheckBox, 15 | QVBoxLayout, 16 | QLabel, 17 | QPlainTextEdit, 18 | ) 19 | 20 | from magicgui.widgets import create_widget 21 | 22 | from PIL import Image 23 | import numpy as np 24 | 25 | import os 26 | 27 | import torch 28 | from diffusers import StableDiffusionImg2ImgPipeline 29 | 30 | import napari 31 | 32 | from napari.qt.threading import thread_worker 33 | 34 | from napari_stable_diffusion.utils import get_stable_diffusion_model 35 | 36 | 37 | class StableDiffusionImg2ImgWidget(QWidget): 38 | def __init__(self, napari_viewer): 39 | super().__init__() 40 | self.viewer = napari_viewer 41 | 42 | # Textbox for entering prompt 43 | self.prompt_textbox = QPlainTextEdit(self) 44 | 45 | # Number of output images 46 | self.gallery_size = QSpinBox(self) 47 | self.gallery_size.setMinimum(1) 48 | self.gallery_size.setValue(9) 49 | 50 | # Width and height 51 | self.width_input = QSpinBox(self) 52 | self.width_input.setMinimum(1) 53 | self.width_input.setMaximum(2**31 - 1) 54 | # Overflows if larger than this maximum 55 | self.width_input.setValue(512) 56 | 57 | self.height_input = QSpinBox(self) 58 | self.height_input.setMinimum(1) 59 | self.height_input.setMaximum(2**31 - 1) 60 | self.height_input.setValue(512) 61 | 62 | # Select devices: 63 | # CPU is always available 64 | available_devices = ["cpu"] 65 | # Add 'mps' for M1 66 | if ( 67 | hasattr(torch.backends, "mps") 68 | and torch.backends.mps.is_available() 69 | ): 70 | available_devices += ["mps"] 71 | # Add 'cuda' for nvidia cards 72 | if torch.cuda.is_available(): 73 | available_devices += [ 74 | f"cuda:{id}" for id in range(torch.cuda.device_count()) 75 | ] 76 | 77 | self.device_list = QComboBox(self) 78 | self.device_list.addItems(available_devices) 79 | 80 | self.num_inference_steps = QSpinBox(self) 81 | self.num_inference_steps.setMinimum(1) 82 | self.num_inference_steps.setValue(50) 83 | 84 | # Not Safe For Work button 85 | self.nsfw_button = QCheckBox(self) 86 | self.nsfw_button.setCheckState(True) 87 | 88 | btn = QPushButton("Run") 89 | btn.clicked.connect(self._on_click) 90 | 91 | # Layout and labels 92 | self.setLayout(QVBoxLayout()) 93 | self._image_layers = create_widget(annotation=napari.layers.Image) 94 | self.layout().addWidget(QLabel("Image")) 95 | self.layout().addWidget(self._image_layers.native) 96 | 97 | label = QLabel(self) 98 | label.setText("Prompt") 99 | self.layout().addWidget(label) 100 | self.layout().addWidget(self.prompt_textbox) 101 | 102 | label = QLabel(self) 103 | label.setText("Number of images") 104 | self.layout().addWidget(label) 105 | self.layout().addWidget(self.gallery_size) 106 | 107 | label = QLabel(self) 108 | label.setText("Number of inference steps") 109 | self.layout().addWidget(label) 110 | self.layout().addWidget(self.num_inference_steps) 111 | 112 | label = QLabel(self) 113 | label.setText("Image width") 114 | self.layout().addWidget(label) 115 | self.layout().addWidget(self.width_input) 116 | 117 | label = QLabel(self) 118 | label.setText("Image height") 119 | self.layout().addWidget(label) 120 | self.layout().addWidget(self.height_input) 121 | 122 | label = QLabel(self) 123 | label.setText("Enable Not Safe For Work mode") 124 | self.layout().addWidget(label) 125 | self.layout().addWidget(self.nsfw_button) 126 | 127 | label = QLabel(self) 128 | label.setText("Compute device") 129 | self.layout().addWidget(label) 130 | self.layout().addWidget(self.device_list) 131 | 132 | self.layout().addWidget(btn) 133 | 134 | def _on_click(self): 135 | # Has issues on mps and small GPUs 136 | # self.generate_images_batch() 137 | 138 | # worker = create_worker(self.generate_images_sequential) 139 | # worker.start() 140 | 141 | # TODO: Notify the user that things are processing 142 | 143 | worker = self.generate_images_sequential() 144 | 145 | def yield_catcher(payload): 146 | array, title = payload 147 | 148 | self.viewer.add_image(array, name=title, rgb=True) 149 | 150 | # Show gallery as grid 151 | self.viewer.grid.enabled = True 152 | 153 | worker.yielded.connect(yield_catcher) 154 | 155 | worker.start() 156 | 157 | @thread_worker 158 | def generate_images_sequential(self): 159 | prompt = self.prompt_textbox.document().toPlainText() 160 | print(f"Prompt is {prompt}") 161 | 162 | # Get the device: cpu or gpu 163 | device = self.device_list.currentText() 164 | 165 | # Get huggingface token from environment variable. Generate at HF 166 | MY_SECRET_TOKEN = ( 167 | os.environ.get("HF_TOKEN_SD") 168 | if "HF_TOKEN_SD" in os.environ 169 | else None 170 | ) 171 | 172 | # Pre-generate the latents to ensure correct dtype 173 | batch_size = len(prompt) 174 | in_channels = 3 175 | height = self.height_input.value() 176 | width = self.width_input.value() 177 | latents_shape = (batch_size, in_channels, height // 8, width // 8) 178 | 179 | # Load the pipeline 180 | pipe = StableDiffusionImg2ImgPipeline.from_pretrained( 181 | get_stable_diffusion_model(), 182 | use_auth_token=MY_SECRET_TOKEN, 183 | height=height, 184 | width=width, 185 | num_inference_steps=self.num_inference_steps.value(), 186 | ) 187 | pipe.to(device) 188 | 189 | # Fail if no image selected 190 | if self._image_layers.value is None: 191 | print("No image selected") 192 | return 193 | 194 | # Get initial image 195 | init_image = Image.fromarray(self._image_layers.value.data).convert( 196 | "RGB" 197 | ) 198 | init_image = init_image.resize((768, 512)) 199 | 200 | # Run the pipeline 201 | num_images = self.gallery_size.value() 202 | 203 | # Populate the gallery 204 | for gallery_id in range(num_images): 205 | # Generate our random latent space uniquely per image 206 | latents = torch.randn( 207 | latents_shape, 208 | generator=None, 209 | device=("cpu" if device == "mps" else device), 210 | # dtype=torch.float16, 211 | ) 212 | pipe.latents = latents 213 | pipe.to(device) 214 | 215 | # TODO add strength and guidance_scale to GUI 216 | image_list = pipe( 217 | prompt=[prompt], 218 | image=init_image, 219 | strength=0.75, 220 | guidance_scale=7.5, 221 | ) 222 | 223 | array = np.array(image_list.images[0]) 224 | 225 | # If NSFW, then zero over image 226 | if image_list["nsfw_content_detected"][0]: 227 | array = np.zeros_like(array) 228 | 229 | # Empty GPU cache as we generate images 230 | if torch.cuda.is_available(): 231 | torch.cuda.empty_cache() 232 | 233 | yield (array, f"nsd_{prompt}-{gallery_id}") 234 | -------------------------------------------------------------------------------- /src/napari_stable_diffusion/_widget_inpaint.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module is an example of a barebones QWidget plugin for napari 3 | 4 | It implements the Widget specification. 5 | see: https://napari.org/stable/plugins/guides.html?#widgets 6 | 7 | Replace code below according to your needs. 8 | """ 9 | from qtpy.QtWidgets import ( 10 | QPushButton, 11 | QWidget, 12 | QComboBox, 13 | QSpinBox, 14 | QCheckBox, 15 | QVBoxLayout, 16 | QLabel, 17 | QPlainTextEdit, 18 | ) 19 | 20 | from magicgui.widgets import create_widget 21 | 22 | from PIL import Image 23 | import numpy as np 24 | 25 | import os 26 | 27 | import torch 28 | from diffusers import StableDiffusionInpaintPipeline 29 | 30 | import napari 31 | 32 | from napari.qt.threading import thread_worker 33 | 34 | from napari_stable_diffusion.utils import get_stable_diffusion_model 35 | 36 | class StableDiffusionInpaintWidget(QWidget): 37 | def __init__(self, napari_viewer): 38 | super().__init__() 39 | self.viewer = napari_viewer 40 | 41 | # Textbox for entering prompt 42 | self.prompt_textbox = QPlainTextEdit(self) 43 | 44 | # Number of output images 45 | self.gallery_size = QSpinBox(self) 46 | self.gallery_size.setMinimum(1) 47 | self.gallery_size.setValue(9) 48 | 49 | # Width and height 50 | self.width_input = QSpinBox(self) 51 | self.width_input.setMinimum(1) 52 | self.width_input.setMaximum(2**31 - 1) 53 | # Overflows if larger than this maximum 54 | self.width_input.setValue(512) 55 | 56 | self.height_input = QSpinBox(self) 57 | self.height_input.setMinimum(1) 58 | self.height_input.setMaximum(2**31 - 1) 59 | self.height_input.setValue(512) 60 | 61 | # Select devices: 62 | # CPU is always available 63 | available_devices = ["cpu"] 64 | # Add 'mps' for M1 65 | if ( 66 | hasattr(torch.backends, "mps") 67 | and torch.backends.mps.is_available() 68 | ): 69 | available_devices += ["mps"] 70 | # Add 'cuda' for nvidia cards 71 | if torch.cuda.is_available(): 72 | available_devices += [ 73 | f"cuda:{id}" for id in range(torch.cuda.device_count()) 74 | ] 75 | 76 | self.device_list = QComboBox(self) 77 | self.device_list.addItems(available_devices) 78 | 79 | self.num_inference_steps = QSpinBox(self) 80 | self.num_inference_steps.setMinimum(1) 81 | self.num_inference_steps.setValue(50) 82 | 83 | # Not Safe For Work button 84 | self.nsfw_button = QCheckBox(self) 85 | self.nsfw_button.setCheckState(True) 86 | 87 | btn = QPushButton("Run") 88 | btn.clicked.connect(self._on_click) 89 | 90 | # Layout and labels 91 | self.setLayout(QVBoxLayout()) 92 | 93 | # Image selection widget 94 | self._image_layers = create_widget(annotation=napari.layers.Image) 95 | self.layout().addWidget(QLabel("Image")) 96 | self.layout().addWidget(self._image_layers.native) 97 | 98 | # Mask selection widget 99 | self._mask_layers = create_widget(annotation=napari.layers.Image) 100 | self.layout().addWidget(QLabel("Mask")) 101 | self.layout().addWidget(self._mask_layers.native) 102 | 103 | label = QLabel(self) 104 | label.setText("Prompt") 105 | self.layout().addWidget(label) 106 | self.layout().addWidget(self.prompt_textbox) 107 | 108 | label = QLabel(self) 109 | label.setText("Number of images") 110 | self.layout().addWidget(label) 111 | self.layout().addWidget(self.gallery_size) 112 | 113 | label = QLabel(self) 114 | label.setText("Number of inference steps") 115 | self.layout().addWidget(label) 116 | self.layout().addWidget(self.num_inference_steps) 117 | 118 | label = QLabel(self) 119 | label.setText("Image width") 120 | self.layout().addWidget(label) 121 | self.layout().addWidget(self.width_input) 122 | 123 | label = QLabel(self) 124 | label.setText("Image height") 125 | self.layout().addWidget(label) 126 | self.layout().addWidget(self.height_input) 127 | 128 | label = QLabel(self) 129 | label.setText("Enable Not Safe For Work mode") 130 | self.layout().addWidget(label) 131 | self.layout().addWidget(self.nsfw_button) 132 | 133 | label = QLabel(self) 134 | label.setText("Compute device") 135 | self.layout().addWidget(label) 136 | self.layout().addWidget(self.device_list) 137 | 138 | self.layout().addWidget(btn) 139 | 140 | def _on_click(self): 141 | # Has issues on mps and small GPUs 142 | # self.generate_images_batch() 143 | 144 | # worker = create_worker(self.generate_images_sequential) 145 | # worker.start() 146 | 147 | # TODO: Notify the user that things are processing 148 | 149 | worker = self.generate_images_sequential() 150 | 151 | def yield_catcher(payload): 152 | array, title = payload 153 | 154 | self.viewer.add_image(array, name=title, rgb=True) 155 | 156 | # Show gallery as grid 157 | self.viewer.grid.enabled = True 158 | 159 | worker.yielded.connect(yield_catcher) 160 | 161 | worker.start() 162 | 163 | @thread_worker 164 | def generate_images_sequential(self): 165 | prompt = self.prompt_textbox.document().toPlainText() 166 | print(f"Prompt is {prompt}") 167 | 168 | # Get the device: cpu or gpu 169 | device = self.device_list.currentText() 170 | 171 | # Get huggingface token from environment variable. Generate at HF 172 | MY_SECRET_TOKEN = ( 173 | os.environ.get("HF_TOKEN_SD") 174 | if "HF_TOKEN_SD" in os.environ 175 | else None 176 | ) 177 | 178 | # Pre-generate the latents to ensure correct dtype 179 | batch_size = len(prompt) 180 | in_channels = 3 181 | height = self.height_input.value() 182 | width = self.width_input.value() 183 | latents_shape = (batch_size, in_channels, height // 8, width // 8) 184 | 185 | # Load the pipeline 186 | pipe = StableDiffusionInpaintPipeline.from_pretrained( 187 | get_stable_diffusion_model(), 188 | use_auth_token=MY_SECRET_TOKEN, 189 | height=height, 190 | width=width, 191 | num_inference_steps=self.num_inference_steps.value(), 192 | ) 193 | pipe.to(device) 194 | 195 | # Fail if no image selected 196 | if self._image_layers.value is None: 197 | print("No image selected") 198 | return 199 | 200 | # Fail if no mask selected 201 | if self._image_layers.value is None: 202 | print("No mask selected") 203 | return 204 | 205 | # Get initial image 206 | init_image = Image.fromarray(self._image_layers.value.data).convert( 207 | "RGB" 208 | ) 209 | init_image = init_image.resize((768, 512)) 210 | 211 | # Get initial mask 212 | mask = Image.fromarray(self._mask_layers.value.data).convert("RGB") 213 | mask = mask.resize((768, 512)) 214 | 215 | # Run the pipeline 216 | num_images = self.gallery_size.value() 217 | 218 | # Populate the gallery 219 | for gallery_id in range(num_images): 220 | # Generate our random latent space uniquely per image 221 | latents = torch.randn( 222 | latents_shape, 223 | generator=None, 224 | device=("cpu" if device == "mps" else device), 225 | # dtype=torch.float16, 226 | ) 227 | pipe.latents = latents 228 | pipe.to(device) 229 | 230 | # TODO add strength and guidance_scale to GUI 231 | image_list = pipe( 232 | prompt=[prompt], 233 | image=init_image, 234 | mask_image=mask, 235 | strength=0.75, 236 | guidance_scale=7.5, 237 | ) 238 | 239 | # This is the SD output 240 | array = np.array(image_list.images[0]) 241 | 242 | # Mask the output 243 | array = np.where(mask == 1, array, init_image) 244 | 245 | # If NSFW, then zero over image 246 | if image_list["nsfw_content_detected"][0]: 247 | array = np.zeros_like(array) 248 | 249 | # Empty GPU cache as we generate images 250 | if torch.cuda.is_available(): 251 | torch.cuda.empty_cache() 252 | 253 | yield (array, f"nsd_{prompt}-{gallery_id}") 254 | --------------------------------------------------------------------------------