├── pydalle ├── imperative │ ├── client │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── responses.py │ │ └── dalle.py │ ├── outside │ │ ├── __init__.py │ │ ├── sysrand.py │ │ ├── np.py │ │ ├── files.py │ │ ├── pil.py │ │ └── internet.py │ ├── api │ │ ├── __init__.py │ │ ├── auth0.py │ │ └── labs.py │ └── __init__.py ├── functional │ ├── api │ │ ├── flow │ │ │ ├── __init__.py │ │ │ ├── auth0.py │ │ │ └── labs.py │ │ ├── request │ │ │ ├── __init__.py │ │ │ ├── auth0.py │ │ │ └── labs.py │ │ ├── response │ │ │ ├── __init__.py │ │ │ └── labs.py │ │ └── __init__.py │ ├── __init__.py │ ├── utils.py │ ├── assumptions.py │ └── types.py └── __init__.py ├── MANIFEST.in ├── docs ├── source │ ├── modules.rst │ ├── pydalle.rst │ ├── pydalle.imperative.rst │ ├── pydalle.functional.api.rst │ ├── pydalle.functional.api.response.rst │ ├── index.rst │ ├── pydalle.imperative.api.rst │ ├── pydalle.functional.api.flow.rst │ ├── pydalle.functional.api.request.rst │ ├── pydalle.imperative.client.rst │ ├── pydalle.functional.rst │ ├── pydalle.imperative.outside.rst │ └── conf.py ├── requirements.txt ├── make.py ├── Makefile └── make.bat ├── .gitignore ├── pyproject.toml ├── .readthedocs.yaml ├── Makefile ├── LICENSE ├── setup.py ├── examples ├── low_level │ ├── dev_async.py │ ├── dev.py │ └── dev.ipynb ├── dev_client.py └── dev_client_async.py └── README.md /pydalle/imperative/client/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | pydalle 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | pydalle 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | .ipynb_checkpoints 4 | dist 5 | *.egg-info 6 | docs/_build/ 7 | build 8 | 9 | -------------------------------------------------------------------------------- /pydalle/functional/api/flow/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package contains modules which are used handle the flow of requests to external APIs. 3 | """ 4 | -------------------------------------------------------------------------------- /pydalle/functional/api/request/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package contains modules which are used handle defining requests to be sent to external APIs. 3 | """ 4 | -------------------------------------------------------------------------------- /pydalle/functional/api/response/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package contains modules which are used handle defining the responses expected from external APIs. 3 | """ 4 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Defining the exact version will make sure things don't break 2 | sphinx==5.1.1 3 | sphinx_rtd_theme==1.0.0 4 | readthedocs-sphinx-search==0.1.2 5 | -------------------------------------------------------------------------------- /pydalle/functional/api/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package contains subpackages which are used to define the functional parts 3 | of communicating with external APIs. 4 | """ 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.pytest.ini_options] 6 | asyncio_mode = "auto" 7 | -------------------------------------------------------------------------------- /pydalle/imperative/outside/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package contains all the modules which communicate with the outside world. 3 | Code outside this package but still within :mod:`pydalle.imperative` 4 | may use the modules in this package. 5 | """ 6 | -------------------------------------------------------------------------------- /pydalle/functional/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package contains all the functional parts of PyDalle. 3 | 4 | Specifically, none of the code in this package will: 5 | 6 | * import any external libraries 7 | 8 | * perform any external I/O 9 | 10 | * have any side effects 11 | """ 12 | -------------------------------------------------------------------------------- /docs/source/pydalle.rst: -------------------------------------------------------------------------------- 1 | pydalle package 2 | =============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | pydalle.functional 11 | pydalle.imperative 12 | 13 | Module contents 14 | --------------- 15 | 16 | .. automodule:: pydalle 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | -------------------------------------------------------------------------------- /pydalle/imperative/api/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package contains all the modules which actually communicate with external APIs. 3 | It does this by facilitating the communication between the functional flows defined in 4 | :mod:`pydalle.functional.api.flow` and the code for communicating with the outside world in 5 | :mod:`pydalle.imperative.outside` 6 | """ 7 | -------------------------------------------------------------------------------- /pydalle/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This packages the functional and imperative subpackages of PyDalle. 3 | If you're an end-user, you probably want to import from the :mod:`pydalle.imperative` subpackage. 4 | 5 | For convenience :class:`pydalle.imperative.client.dalle.Dalle` is imported into this package. 6 | """ 7 | 8 | from pydalle.imperative.client.dalle import Dalle 9 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: "ubuntu-20.04" 5 | tools: 6 | python: "3.9" 7 | 8 | # Build from the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/source/conf.py 11 | 12 | # Explicitly set the version of Python and its requirements 13 | python: 14 | install: 15 | - requirements: docs/requirements.txt 16 | - method: setuptools 17 | path: . 18 | -------------------------------------------------------------------------------- /pydalle/imperative/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package contains all the imperative parts of PyDalle. 3 | 4 | Specifically, none of the code OUTSIDE this package will: 5 | 6 | * import any external libraries 7 | 8 | * perform any external I/O 9 | 10 | * have any side effects 11 | 12 | Even within this package, any code which directly performs 13 | I/O will be in the :mod:`pydalle.imperative.outside` package. 14 | """ 15 | -------------------------------------------------------------------------------- /docs/source/pydalle.imperative.rst: -------------------------------------------------------------------------------- 1 | pydalle.imperative package 2 | ========================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | pydalle.imperative.api 11 | pydalle.imperative.client 12 | pydalle.imperative.outside 13 | 14 | Module contents 15 | --------------- 16 | 17 | .. automodule:: pydalle.imperative 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | -------------------------------------------------------------------------------- /pydalle/imperative/outside/sysrand.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains all functions pydalle uses which depend on the system's random number generator. 3 | """ 4 | 5 | import secrets 6 | 7 | from pydalle.functional.types import SupportsLenAndGetItem, T 8 | 9 | 10 | def secure_random_choice(seq: SupportsLenAndGetItem[T]) -> T: 11 | """ 12 | Return a cryptographically secure random element from a sequence. 13 | """ 14 | return secrets.choice(seq) 15 | -------------------------------------------------------------------------------- /docs/source/pydalle.functional.api.rst: -------------------------------------------------------------------------------- 1 | pydalle.functional.api package 2 | ============================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | pydalle.functional.api.flow 11 | pydalle.functional.api.request 12 | pydalle.functional.api.response 13 | 14 | Module contents 15 | --------------- 16 | 17 | .. automodule:: pydalle.functional.api 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | -------------------------------------------------------------------------------- /docs/source/pydalle.functional.api.response.rst: -------------------------------------------------------------------------------- 1 | pydalle.functional.api.response package 2 | ======================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | 8 | .. automodule:: pydalle.functional.api.response.labs 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | Module contents 14 | --------------- 15 | 16 | .. automodule:: pydalle.functional.api.response 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ifneq (,$(wildcard ./.env)) 2 | include .env 3 | export 4 | endif 5 | 6 | .PHONY: build 7 | build: 8 | py -m build 9 | 10 | .PHONY: check 11 | check: 12 | py -m twine check dist/* 13 | 14 | .PHONY: upload 15 | upload: 16 | py -m twine upload --skip-existing dist/* 17 | 18 | .PHONY: uploadtest 19 | uploadtest: 20 | py -m twine upload --repository testpypi --skip-existing dist/* 21 | 22 | .PHONY: install 23 | install: 24 | pip install . 25 | 26 | .PHONY: devinstall 27 | devinstall: 28 | pip install -e . 29 | -------------------------------------------------------------------------------- /docs/make.py: -------------------------------------------------------------------------------- 1 | import os 2 | from subprocess import call 3 | 4 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | PACKAGE_DIR = os.path.join(CURRENT_DIR, '..', 'pydalle') 6 | SOURCE_DIR = os.path.join(CURRENT_DIR, 'source') 7 | OUTPUT_FORMATS = ["html"] 8 | 9 | print(f"Generating API documentation for {PACKAGE_DIR}") 10 | call(["sphinx-apidoc", "-E", "-a", "-o", SOURCE_DIR, PACKAGE_DIR]) 11 | for output_format in OUTPUT_FORMATS: 12 | print(f"Generating {output_format} documentation") 13 | call(["make", output_format], cwd=CURRENT_DIR) 14 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. pydalle documentation master file, created by 2 | sphinx-quickstart on Fri Aug 5 18:00:27 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to pydalle's documentation! 7 | =================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | modules 14 | 15 | 16 | Indices and tables 17 | ================== 18 | 19 | * :ref:`genindex` 20 | * :ref:`modindex` 21 | * :ref:`search` 22 | -------------------------------------------------------------------------------- /docs/source/pydalle.imperative.api.rst: -------------------------------------------------------------------------------- 1 | pydalle.imperative.api package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | 8 | .. automodule:: pydalle.imperative.api.auth0 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | 14 | .. automodule:: pydalle.imperative.api.labs 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | Module contents 20 | --------------- 21 | 22 | .. automodule:: pydalle.imperative.api 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | -------------------------------------------------------------------------------- /docs/source/pydalle.functional.api.flow.rst: -------------------------------------------------------------------------------- 1 | pydalle.functional.api.flow package 2 | =================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | 8 | .. automodule:: pydalle.functional.api.flow.auth0 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | 14 | .. automodule:: pydalle.functional.api.flow.labs 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | Module contents 20 | --------------- 21 | 22 | .. automodule:: pydalle.functional.api.flow 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | -------------------------------------------------------------------------------- /docs/source/pydalle.functional.api.request.rst: -------------------------------------------------------------------------------- 1 | pydalle.functional.api.request package 2 | ====================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | 8 | .. automodule:: pydalle.functional.api.request.auth0 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | 14 | .. automodule:: pydalle.functional.api.request.labs 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | Module contents 20 | --------------- 21 | 22 | .. automodule:: pydalle.functional.api.request 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | -------------------------------------------------------------------------------- /pydalle/imperative/outside/np.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains all functions pydalle uses to directly interface with numpy. 3 | """ 4 | 5 | try: 6 | from numpy import array, ndarray 7 | except ImportError as e: 8 | from pydalle.functional.types import LazyImportError 9 | 10 | array = LazyImportError("numpy.array", e) 11 | ndarray = LazyImportError("numpy.ndarray", e) 12 | del LazyImportError 13 | 14 | from pydalle.imperative.outside.pil import PILImageType, PILImage 15 | 16 | 17 | def pil_image_to_np_array(image: PILImageType) -> 'ndarray': 18 | return array(image) 19 | 20 | 21 | def np_array_to_pil_image(array: 'ndarray') -> PILImageType: 22 | return PILImage.fromarray(array) 23 | -------------------------------------------------------------------------------- /docs/source/pydalle.imperative.client.rst: -------------------------------------------------------------------------------- 1 | pydalle.imperative.client package 2 | ================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | 8 | .. automodule:: pydalle.imperative.client.dalle 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | 14 | .. automodule:: pydalle.imperative.client.responses 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | 20 | .. automodule:: pydalle.imperative.client.utils 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | Module contents 26 | --------------- 27 | 28 | .. automodule:: pydalle.imperative.client 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/source/pydalle.functional.rst: -------------------------------------------------------------------------------- 1 | pydalle.functional package 2 | ========================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | pydalle.functional.api 11 | 12 | Submodules 13 | ---------- 14 | 15 | 16 | .. automodule:: pydalle.functional.assumptions 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | 21 | 22 | .. automodule:: pydalle.functional.types 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | 28 | .. automodule:: pydalle.functional.utils 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | Module contents 34 | --------------- 35 | 36 | .. automodule:: pydalle.functional 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | -------------------------------------------------------------------------------- /docs/source/pydalle.imperative.outside.rst: -------------------------------------------------------------------------------- 1 | pydalle.imperative.outside package 2 | ================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | 8 | .. automodule:: pydalle.imperative.outside.files 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | 14 | .. automodule:: pydalle.imperative.outside.internet 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | 19 | 20 | .. automodule:: pydalle.imperative.outside.pil 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | 26 | .. automodule:: pydalle.imperative.outside.sysrand 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: pydalle.imperative.outside 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Michael Phelps 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /pydalle/imperative/outside/files.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains all functions pydalle uses to interface with the filesystem. 3 | """ 4 | 5 | import warnings 6 | from os import PathLike 7 | from typing import Union, IO 8 | 9 | try: 10 | import aiofiles 11 | except ImportError as _e: 12 | from pydalle.functional.types import LazyImportError 13 | 14 | aiofiles = LazyImportError("aiofiles", _e) 15 | del LazyImportError 16 | 17 | 18 | def read_bytes(file_like: Union[str, PathLike, IO[bytes]]) -> bytes: 19 | if isinstance(file_like, (str, PathLike)): 20 | with open(file_like, "rb") as f: 21 | return f.read() 22 | return file_like.read() 23 | 24 | 25 | async def read_bytes_async(file_like: Union[str, PathLike, IO[bytes]]) -> bytes: 26 | if isinstance(file_like, (str, PathLike)): 27 | try: 28 | async with aiofiles.open(file_like, "rb") as f: 29 | return await f.read() 30 | except ImportError as _e: 31 | warnings.warn(f"aiofiles not found, falling back to sync version: {_e}", RuntimeWarning) 32 | return read_bytes(file_like) 33 | return file_like.read() 34 | -------------------------------------------------------------------------------- /pydalle/functional/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains functional utilities used throughout the codebase. 3 | """ 4 | 5 | from typing import Optional 6 | from urllib.parse import parse_qs, urlparse 7 | 8 | from pydalle.functional.types import HttpResponse, JsonDict, FlowError 9 | 10 | 11 | def get_query_param(url: str, param: str) -> str: 12 | return parse_qs(urlparse(url).query)[param][0] 13 | 14 | 15 | def send_from(generator, fn): 16 | r = yield next(generator) 17 | while True: 18 | try: 19 | r = yield generator.send(r) 20 | except StopIteration as e: 21 | return fn(e.value) 22 | 23 | 24 | def try_json(r: HttpResponse, status_code: Optional[int] = None) -> JsonDict: 25 | if status_code is not None and r.status_code != status_code: 26 | raise FlowError("Request returned an unexpected status code", r) 27 | try: 28 | out = r.json() 29 | except Exception as e: 30 | raise FlowError("Failed to parse response", r) from e 31 | if not isinstance(out, dict): 32 | raise FlowError("Response was not a JSON object", r) 33 | return out 34 | 35 | 36 | def filter_none(d: JsonDict) -> JsonDict: 37 | return {k: v for k, v in d.items() if v is not None} 38 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.insert(0, os.path.abspath('../../pydalle')) 5 | 6 | # Configuration file for the Sphinx documentation builder. 7 | # 8 | # For the full list of built-in configuration values, see the documentation: 9 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 10 | 11 | # -- Project information ----------------------------------------------------- 12 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 13 | 14 | project = 'pydalle' 15 | copyright = '2022, Michael Phelps' 16 | author = 'Michael Phelps' 17 | 18 | # -- General configuration --------------------------------------------------- 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 20 | 21 | extensions = [ 22 | 'sphinx.ext.duration', 23 | 'sphinx.ext.doctest', 24 | 'sphinx.ext.autodoc', 25 | 'sphinx.ext.autosummary', 26 | 'sphinx.ext.intersphinx', 27 | ] 28 | 29 | templates_path = ['_templates'] 30 | exclude_patterns = ['build', 'Thumbs.db', '.DS_Store'] 31 | 32 | # -- Options for HTML output ------------------------------------------------- 33 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 34 | 35 | html_theme = 'sphinx_rtd_theme' 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | REPO_URL = 'https://github.com/nottheswimmer/dalle' 4 | REPO_BLOB_PREFIX = f'{REPO_URL}/blob/main/' 5 | 6 | 7 | def github_md_to_setup_md(md: str) -> str: 8 | # Replace relative links with absolute links to blobs in the repo 9 | md = md.replace('](./', f']({REPO_BLOB_PREFIX}') 10 | return md 11 | 12 | 13 | with open('README.md', encoding='utf-8') as f: 14 | long_description = github_md_to_setup_md(f.read()) 15 | 16 | setup( 17 | name='pydalle', 18 | version='0.2.0', 19 | description='A library for providing programmatic access to the DALL·E 2 API', 20 | long_description=long_description, 21 | long_description_content_type='text/markdown', 22 | author='Michael Phelps', 23 | author_email='michaelphelps@nottheswimmer.org', 24 | url=REPO_URL, 25 | packages=find_packages(), 26 | license='MIT', 27 | classifiers=[ 28 | 'Development Status :: 4 - Beta', 29 | 'Intended Audience :: Developers', 30 | 'License :: OSI Approved :: MIT License', 31 | 'Programming Language :: Python :: 3', 32 | 'Programming Language :: Python :: 3.8', 33 | 'Programming Language :: Python :: 3.9', 34 | 'Programming Language :: Python :: 3.10', 35 | 'Topic :: Software Development :: Libraries', 36 | 'Topic :: Software Development :: Libraries :: Python Modules', 37 | ], 38 | extras_require={ 39 | 'async': ['aiofiles', 'aiohttp'], 40 | 'sync': ['requests'], 41 | 'images': ['pillow', 'numpy'], 42 | 'all': ['aiofiles', 'aiohttp', 'requests', 'pillow', 'numpy'], 43 | }, 44 | ) 45 | -------------------------------------------------------------------------------- /pydalle/functional/assumptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains strings about the outside world which were not parameterized 3 | because during development it was assumed that they wouldn't change. 4 | """ 5 | 6 | AUTH0_AUTHORIZE_URL_TEMPLATE = "https://%s/authorize" 7 | AUTH0_TOKEN_URL_TEMPLATE = "https://%s/oauth/token" 8 | 9 | OPENAI_AUTH0_CLIENT_ID = "DMg91f5PCHQtc7u018WKiL0zopKdiHle" 10 | OPENAI_AUTH0_DOMAIN = "auth0.openai.com" 11 | OPENAI_AUTH0_AUDIENCE = "https://api.openai.com/v1" 12 | OPENAI_AUTH0_SCOPE = "openid profile email offline_access" 13 | 14 | OPENAI_LABS_REDIRECT_URI = "https://labs.openai.com/auth/callback" 15 | OPENAI_LABS_API_URL = "https://labs.openai.com/api/labs" 16 | OPENAI_LABS_LOGIN_URL = f"{OPENAI_LABS_API_URL}/auth/login" 17 | OPENAI_LABS_TASKS_URL = f"{OPENAI_LABS_API_URL}/tasks" 18 | OPENAI_LABS_TASK_URL_TEMPLATE = f"{OPENAI_LABS_TASKS_URL}/%s" 19 | OPENAI_LABS_GENERATION_URL = "https://labs.openai.com/api/labs/generations" 20 | OPENAI_LABS_GENERATION_URL_TEMPLATE = f"{OPENAI_LABS_GENERATION_URL}/%s" 21 | OPENAI_LABS_GENERATION_DOWNLOAD_URL_TEMPLATE = f"{OPENAI_LABS_GENERATION_URL_TEMPLATE}/download" 22 | OPENAI_LABS_BILLING_URL = f"{OPENAI_LABS_API_URL}/billing" 23 | OPENAI_LABS_BILLING_CREDIT_SUMMARY_URL = f"{OPENAI_LABS_BILLING_URL}/credit_summary" 24 | OPENAI_LABS_GENERATION_SHARE_URL_TEMPLATE = f"{OPENAI_LABS_GENERATION_URL}/%s/share" 25 | OPENAI_LABS_GENERATION_FLAG_URL_TEMPLATE = f"{OPENAI_LABS_GENERATION_URL}/%s/flags" 26 | OPENAI_LABS_COLLECTION_URL = f"{OPENAI_LABS_API_URL}/collections" 27 | OPENAI_LABS_COLLECTION_GENERATION_URL_TEMPLATE = f"{OPENAI_LABS_COLLECTION_URL}/%s/generations" 28 | OPENAI_LABS_SHARE_URL_TEMPLATE = "https://labs.openai.com/s/%s" 29 | -------------------------------------------------------------------------------- /examples/low_level/dev_async.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | import platform 4 | 5 | from pydalle.imperative.api.labs import get_bearer_token_async, get_tasks_async, poll_for_task_completion_async, \ 6 | create_text2im_task_async, create_variations_task_async 7 | 8 | if platform.system() == 'Windows': 9 | asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) 10 | 11 | OPENAI_USERNAME = os.environ['OPENAI_USERNAME'] 12 | OPENAI_PASSWORD = os.environ['OPENAI_PASSWORD'] 13 | 14 | 15 | async def main_async(): 16 | print("Attempting to get token for DALL·E...") 17 | token = await get_bearer_token_async(OPENAI_USERNAME, OPENAI_PASSWORD) 18 | print("Token:", token) 19 | 20 | print("Attempting to check tasks...") 21 | tasks = await get_tasks_async(token) 22 | for task in tasks.data: 23 | print(task) 24 | print() 25 | 26 | print("Attempting to create text2im task...") 27 | pending_task = await create_text2im_task_async(token, "A cute cat") 28 | print(pending_task) 29 | print() 30 | 31 | print("Waiting for task to complete...") 32 | task = await poll_for_task_completion_async(token, pending_task.id) 33 | print(task) 34 | print() 35 | 36 | print("Attempting to create variations task...") 37 | pending_task = await create_variations_task_async(token, task.generations[0].id) 38 | print(pending_task) 39 | print() 40 | 41 | print("Waiting for task to complete...") 42 | task = await poll_for_task_completion_async(token, pending_task.id) 43 | print(task) 44 | print() 45 | 46 | # For additional examples, see dev.py and dev.ipynb 47 | 48 | 49 | if __name__ == '__main__': 50 | asyncio.run(main_async()) 51 | -------------------------------------------------------------------------------- /pydalle/imperative/outside/pil.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains all functions pydalle uses to directly interface with PIL. 3 | """ 4 | 5 | from io import BytesIO 6 | 7 | try: 8 | from PIL import Image as PILImage 9 | except ImportError as e: 10 | from pydalle.functional.types import LazyImportError 11 | 12 | PILImage = LazyImportError("PIL.Image", e) 13 | del LazyImportError 14 | 15 | PILImageType = type(PILImage) 16 | 17 | 18 | def bytes_to_pil_image(image: bytes) -> PILImageType: 19 | return PILImage.open(BytesIO(image)) 20 | 21 | 22 | def pil_image_to_png_bytes(image: PILImageType) -> bytes: 23 | buffer = BytesIO() 24 | image.save(buffer, format="PNG") 25 | return buffer.getvalue() 26 | 27 | 28 | def image_bytes_to_png_bytes(image: bytes) -> bytes: 29 | return pil_image_to_png_bytes(PILImage.open(BytesIO(image))) 30 | 31 | 32 | def bytes_to_masked_pil_image(image: bytes, x1: float, y1: float, x2: float, y2: float) -> PILImageType: 33 | image = bytes_to_pil_image(image).convert("RGBA") 34 | x1 = int(x1 * image.width) 35 | y1 = int(y1 * image.height) 36 | x2 = int(x2 * image.width) 37 | y2 = int(y2 * image.height) 38 | image.paste(PILImage.new("RGBA", (x2 - x1, y2 - y1), (0, 0, 0, 0)), (x1, y1)) 39 | return image 40 | 41 | 42 | def bytes_to_padded_pil_image(image: bytes, p: float, cx: float = 0.5, cy: float = 0.5) -> PILImageType: 43 | """ 44 | Shrinks an image by a given percentage. The actual image size does not change, 45 | but the image is scaled down by the given percentage and a transparent border 46 | is added to the edges. 47 | """ 48 | old_image = bytes_to_pil_image(image).convert("RGBA") 49 | new_image = PILImage.new("RGBA", (old_image.width, old_image.height), (0, 0, 0, 0)) 50 | old_image = old_image.resize((int(old_image.width * p), int(old_image.height * p)), 51 | resample=PILImage.LANCZOS) 52 | new_image.paste(old_image, (int((new_image.width - old_image.width) * cx), 53 | int((new_image.height - old_image.height) * cy))) 54 | return new_image -------------------------------------------------------------------------------- /examples/dev_client.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from pydalle import Dalle 4 | 5 | OPENAI_USERNAME = os.environ.get('OPENAI_USERNAME') 6 | OPENAI_PASSWORD = os.environ.get('OPENAI_PASSWORD') 7 | 8 | 9 | def main(): 10 | client = Dalle(OPENAI_USERNAME, OPENAI_PASSWORD) 11 | print(f"Client created. {client.get_credit_summary().aggregate_credits} credits remaining...") 12 | tasks = client.get_tasks(limit=5) 13 | print(f"{len(tasks)} tasks found...") 14 | 15 | print("Attempting to download a generation of the first task and show off some built-in helpers...") 16 | if tasks and tasks[0].generations: 17 | example = tasks[0].generations[0].download() 18 | example.to_pil().show() # Convert the image to a PIL image and show it 19 | example.to_pil_masked(x1=0.5, y1=0, x2=1, y2=1).show() # Show a version with left side transparent (for edits) 20 | example.to_pil_padded(0.5).show() # Show w/ 50% padding around the image, centered at (50%, 50%) 21 | example.to_pil_padded(0.4, cx=0.25, cy=0.25).show() # Show w/ 40% padding, centered at (25%, 25%) 22 | 23 | print("Attempting to do a text2im task...") 24 | completed_text2im_task = client.text2im("A cute cat") 25 | for image in completed_text2im_task.download(): 26 | image.to_pil().show() 27 | 28 | print("Attempting to create variations task on the first cat...") 29 | first_generation = completed_text2im_task.generations[0] 30 | completed_variation_task = first_generation.variations() 31 | first_variation = completed_variation_task.generations[0] 32 | first_image = first_variation.download() 33 | first_image.to_pil().show() 34 | 35 | print("Attempting to create inpainting task and showing the mask...") 36 | # Make the right-side of the image transparent 37 | mask = first_image.to_pil_masked(x1=0.5, y1=0, x2=1, y2=1) 38 | mask.show("inpainting mask") 39 | completed_inpainting_task = first_generation.inpainting("A cute cat, with a dark side", mask) 40 | for image in completed_inpainting_task.download(): 41 | image.to_pil().show() 42 | 43 | 44 | if __name__ == '__main__': 45 | main() 46 | -------------------------------------------------------------------------------- /examples/dev_client_async.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | import platform 4 | 5 | from pydalle import Dalle 6 | 7 | if platform.system() == 'Windows': 8 | asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) 9 | 10 | OPENAI_USERNAME = os.environ.get('OPENAI_USERNAME') 11 | OPENAI_PASSWORD = os.environ.get('OPENAI_PASSWORD') 12 | 13 | 14 | async def main(): 15 | client = Dalle(OPENAI_USERNAME, OPENAI_PASSWORD) 16 | print(f"Client created. {(await client.get_credit_summary_async()).aggregate_credits} credits remaining...") 17 | tasks = client.get_tasks(limit=5) 18 | print(f"{len(tasks)} tasks found...") 19 | 20 | print("Attempting to download a generation of the first task and show off some built-in helpers...") 21 | if tasks and tasks[0].generations: 22 | example = await tasks[0].generations[0].download_async() 23 | example.to_pil().show() # Convert the image to a PIL image and show it 24 | example.to_pil_masked(x1=0.5, y1=0, x2=1, y2=1).show() # Show a version with left side transparent (for edits) 25 | example.to_pil_padded(0.5).show() # Show w/ 50% padding around the image, centered at (50%, 50%) 26 | example.to_pil_padded(0.4, cx=0.25, cy=0.25).show() # Show w/ 40% padding, centered at (25%, 25%) 27 | 28 | print("Attempting to do a text2im task...") 29 | completed_text2im_task = await client.text2im_async("A cute cat") 30 | async for image in completed_text2im_task.download_async(): 31 | image.to_pil().show() 32 | 33 | print("Attempting to create variations task on the first cat...") 34 | first_generation = completed_text2im_task.generations[0] 35 | completed_variation_task = first_generation.variations() 36 | first_variation = completed_variation_task.generations[0] 37 | first_image = (await first_variation.download_async()).to_pil() 38 | first_image.show() 39 | 40 | print("Attempting to create inpainting task and showing the mask...") 41 | mask = first_image.to_pil_masked(x1=0.5, y1=0, x2=1, y2=1) 42 | mask.show("inpainting mask") 43 | completed_inpainting_task = await first_generation.inpainting_async("A cute cat, with a dark side", mask) 44 | async for image in completed_inpainting_task.download_async(): 45 | image.to_pil().show() 46 | 47 | 48 | if __name__ == '__main__': 49 | asyncio.run(main()) 50 | -------------------------------------------------------------------------------- /pydalle/imperative/client/utils.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from pydalle.functional.types import FlowError, T 3 | 4 | 5 | def requires_authentication(func: T) -> T: 6 | """ 7 | Decorator to ensure that the Dalle has authenticated before calling the decorated function 8 | (or, if it has authenticated but the token has expired, it will refresh the tokens then 9 | try again). 10 | """ 11 | 12 | @wraps(func) 13 | def wrapper(self, *args, **kwargs): 14 | # If we have never authenticated, do so now 15 | if not self.has_authenticated: 16 | self.refresh_tokens() 17 | else: 18 | # Otherwise, we'll try the request and see if it results in an authentication error 19 | try: 20 | return func(self, *args, **kwargs) 21 | except FlowError as e: 22 | if e.response.status_code == 401: 23 | try: 24 | if e.response.json()['error']['code'] == "invalid_api_key": 25 | # If it does, refresh the tokens and fall through to the last attempt 26 | self.refresh_tokens() 27 | except Exception: 28 | # If it has some other 401 error, reraise it 29 | raise e 30 | else: 31 | # If it's not a 401, reraise it 32 | raise e 33 | # If we've gotten here, we should definitely be authenticated 34 | return func(self, *args, **kwargs) 35 | 36 | return wrapper 37 | 38 | 39 | def requires_authentication_async(func: T) -> T: 40 | """ 41 | Async version of the :func:`requires_authentication` decorator. 42 | """ 43 | 44 | @wraps(func) 45 | async def wrapper(self, *args, **kwargs): 46 | if not self.has_authenticated: 47 | await self.refresh_tokens_async() 48 | else: 49 | try: 50 | return await func(self, *args, **kwargs) 51 | except FlowError as e: 52 | if e.response.status_code == 401: 53 | try: 54 | if e.response.json()['error']['code'] == "invalid_api_key": 55 | await self.refresh_tokens_async() 56 | except Exception: 57 | raise e 58 | else: 59 | raise e 60 | return await func(self, *args, **kwargs) 61 | 62 | return wrapper 63 | -------------------------------------------------------------------------------- /pydalle/imperative/api/auth0.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the implementations of calls to the Auth0 API. 3 | """ 4 | from typing import Optional, Dict 5 | 6 | from pydalle.functional.api.flow.auth0 import get_access_token_flow 7 | from pydalle.functional.api.request.auth0 import urlsafe_b64encode_string 8 | from pydalle.imperative.outside.internet import session_flow, session_flow_async 9 | from pydalle.imperative.outside.sysrand import secure_random_choice 10 | 11 | 12 | def get_access_token_from_credentials(username: str, password: str, domain: str, client_id: str, 13 | audience: str, redirect_uri: str, scope: str, headers: Optional[Dict[str, str]] = None) -> str: 14 | return session_flow(get_access_token_flow, headers, 15 | username=username, password=password, domain=domain, 16 | client_id=client_id, audience=audience, 17 | redirect_uri=redirect_uri, scope=scope, 18 | code_verifier=_random_secure_string(), 19 | initial_state=_random_secure_urlsafe_b64encoded_string(), 20 | nonce=_random_secure_urlsafe_b64encoded_string()) 21 | 22 | 23 | async def get_access_token_from_credentials_async(username: str, password: str, domain: str, client_id: str, 24 | audience: str, redirect_uri: str, scope: str, 25 | headers: Optional[Dict[str, str]] = None) -> str: 26 | return await session_flow_async(get_access_token_flow, headers, 27 | username=username, password=password, domain=domain, 28 | client_id=client_id, audience=audience, 29 | redirect_uri=redirect_uri, scope=scope, 30 | code_verifier=_random_secure_string(), 31 | initial_state=_random_secure_urlsafe_b64encoded_string(), 32 | nonce=_random_secure_urlsafe_b64encoded_string()) 33 | 34 | 35 | def _random_secure_urlsafe_b64encoded_string() -> str: 36 | """ 37 | https://auth0.com/docs/get-started/authentication-and-authorization-flow/call-your-api-using-the-authorization-code-flow-with-pkce#javascript-sample 38 | """ 39 | return urlsafe_b64encode_string(_random_secure_string()) 40 | 41 | 42 | _RANDOM_CHARACTERS = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-_~." 43 | 44 | 45 | def _random_secure_string() -> str: 46 | """ 47 | This is how it was basically implemented in auth0-spa-js 48 | """ 49 | return "".join(secure_random_choice(_RANDOM_CHARACTERS) for _ in range(43)) 50 | -------------------------------------------------------------------------------- /pydalle/functional/api/request/auth0.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains functions which are used to construct requests to the Auth0 API. 3 | """ 4 | 5 | import json 6 | from base64 import urlsafe_b64encode 7 | from hashlib import sha256 8 | from urllib.parse import urlencode 9 | 10 | from pydalle.functional.assumptions import AUTH0_TOKEN_URL_TEMPLATE, AUTH0_AUTHORIZE_URL_TEMPLATE 11 | from pydalle.functional.types import HttpRequest 12 | 13 | 14 | def request_access_token(client_id, code, code_verifier, domain, redirect_uri): 15 | return HttpRequest(**{ 16 | "method": "post", 17 | "url": (AUTH0_TOKEN_URL_TEMPLATE % domain), 18 | "data": json.dumps({ 19 | "grant_type": "authorization_code", 20 | "code": code, 21 | "client_id": client_id, 22 | "code_verifier": code_verifier, 23 | "redirect_uri": redirect_uri, 24 | }), 25 | "headers": {"Content-Type": "application/json"}, 26 | }) 27 | 28 | 29 | def request_provide_username_password(password_url, username, password, state, sleep=None): 30 | return HttpRequest(**{ 31 | "method": "post", 32 | "url": password_url, 33 | "data": urlencode({ 34 | "username": username, 35 | "password": password, 36 | "action": "default", 37 | "state": state, 38 | }), 39 | "headers": {"Content-Type": "application/x-www-form-urlencoded"}, 40 | }, sleep=sleep) 41 | 42 | 43 | def request_provide_username(username_url, username, state): 44 | return HttpRequest(**{ 45 | "method": "post", 46 | "url": username_url, 47 | "data": urlencode({ 48 | "username": username, 49 | "action": "default", 50 | "state": state, 51 | }), 52 | "headers": {"Content-Type": "application/x-www-form-urlencoded"}, 53 | }) 54 | 55 | 56 | def request_authorization_code(audience, client_id, code_verifier, domain, initial_state, nonce, redirect_uri, scope): 57 | return HttpRequest(**{ 58 | "method": "get", 59 | "url": (AUTH0_AUTHORIZE_URL_TEMPLATE % domain), 60 | "params": { 61 | "client_id": client_id, 62 | "audience": audience, 63 | "redirect_uri": redirect_uri, 64 | "scope": scope, 65 | "response_type": "code", 66 | "response_mode": "query", 67 | "state": initial_state, 68 | "nonce": nonce, 69 | "code_challenge": _create_code_challenge(code_verifier), 70 | "code_challenge_method": "S256", 71 | "max_age": "0", 72 | }}) 73 | 74 | 75 | def urlsafe_b64encode_string(s: str) -> str: 76 | return _urlsafe_b64encode_hex_string(s.encode()) 77 | 78 | 79 | def _create_code_challenge(code_verifier: str) -> str: 80 | return _urlsafe_b64encode_hex_string(_sha256_string_hex(code_verifier)) 81 | 82 | 83 | def _urlsafe_b64encode_hex_string(s: bytes) -> str: 84 | return urlsafe_b64encode(s).rstrip(b"=").decode() 85 | 86 | 87 | def _sha256_string_hex(s: str) -> bytes: 88 | return sha256(s.encode()).digest() 89 | -------------------------------------------------------------------------------- /pydalle/functional/api/flow/auth0.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains functions which are used handle the flow of requests to the Auth0 API. 3 | """ 4 | 5 | from pydalle.functional.api.request.auth0 import request_access_token, request_provide_username_password, \ 6 | request_provide_username, request_authorization_code 7 | from pydalle.functional.types import HttpFlow, FlowError, HttpResponse 8 | from pydalle.functional.utils import get_query_param, send_from 9 | 10 | DEFAULT_INTERVAL = 1.0 11 | 12 | 13 | def get_access_token_flow(*args, **kwargs) -> HttpFlow[str]: 14 | def fn(response): 15 | try: 16 | return response.json()["access_token"] 17 | except Exception as e: 18 | raise FlowError("Failed to get access token from response", response) from e 19 | 20 | return send_from(get_access_token_response_flow(*args, **kwargs), fn) 21 | 22 | 23 | def get_access_token_response_flow( 24 | username: str, 25 | password: str, 26 | domain: str, 27 | client_id: str, 28 | audience: str, 29 | redirect_uri: str, 30 | scope: str, 31 | code_verifier: str, 32 | initial_state: str, 33 | nonce: str, 34 | ) -> HttpFlow[HttpResponse]: 35 | """ 36 | https://auth0.com/docs/get-started/authentication-and-authorization-flow/authorization-code-flow 37 | """ 38 | # Step 1: User -> Regular Web App: Click login link (No code necessary) 39 | # Step 2: Regular Web App -> Auth0 Tenant: Authorization Code Request to /authorize 40 | # Step 3: Auth0 Tenant -> User: Redirect to login/authorization prompt 41 | r = yield request_authorization_code(audience, client_id, code_verifier, domain, initial_state, nonce, 42 | redirect_uri, scope) 43 | if r.status_code != 200: 44 | raise FlowError("Failed to redirect to login/authorization prompt", r) 45 | # Step 4: User -> Auth0 Tenant: Authenticate and Consent 46 | try: 47 | state = get_query_param(r.url, "state") 48 | except Exception as e: 49 | raise FlowError("Failed to get state from redirect", r) from e 50 | r = yield request_provide_username(r.url, username, state) 51 | if r.status_code != 200: 52 | raise FlowError("Failed to provide username to auth0", r) 53 | # Step 4: User -> Auth0 Tenant: Authenticate and Consent (Continued) 54 | # Step 5: Auth0 Tenant -> Regular Web App: Authorization Code 55 | r = yield request_provide_username_password(r.url, username, password, state) 56 | while r.status_code == 504: 57 | r = yield request_provide_username_password(r.url, username, password, state, sleep=DEFAULT_INTERVAL) 58 | if r.status_code != 200: 59 | raise FlowError("Failed to provide password to auth0", r) 60 | # Step 6. Auth0 Tenant -> Regular Web App: Authorization Code + Client ID + Client Secret to /oauth/token 61 | # Step 7. Auth0 Tenant: Validate Authorization Code + Client ID + Client Secret 62 | # Step 8. Auth0 Tenant -> Regular Web App: ID Token and Access Token 63 | try: 64 | code = get_query_param(r.url, "code") 65 | except Exception as e: 66 | raise FlowError("Failed to get code from redirect", r) from e 67 | r = yield request_access_token(client_id, code, code_verifier, domain, redirect_uri) 68 | if r.status_code != 200: 69 | raise FlowError("Failed to get access token", r) 70 | return r 71 | -------------------------------------------------------------------------------- /pydalle/imperative/outside/internet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains all functions pydalle uses to make requests to the internet. 3 | """ 4 | 5 | import asyncio 6 | import time 7 | from typing import Optional, Dict 8 | 9 | try: 10 | import requests 11 | except ImportError as _e: 12 | from pydalle.functional.types import LazyImportError 13 | 14 | requests = LazyImportError("requests", _e) 15 | del LazyImportError 16 | 17 | try: 18 | import aiohttp 19 | except ImportError as _e: 20 | from pydalle.functional.types import LazyImportError 21 | 22 | aiohttp = LazyImportError("aiohttp", _e) 23 | del LazyImportError 24 | 25 | from pydalle.functional.types import HttpFlowFunc, T, HttpRequest, HttpResponse 26 | 27 | 28 | def session_flow(__flow: HttpFlowFunc[T], __headers=Optional[Dict[str, str]], /, **kwargs) -> T: 29 | handler = __flow(**kwargs) 30 | next_request = next(handler) 31 | session = requests.Session() 32 | if __headers: 33 | session.headers.update(__headers) 34 | while True: 35 | try: 36 | response = request(next_request, session=session) 37 | next_request = handler.send(response) 38 | except StopIteration as e: 39 | return e.value 40 | 41 | 42 | def request(r: HttpRequest, /, session: Optional['requests.Session'] = None) -> HttpResponse: 43 | if session is None: 44 | session = requests.Session() 45 | if r.sleep is not None: 46 | time.sleep(r.sleep) 47 | response = session.request(r.method, r.url, params=r.params, data=r.data, headers=r.headers) 48 | return _requests_response_to_http_response(response, r) 49 | 50 | 51 | async def session_flow_async(__flow: HttpFlowFunc[T], __headers=Optional[Dict[str, str]], /, **kwargs) -> T: 52 | handler = __flow(**kwargs) 53 | next_request = next(handler) 54 | async with aiohttp.ClientSession() as session: 55 | if __headers: 56 | session.headers.update(__headers) 57 | while True: 58 | try: 59 | response = await request_async(next_request, session=session) 60 | next_request = handler.send(response) 61 | except StopIteration as e: 62 | return e.value 63 | 64 | 65 | async def request_async(r: HttpRequest, /, session: Optional['aiohttp.ClientSession'] = None) -> HttpResponse: 66 | if not session: 67 | session = aiohttp 68 | if r.sleep is not None: 69 | await asyncio.sleep(r.sleep) 70 | async with session.request(r.method, r.url, params=r.params, data=r.data, headers=r.headers) as response: 71 | return await _aiohttp_response_to_http_response(response, r) 72 | 73 | 74 | def _requests_response_to_http_response(response: 'requests.Response', http_request: HttpRequest) -> HttpResponse: 75 | return HttpResponse(status_code=response.status_code, 76 | content=response.text if http_request.decode else response.content, 77 | url=response.url, request=http_request) 78 | 79 | 80 | async def _aiohttp_response_to_http_response(response: 'aiohttp.ClientResponse', 81 | http_request: HttpRequest) -> HttpResponse: 82 | return HttpResponse( 83 | status_code=response.status, 84 | content=(await response.text()) if http_request.decode else (await response.read()), 85 | url=str(response.url), 86 | request=http_request) 87 | -------------------------------------------------------------------------------- /examples/low_level/dev.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | from io import BytesIO 4 | 5 | import base64 6 | from PIL import Image 7 | 8 | from pydalle.imperative.api.labs import get_bearer_token, get_tasks, create_text2im_task, poll_for_task_completion, \ 9 | create_variations_task, create_inpainting_task, download_generation, share_generation, save_generations, \ 10 | get_access_token, get_bearer_token_from_access_token, get_login_info, get_credit_summary 11 | 12 | OPENAI_USERNAME = os.environ.get('OPENAI_USERNAME') 13 | OPENAI_PASSWORD = os.environ.get('OPENAI_PASSWORD') 14 | 15 | 16 | def main(): 17 | print("Attempting to get token for DALL·E...") 18 | 19 | # If you just want the bearer token, do this: 20 | # token = get_bearer_token(OPENAI_USERNAME, OPENAI_PASSWORD) 21 | 22 | # Slightly more involved if you also want to use methods that need the access token: 23 | access_token = get_access_token(OPENAI_USERNAME, OPENAI_PASSWORD) 24 | token = get_bearer_token_from_access_token(access_token) 25 | print("Token:", token) 26 | print("Also printing credits using that access token...") 27 | login_info = get_login_info(access_token) 28 | remaining_credits = login_info.billing_info.aggregate_credits 29 | print(f"{remaining_credits} credits remaining...") 30 | 31 | # But, actually, you can just use the credit summary method with the bearer token now: 32 | credit_summary = get_credit_summary(token) 33 | print(f"SAME THING: {credit_summary.aggregate_credits} credits remaining...") 34 | 35 | print("Attempting to check tasks...") 36 | tasks = get_tasks(token) 37 | for task in tasks.data: 38 | print(task) 39 | download_and_show(task, token) 40 | if input("Do you want to share this generation? (y/n): ") == "y": 41 | r = share_generation(token, task.generations.data[0].id) 42 | print(f"Share URL: {r.share_url}") 43 | if input("Do you want to save this generation? (y/n): ") == "y": 44 | r = save_generations(token, [task.generations.data[0].id]) 45 | print(f"Saved to collection {r.name} ({r.alias})") 46 | break 47 | print() 48 | 49 | print("Attempting to create text2im task...") 50 | pending_task = create_text2im_task(token, "A cute cat") 51 | print(pending_task) 52 | print() 53 | 54 | task = wait_for_task(pending_task, token) 55 | download_and_show(task, token) 56 | 57 | print("Attempting to create variations task...") 58 | pending_task = create_variations_task(token, task.generations[0].id) 59 | print(pending_task) 60 | print() 61 | 62 | task = wait_for_task(pending_task, token) 63 | image = download_and_show(task, token) 64 | 65 | print("Attempting to create inpainting task and showing the mask...") 66 | # Make the right-side of the image transparent 67 | image = image.convert("RGBA") 68 | for i in range(image.width): 69 | if i > image.width / 2: 70 | for j in range(image.height): 71 | image.putpixel((i, j), (0, 0, 0, 0)) 72 | image.show("inpainting mask") 73 | # Convert image to a base64 png string 74 | with BytesIO() as buffer: 75 | image.save(buffer, format="PNG") 76 | base64_png = base64.b64encode(buffer.getvalue()).decode() 77 | pending_task = create_inpainting_task(token, 78 | caption="A cute cat, with a dark side", 79 | masked_image=base64_png, 80 | parent_id_or_image=task.generations[0].id) 81 | print(pending_task) 82 | print() 83 | 84 | task = wait_for_task(pending_task, token) 85 | 86 | download_and_show(task, token) 87 | 88 | 89 | def wait_for_task(pending_task, token): 90 | print("Waiting for task to complete...") 91 | task = poll_for_task_completion(token, pending_task.id) 92 | print(task) 93 | print() 94 | return task 95 | 96 | 97 | def download_and_show(task, token): 98 | print("Attempting to download first generated image and show...") 99 | image = Image.open(io.BytesIO(download_generation(token, task.generations.data[0].id))) 100 | image.show("generated image") 101 | print() 102 | return image 103 | 104 | 105 | if __name__ == '__main__': 106 | main() 107 | -------------------------------------------------------------------------------- /pydalle/functional/types.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains both type hints and structures used throughout the codebase. 3 | """ 4 | 5 | import json 6 | from copy import deepcopy 7 | from dataclasses import dataclass 8 | from typing import TypeVar, Protocol, Optional, Dict, Generator, Callable, Any, Union, List 9 | from urllib.parse import urlencode, parse_qs 10 | 11 | T = TypeVar("T") 12 | _T_co = TypeVar("_T_co", covariant=True) 13 | 14 | 15 | class SupportsLenAndGetItem(Protocol[_T_co]): 16 | def __len__(self) -> int: ... 17 | 18 | def __getitem__(self, __k: int) -> _T_co: ... 19 | 20 | 21 | @dataclass 22 | class HttpRequest: 23 | method: str 24 | url: str 25 | params: Optional[Dict[str, Union[int, str]]] = None 26 | headers: Optional[Dict[str, str]] = None 27 | data: Optional[str] = None 28 | sleep: Optional[float] = None 29 | decode: bool = True 30 | 31 | 32 | _CENSORED_REQUEST_KEYS = {"authorization", "password", "code", "code_verifier"} 33 | 34 | 35 | @dataclass 36 | class HttpResponse: 37 | status_code: int 38 | url: str 39 | content: Union[str, bytes] 40 | request: HttpRequest 41 | 42 | def json(self, **kwargs) -> 'JsonValue': 43 | return json.loads(self.content, **kwargs) 44 | 45 | def _to_censored_response(self) -> 'HttpResponse': 46 | """ 47 | Try to censor sensitive data in the request if this may be printed as part of an error message or a traceback 48 | """ 49 | new = deepcopy(self) 50 | # Censor parameters 51 | if new.request.params: 52 | for param in new.request.params: 53 | if param.lower() in _CENSORED_REQUEST_KEYS: 54 | new.request.params[param] = "***REDACTED***" 55 | # Censor headers 56 | if new.request.headers: 57 | for header in new.request.headers: 58 | if header.lower() in _CENSORED_REQUEST_KEYS: 59 | new.request.headers[header] = "***REDACTED***" 60 | # Censor data 61 | if new.request.data: 62 | try: 63 | # If it's JSON... 64 | data = json.loads(new.request.data) 65 | for key in data: 66 | if key.lower() in _CENSORED_REQUEST_KEYS: 67 | data[key] = "***REDACTED***" 68 | new.request.data = json.dumps(data) 69 | except json.JSONDecodeError: 70 | pass 71 | try: 72 | # If it's a query string... 73 | data = parse_qs(new.request.data) 74 | for key in data: 75 | if key.lower() in _CENSORED_REQUEST_KEYS: 76 | data[key] = ["***REDACTED***"] 77 | new.request.data = urlencode(data) 78 | except ValueError: 79 | pass 80 | return new 81 | 82 | 83 | HttpFlow = Generator[HttpRequest, HttpResponse, T] 84 | HttpFlowFunc = Callable[[Any], HttpFlow[T]] 85 | 86 | 87 | class FlowError(Exception): 88 | def __init__(self, message: str, response: HttpResponse, *args: Any, censor: bool = True): 89 | if censor: 90 | response = response._to_censored_response() 91 | super().__init__(message, response, *args) 92 | self.response = response 93 | 94 | 95 | # TODO: Recursive type hints. My IDE wasn't appreciating them for now. 96 | # JsonValue = Union[str, int, float, bool, None, 'JsonDict', 'JsonList'] 97 | JsonValue = Any 98 | JsonDict = Dict[str, JsonValue] 99 | JsonList = List[JsonValue] 100 | 101 | 102 | class LazyImportError: 103 | def __init__(self, name: str, e: ImportError): 104 | self.name = name 105 | self.e = e 106 | 107 | def throw(self, reason, *args, **kwargs): 108 | if reason == "__call__": 109 | prefix = f"{self.name}(" 110 | prefix += ", ".join(map(str, args)) 111 | prefix += ", " if args else "" 112 | prefix += ", ".join(f"{k}={v}" for k, v in kwargs.items()) 113 | prefix += ")" 114 | elif reason == "__getattr__": 115 | prefix = f"{self.name}.{args[0]}" 116 | else: 117 | prefix = f"{self.name}.{reason}" 118 | 119 | raise ImportError(f"""\ 120 | {prefix}: The {self.name} package is required for this module. 121 | To install it, run: 122 | pip install {self.name}""") from self.e 123 | 124 | def __call__(self, *args, **kwargs): 125 | self.throw("__call__", *args, **kwargs) 126 | 127 | def __getattr__(self, name): 128 | # Help readthedocs put something here for optional dependencies 129 | if name == "__qualname__": 130 | return self.name 131 | if name == "__args__": 132 | return () 133 | self.throw("__getattr__", name) 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **ATTENTION: pydalle is now no longer supported as there is an [official API](https://beta.openai.com/docs/guides/images). With the recent addition of a captcha to the OpenAI login page, the automated approach to get a token used in this library won't work any longer anyway. Thanks for using pydalle while it was relevant!** 2 | 3 | # pydalle: A DALL·E 2 API Wrapper for Python 4 | 5 | ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pydalle) 6 | ![PyPI - Wheel](https://img.shields.io/pypi/wheel/pydalle) 7 | ![PyPI - License](https://img.shields.io/pypi/l/pydalle) 8 | 9 | This library provides basic programmatic access to the DALL·E 2 API. 10 | 11 | The intent of this library is to provide researchers with a means to easily layout 12 | results from DALL·E 2 into a jupyter notebook or similar. 13 | 14 | pydalle has two main modes of use: 15 | 16 | - **`pydalle.Dalle`**: This is the main class of the library. It provides a user-friendly 17 | interface to the DALL·E 2 API. [Read more here][4]. 18 | - **`pydalle.imperative.api.labs`**: This module provides a set of lower-level functions that 19 | can be used to interact with the DALL·E 2 API. [Read more here][5]. 20 | 21 | ## Installation 22 | 23 | ### Install with all dependencies 24 | 25 | pip install pydalle[all] # Install all dependencies, recommended for most users 26 | 27 | ### Pick and choose your dependencies 28 | 29 | pip install pydalle # Just install the library with no optional dependencies 30 | pip install pydalle[sync] # Also installs requests (for synchronous networking) 31 | pip install pydalle[async] # Also installs aiohttp and aiofiles (required for async networking / file handling) 32 | pip install pydalle[images] # Also installs Pillow and numpy (required for help with image processing) 33 | 34 | ## Tips 35 | 36 | - Get access by signing up for the [DALL·E 2 waitlist][1]. 37 | 38 | - Ensure your usage of DALL·E 2 abides by DALL·E 2's [content policy][2] and [terms of use][3]. 39 | 40 | - Be mindful about how easy this library makes it for you to spend your money / DALL·E 2 credits. 41 | 42 | ## Getting Started 43 | 44 | Once you have installed pydalle, you can start using it by importing it and creating a `Dalle` object. 45 | You can find all the available methods on the [Dalle class][4]. 46 | 47 | ```python 48 | import os 49 | 50 | from pydalle import Dalle 51 | 52 | OPENAI_USERNAME = os.environ.get('OPENAI_USERNAME') 53 | OPENAI_PASSWORD = os.environ.get('OPENAI_PASSWORD') 54 | 55 | 56 | def main(): 57 | client = Dalle(OPENAI_USERNAME, OPENAI_PASSWORD) 58 | print(f"Client created. {client.get_credit_summary().aggregate_credits} credits remaining...") 59 | tasks = client.get_tasks(limit=5) 60 | print(f"{len(tasks)} tasks found...") 61 | 62 | print("Attempting to download a generation of the first task and show off some built-in helpers...") 63 | if tasks and tasks[0].generations: 64 | example = tasks[0].generations[0].download() 65 | example.to_pil().show() # Convert the image to a PIL image and show it 66 | example.to_pil_masked(x1=0.5, y1=0, x2=1, y2=1).show() # Show a version with left side transparent (for edits) 67 | example.to_pil_padded(0.5).show() # Show w/ 50% padding around the image, centered at (50%, 50%) 68 | example.to_pil_padded(0.4, cx=0.25, cy=0.25).show() # Show w/ 40% padding, centered at (25%, 25%) 69 | 70 | print("Attempting to do a text2im task...") 71 | completed_text2im_task = client.text2im("A cute cat") 72 | for image in completed_text2im_task.download(): 73 | image.to_pil().show() 74 | 75 | print("Attempting to create variations task on the first cat...") 76 | first_generation = completed_text2im_task.generations[0] 77 | completed_variation_task = first_generation.variations() 78 | first_variation = completed_variation_task.generations[0] 79 | first_image = first_variation.download() 80 | first_image.to_pil().show() 81 | 82 | print("Attempting to create inpainting task and showing the mask...") 83 | # Make the right-side of the image transparent 84 | mask = first_image.to_pil_masked(x1=0.5, y1=0, x2=1, y2=1) 85 | mask.show("inpainting mask") 86 | completed_inpainting_task = first_generation.inpainting("A cute cat, with a dark side", mask) 87 | for image in completed_inpainting_task.download(): 88 | image.to_pil().show() 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | 94 | ``` 95 | 96 | For an equivalent async code example, see [examples/dev_client_async.py](./examples/dev_client_async.py). 97 | 98 | For examples of the low-level API and using this in a notebook, see 99 | the [examples/low_level](./examples/low_level) directory. 100 | 101 | [1]: https://labs.openai.com/waitlist 102 | 103 | [2]: https://labs.openai.com/policies/content-policy 104 | 105 | [3]: https://labs.openai.com/policies/terms 106 | 107 | [4]: https://pydalle.readthedocs.io/en/latest/pydalle.imperative.client.html#pydalle.imperative.client.dalle.Dalle 108 | 109 | [5]: https://pydalle.readthedocs.io/en/latest/pydalle.imperative.api.html#module-pydalle.imperative.api.labs 110 | -------------------------------------------------------------------------------- /pydalle/functional/api/request/labs.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains functions which are used to construct requests to the OpenAI Labs API. 3 | """ 4 | 5 | import json 6 | from typing import Optional, List 7 | 8 | from pydalle.functional.api.response.labs import TaskType 9 | from pydalle.functional.assumptions import OPENAI_LABS_TASKS_URL, OPENAI_LABS_LOGIN_URL, \ 10 | OPENAI_LABS_TASK_URL_TEMPLATE,OPENAI_LABS_GENERATION_URL_TEMPLATE, OPENAI_LABS_GENERATION_DOWNLOAD_URL_TEMPLATE, \ 11 | OPENAI_LABS_GENERATION_SHARE_URL_TEMPLATE, OPENAI_LABS_COLLECTION_GENERATION_URL_TEMPLATE, \ 12 | OPENAI_LABS_GENERATION_FLAG_URL_TEMPLATE, OPENAI_LABS_BILLING_CREDIT_SUMMARY_URL 13 | from pydalle.functional.types import HttpRequest 14 | from pydalle.functional.utils import filter_none 15 | 16 | 17 | def get_task_request(bearer_token: str, task_id: str, sleep: Optional[float] = None) -> HttpRequest: 18 | return HttpRequest(method="get", 19 | url=OPENAI_LABS_TASK_URL_TEMPLATE % task_id, 20 | headers={"Authorization": f"Bearer {bearer_token}"}, 21 | sleep=sleep) 22 | 23 | 24 | def get_tasks_request(bearer_token: str, limit: Optional[int] = None, from_ts: Optional[int] = None, sleep: Optional[float] = None) -> HttpRequest: 25 | return HttpRequest(method="get", 26 | url=OPENAI_LABS_TASKS_URL, 27 | params=filter_none({"from_ts": from_ts, "limit": limit}), 28 | headers={"Authorization": f"Bearer {bearer_token}"}, 29 | sleep=sleep) 30 | 31 | 32 | def get_generation_request(bearer_token: str, generation_id: str, sleep: Optional[float] = None) -> HttpRequest: 33 | return HttpRequest(method="get", 34 | url=OPENAI_LABS_GENERATION_URL_TEMPLATE % generation_id, 35 | headers={"Authorization": f"Bearer {bearer_token}"}, 36 | sleep=sleep) 37 | 38 | def login_request(access_token: str, sleep: Optional[float] = None) -> HttpRequest: 39 | return HttpRequest(method="post", url=OPENAI_LABS_LOGIN_URL, headers={"Authorization": f"Bearer {access_token}"}, 40 | sleep=sleep) 41 | 42 | 43 | def create_task_request(bearer_token: str, 44 | task_type: TaskType, 45 | batch_size: int, 46 | caption: Optional[str] = None, 47 | parent_id_or_image: Optional[str] = None, 48 | masked_image: Optional[str] = None, 49 | sleep: Optional[float] = None) -> HttpRequest: 50 | image_key = _classify_image_parameter(parent_id_or_image) 51 | return HttpRequest(method="post", 52 | url=OPENAI_LABS_TASKS_URL, 53 | data=json.dumps( 54 | filter_none( 55 | {"task_type": task_type, "prompt": 56 | filter_none({"caption": caption, 57 | "batch_size": batch_size, 58 | image_key: parent_id_or_image, 59 | "masked_image": masked_image})}) 60 | ), 61 | headers={"Authorization": f"Bearer {bearer_token}", 62 | "Content-Type": "application/json"}, 63 | sleep=sleep) 64 | 65 | 66 | def download_generation_request(bearer_token: str, generation_id: str, sleep: Optional[float] = None) -> HttpRequest: 67 | return HttpRequest(method="get", 68 | url=OPENAI_LABS_GENERATION_DOWNLOAD_URL_TEMPLATE % generation_id, 69 | headers={"Authorization": f"Bearer {bearer_token}"}, 70 | decode=False, 71 | sleep=sleep) 72 | 73 | 74 | def share_generation_request(bearer_token: str, generation_id: str, sleep: Optional[float] = None) -> HttpRequest: 75 | return HttpRequest(method="post", 76 | url=OPENAI_LABS_GENERATION_SHARE_URL_TEMPLATE % generation_id, 77 | headers={"Authorization": f"Bearer {bearer_token}"}, 78 | sleep=sleep) 79 | 80 | 81 | def save_generations_request(bearer_token: str, generation_ids: List[str], collection_id_or_alias: str, 82 | sleep: Optional[float] = None) -> HttpRequest: 83 | return HttpRequest(method="post", 84 | url=OPENAI_LABS_COLLECTION_GENERATION_URL_TEMPLATE % collection_id_or_alias, 85 | data=json.dumps({"generation_ids": generation_ids}), 86 | headers={"Authorization": f"Bearer {bearer_token}", 87 | "Content-Type": "application/json"}, 88 | sleep=sleep) 89 | 90 | 91 | def flag_generation_request(bearer_token: str, generation_id: str, description: str, 92 | sleep: Optional[float] = None) -> HttpRequest: 93 | return HttpRequest(method="post", 94 | url=OPENAI_LABS_GENERATION_FLAG_URL_TEMPLATE % generation_id, 95 | data=json.dumps({"description": description}), 96 | headers={"Authorization": f"Bearer {bearer_token}", 97 | "Content-Type": "application/json"}, 98 | sleep=sleep) 99 | 100 | 101 | def get_credit_summary_request(bearer_token: str, sleep: Optional[float] = None) -> HttpRequest: 102 | return HttpRequest(method="get", 103 | url=OPENAI_LABS_BILLING_CREDIT_SUMMARY_URL, 104 | headers={"Authorization": f"Bearer {bearer_token}"}, 105 | sleep=sleep) 106 | 107 | 108 | def _classify_image_parameter(parent_id_or_image): 109 | if parent_id_or_image is not None: 110 | if parent_id_or_image.startswith("generation-"): 111 | return "parent_generation_id" 112 | elif parent_id_or_image.startswith("prompt-"): 113 | return "parent_prompt_id" 114 | return "image" 115 | -------------------------------------------------------------------------------- /examples/low_level/dev.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "pycharm": { 8 | "name": "#%%\n" 9 | } 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "import getpass\n", 14 | "from pydalle.imperative.api.labs import get_bearer_token\n", 15 | "\n", 16 | "print(\"Attempting to get token for DALL-E...\")\n", 17 | "token = get_bearer_token(\n", 18 | " input(\"OpenAI Username: \"),\n", 19 | " getpass.getpass('OpenAI Password: ')\n", 20 | ")\n", 21 | "print(\"Token:\", token[:5] + (\"*\" * (len(token) - 5)))" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": { 28 | "pycharm": { 29 | "name": "#%%\n" 30 | } 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "from pydalle.imperative.api.labs import get_tasks\n", 35 | "\n", 36 | "from_ts = 0\n", 37 | "print(\"Getting tasks for tasks starting from timestamp\", from_ts)\n", 38 | "tasks = get_tasks(token, from_ts)\n", 39 | "print(\"# of Tasks:\", len(tasks))" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "pycharm": { 47 | "name": "#%%\n" 48 | } 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "import io\n", 53 | "from pydalle.imperative.api.labs import download_generation\n", 54 | "from typing import List\n", 55 | "import matplotlib.pyplot as plt\n", 56 | "from PIL import Image as PILImage\n", 57 | "from PIL.Image import Image as PILImageType\n", 58 | "import textwrap, os\n", 59 | "\n", 60 | "def display_images(\n", 61 | " images: List[PILImageType],\n", 62 | " columns=4, width=20, height=8, max_images=24,\n", 63 | " label_wrap_length=50, label_font_size=8,\n", 64 | " disable_axis=True):\n", 65 | "\n", 66 | " if not images:\n", 67 | " print(\"No images to display.\")\n", 68 | " return\n", 69 | "\n", 70 | " if len(images) > max_images:\n", 71 | " print(f\"Showing {max_images} images of {len(images)}:\")\n", 72 | " images=images[0:max_images]\n", 73 | "\n", 74 | " height = max(height, int(len(images)/columns) * height)\n", 75 | " plt.figure(figsize=(width, height))\n", 76 | " for i, image in enumerate(images):\n", 77 | "\n", 78 | " plt.subplot(int(len(images) / columns + 1), columns, i + 1)\n", 79 | " if disable_axis:\n", 80 | " plt.axis('off')\n", 81 | " plt.imshow(image)\n", 82 | "\n", 83 | " if hasattr(image, 'filename'):\n", 84 | " title=image.filename\n", 85 | " if title.endswith(\"/\"): title = title[0:-1]\n", 86 | " title=os.path.basename(title)\n", 87 | " title=textwrap.wrap(title, label_wrap_length)\n", 88 | " title=\"\\n\".join(title)\n", 89 | " plt.title(title, fontsize=label_font_size)\n", 90 | "\n", 91 | "def generation_to_pil(g):\n", 92 | " img = io.BytesIO(download_generation(token, g.id))\n", 93 | " return PILImage.open(img)\n", 94 | "\n", 95 | "def display_task_generations(t):\n", 96 | " images = []\n", 97 | " for generation in t.generations:\n", 98 | " images.append(generation_to_pil(generation))\n", 99 | " display_images(images)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": { 106 | "pycharm": { 107 | "name": "#%%\n" 108 | } 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "if tasks:\n", 113 | " display_task_generations(tasks[0])\n", 114 | "else:\n", 115 | " print(\"No tasks found.\")" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": { 122 | "pycharm": { 123 | "name": "#%%\n" 124 | } 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "from pydalle.imperative.api.labs import create_text2im_task\n", 129 | "\n", 130 | "print(\"Attempting to create text2im task...\")\n", 131 | "pending_task = create_text2im_task(token, \"A cute cat\")\n", 132 | "print(pending_task)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "pycharm": { 140 | "name": "#%%\n" 141 | } 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "from pydalle.imperative.api.labs import poll_for_task_completion\n", 146 | "\n", 147 | "print(\"Waiting for task to complete...\")\n", 148 | "task = poll_for_task_completion(token, pending_task.id)\n", 149 | "display_task_generations(task)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": { 156 | "pycharm": { 157 | "name": "#%%\n" 158 | } 159 | }, 160 | "outputs": [], 161 | "source": [ 162 | "from pydalle.imperative.api.labs import create_variations_task\n", 163 | "\n", 164 | "print(\"Attempting to create variations task...\")\n", 165 | "pending_task = create_variations_task(token, task.generations[0].id)\n", 166 | "print(pending_task)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": { 173 | "pycharm": { 174 | "name": "#%%\n" 175 | } 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "print(\"Waiting for task to complete...\")\n", 180 | "task = poll_for_task_completion(token, pending_task.id)\n", 181 | "display_task_generations(task)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": { 188 | "pycharm": { 189 | "name": "#%%\n" 190 | } 191 | }, 192 | "outputs": [], 193 | "source": [ 194 | "import base64\n", 195 | "from io import BytesIO\n", 196 | "\n", 197 | "# Convert generation to a cropped base64 PNG string\n", 198 | "image = generation_to_pil(task.generations[0])\n", 199 | "image = image.convert(\"RGBA\")\n", 200 | "for i in range(image.width):\n", 201 | " if i > image.width / 2:\n", 202 | " for j in range(image.height):\n", 203 | " image.putpixel((i, j), (0, 0, 0, 0))\n", 204 | "with BytesIO() as buffer:\n", 205 | " image.save(buffer, format=\"PNG\")\n", 206 | " base64_png = base64.b64encode(buffer.getvalue()).decode()\n", 207 | "image" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": { 214 | "pycharm": { 215 | "name": "#%%\n" 216 | } 217 | }, 218 | "outputs": [], 219 | "source": [ 220 | "from pydalle.imperative.api.labs import create_inpainting_task\n", 221 | "\n", 222 | "\n", 223 | "print(\"Attempting to create inpainting task...\")\n", 224 | "pending_task = create_inpainting_task(token, caption=\"A cute cat, with a dark side\",\n", 225 | " masked_image=base64_png,\n", 226 | " parent_id_or_image=task.generations[0].id)\n", 227 | "print(pending_task)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": { 234 | "pycharm": { 235 | "name": "#%%\n" 236 | } 237 | }, 238 | "outputs": [], 239 | "source": [ 240 | "print(\"Waiting for task to complete...\")\n", 241 | "task = poll_for_task_completion(token, pending_task.id)\n", 242 | "display_task_generations(task)" 243 | ] 244 | } 245 | ], 246 | "metadata": { 247 | "kernelspec": { 248 | "display_name": "Python 3 (ipykernel)", 249 | "language": "python", 250 | "name": "python3" 251 | }, 252 | "language_info": { 253 | "codemirror_mode": { 254 | "name": "ipython", 255 | "version": 3 256 | }, 257 | "file_extension": ".py", 258 | "mimetype": "text/x-python", 259 | "name": "python", 260 | "nbconvert_exporter": "python", 261 | "pygments_lexer": "ipython3", 262 | "version": "3.8.13" 263 | } 264 | }, 265 | "nbformat": 4, 266 | "nbformat_minor": 1 267 | } -------------------------------------------------------------------------------- /pydalle/functional/api/flow/labs.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains functions which are used handle the flow of requests to the labs API. 3 | """ 4 | 5 | from typing import Optional, List 6 | 7 | from pydalle.functional.api.request.labs import login_request, get_tasks_request, create_task_request, \ 8 | get_task_request, download_generation_request, save_generations_request, share_generation_request, \ 9 | flag_generation_request, get_credit_summary_request, get_generation_request 10 | from pydalle.functional.api.response.labs import TaskList, TaskType, Task, Generation, Collection, Login, UserFlag, \ 11 | BillingInfo 12 | from pydalle.functional.types import HttpFlow, FlowError, JsonDict 13 | from pydalle.functional.utils import send_from, try_json 14 | 15 | DEFAULT_INTERVAL = 1.0 16 | 17 | 18 | def get_login_info_json_flow(access_token: str) -> HttpFlow[JsonDict]: 19 | if access_token.startswith("sess-"): 20 | raise ValueError("Invalid access token: It appears you've passed in a session " 21 | "token instead of the expected access token") 22 | r = yield login_request(access_token) 23 | while r.status_code == 504: 24 | r = yield login_request(access_token, sleep=DEFAULT_INTERVAL) 25 | return try_json(r, status_code=200) 26 | 27 | 28 | def get_bearer_token_flow(access_token: str) -> HttpFlow[str]: 29 | def fn(response): 30 | try: 31 | return response["user"]["session"]["sensitive_id"] 32 | except Exception as e: 33 | raise FlowError("Failed to get bearer token from response", response) from e 34 | 35 | return send_from(get_login_info_json_flow(access_token), fn) 36 | 37 | 38 | def get_login_info_flow(access_token: str) -> HttpFlow[Login]: 39 | def fn(response): 40 | try: 41 | return Login.from_dict(response) 42 | except Exception as e: 43 | raise FlowError("Failed to parse response", response) from e 44 | 45 | return send_from(get_login_info_json_flow(access_token), fn) 46 | 47 | 48 | 49 | def get_tasks_flow(bearer_token: str, limit: Optional[int] = None, from_ts: Optional[int] = None) -> HttpFlow[TaskList]: 50 | r = yield get_tasks_request(bearer_token, limit, from_ts) 51 | while r.status_code == 504: 52 | r = yield get_tasks_request(bearer_token, limit, from_ts, sleep=DEFAULT_INTERVAL) 53 | j = try_json(r, status_code=200) 54 | try: 55 | return TaskList.from_dict(j) 56 | except Exception as e: 57 | raise FlowError("Failed to parse response", r) from e 58 | 59 | 60 | def _create_task_flow(bearer_token: str, 61 | task_type: TaskType, 62 | batch_size: int, 63 | caption: Optional[str] = None, # for 'text2im' and 'inpainting' task types 64 | parent_id_or_image: Optional[str] = None, # for 'variations' and 'inpainting' task types 65 | masked_image: Optional[str] = None # for 'inpainting' task type 66 | ) -> HttpFlow[Task]: 67 | request = create_task_request(bearer_token, task_type, 68 | caption=caption, batch_size=batch_size, 69 | parent_id_or_image=parent_id_or_image, 70 | masked_image=masked_image) 71 | r = yield request 72 | while r.status_code == 504: 73 | request.sleep = DEFAULT_INTERVAL 74 | r = yield request 75 | j = try_json(r, status_code=200) 76 | try: 77 | return Task.from_dict(j) 78 | except Exception as e: 79 | raise FlowError("Failed to parse response", r) from e 80 | 81 | 82 | def create_text2im_task_flow(bearer_token: str, caption: str, batch_size: int = 4) -> HttpFlow[Task]: 83 | return _create_task_flow(bearer_token, caption=caption, task_type="text2im", batch_size=batch_size) 84 | 85 | 86 | def create_variations_task_flow(bearer_token: str, parent_id_or_image: str, batch_size: int = 3) -> HttpFlow[Task]: 87 | return _create_task_flow(bearer_token, task_type="variations", batch_size=batch_size, 88 | parent_id_or_image=parent_id_or_image) 89 | 90 | 91 | def create_inpainting_task_flow(bearer_token: str, caption: str, 92 | masked_image: str, 93 | parent_id_or_image: Optional[str] = None, 94 | batch_size: int = 3) -> HttpFlow[Task]: 95 | if parent_id_or_image is None: 96 | parent_id_or_image = masked_image 97 | return _create_task_flow(bearer_token, 98 | task_type="inpainting", 99 | batch_size=batch_size, 100 | caption=caption, 101 | parent_id_or_image=parent_id_or_image, 102 | masked_image=masked_image) 103 | 104 | 105 | def get_task_flow(bearer_token: str, task_id: str) -> HttpFlow[Task]: 106 | r = yield get_task_request(bearer_token, task_id=task_id) 107 | while r.status_code == 504: 108 | r = yield get_task_request(bearer_token, task_id=task_id, sleep=DEFAULT_INTERVAL) 109 | j = try_json(r, status_code=200) 110 | try: 111 | return Task.from_dict(j) 112 | except Exception as e: 113 | raise FlowError("Failed to parse response", r) from e 114 | 115 | 116 | def get_generation_flow(bearer_token: str, generation_id: str) -> HttpFlow[Generation]: 117 | r = yield get_generation_request(bearer_token, generation_id=generation_id) 118 | while r.status_code == 504: 119 | r = yield get_generation_request(bearer_token, generation_id=generation_id, sleep=DEFAULT_INTERVAL) 120 | j = try_json(r, status_code=200) 121 | try: 122 | return Generation.from_dict(j) 123 | except Exception as e: 124 | raise FlowError("Failed to parse response", r) from e 125 | 126 | 127 | def poll_for_task_completion_flow(bearer_token: str, 128 | task_id: str, 129 | interval: float = DEFAULT_INTERVAL, 130 | _max_attempts: int = 1000) -> HttpFlow[Task]: 131 | r = yield get_task_request(bearer_token, task_id=task_id) 132 | for _ in range(_max_attempts): 133 | if r.status_code != 504: 134 | j = try_json(r, status_code=200) 135 | if j["status"] != "pending": 136 | try: 137 | return Task.from_dict(j) 138 | except Exception as e: 139 | raise FlowError("Failed to parse response", r) from e 140 | r = yield get_task_request(bearer_token, task_id=task_id, sleep=interval) 141 | raise FlowError("Failed to poll for task completion: Reached max attempts", r) 142 | 143 | 144 | def download_generation_flow(bearer_token: str, generation_id: str) -> HttpFlow[bytes]: 145 | r = yield download_generation_request(bearer_token, generation_id) 146 | while r.status_code == 504: 147 | r = yield download_generation_request(bearer_token, generation_id, sleep=DEFAULT_INTERVAL) 148 | if r.status_code != 200: 149 | raise FlowError("Failed to download generation", r) 150 | return r.content 151 | 152 | 153 | def share_generation_flow(bearer_token: str, generation_id: str) -> HttpFlow[Generation]: 154 | r = yield share_generation_request(bearer_token, generation_id) 155 | while r.status_code == 504: 156 | r = yield share_generation_request(bearer_token, generation_id, sleep=DEFAULT_INTERVAL) 157 | if r.status_code != 200: 158 | raise FlowError("Failed to share generation", r) 159 | j = try_json(r, status_code=200) 160 | try: 161 | return Generation.from_dict(j) 162 | except Exception as e: 163 | raise FlowError("Failed to parse response", r) from e 164 | 165 | 166 | def flag_generation_flow(bearer_token: str, generation_id: str, description: str) -> HttpFlow[UserFlag]: 167 | r = yield flag_generation_request(bearer_token, generation_id, description) 168 | while r.status_code == 504: 169 | r = yield flag_generation_request(bearer_token, generation_id, description, sleep=DEFAULT_INTERVAL) 170 | if r.status_code != 200: 171 | raise FlowError("Failed to flag generation", r) 172 | j = try_json(r, status_code=200) 173 | try: 174 | return UserFlag.from_dict(j) 175 | except Exception as e: 176 | raise FlowError("Failed to parse response", r) from e 177 | 178 | 179 | def save_generations_flow(bearer_token: str, generation_ids: List[str], 180 | collection_id_or_alias: str) -> HttpFlow[Collection]: 181 | if isinstance(generation_ids, str): 182 | generation_ids = [generation_ids] 183 | r = yield save_generations_request(bearer_token, generation_ids, collection_id_or_alias) 184 | while r.status_code == 504: 185 | r = yield save_generations_request(bearer_token, generation_ids, collection_id_or_alias, sleep=DEFAULT_INTERVAL) 186 | if r.status_code != 200: 187 | raise FlowError("Failed to save generations", r) 188 | j = try_json(r, status_code=200) 189 | try: 190 | return Collection.from_dict(j) 191 | except Exception as e: 192 | raise FlowError("Failed to parse response", r) from e 193 | 194 | 195 | def get_credit_summary_flow(bearer_token: str) -> HttpFlow[BillingInfo]: 196 | r = yield get_credit_summary_request(bearer_token) 197 | while r.status_code == 504: 198 | r = yield get_credit_summary_request(bearer_token, sleep=DEFAULT_INTERVAL) 199 | if r.status_code != 200: 200 | raise FlowError("Failed to get credit summary", r) 201 | j = try_json(r, status_code=200) 202 | try: 203 | return BillingInfo.from_dict(j) 204 | except Exception as e: 205 | raise FlowError("Failed to parse response", r) from e 206 | -------------------------------------------------------------------------------- /pydalle/functional/api/response/labs.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains dataclasses which represent the Labs API's response objects. 3 | """ 4 | 5 | from dataclasses import dataclass 6 | from typing import List, Optional, Literal, Union, Any 7 | 8 | from pydalle.functional.assumptions import OPENAI_LABS_SHARE_URL_TEMPLATE 9 | 10 | 11 | @dataclass 12 | class TaskList: 13 | raw: dict 14 | object: Literal["list"] 15 | data: List['Task'] 16 | 17 | @classmethod 18 | def from_dict(cls, d: dict) -> 'TaskList': 19 | return cls(object=d["object"], 20 | data=[Task.from_dict(t) for t in d["data"]], 21 | raw=d) 22 | 23 | def __iter__(self): 24 | return iter(self.data) 25 | 26 | def __len__(self): 27 | return len(self.data) 28 | 29 | def __getitem__(self, index): 30 | return self.data[index] 31 | 32 | 33 | TaskType = Union[Literal["inpainting"], Literal["text2im"], Literal["variations"]] 34 | 35 | 36 | @dataclass 37 | class Task: 38 | raw: dict 39 | object: Literal["task"] 40 | id: str 41 | created: int 42 | task_type: TaskType 43 | status: Union[Literal["succeeded"], Literal["pending"], Literal["rejected"]] 44 | status_information: 'StatusInformation' 45 | prompt_id: str 46 | generations: Optional["GenerationList"] 47 | prompt: "Prompt" 48 | 49 | @classmethod 50 | def from_dict(cls, d: dict) -> 'Task': 51 | return cls(object=d["object"], 52 | id=d["id"], 53 | created=d["created"], 54 | task_type=d["task_type"], 55 | status=d["status"], 56 | status_information=StatusInformation.from_dict(d["status_information"]), 57 | prompt_id=d["prompt_id"], 58 | generations=GenerationList.from_dict(d["generations"]) if "generations" in d else None, 59 | prompt=Prompt.from_dict(d["prompt"]), 60 | raw=d) 61 | 62 | def __str__(self): 63 | return f"Task(id={self.id}, task_type={self.task_type}, status={self.status})" 64 | 65 | 66 | @dataclass 67 | class StatusInformation: 68 | raw: dict 69 | type: Optional[Literal["error"]] = None 70 | message: Optional[Literal["Your task failed as a result of our safety system."]] = None 71 | code: Optional[Literal["task_failed_text_safety_system"]] = None 72 | 73 | @classmethod 74 | def from_dict(cls, d: dict) -> 'StatusInformation': 75 | return cls(type=d.get("type"), 76 | message=d.get("message"), 77 | code=d.get("code"), 78 | raw=d) 79 | 80 | 81 | @dataclass 82 | class Prompt: 83 | raw: dict 84 | id: str 85 | object: Literal["prompt"] 86 | created: int 87 | prompt_type: Union[Literal["CaptionlessImagePrompt"], 88 | Literal["CaptionImagePrompt"], 89 | Literal["CaptionPrompt"]] 90 | prompt: "PromptData" 91 | parent_generation_id: Optional[str] = None 92 | 93 | @classmethod 94 | def from_dict(cls, d: dict) -> 'Prompt': 95 | return cls(id=d["id"], 96 | object=d["object"], 97 | created=d["created"], 98 | prompt_type=d["prompt_type"], 99 | prompt=PromptData.from_dict(d["prompt"]), 100 | parent_generation_id=d.get("parent_generation_id"), 101 | raw=d) 102 | 103 | 104 | @dataclass 105 | class PromptData: 106 | raw: dict 107 | caption: Optional[str] = None 108 | image_path: Optional[str] = None 109 | masked_image_path: Optional[str] = None 110 | 111 | @classmethod 112 | def from_dict(cls, d: dict) -> 'PromptData': 113 | return cls(caption=d.get("caption"), 114 | image_path=d.get("image_path"), 115 | masked_image_path=d.get("masked_image_path"), 116 | raw=d) 117 | 118 | 119 | @dataclass 120 | class GenerationList: 121 | raw: dict 122 | object: Literal["list"] 123 | data: List["Generation"] 124 | 125 | @classmethod 126 | def from_dict(cls, d: dict) -> 'GenerationList': 127 | return cls(object=d["object"], 128 | data=[Generation.from_dict(g) for g in d["data"]], 129 | raw=d) 130 | 131 | def __iter__(self): 132 | return iter(self.data) 133 | 134 | def __len__(self): 135 | return len(self.data) 136 | 137 | def __getitem__(self, index): 138 | return self.data[index] 139 | 140 | 141 | @dataclass 142 | class Generation: 143 | raw: dict 144 | id: str 145 | object: Literal["generation"] 146 | created: int 147 | generation_type: Literal["ImageGeneration"] 148 | generation: "GenerationData" 149 | task_id: str 150 | prompt_id: str 151 | is_public: bool 152 | 153 | @classmethod 154 | def from_dict(cls, d: dict) -> 'Generation': 155 | return cls(id=d["id"], 156 | object=d["object"], 157 | created=d["created"], 158 | generation_type=d["generation_type"], 159 | generation=GenerationData.from_dict(d["generation"]), 160 | task_id=d["task_id"], 161 | prompt_id=d["prompt_id"], 162 | is_public=d["is_public"], 163 | raw=d) 164 | 165 | @property 166 | def share_url(self): 167 | # The generation must be public for the share url to be available 168 | return OPENAI_LABS_SHARE_URL_TEMPLATE % (self.id.replace("generation-", "", 1)) 169 | 170 | 171 | @dataclass 172 | class GenerationData: 173 | raw: dict 174 | image_path: str 175 | 176 | @classmethod 177 | def from_dict(cls, d: dict) -> 'GenerationData': 178 | return cls(image_path=d["image_path"], 179 | raw=d) 180 | 181 | 182 | @dataclass 183 | class Collection: 184 | raw: dict 185 | object: Literal["collection"] 186 | id: str 187 | created: int 188 | name: str 189 | description: str 190 | is_public: bool 191 | alias: str 192 | 193 | @classmethod 194 | def from_dict(cls, d: dict) -> 'Collection': 195 | return cls(object=d["object"], 196 | id=d["id"], 197 | created=d["created"], 198 | name=d["name"], 199 | description=d["description"], 200 | is_public=d["is_public"], 201 | alias=d["alias"], 202 | raw=d) 203 | 204 | 205 | @dataclass 206 | class Breakdown: 207 | raw: dict 208 | free: int 209 | 210 | @classmethod 211 | def from_dict(cls, d: dict) -> 'Breakdown': 212 | return cls(free=d["free"], 213 | raw=d) 214 | 215 | 216 | @dataclass 217 | class BillingInfo: 218 | raw: dict 219 | aggregate_credits: int 220 | next_grant_ts: int 221 | breakdown: Breakdown 222 | 223 | @classmethod 224 | def from_dict(cls, d: dict) -> 'BillingInfo': 225 | return cls(aggregate_credits=d["aggregate_credits"], 226 | next_grant_ts=d["next_grant_ts"], 227 | breakdown=Breakdown.from_dict(d["breakdown"]), 228 | raw=d) 229 | 230 | 231 | @dataclass 232 | class Features: 233 | raw: dict 234 | public_endpoints: bool 235 | image_uploads: bool 236 | 237 | @classmethod 238 | def from_dict(cls, d: dict) -> 'Features': 239 | return cls(public_endpoints=d["public_endpoints"], 240 | image_uploads=d["image_uploads"], 241 | raw=d) 242 | 243 | 244 | @dataclass 245 | class Organization: 246 | raw: dict 247 | object: Literal["organization"] 248 | id: str 249 | created: int 250 | title: str 251 | name: str 252 | description: str 253 | personal: bool 254 | is_default: bool 255 | role: str 256 | groups: List[str] 257 | 258 | @classmethod 259 | def from_dict(cls, d: dict) -> 'Organization': 260 | return cls(object=d["object"], 261 | id=d["id"], 262 | created=d["created"], 263 | title=d["title"], 264 | name=d["name"], 265 | description=d["description"], 266 | personal=d["personal"], 267 | is_default=d["is_default"], 268 | role=d["role"], 269 | groups=d["groups"], 270 | raw=d) 271 | 272 | 273 | @dataclass 274 | class OrganizationList: 275 | raw: dict 276 | object: Literal["list"] 277 | data: List[Organization] 278 | 279 | @classmethod 280 | def from_dict(cls, d: dict) -> 'OrganizationList': 281 | return cls(object=d["object"], 282 | data=[Organization.from_dict(o) for o in d["data"]], 283 | raw=d) 284 | 285 | 286 | @dataclass 287 | class Session: 288 | raw: dict 289 | sensitive_id: str 290 | object: Literal["session"] 291 | created: int 292 | last_use: int 293 | publishable: bool 294 | 295 | @classmethod 296 | def from_dict(cls, d: dict) -> 'Session': 297 | return cls(sensitive_id=d["sensitive_id"], 298 | object=d["object"], 299 | created=d["created"], 300 | last_use=d["last_use"], 301 | publishable=d["publishable"], 302 | raw=d) 303 | 304 | 305 | @dataclass 306 | class User: 307 | raw: dict 308 | object: Literal["user"] 309 | id: str 310 | email: str 311 | name: str 312 | picture: str 313 | created: int 314 | accepted_terms_at: int 315 | session: Session 316 | groups: List[str] 317 | orgs: OrganizationList 318 | intercom_hash: str 319 | accepted_terms: int 320 | seen_upload_guidelines: int 321 | seen_billing_onboarding: int 322 | 323 | @classmethod 324 | def from_dict(cls, d: dict) -> 'User': 325 | return cls(object=d["object"], 326 | id=d["id"], 327 | email=d["email"], 328 | name=d["name"], 329 | picture=d["picture"], 330 | created=d["created"], 331 | accepted_terms_at=d["accepted_terms_at"], 332 | session=Session.from_dict(d["session"]), 333 | groups=d["groups"], 334 | orgs=OrganizationList.from_dict(d["orgs"]), 335 | intercom_hash=d["intercom_hash"], 336 | accepted_terms=d["accepted_terms"], 337 | seen_upload_guidelines=d["seen_upload_guidelines"], 338 | seen_billing_onboarding=d["seen_billing_onboarding"], 339 | raw=d) 340 | 341 | 342 | @dataclass 343 | class Login: 344 | raw: dict 345 | object: Literal["login"] 346 | user: User 347 | invites: List[Any] 348 | features: Features 349 | billing_info: BillingInfo 350 | 351 | @classmethod 352 | def from_dict(cls, d: dict) -> 'Login': 353 | return cls(object=d["object"], 354 | user=User.from_dict(d["user"]), 355 | invites=d["invites"], 356 | features=Features.from_dict(d["features"]), 357 | billing_info=BillingInfo.from_dict(d["billing_info"]), 358 | raw=d) 359 | 360 | 361 | @dataclass 362 | class UserFlag: 363 | raw: dict 364 | object: Literal["user_flag"] 365 | id: str 366 | created: int 367 | generation_id: str 368 | description: str 369 | 370 | @classmethod 371 | def from_dict(cls, d: dict) -> 'UserFlag': 372 | return cls(object=d["object"], 373 | id=d["id"], 374 | created=d["created"], 375 | generation_id=d["generation_id"], 376 | description=d["description"], 377 | raw=d) 378 | -------------------------------------------------------------------------------- /pydalle/imperative/client/responses.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import asyncio 3 | from functools import cached_property 4 | from typing import Optional, List, Iterator, TYPE_CHECKING, Any, Union, Generator, AsyncGenerator 5 | 6 | from pydalle.functional.api.response.labs import TaskList, Task, Generation, Collection, UserFlag, BillingInfo, \ 7 | TaskType, Prompt, StatusInformation, GenerationData, Breakdown, Login, User, Features, GenerationList 8 | from pydalle.functional.types import HttpRequest, T 9 | from pydalle.imperative.outside import files 10 | from pydalle.imperative.outside.internet import request, request_async 11 | from pydalle.imperative.outside.pil import PILImageType, pil_image_to_png_bytes, image_bytes_to_png_bytes, \ 12 | bytes_to_pil_image, bytes_to_masked_pil_image, bytes_to_padded_pil_image 13 | from pydalle.imperative.outside.np import ndarray, pil_image_to_np_array, np_array_to_pil_image 14 | 15 | if TYPE_CHECKING: 16 | from pydalle.imperative.client.dalle import Dalle 17 | 18 | PNG_PREFIX = b"\x89PNG\r\n\x1a\n" 19 | PNG_BASE64_PREFIX = b"iVBORw0KGgo" 20 | PNG_BASE64_PREFIX_STR = PNG_BASE64_PREFIX.decode() 21 | 22 | 23 | class WrappedResponse: 24 | """ 25 | Generic wrapper for responses from the Dalle API. This has the following benefits: 26 | 27 | * The dalle client is passed in as an argument to the constructor, so that the client can be used to 28 | make requests related to the wrapped response. 29 | 30 | * Getters for attributes in the lower-level response helps protect against changes in the API. 31 | """ 32 | 33 | def __init__(self, wrapped: T, dalle: 'Dalle'): 34 | self.wrapped: T = wrapped 35 | self.dalle = dalle 36 | 37 | def __repr__(self): 38 | return f"{self.__class__.__name__}({self.wrapped!r})" 39 | 40 | 41 | class WrappedTaskList(WrappedResponse): 42 | wrapped: TaskList 43 | 44 | def __init__(self, task_list: TaskList, dalle: 'Dalle'): 45 | super().__init__(task_list, dalle) 46 | 47 | def __iter__(self) -> Iterator['WrappedTask']: 48 | yield from (WrappedTask(task, self.dalle) for task in self.wrapped) 49 | 50 | def __getitem__(self, index) -> 'WrappedTask': 51 | return WrappedTask(self.wrapped[index], self.dalle) 52 | 53 | def __len__(self) -> int: 54 | return len(self.wrapped) 55 | 56 | 57 | class WrappedTask(WrappedResponse): 58 | wrapped: Task 59 | 60 | def __init__(self, task: Task, dalle: 'Dalle'): 61 | super().__init__(task, dalle) 62 | 63 | @property 64 | def id(self) -> str: 65 | return self.wrapped.id 66 | 67 | @property 68 | def created(self) -> int: 69 | return self.wrapped.created 70 | 71 | @property 72 | def task_type(self) -> TaskType: 73 | return self.wrapped.task_type 74 | 75 | @property 76 | def status(self) -> str: 77 | return self.wrapped.status 78 | 79 | @property 80 | def status_information(self) -> 'StatusInformation': 81 | return self.wrapped.status_information 82 | 83 | @property 84 | def prompt_id(self) -> str: 85 | return self.wrapped.prompt_id 86 | 87 | @property 88 | def generations(self) -> Optional['WrappedGenerationList']: 89 | if self.wrapped.generations is None: 90 | return None 91 | return WrappedGenerationList(self.wrapped.generations, self.dalle) 92 | 93 | @property 94 | def prompt(self) -> 'Prompt': 95 | return self.wrapped.prompt 96 | 97 | @property 98 | def succeeded(self) -> bool: 99 | return self.wrapped.status == 'succeeded' 100 | 101 | @property 102 | def pending(self) -> bool: 103 | return self.wrapped.status == 'pending' 104 | 105 | @property 106 | def rejected(self) -> bool: 107 | return self.wrapped.status == 'rejected' 108 | 109 | def download(self, direct=False) -> Generator['WrappedImage', None, None]: 110 | yield from (generation.download(direct=direct) for generation in self.generations) 111 | 112 | async def download_async(self, direct=False) -> AsyncGenerator['WrappedImage', None]: 113 | download_tasks = [generation.download_async(direct=direct) for generation in self.generations] 114 | for download_task in asyncio.as_completed(download_tasks): 115 | yield await download_task 116 | 117 | def wait(self) -> 'WrappedTask': 118 | return self.dalle.poll_for_task_completion(self.id) 119 | 120 | async def wait_async(self) -> 'WrappedTask': 121 | return await self.dalle.poll_for_task_completion_async(self.id) 122 | 123 | 124 | class WrappedGenerationList: 125 | wrapped: GenerationList 126 | 127 | def __init__(self, generation_list: GenerationList, dalle: 'Dalle'): 128 | self.wrapped = generation_list 129 | self.dalle = dalle 130 | 131 | def __iter__(self) -> Iterator['WrappedGeneration']: 132 | yield from (WrappedGeneration(generation, self.dalle) for generation in self.wrapped) 133 | 134 | def __getitem__(self, index) -> 'WrappedGeneration': 135 | return WrappedGeneration(self.wrapped[index], self.dalle) 136 | 137 | def __len__(self) -> int: 138 | return len(self.wrapped) 139 | 140 | 141 | class WrappedGeneration(WrappedResponse): 142 | wrapped: Generation 143 | 144 | def __init__(self, generation: Generation, dalle: 'Dalle'): 145 | super().__init__(generation, dalle) 146 | 147 | @property 148 | def id(self) -> str: 149 | return self.wrapped.id 150 | 151 | @property 152 | def created(self) -> int: 153 | return self.wrapped.created 154 | 155 | @property 156 | def generation_type(self) -> str: 157 | return self.wrapped.generation_type 158 | 159 | @property 160 | def generation(self) -> 'GenerationData': 161 | return self.wrapped.generation 162 | 163 | @property 164 | def task_id(self) -> str: 165 | return self.wrapped.task_id 166 | 167 | @property 168 | def prompt_id(self) -> str: 169 | return self.wrapped.prompt_id 170 | 171 | @property 172 | def is_public(self) -> bool: 173 | return self.wrapped.is_public 174 | 175 | @property 176 | def direct_image_path(self) -> str: 177 | return self.generation.image_path 178 | 179 | def download(self, direct=False) -> 'WrappedImage': 180 | return self.dalle.download_generation(self, direct=direct) 181 | 182 | def variations(self, wait=True): 183 | return self.dalle.variations(self, wait=wait) 184 | 185 | def inpainting(self, caption: str, masked_image: 'ImageLike', wait=True): 186 | return self.dalle.inpainting(caption=caption, masked_image=masked_image, wait=wait) 187 | 188 | async def download_async(self, direct=False) -> 'WrappedImage': 189 | return await self.dalle.download_generation_async(self, direct=direct) 190 | 191 | async def variations_async(self, wait=True): 192 | return await self.dalle.variations_async(self, wait=wait) 193 | 194 | async def inpainting_async(self, caption: str, masked_image: 'ImageLike', wait=True): 195 | return await self.dalle.inpainting_async(caption=caption, masked_image=masked_image, wait=wait) 196 | 197 | 198 | GenerationLike = Union[WrappedGeneration, Generation, str] 199 | 200 | 201 | def get_generation_id(generation: GenerationLike) -> str: 202 | if isinstance(generation, (WrappedGeneration, Generation)): 203 | return generation.id 204 | if not str(generation).startswith("generation-"): 205 | raise ValueError("Unrecognized generation: {}".format(generation)) 206 | return str(generation) 207 | 208 | 209 | TaskLike = Union[WrappedTask, Task, str] 210 | 211 | 212 | def get_task_id(task: TaskLike) -> str: 213 | if isinstance(task, (WrappedTask, Task)): 214 | return task.id 215 | if not str(task).startswith("task-"): 216 | raise ValueError("Unrecognized task: {}".format(task)) 217 | return str(task) 218 | 219 | 220 | class WrappedCollection(WrappedResponse): 221 | wrapped: Collection 222 | 223 | def __init__(self, collection: Collection, dalle: 'Dalle'): 224 | super().__init__(collection, dalle) 225 | 226 | @property 227 | def id(self) -> str: 228 | return self.wrapped.id 229 | 230 | @property 231 | def created(self) -> int: 232 | return self.wrapped.created 233 | 234 | @property 235 | def name(self) -> str: 236 | return self.wrapped.name 237 | 238 | @property 239 | def description(self) -> str: 240 | return self.wrapped.description 241 | 242 | @property 243 | def is_public(self) -> bool: 244 | return self.wrapped.is_public 245 | 246 | @property 247 | def alias(self) -> str: 248 | return self.wrapped.alias 249 | 250 | 251 | class WrappedImage(WrappedResponse): 252 | wrapped: bytes 253 | 254 | def __init__(self, image: bytes, dalle: 'Dalle', filetype: str = 'png'): 255 | super().__init__(image, dalle) 256 | self.filetype = filetype.lower() 257 | 258 | def __bytes__(self): 259 | return self.png_bytes 260 | 261 | @cached_property 262 | def png_bytes(self) -> bytes: 263 | if self.filetype == 'png': 264 | return self.wrapped 265 | return image_bytes_to_png_bytes(self.wrapped) 266 | 267 | def to_pil(self) -> PILImageType: 268 | """ 269 | Returns a PIL image object for the image. 270 | 271 | :return: A PIL image object. 272 | """ 273 | return bytes_to_pil_image(bytes(self)) 274 | 275 | def to_numpy(self) -> ndarray: 276 | """ 277 | Returns a numpy array for the image. 278 | 279 | :return: A numpy array. 280 | """ 281 | return pil_image_to_np_array(self.to_pil()) 282 | 283 | def to_pil_masked(self, x1: float, y1: float, x2: float, y2: float) -> PILImageType: 284 | """ 285 | Returns a PIL image object for the image, with the given mask applied. 286 | 287 | :param x1: The percentage of the image on the left before the mask starts. 288 | :param y1: The percentage of the image on the top before the mask starts. 289 | :param x2: The percentage of the image on the right after the mask ends. 290 | :param y2: The percentage of the image on the bottom after the mask ends. 291 | 292 | :return: A masked PIL image object. 293 | """ 294 | return bytes_to_masked_pil_image(bytes(self), x1, y1, x2, y2) 295 | 296 | def to_pil_padded(self, p: float, cx: float = 0.5, cy: float = 0.5) -> PILImageType: 297 | """ 298 | Returns a PIL image object for the image, with the given padding applied. 299 | 300 | :param p: The percentage of the image to pad. E.g. 0.5 means the image will be shrunk by 50%. 301 | :param cx: Where the newly shrunk image will be centered horizontally. Default is 0.5, the center. 302 | :param cy: Where the newly shrunk image will be centered vertically. Default is 0.5, the center. 303 | :return: A padded PIL image object. 304 | """ 305 | return bytes_to_padded_pil_image(bytes(self), p, cx, cy) 306 | 307 | 308 | PromptLike = Union[Prompt, str] 309 | 310 | ImageLike = Union[WrappedImage, PILImageType, bytes, str] 311 | ParentLike = Union[ImageLike, GenerationLike, PromptLike] 312 | 313 | 314 | def _get_image_png_base64_no_io(image: ImageLike) -> str: 315 | if isinstance(image, str): 316 | if image.startswith(PNG_BASE64_PREFIX.decode()): 317 | # If it's already a base64 encoded PNG, we're good 318 | return image 319 | try: 320 | # If it's a base64 encoded string in the wrong format, we'll try the PIL trick 321 | decoded = base64.b64decode(image) 322 | if image == base64.b64encode(decoded).decode(): 323 | return base64.b64encode(image_bytes_to_png_bytes(decoded)).decode() 324 | except ValueError: 325 | pass 326 | if isinstance(image, WrappedImage): 327 | return base64.b64encode(image.png_bytes).decode() 328 | if isinstance(image, bytes): 329 | if image.startswith(PNG_BASE64_PREFIX): 330 | # If it's already a base64 encoded PNG, we just need to decode it 331 | return image.decode() 332 | elif image.startswith(PNG_PREFIX): 333 | # If it's a PNG bytes, we'll base64 encode it 334 | return base64.b64encode(image).decode() 335 | # Check if the bytes are base64 encoded. If it is, 336 | # we'll assume it's an image in another format base64 encoded 337 | try: 338 | decoded = base64.b64decode(image) 339 | 340 | if image == base64.b64encode(decoded): 341 | image = decoded 342 | except ValueError: 343 | pass 344 | # So, now it's bytes that are not a PNG or base64 encoded PNG. 345 | # Best guess is that it's an image of some other format. 346 | # If the user has PIL installed, we'll try to convert it to PNG 347 | return base64.b64encode(image_bytes_to_png_bytes(image)).decode() 348 | try: 349 | if isinstance(image, ndarray): 350 | return base64.b64encode(pil_image_to_png_bytes(np_array_to_pil_image(image))).decode() 351 | except ImportError: 352 | pass 353 | try: 354 | if str(image.__class__).startswith(" str: 362 | if result := _get_image_png_base64_no_io(image): 363 | return result 364 | # Maybe it's a URL? 365 | if (lower := image.lower()).startswith("http://") or lower.startswith("https://"): 366 | # If it's a URL, we'll try to download it 367 | r = request(HttpRequest("get", image, headers=headers, decode=False)) 368 | if r.status_code == 200: 369 | return get_image_png_base64(r.content, headers) 370 | raise ValueError(f"Could not download image: {image}") 371 | # Maybe it's a file path? 372 | try: 373 | return get_image_png_base64(files.read_bytes(image), headers) 374 | except FileNotFoundError: 375 | pass 376 | # Out of ideas. Just raise an error 377 | raise ValueError(f"Could not convert image to PNG: {image}") 378 | 379 | 380 | async def get_image_png_base64_async(image: ImageLike, headers: Optional[dict] = None) -> str: 381 | if result := _get_image_png_base64_no_io(image): 382 | return result 383 | if (lower := image.lower()).startswith("http://") or lower.startswith("https://"): 384 | r = await request_async(HttpRequest("get", image, headers=headers, decode=False)) 385 | if r.status_code == 200: 386 | return await get_image_png_base64_async(r.content) 387 | raise ValueError(f"Could not download image: {image}") 388 | try: 389 | return await get_image_png_base64_async(await files.read_bytes_async(image)) 390 | except FileNotFoundError: 391 | pass 392 | raise ValueError(f"Could not convert image to PNG: {image}") 393 | 394 | 395 | def get_parent_id_or_png_base64(parent: ParentLike, headers: Optional[dict]) -> Union[str, bytes]: 396 | if isinstance(parent, (Prompt, Generation, WrappedGeneration)): 397 | return parent.id 398 | if isinstance(parent, str) and parent.startswith("generation-") or parent.startswith("prompt-"): 399 | return parent 400 | return get_image_png_base64(parent, headers) 401 | 402 | 403 | async def get_parent_id_or_png_base64_async(parent: ParentLike, headers: Optional[dict]) -> Union[str, bytes]: 404 | if isinstance(parent, (Prompt, Generation, WrappedGeneration)): 405 | return parent.id 406 | if isinstance(parent, str) and parent.startswith("generation-") or parent.startswith("prompt-"): 407 | return parent 408 | return await get_image_png_base64_async(parent, headers) 409 | 410 | 411 | class WrappedUserFlag(WrappedResponse): 412 | wrapped: UserFlag 413 | 414 | def __init__(self, user_flag: UserFlag, dalle: 'Dalle'): 415 | super().__init__(user_flag, dalle) 416 | 417 | @property 418 | def id(self) -> str: 419 | return self.wrapped.id 420 | 421 | @property 422 | def created(self) -> int: 423 | return self.wrapped.created 424 | 425 | @property 426 | def generation_id(self) -> str: 427 | return self.wrapped.generation_id 428 | 429 | @property 430 | def description(self) -> str: 431 | return self.wrapped.description 432 | 433 | 434 | class WrappedBillingInfo(WrappedResponse): 435 | wrapped: BillingInfo 436 | 437 | def __init__(self, billing_info: BillingInfo, dalle: 'Dalle'): 438 | super().__init__(billing_info, dalle) 439 | 440 | @property 441 | def aggregate_credits(self) -> int: 442 | return self.wrapped.aggregate_credits 443 | 444 | @property 445 | def next_grant_ts(self) -> int: 446 | return self.wrapped.next_grant_ts 447 | 448 | @property 449 | def breakdown(self) -> Breakdown: 450 | return self.wrapped.breakdown 451 | 452 | 453 | class WrappedLogin(WrappedResponse): 454 | wrapped: Login 455 | 456 | def __init__(self, login: Login, dalle: 'Dalle'): 457 | super().__init__(login, dalle) 458 | 459 | @property 460 | def user(self) -> User: 461 | return self.wrapped.user 462 | 463 | @property 464 | def invites(self) -> List[Any]: 465 | return self.wrapped.invites 466 | 467 | @property 468 | def features(self) -> Features: 469 | return self.wrapped.features 470 | -------------------------------------------------------------------------------- /pydalle/imperative/api/labs.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the implementations of API calls to the labs API. 3 | """ 4 | 5 | from typing import Optional, Dict, List 6 | 7 | from pydalle.functional.api.response.labs import TaskList, Task, Generation, Collection, Login, UserFlag, BillingInfo 8 | from pydalle.functional.assumptions import OPENAI_AUTH0_DOMAIN, OPENAI_AUTH0_CLIENT_ID, \ 9 | OPENAI_AUTH0_AUDIENCE, OPENAI_LABS_REDIRECT_URI, OPENAI_AUTH0_SCOPE 10 | from pydalle.functional.api.flow.labs import get_bearer_token_flow, get_tasks_flow, get_task_flow, \ 11 | create_text2im_task_flow, poll_for_task_completion_flow, create_variations_task_flow, \ 12 | create_inpainting_task_flow, download_generation_flow, share_generation_flow, save_generations_flow, \ 13 | get_login_info_flow, flag_generation_flow, get_credit_summary_flow, get_generation_flow 14 | from pydalle.imperative.api.auth0 import get_access_token_from_credentials, get_access_token_from_credentials_async 15 | from pydalle.imperative.outside.internet import session_flow, session_flow_async 16 | 17 | _LABS_AUTH0_PARAMS = { 18 | "domain": OPENAI_AUTH0_DOMAIN, 19 | "client_id": OPENAI_AUTH0_CLIENT_ID, 20 | "audience": OPENAI_AUTH0_AUDIENCE, 21 | "redirect_uri": OPENAI_LABS_REDIRECT_URI, 22 | "scope": OPENAI_AUTH0_SCOPE, 23 | } 24 | 25 | 26 | def get_access_token(username: str, password: str, headers: Optional[Dict[str, str]] = None) -> str: 27 | """ 28 | Get an access token from the given credentials. 29 | 30 | :param username: The username or email address associated with the OpenAI account. 31 | :param password: The password associated with the OpenAI account. 32 | :param headers: Optional headers to send with the request. 33 | :return: An access token, needed for retrieving a labs bearer token. 34 | """ 35 | return get_access_token_from_credentials(username, password, **_LABS_AUTH0_PARAMS, headers=headers) 36 | 37 | 38 | async def get_access_token_async(username: str, password: str, 39 | headers: Optional[Dict[str, str]] = None) -> str: 40 | return await get_access_token_from_credentials_async(username, password, **_LABS_AUTH0_PARAMS, headers=headers) 41 | 42 | 43 | def get_bearer_token(username: str, password: str, headers: Optional[Dict[str, str]] = None) -> str: 44 | """ 45 | Get an access token from the given credentials. 46 | 47 | :param username: The username or email address associated with the OpenAI account. 48 | :param password: The password associated with the OpenAI account. 49 | :param headers: Optional headers to send with the request. 50 | :return: A bearer token, needed for most API calls. 51 | """ 52 | access_token = get_access_token_from_credentials(username, password, **_LABS_AUTH0_PARAMS, headers=headers) 53 | return session_flow(get_bearer_token_flow, headers, access_token=access_token) 54 | 55 | 56 | async def get_bearer_token_async(username: str, password: str, headers: Optional[Dict[str, str]] = None) -> str: 57 | access_token = ( 58 | await get_access_token_from_credentials_async(username, password, **_LABS_AUTH0_PARAMS, headers=headers)) 59 | return await session_flow_async(get_bearer_token_flow, headers, access_token=access_token) 60 | 61 | 62 | def get_login_info(access_token: str, headers: Optional[Dict[str, str]] = None) -> Login: 63 | """ 64 | Get the login information for the account authenticated by the given access token. 65 | 66 | :param access_token: The access token to use. 67 | :param headers: Optional headers to send with the request. 68 | :return: The login information for the account. 69 | """ 70 | return session_flow(get_login_info_flow, headers, access_token=access_token) 71 | 72 | 73 | async def get_login_info_async(access_token: str, headers: Optional[Dict[str, str]] = None) -> Login: 74 | return await session_flow_async(get_login_info_flow, headers, access_token=access_token) 75 | 76 | 77 | def get_bearer_token_from_access_token(access_token: str, headers: Optional[Dict[str, str]] = None) -> str: 78 | """ 79 | Get a bearer token from the given access token. 80 | 81 | :param access_token: The access token to use. 82 | :param headers: Optional headers to send with the request. 83 | :return: A bearer token, needed for most API calls. 84 | """ 85 | return session_flow(get_bearer_token_flow, headers, access_token=access_token) 86 | 87 | 88 | async def get_bearer_token_from_access_token_async(access_token: str, headers: Optional[Dict[str, str]] = None) -> str: 89 | return await session_flow_async(get_bearer_token_flow, headers, access_token=access_token) 90 | 91 | 92 | def get_tasks(bearer_token: str, limit: Optional[int] = None, from_ts: Optional[int] = None, 93 | headers: Optional[Dict[str, str]] = None) -> TaskList: 94 | """ 95 | Get the list of tasks for the account authenticated by the given bearer token. 96 | 97 | :param bearer_token: The bearer token to use. 98 | :param from_ts: Optional unix timestamp to exclude tasks created before this time. 99 | :param limit: Optional limit on the number of tasks to return. Server-side and maximum default is 50. 100 | :param headers: Optional headers to send with the request. 101 | :return: The list of tasks for the account. 102 | """ 103 | return session_flow(get_tasks_flow, headers, limit=limit, from_ts=from_ts, bearer_token=bearer_token) 104 | 105 | 106 | async def get_tasks_async(bearer_token: str, from_ts: Optional[int] = None, 107 | limit: Optional[int] = None, 108 | headers: Optional[Dict[str, str]] = None) -> TaskList: 109 | return await session_flow_async(get_tasks_flow, headers, limit=limit, from_ts=from_ts, bearer_token=bearer_token) 110 | 111 | 112 | def get_task(bearer_token: str, task_id: str, headers: Optional[Dict[str, str]] = None) -> Task: 113 | """ 114 | Get the task with the given ID for the account authenticated by the given bearer token. 115 | 116 | :param bearer_token: The bearer token to use. 117 | :param task_id: The ID of the task to get. 118 | :param headers: Optional headers to send with the request. 119 | :return: The task with the given ID. 120 | """ 121 | return session_flow(get_task_flow, headers, task_id=task_id, bearer_token=bearer_token) 122 | 123 | 124 | async def get_task_async(bearer_token: str, task_id: str, headers: Optional[Dict[str, str]] = None) -> Task: 125 | return await session_flow_async(get_task_flow, headers, task_id=task_id, bearer_token=bearer_token) 126 | 127 | 128 | def create_text2im_task(bearer_token: str, caption: str, batch_size: int = 4, 129 | headers: Optional[Dict[str, str]] = None) -> Task: 130 | """ 131 | Create a "text-to-image" task for a given caption. 132 | 133 | :param bearer_token: The bearer token to use. 134 | :param caption: The text to generate images for. 135 | :param batch_size: The number of images to generate per request. 136 | :param headers: Optional headers to send with the request. 137 | :return: The created task, which will either be pending or rejected. 138 | """ 139 | return session_flow(create_text2im_task_flow, headers, caption=caption, batch_size=batch_size, 140 | bearer_token=bearer_token) 141 | 142 | 143 | async def create_text2im_task_async(bearer_token: str, caption: str, batch_size: int = 4, 144 | headers: Optional[Dict[str, str]] = None) -> Task: 145 | return await session_flow_async(create_text2im_task_flow, headers, caption=caption, batch_size=batch_size, 146 | bearer_token=bearer_token) 147 | 148 | 149 | def create_variations_task(bearer_token: str, parent_id_or_image: str, batch_size: int = 3, 150 | headers: Optional[Dict[str, str]] = None) -> Task: 151 | """ 152 | Create a "variations" task for a given image. 153 | 154 | :param bearer_token: The bearer token to use. 155 | :param parent_id_or_image: The ID of the parent (generation ID or prompt ID) or a base64-encoded PNG 156 | :param batch_size: The number of variations to generate per request. 157 | :param headers: Optional headers to send with the request. 158 | :return: The created task, which will either be pending or rejected. 159 | """ 160 | return session_flow(create_variations_task_flow, headers, parent_id_or_image=parent_id_or_image, 161 | batch_size=batch_size, bearer_token=bearer_token) 162 | 163 | 164 | async def create_variations_task_async(bearer_token: str, parent_id_or_image: str, 165 | batch_size: int = 3, headers: Optional[Dict[str, str]] = None) -> Task: 166 | return await session_flow_async(create_variations_task_flow, headers, 167 | parent_id_or_image=parent_id_or_image, 168 | batch_size=batch_size, bearer_token=bearer_token) 169 | 170 | 171 | def create_inpainting_task(bearer_token: str, caption: str, masked_image: str, parent_id_or_image: Optional[str] = None, 172 | batch_size: int = 3, headers: Optional[Dict[str, str]] = None) -> Task: 173 | """ 174 | Create an "inpainting" task for a given caption and masked image. 175 | 176 | :param bearer_token: The bearer token to use. 177 | :param caption: The text to generate images for. 178 | :param masked_image: The base64-encoded PNG to mask. 179 | :param parent_id_or_image: The ID of the parent (generation ID or prompt ID) or a base64-encoded PNG 180 | :param batch_size: The number of images to generate per request. 181 | :param headers: Optional headers to send with the request. 182 | """ 183 | return session_flow(create_inpainting_task_flow, headers, caption=caption, parent_id_or_image=parent_id_or_image, 184 | masked_image=masked_image, batch_size=batch_size, bearer_token=bearer_token) 185 | 186 | 187 | async def create_inpainting_task_async(bearer_token: str, caption: str, masked_image: str, 188 | parent_id_or_image: Optional[str] = None, batch_size: int = 3, 189 | headers: Optional[Dict[str, str]] = None) -> Task: 190 | return await session_flow_async(create_inpainting_task_flow, headers, caption=caption, 191 | parent_id_or_image=parent_id_or_image, masked_image=masked_image, 192 | batch_size=batch_size, bearer_token=bearer_token) 193 | 194 | 195 | def poll_for_task_completion(bearer_token: str, task_id: str, interval: float = 1.0, 196 | max_attempts: int = 1000, headers: Optional[Dict[str, str]] = None) -> Task: 197 | """ 198 | Poll for the completion of a task. 199 | 200 | :param bearer_token: The bearer token to use. 201 | :param task_id: The ID of the task to poll. 202 | :param interval: The interval to wait between requests. 203 | :param max_attempts: The maximum number of times to poll before giving up. 204 | :param headers: Optional headers to send with the request. 205 | :return: The task with the given ID. 206 | """ 207 | return session_flow(poll_for_task_completion_flow, headers, task_id=task_id, bearer_token=bearer_token, 208 | interval=interval, _max_attempts=max_attempts) 209 | 210 | 211 | async def poll_for_task_completion_async(bearer_token: str, task_id: str, interval: float = 1.0, 212 | max_attempts: int = 1000, headers: Optional[Dict[str, str]] = None) -> Task: 213 | return await session_flow_async(poll_for_task_completion_flow, headers, task_id=task_id, bearer_token=bearer_token, 214 | interval=interval, _max_attempts=max_attempts) 215 | 216 | 217 | def download_generation(bearer_token: str, generation_id: str, headers: Optional[Dict[str, str]] = None) -> bytes: 218 | """ 219 | Download a generated image by its ID. 220 | 221 | :param bearer_token: The bearer token to use. 222 | :param generation_id: The ID of the generation to download. 223 | :param headers: Optional headers to send with the request. 224 | :return: The bytes of the image. 225 | """ 226 | return session_flow(download_generation_flow, headers, generation_id=generation_id, bearer_token=bearer_token) 227 | 228 | 229 | async def download_generation_async(bearer_token: str, generation_id: str, 230 | headers: Optional[Dict[str, str]] = None) -> bytes: 231 | return await session_flow_async(download_generation_flow, headers, generation_id=generation_id, 232 | bearer_token=bearer_token) 233 | 234 | 235 | def share_generation(bearer_token: str, generation_id: str, headers: Optional[Dict[str, str]] = None) -> Generation: 236 | """ 237 | Share a generated image by its ID. This makes the image public, making the share_url available for access. 238 | 239 | :param bearer_token: The bearer token to use. 240 | :param generation_id: The ID of the generation to share. 241 | :param headers: Optional headers to send with the request. 242 | :return: The shared generation. 243 | """ 244 | return session_flow(share_generation_flow, headers, generation_id=generation_id, bearer_token=bearer_token) 245 | 246 | 247 | async def share_generation_async(bearer_token: str, generation_id: str, 248 | headers: Optional[Dict[str, str]] = None) -> Generation: 249 | return await session_flow_async(share_generation_flow, headers, generation_id=generation_id, 250 | bearer_token=bearer_token) 251 | 252 | 253 | def save_generations(bearer_token: str, generation_ids: List[str], collection_id_or_alias="private", 254 | headers: Optional[Dict[str, str]] = None) -> Collection: 255 | """ 256 | Save a list of generations by their IDs to a collection. 257 | 258 | :param bearer_token: The bearer token to use. 259 | :param generation_ids: The IDs of the generations to save. 260 | :param collection_id_or_alias: The ID of the collection to save to. Defaults to your private collection. 261 | :param headers: Optional headers to send with the request. 262 | :return: The collection with the given ID. 263 | """ 264 | return session_flow(save_generations_flow, headers, collection_id_or_alias=collection_id_or_alias, 265 | generation_ids=generation_ids, bearer_token=bearer_token) 266 | 267 | 268 | async def save_generations_async(bearer_token: str, generation_ids: List[str], collection_id_or_alias="private", 269 | headers: Optional[Dict[str, str]] = None) -> Collection: 270 | return await session_flow_async(save_generations_flow, headers, collection_id_or_alias=collection_id_or_alias, 271 | generation_ids=generation_ids, 272 | bearer_token=bearer_token) 273 | 274 | 275 | def _flag_generation(bearer_token: str, generation_id: str, description: str, 276 | headers: Optional[Dict[str, str]] = None) -> UserFlag: 277 | return session_flow(flag_generation_flow, headers, generation_id=generation_id, reason=description, 278 | bearer_token=bearer_token) 279 | 280 | 281 | async def _flag_generation_async(bearer_token: str, generation_id: str, description: str, 282 | headers: Optional[Dict[str, str]] = None) -> UserFlag: 283 | return await session_flow_async(flag_generation_flow, headers, generation_id=generation_id, reason=description, 284 | bearer_token=bearer_token) 285 | 286 | 287 | def flag_generation_sensitive(bearer_token: str, generation_id: str, 288 | headers: Optional[Dict[str, str]] = None) -> UserFlag: 289 | """ 290 | Flag a generation as sensitive. 291 | 292 | :param bearer_token: The bearer token to use. 293 | :param generation_id: The ID of the generation to flag. 294 | :param headers: Optional headers to send with the request. 295 | :return: The user flag. 296 | """ 297 | return _flag_generation(bearer_token, generation_id, "Sensitive", headers) 298 | 299 | 300 | async def flag_generation_sensitive_async(bearer_token: str, generation_id: str, 301 | headers: Optional[Dict[str, str]] = None) -> UserFlag: 302 | return await _flag_generation_async(bearer_token, generation_id, "Sensitive", headers) 303 | 304 | 305 | def flag_generation_unexpected(bearer_token: str, generation_id: str, 306 | headers: Optional[Dict[str, str]] = None) -> UserFlag: 307 | """ 308 | Flag a generation as unexpected. 309 | 310 | :param bearer_token: The bearer token to use. 311 | :param generation_id: The ID of the generation to flag. 312 | :param headers: Optional headers to send with the request. 313 | :return: The user flag. 314 | """ 315 | return _flag_generation(bearer_token, generation_id, "Unexpected", headers) 316 | 317 | 318 | async def flag_generation_unexpected_async(bearer_token: str, generation_id: str, 319 | headers: Optional[Dict[str, str]] = None) -> UserFlag: 320 | return await _flag_generation_async(bearer_token, generation_id, "Unexpected", headers) 321 | 322 | 323 | def get_credit_summary(bearer_token: str, headers: Optional[Dict[str, str]] = None) -> BillingInfo: 324 | """ 325 | Get the credit summary for the user. 326 | 327 | :param bearer_token: The bearer token to use. 328 | :param headers: Optional headers to send with the request. 329 | :return: The billing info. 330 | """ 331 | return session_flow(get_credit_summary_flow, headers, bearer_token=bearer_token) 332 | 333 | 334 | async def get_credit_summary_async(bearer_token: str, headers: Optional[Dict[str, str]] = None) -> BillingInfo: 335 | return await session_flow_async(get_credit_summary_flow, headers, bearer_token=bearer_token) 336 | 337 | 338 | def get_generation(bearer_token: str, generation_id: str, headers: Optional[Dict[str, str]] = None) -> Generation: 339 | """ 340 | Get a generation by its ID. 341 | 342 | :param bearer_token: The bearer token to use. 343 | :param generation_id: The ID of the generation to get. 344 | :param headers: Optional headers to send with the request. 345 | :return: The generation. 346 | """ 347 | return session_flow(get_generation_flow, headers, generation_id=generation_id, bearer_token=bearer_token) 348 | 349 | 350 | async def get_generation_async(bearer_token: str, generation_id: str, 351 | headers: Optional[Dict[str, str]] = None) -> Generation: 352 | return await session_flow_async(get_generation_flow, headers, generation_id=generation_id, 353 | bearer_token=bearer_token) 354 | 355 | 356 | for name, func in list(globals().items()): 357 | if f"{name}_async" in locals(): 358 | if locals()[f"{name}_async"].__doc__ is None: 359 | locals()[f"{name}_async"].__doc__ = f"Async version of :func:`{name}`" 360 | -------------------------------------------------------------------------------- /pydalle/imperative/client/dalle.py: -------------------------------------------------------------------------------- 1 | """ 2 | A user-friendly interface for the low-level functional API of pydalle. 3 | """ 4 | 5 | from typing import Optional, Union, Iterable 6 | 7 | from pydalle.functional.api.response.labs import Generation, Task 8 | from pydalle.functional.types import HttpRequest 9 | from pydalle.imperative.api import labs 10 | from pydalle.imperative.outside.internet import request, request_async 11 | from pydalle.imperative.client.responses import WrappedLogin, WrappedBillingInfo, WrappedUserFlag, WrappedCollection, \ 12 | WrappedGeneration, WrappedImage, WrappedTask, WrappedTaskList, GenerationLike, get_generation_id, TaskLike, \ 13 | get_task_id, ParentLike, get_parent_id_or_png_base64, get_parent_id_or_png_base64_async, ImageLike, \ 14 | get_image_png_base64, get_image_png_base64_async 15 | from pydalle.imperative.client.utils import requires_authentication, requires_authentication_async 16 | 17 | 18 | class Dalle: 19 | """ 20 | A user-friendly interface for the low-level functional API of pydalle. 21 | """ 22 | 23 | def __init__(self, username: str, password: str, /, headers: Optional[dict] = None): 24 | """ 25 | Creates a new Dalle instance. 26 | 27 | :param username: The username to use when logging in. 28 | :param password: The password to use when logging in. 29 | :param headers: Optional headers to use when making requests. 30 | """ 31 | if not username: 32 | raise ValueError("username must not be empty") 33 | if not password: 34 | raise ValueError("password must not be empty") 35 | 36 | self.__username = username 37 | self.__password = password 38 | self.__access_token = None 39 | self.__bearer_token = None 40 | 41 | self.headers = headers 42 | self.has_authenticated = False 43 | 44 | def refresh_tokens(self) -> None: 45 | """ 46 | Refreshes the access token and bearer token. 47 | """ 48 | self.__access_token = labs.get_access_token(username=self.__username, password=self.__password, 49 | headers=self.headers) 50 | self.__bearer_token = labs.get_bearer_token_from_access_token(access_token=self.__access_token, 51 | headers=self.headers) 52 | self.has_authenticated = True 53 | 54 | async def refresh_tokens_async(self) -> None: 55 | """ 56 | Asynchronously refreshes the access token and bearer token. 57 | """ 58 | self.__access_token = await labs.get_access_token_async(username=self.__username, password=self.__password, 59 | headers=self.headers) 60 | self.__bearer_token = await labs.get_bearer_token_from_access_token_async(access_token=self.__access_token, 61 | headers=self.headers) 62 | self.has_authenticated = True 63 | 64 | @requires_authentication 65 | def get_tasks(self, limit: Optional[int] = None, from_ts: Optional[int] = None) -> WrappedTaskList: 66 | """ 67 | Gets a list of tasks. 68 | 69 | :param limit: The maximum number of tasks to return. 70 | :param from_ts: The timestamp to start from. 71 | :return: A list of tasks. 72 | """ 73 | return WrappedTaskList( 74 | labs.get_tasks(bearer_token=self.__bearer_token, from_ts=from_ts, headers=self.headers, limit=limit), 75 | self) 76 | 77 | @requires_authentication_async 78 | async def get_tasks_async(self, limit: Optional[int] = None, from_ts: Optional[int] = None) -> WrappedTaskList: 79 | """ 80 | Asynchronously a list of tasks. 81 | 82 | :param limit: The maximum number of tasks to return. 83 | :param from_ts: The timestamp to start from. 84 | :return: A list of tasks. 85 | """ 86 | return WrappedTaskList( 87 | await labs.get_tasks_async(bearer_token=self.__bearer_token, from_ts=from_ts, headers=self.headers, 88 | limit=limit), self) 89 | 90 | @requires_authentication 91 | def get_task(self, task: TaskLike) -> WrappedTask: 92 | """ 93 | Gets a task. 94 | 95 | :param task: The task to get (either a task ID or a task object). 96 | :return: The task. 97 | """ 98 | return WrappedTask( 99 | labs.get_task(bearer_token=self.__bearer_token, task_id=get_task_id(task), headers=self.headers), self) 100 | 101 | @requires_authentication_async 102 | async def get_task_async(self, task: TaskLike) -> WrappedTask: 103 | """ 104 | Asynchronously gets a task. 105 | 106 | :param task: The task to get (either a task ID or a task object). 107 | :return: The task. 108 | """ 109 | return WrappedTask( 110 | await labs.get_task_async(bearer_token=self.__bearer_token, task_id=get_task_id(task), 111 | headers=self.headers), 112 | self) 113 | 114 | @requires_authentication 115 | def get_generation(self, generation: GenerationLike) -> WrappedGeneration: 116 | """ 117 | Gets a generation. 118 | 119 | :param generation: The generation to get (either a generation ID or a generation object). 120 | :return: The generation. 121 | """ 122 | return WrappedGeneration( 123 | labs.get_generation(bearer_token=self.__bearer_token, generation_id=get_generation_id(generation), 124 | headers=self.headers), self) 125 | 126 | @requires_authentication_async 127 | async def get_generation_async(self, generation: GenerationLike) -> WrappedGeneration: 128 | """ 129 | Asynchronously gets a generation. 130 | 131 | :param generation: The generation to get (either a generation ID or a generation object). 132 | :return: The generation. 133 | """ 134 | return WrappedGeneration( 135 | await labs.get_generation_async(bearer_token=self.__bearer_token, 136 | generation_id=get_generation_id(generation), 137 | headers=self.headers), self) 138 | 139 | @requires_authentication 140 | def create_text2im_task(self, caption: str, batch_size: int = 4) -> WrappedTask: 141 | """ 142 | Creates a text2im task. 143 | 144 | :param caption: The caption to use. 145 | :param batch_size: The batch size to use. 146 | :return: The task. 147 | """ 148 | return WrappedTask( 149 | labs.create_text2im_task(bearer_token=self.__bearer_token, caption=caption, batch_size=batch_size, 150 | headers=self.headers), self) 151 | 152 | @requires_authentication_async 153 | async def create_text2im_task_async(self, caption: str, batch_size: int = 4) -> WrappedTask: 154 | """ 155 | Asynchronously creates a text2im task. 156 | 157 | :param caption: The caption to use. 158 | :param batch_size: The batch size to use. 159 | :return: The task. 160 | """ 161 | return WrappedTask( 162 | await labs.create_text2im_task_async(bearer_token=self.__bearer_token, caption=caption, 163 | batch_size=batch_size, 164 | headers=self.headers), self) 165 | 166 | @requires_authentication 167 | def text2im(self, caption: str, batch_size: int = 4, wait: bool = True) -> WrappedTask: 168 | """ 169 | Convenience function to create and wait a text2im task. 170 | 171 | :param caption: The caption to use. 172 | :param batch_size: The batch size to use. 173 | :param wait: Whether to wait for the task to finish, default is True. 174 | :return: The task. 175 | """ 176 | task = self.create_text2im_task(caption=caption, batch_size=batch_size) 177 | if wait: 178 | return task.wait() 179 | return task 180 | 181 | @requires_authentication_async 182 | async def text2im_async(self, caption: str, batch_size: int = 4, wait: bool = True) -> WrappedTask: 183 | """ 184 | Asynchronously creates and waits for a text2im task. 185 | 186 | :param caption: The caption to use. 187 | :param batch_size: The batch size to use. 188 | :param wait: Whether to wait for the task to finish, default is True. 189 | :return: The task. 190 | """ 191 | task = await self.create_text2im_task_async(caption=caption, batch_size=batch_size) 192 | if wait: 193 | return await task.wait_async() 194 | return task 195 | 196 | @requires_authentication 197 | def create_variations_task(self, parent: ParentLike, batch_size: int = 3) -> WrappedTask: 198 | """ 199 | Creates a variations task. 200 | 201 | :param parent: The parent to use. (Either a prompt, a generation, or an image). 202 | :param batch_size: The batch size to use. 203 | :return: The task. 204 | """ 205 | return WrappedTask( 206 | labs.create_variations_task(bearer_token=self.__bearer_token, 207 | parent_id_or_image=get_parent_id_or_png_base64(parent, self.headers), 208 | batch_size=batch_size, headers=self.headers), self) 209 | 210 | @requires_authentication_async 211 | async def create_variations_task_async(self, parent: ParentLike, batch_size: int = 3) -> WrappedTask: 212 | """ 213 | Asynchronously creates a variations task. 214 | 215 | :param parent: The parent to use. (Either a prompt, a generation, or an image). 216 | :param batch_size: The batch size to use. 217 | :return: The task. 218 | """ 219 | return WrappedTask(await labs.create_variations_task_async( 220 | bearer_token=self.__bearer_token, 221 | parent_id_or_image=await get_parent_id_or_png_base64_async(parent, self.headers), 222 | batch_size=batch_size, headers=self.headers), self) 223 | 224 | @requires_authentication 225 | def variations(self, parent: ParentLike, batch_size: int = 3, wait: bool = True) -> WrappedTask: 226 | """ 227 | Convenience function to create and wait a variations task. 228 | 229 | :param parent: The parent to use. (Either a prompt, a generation, or an image). 230 | :param batch_size: The batch size to use. 231 | :param wait: Whether to wait for the task to finish, default is True. 232 | :return: The task. 233 | """ 234 | task = self.create_variations_task(parent=parent, batch_size=batch_size) 235 | if wait: 236 | return task.wait() 237 | return task 238 | 239 | @requires_authentication_async 240 | async def variations_async(self, parent: ParentLike, batch_size: int = 3, wait: bool = True) -> WrappedTask: 241 | """ 242 | Asynchronously creates and waits for a variations task. 243 | 244 | :param parent: The parent to use. (Either a prompt, a generation, or an image). 245 | :param batch_size: The batch size to use. 246 | :param wait: Whether to wait for the task to finish, default is True. 247 | :return: The task. 248 | """ 249 | task = await self.create_variations_task_async(parent=parent, batch_size=batch_size) 250 | if wait: 251 | return await task.wait_async() 252 | return task 253 | 254 | @requires_authentication 255 | def create_inpainting_task(self, caption: str, masked_image: ImageLike, parent: Optional[ParentLike] = None, 256 | batch_size: int = 3) -> WrappedTask: 257 | """ 258 | Creates an inpainting task. 259 | 260 | :param caption: The caption to use. 261 | :param masked_image: The masked image to use. 262 | :param parent: The parent to use. (Either a prompt, a generation, or an image). 263 | :param batch_size: The batch size to use. 264 | :return: The task. 265 | """ 266 | return WrappedTask( 267 | labs.create_inpainting_task( 268 | bearer_token=self.__bearer_token, caption=caption, 269 | masked_image=get_image_png_base64(masked_image, headers=self.headers), 270 | parent_id_or_image=get_parent_id_or_png_base64(parent, self.headers) if parent else None, 271 | batch_size=batch_size, 272 | headers=self.headers), self) 273 | 274 | @requires_authentication_async 275 | async def create_inpainting_task_async(self, caption: str, 276 | masked_image: ImageLike, 277 | parent: Optional[ParentLike] = None, 278 | batch_size: int = 3) -> WrappedTask: 279 | """ 280 | Asynchronously creates an inpainting task. 281 | 282 | :param caption: The caption to use. 283 | :param masked_image: The masked image to use. 284 | :param parent: The parent to use. (Either a prompt, a generation, or an image). 285 | :param batch_size: The batch size to use. 286 | :return: The task. 287 | """ 288 | return WrappedTask(await labs.create_inpainting_task_async( 289 | bearer_token=self.__bearer_token, caption=caption, 290 | masked_image=await get_image_png_base64_async(masked_image, self.headers), 291 | parent_id_or_image=(await get_parent_id_or_png_base64_async(parent, self.headers)) if parent else None, 292 | batch_size=batch_size, headers=self.headers), self) 293 | 294 | @requires_authentication 295 | def inpainting(self, caption: str, masked_image: ImageLike, parent: Optional[ParentLike] = None, 296 | batch_size: int = 3, wait: bool = True) -> WrappedTask: 297 | """ 298 | Convenience function to create and wait an inpainting task. 299 | 300 | :param caption: The caption to use. 301 | :param masked_image: The masked image to use. 302 | :param parent: The parent to use. (Either a prompt, a generation, or an image). 303 | :param batch_size: The batch size to use. 304 | :param wait: Whether to wait for the task to finish, default is True. 305 | :return: The task. 306 | """ 307 | task = self.create_inpainting_task(caption=caption, masked_image=masked_image, parent=parent, 308 | batch_size=batch_size) 309 | if wait: 310 | return task.wait() 311 | return task 312 | 313 | @requires_authentication_async 314 | async def inpainting_async(self, caption: str, masked_image: ImageLike, parent: Optional[ParentLike] = None, 315 | batch_size: int = 3, wait: bool = True) -> WrappedTask: 316 | """ 317 | Asynchronously creates and waits for an inpainting task. 318 | 319 | :param caption: The caption to use. 320 | :param masked_image: The masked image to use. 321 | :param parent: The parent to use. (Either a prompt, a generation, or an image). 322 | :param batch_size: The batch size to use. 323 | :param wait: Whether to wait for the task to finish, default is True. 324 | :return: The task. 325 | """ 326 | task = await self.create_inpainting_task_async(caption=caption, masked_image=masked_image, parent=parent, 327 | batch_size=batch_size) 328 | if wait: 329 | return await task.wait_async() 330 | return task 331 | 332 | @requires_authentication 333 | def poll_for_task_completion(self, task: TaskLike, interval: float = 1.0, max_attempts: int = 1000) -> WrappedTask: 334 | """ 335 | Polls for the completion of a task. 336 | 337 | :param task: The task to poll. 338 | :param interval: The interval to use (in seconds). 339 | :param max_attempts: The maximum number of attempts. 340 | """ 341 | if isinstance(task, (WrappedTask, Task)): 342 | if task.status != "pending": 343 | return task 344 | return WrappedTask( 345 | labs.poll_for_task_completion(bearer_token=self.__bearer_token, task_id=get_task_id(task), 346 | interval=interval, 347 | max_attempts=max_attempts, headers=self.headers), self) 348 | 349 | @requires_authentication_async 350 | async def poll_for_task_completion_async(self, task: TaskLike, interval: float = 1.0, 351 | max_attempts: int = 1000) -> WrappedTask: 352 | """ 353 | Asynchronously polls for the completion of a task. 354 | 355 | :param task: The task to poll. 356 | :param interval: The interval to use (in seconds). 357 | :param max_attempts: The maximum number of attempts. 358 | """ 359 | if isinstance(task, (WrappedTask, Task)): 360 | if task.status != "pending": 361 | return task 362 | return WrappedTask( 363 | await labs.poll_for_task_completion_async(bearer_token=self.__bearer_token, task_id=get_task_id(task), 364 | interval=interval, 365 | max_attempts=max_attempts, headers=self.headers), self) 366 | 367 | @requires_authentication 368 | def download_generation(self, generation: GenerationLike, direct: bool = False) -> WrappedImage: 369 | """ 370 | Downloads a generation. 371 | 372 | :param generation: The generation to download. 373 | :param direct: Whether to download the generation using the direct download URL, which does not add a watermark. 374 | You should only use this if you intend to add the watermark to the image yourself. Unwatermarked images 375 | should not be shared publicly. 376 | :return: The image. 377 | """ 378 | if direct: 379 | return self.download_generation_direct(generation) 380 | return WrappedImage(labs.download_generation(bearer_token=self.__bearer_token, 381 | generation_id=get_generation_id(generation), 382 | headers=self.headers), self) 383 | 384 | @requires_authentication_async 385 | async def download_generation_async(self, generation: GenerationLike, direct: bool = False) -> WrappedImage: 386 | """ 387 | Asynchronously downloads a generation. 388 | 389 | :param generation: The generation to download. 390 | :param direct: Whether to download the generation using the direct download URL, which does not add a watermark. 391 | You should only use this if you intend to add the watermark to the image yourself. Unwatermarked images 392 | should not be shared publicly. 393 | :return: The image. 394 | """ 395 | if direct: 396 | return await self.download_generation_direct_async(generation) 397 | return WrappedImage(await labs.download_generation_async(bearer_token=self.__bearer_token, 398 | generation_id=get_generation_id(generation), 399 | headers=self.headers), self) 400 | 401 | @requires_authentication 402 | def download_generation_direct(self, generation: GenerationLike) -> WrappedImage: 403 | """ 404 | Downloads a generation using the direct download URL, which does not add a watermark. 405 | You should only use this if you intend to add the watermark to the image yourself. Unwatermarked images 406 | should not be shared publicly. 407 | 408 | :param generation: The generation to download. 409 | :return: The image. 410 | """ 411 | if isinstance(generation, (WrappedGeneration, Generation)): 412 | image_path = generation.generation.image_path 413 | else: 414 | image_path = labs.get_generation(bearer_token=self.__bearer_token, 415 | generation_id=get_generation_id(generation), 416 | headers=self.headers).generation.image_path 417 | return WrappedImage( 418 | request(HttpRequest(method="get", url=image_path, headers=self.headers, decode=False)).content, self, 419 | filetype="webp") 420 | 421 | @requires_authentication_async 422 | async def download_generation_direct_async(self, generation: GenerationLike) -> WrappedImage: 423 | """ 424 | Asynchronously downloads a generation using the direct download URL, which does not add a watermark. 425 | 426 | :param generation: The generation to download. 427 | :return: The image. 428 | """ 429 | if isinstance(generation, (WrappedGeneration, Generation)): 430 | image_path = generation.generation.image_path 431 | else: 432 | image_path = (await labs.get_generation_async(bearer_token=self.__bearer_token, 433 | generation_id=get_generation_id(generation), 434 | headers=self.headers)).generation.image_path 435 | return WrappedImage( 436 | (await request_async( 437 | HttpRequest(method="get", url=image_path, headers=self.headers, decode=False))).content, self, 438 | filetype="webp") 439 | 440 | @requires_authentication 441 | def share_generation(self, generation: GenerationLike) -> WrappedGeneration: 442 | """ 443 | Shares a generation (i.e. people can then download the generation from the share URL). 444 | See DALL·E 2's `content policy `_ to see what is OK to share. 445 | 446 | :param generation: The generation to share. 447 | :return: The shared generation. 448 | """ 449 | return WrappedGeneration( 450 | labs.share_generation(bearer_token=self.__bearer_token, generation_id=get_generation_id(generation), 451 | headers=self.headers), self) 452 | 453 | @requires_authentication_async 454 | async def share_generation_async(self, generation: GenerationLike) -> WrappedGeneration: 455 | """ 456 | Asynchronously shares a generation (i.e. people can then download the generation from the share URL). 457 | See DALL·E 2's `content policy `_ to see what is OK to share. 458 | 459 | :param generation: The generation to share. 460 | :return: The shared generation. 461 | """ 462 | return WrappedGeneration( 463 | await labs.share_generation_async(bearer_token=self.__bearer_token, 464 | generation_id=get_generation_id(generation), 465 | headers=self.headers), self) 466 | 467 | @requires_authentication 468 | def save_generations(self, generations: Union[Iterable[GenerationLike], GenerationLike]) -> WrappedCollection: 469 | """ 470 | Saves one or more generations to your personal collection. 471 | 472 | :param generations: The generation(s) to save. 473 | :return: The collection saved to. 474 | """ 475 | try: 476 | generation_ids = [get_generation_id(generations)] 477 | except ValueError: 478 | generation_ids = [get_generation_id(generation) for generation in generations] 479 | return WrappedCollection(labs.save_generations(bearer_token=self.__bearer_token, generation_ids=generation_ids, 480 | headers=self.headers), self) 481 | 482 | @requires_authentication_async 483 | async def save_generations_async(self, 484 | generations: Union[Iterable[GenerationLike], GenerationLike]) -> WrappedCollection: 485 | """ 486 | Asynchronously saves one or more generations to your personal collection. 487 | 488 | :param generations: The generation(s) to save. 489 | :return: The collection saved to. 490 | """ 491 | try: 492 | generation_ids = [get_generation_id(generations)] 493 | except ValueError: 494 | generation_ids = [get_generation_id(generation) for generation in generations] 495 | return WrappedCollection( 496 | await labs.save_generations_async(bearer_token=self.__bearer_token, generation_ids=generation_ids, 497 | headers=self.headers), self) 498 | 499 | @requires_authentication 500 | def flag_generation_sensitive(self, generation: GenerationLike) -> WrappedUserFlag: 501 | """ 502 | Flags a generation as sensitive. 503 | 504 | :param generation: The generation to flag. 505 | :return: The user flag. 506 | """ 507 | return WrappedUserFlag( 508 | labs.flag_generation_sensitive(bearer_token=self.__bearer_token, 509 | generation_id=get_generation_id(generation), 510 | headers=self.headers), self) 511 | 512 | @requires_authentication_async 513 | async def flag_generation_sensitive_async(self, generation: GenerationLike) -> WrappedUserFlag: 514 | """ 515 | Asynchronously flags a generation as sensitive. 516 | 517 | :param generation: The generation to flag. 518 | :return: The user flag. 519 | """ 520 | return WrappedUserFlag( 521 | await labs.flag_generation_sensitive_async(bearer_token=self.__bearer_token, 522 | generation_id=get_generation_id(generation), 523 | headers=self.headers), self) 524 | 525 | @requires_authentication 526 | def flag_generation_unexpected(self, generation: GenerationLike) -> WrappedUserFlag: 527 | """ 528 | Flags a generation as unexpected. 529 | 530 | :param generation: The generation to flag. 531 | :return: The user flag. 532 | """ 533 | return WrappedUserFlag( 534 | labs.flag_generation_unexpected(bearer_token=self.__bearer_token, 535 | generation_id=get_generation_id(generation), 536 | headers=self.headers), self) 537 | 538 | @requires_authentication_async 539 | async def flag_generation_unexpected_async(self, generation: GenerationLike) -> WrappedUserFlag: 540 | """ 541 | Asynchronously flags a generation as unexpected. 542 | 543 | :param generation: The generation to flag. 544 | :return: The user flag. 545 | """ 546 | return WrappedUserFlag( 547 | await labs.flag_generation_unexpected_async(bearer_token=self.__bearer_token, 548 | generation_id=get_generation_id(generation), 549 | headers=self.headers), self) 550 | 551 | @requires_authentication 552 | def get_credit_summary(self) -> WrappedBillingInfo: 553 | """ 554 | Gets the user's credit summary. 555 | 556 | :return: The user's credit summary. 557 | """ 558 | return WrappedBillingInfo(labs.get_credit_summary(bearer_token=self.__bearer_token, headers=self.headers), self) 559 | 560 | @requires_authentication_async 561 | async def get_credit_summary_async(self) -> WrappedBillingInfo: 562 | """ 563 | Asynchronously gets the user's credit summary. 564 | 565 | :return: The user's credit summary. 566 | """ 567 | return WrappedBillingInfo( 568 | await labs.get_credit_summary_async(bearer_token=self.__bearer_token, headers=self.headers), self) 569 | 570 | @requires_authentication 571 | def get_login_info(self) -> WrappedLogin: 572 | """ 573 | Gets the user's login information. 574 | 575 | :return: The user's login information. 576 | """ 577 | return WrappedLogin(labs.get_login_info(access_token=self.__access_token, headers=self.headers), self) 578 | 579 | @requires_authentication_async 580 | async def get_login_info_async(self) -> WrappedLogin: 581 | """ 582 | Asynchronously gets the user's login information. 583 | 584 | :return: The user's login information. 585 | """ 586 | return WrappedLogin(await labs.get_login_info_async(access_token=self.__access_token, headers=self.headers), 587 | self) 588 | --------------------------------------------------------------------------------