├── lama_cleaner ├── __init__.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── base.cpython-310.pyc │ │ ├── fcf.cpython-310.pyc │ │ ├── lama.cpython-310.pyc │ │ ├── ldm.cpython-310.pyc │ │ ├── mat.cpython-310.pyc │ │ ├── sd.cpython-310.pyc │ │ ├── zits.cpython-310.pyc │ │ ├── utils.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── opencv2.cpython-310.pyc │ │ ├── ddim_sampler.cpython-310.pyc │ │ └── plms_sampler.cpython-310.pyc │ ├── opencv2.py │ ├── lama.py │ ├── base.py │ ├── ddim_sampler.py │ ├── sd.py │ ├── ldm.py │ ├── plms_sampler.py │ ├── zits.py │ ├── sd_pipeline.py │ ├── utils.py │ └── fcf.py ├── __pycache__ │ ├── urls.cpython-310.pyc │ ├── wsgi.cpython-310.pyc │ ├── helper.cpython-310.pyc │ ├── schema.cpython-310.pyc │ ├── __init__.cpython-310.pyc │ ├── settings.cpython-310.pyc │ └── model_manager.cpython-310.pyc ├── asgi.py ├── wsgi.py ├── urls.py ├── schema.py ├── model_manager.py ├── settings.py └── helper.py ├── example_image ├── image0.jpeg └── image1.jpeg ├── requirements.txt ├── example.py ├── .github └── FUNDING.yml ├── LICENSE ├── README.md ├── gradio.py └── remwm.py /lama_cleaner/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lama_cleaner/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /example_image/image0.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/example_image/image0.jpeg -------------------------------------------------------------------------------- /example_image/image1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/example_image/image1.jpeg -------------------------------------------------------------------------------- /lama_cleaner/__pycache__/urls.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/__pycache__/urls.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/__pycache__/wsgi.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/__pycache__/wsgi.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/__pycache__/helper.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/__pycache__/helper.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/__pycache__/schema.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/__pycache__/schema.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/__pycache__/settings.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/__pycache__/settings.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/model/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/model/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/model/__pycache__/fcf.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/model/__pycache__/fcf.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/model/__pycache__/lama.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/model/__pycache__/lama.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/model/__pycache__/ldm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/model/__pycache__/ldm.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/model/__pycache__/mat.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/model/__pycache__/mat.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/model/__pycache__/sd.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/model/__pycache__/sd.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/model/__pycache__/zits.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/model/__pycache__/zits.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/model/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/model/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/__pycache__/model_manager.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/__pycache__/model_manager.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/model/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/model/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/model/__pycache__/opencv2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/model/__pycache__/opencv2.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/model/__pycache__/ddim_sampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/model/__pycache__/ddim_sampler.cpython-310.pyc -------------------------------------------------------------------------------- /lama_cleaner/model/__pycache__/plms_sampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Damarcreative/rem-wm/HEAD/lama_cleaner/model/__pycache__/plms_sampler.cpython-310.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.6.0.66 2 | pytest==7.1.3 3 | torch==2.2.0 4 | pydantic==1.10.2 5 | loguru==0.6.0 6 | tqdm==4.64.1 7 | Pillow==9.2.0 8 | diffusers==0.32.2 9 | transformers 10 | scikit-image==0.19.3 11 | gradio 12 | timm 13 | -------------------------------------------------------------------------------- /lama_cleaner/asgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ASGI config for lama_cleaner project. 3 | 4 | It exposes the ASGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/4.1/howto/deployment/asgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.asgi import get_asgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'lama_cleaner.settings') 15 | 16 | application = get_asgi_application() 17 | -------------------------------------------------------------------------------- /lama_cleaner/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for lama_cleaner project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/4.1/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'lama_cleaner.settings') 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from remwm import WatermarkRemover 2 | 3 | # Initialize WatermarkRemover object with custom Florence model (optional) 4 | remover = WatermarkRemover(model_id='microsoft/Florence-2-large') 5 | 6 | # To process a single image 7 | input_image_path = "path/to/input/image.jpg" 8 | output_image_path = "path/to/output/image.jpg" 9 | remover.process_images_florence_lama(input_image_path, output_image_path) 10 | 11 | # To batch process images in a folder 12 | input_dir = "path/to/input/folder" 13 | output_dir = "path/to/output/folder" 14 | remover.process_batch(input_dir, output_dir, max_workers=4) -------------------------------------------------------------------------------- /lama_cleaner/model/opencv2.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from lama_cleaner.model.base import InpaintModel 3 | from lama_cleaner.schema import Config 4 | 5 | flag_map = { 6 | "INPAINT_NS": cv2.INPAINT_NS, 7 | "INPAINT_TELEA": cv2.INPAINT_TELEA 8 | } 9 | 10 | class OpenCV2(InpaintModel): 11 | pad_mod = 1 12 | 13 | @staticmethod 14 | def is_downloaded() -> bool: 15 | return True 16 | 17 | def forward(self, image, mask, config: Config): 18 | """Input image and output image have same size 19 | image: [H, W, C] RGB 20 | mask: [H, W, 1] 21 | return: BGR IMAGE 22 | """ 23 | cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag]) 24 | return cur_res 25 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | github: Damarcreative 3 | patreon: # Replace with a single Patreon username 4 | open_collective: # Replace with a single Open Collective username 5 | ko_fi: damarcreative # Replace with a single Ko-fi username 6 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 7 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 8 | liberapay: # Replace with a single Liberapay username 9 | issuehunt: # Replace with a single IssueHunt username 10 | otechie: # Replace with a single Otechie username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /lama_cleaner/urls.py: -------------------------------------------------------------------------------- 1 | """lama_cleaner URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/4.1/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: path('', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.urls import include, path 14 | 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) 15 | """ 16 | from django.contrib import admin 17 | from django.urls import path,include 18 | 19 | urlpatterns = [ 20 | path('admin/', admin.site.urls), 21 | path('inpainting/',include('inpainting.urls')), 22 | ] 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Damar Jati 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /lama_cleaner/schema.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class HDStrategy(str, Enum): 7 | ORIGINAL = "Original" 8 | RESIZE = "Resize" 9 | CROP = "Crop" 10 | 11 | 12 | class LDMSampler(str, Enum): 13 | ddim = "ddim" 14 | plms = "plms" 15 | 16 | 17 | class SDSampler(str, Enum): 18 | ddim = "ddim" 19 | pndm = "pndm" 20 | 21 | 22 | class Config(BaseModel): 23 | ldm_steps: int 24 | ldm_sampler: str = LDMSampler.plms 25 | zits_wireframe: bool = True 26 | hd_strategy: str 27 | hd_strategy_crop_margin: int 28 | hd_strategy_crop_trigger_size: int 29 | hd_strategy_resize_limit: int 30 | 31 | prompt: str = "" 32 | # 始终是在原图尺度上的值 33 | use_croper: bool = False 34 | croper_x: int = None 35 | croper_y: int = None 36 | croper_height: int = None 37 | croper_width: int = None 38 | 39 | # sd 40 | sd_mask_blur: int = 0 41 | sd_strength: float = 0.75 42 | sd_steps: int = 50 43 | sd_guidance_scale: float = 7.5 44 | sd_sampler: str = SDSampler.ddim 45 | # -1 mean random seed 46 | sd_seed: int = 42 47 | 48 | # cv2 49 | cv2_flag: str = 'INPAINT_NS' 50 | cv2_radius: int = 4 51 | -------------------------------------------------------------------------------- /lama_cleaner/model_manager.py: -------------------------------------------------------------------------------- 1 | from lama_cleaner.model.fcf import FcF 2 | from lama_cleaner.model.lama import LaMa 3 | from lama_cleaner.model.ldm import LDM 4 | from lama_cleaner.model.mat import MAT 5 | from lama_cleaner.model.sd import SD14 6 | from lama_cleaner.model.zits import ZITS 7 | from lama_cleaner.model.opencv2 import OpenCV2 8 | from lama_cleaner.schema import Config 9 | 10 | models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.4": SD14, "cv2": OpenCV2} 11 | 12 | 13 | class ModelManager: 14 | def __init__(self, name: str, device, **kwargs): 15 | self.name = name 16 | self.device = device 17 | self.kwargs = kwargs 18 | self.model = self.init_model(name, device, **kwargs) 19 | 20 | def init_model(self, name: str, device, **kwargs): 21 | if name in models: 22 | model = models[name](device, **kwargs) 23 | else: 24 | raise NotImplementedError(f"Not supported model: {name}") 25 | return model 26 | 27 | def is_downloaded(self, name: str) -> bool: 28 | if name in models: 29 | return models[name].is_downloaded() 30 | else: 31 | raise NotImplementedError(f"Not supported model: {name}") 32 | 33 | def __call__(self, image, mask, config: Config): 34 | return self.model(image, mask, config) 35 | 36 | def switch(self, new_name: str): 37 | if new_name == self.name: 38 | return 39 | try: 40 | self.model = self.init_model(new_name, self.device, **self.kwargs) 41 | self.name = new_name 42 | except NotImplementedError as e: 43 | raise e 44 | -------------------------------------------------------------------------------- /lama_cleaner/model/lama.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from loguru import logger 7 | 8 | from lama_cleaner.helper import pad_img_to_modulo, download_model, norm_img, get_cache_path_by_url 9 | from lama_cleaner.model.base import InpaintModel 10 | from lama_cleaner.schema import Config 11 | 12 | LAMA_MODEL_URL = os.environ.get( 13 | "LAMA_MODEL_URL", 14 | "https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", 15 | ) 16 | 17 | #"https://drive.google.com/file/d/1bMD06F9hkkS1oi8cEmb4cSjXz54Pxs6A/view?usp=sharing" #big-lama.pt file 18 | 19 | 20 | class LaMa(InpaintModel): 21 | pad_mod = 8 22 | 23 | def init_model(self, device, **kwargs): 24 | if os.environ.get("LAMA_MODEL"): 25 | model_path = os.environ.get("LAMA_MODEL") 26 | if not os.path.exists(model_path): 27 | raise FileNotFoundError( 28 | f"lama torchscript model not found: {model_path}" 29 | ) 30 | else: 31 | model_path = download_model(LAMA_MODEL_URL) 32 | logger.info(f"Load LaMa model from: {model_path}") 33 | model = torch.jit.load(model_path, map_location="cpu") 34 | model = model.to(device) 35 | model.eval() 36 | self.model = model 37 | self.model_path = model_path 38 | 39 | @staticmethod 40 | def is_downloaded() -> bool: 41 | return os.path.exists(get_cache_path_by_url(LAMA_MODEL_URL)) 42 | 43 | def forward(self, image, mask, config: Config): 44 | """Input image and output image have same size 45 | image: [H, W, C] RGB 46 | mask: [H, W] 47 | return: BGR IMAGE 48 | """ 49 | image = norm_img(image) 50 | mask = norm_img(mask) 51 | 52 | mask = (mask > 0) * 1 53 | image = torch.from_numpy(image).unsqueeze(0).to(self.device) 54 | mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) 55 | 56 | inpainted_image = self.model(image, mask) 57 | 58 | cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() 59 | cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") 60 | cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) 61 | return cur_res 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rem-WM: Watermark Remover using Florence and Lama Cleaner 2 | 3 | **Rem-WM**, a powerful watermark remover tool that leverages the capabilities of Microsoft Florence and Lama Cleaner models. This tool provides an easy-to-use interface for removing watermarks from images, with support for both individual images and batch processing. 4 | 5 | ### Test 6 | https://huggingface.co/spaces/DamarJati/Remove-watermark 7 | 8 | ## Features 9 | 10 | - **Watermark Removal**: Automatically detect and remove watermarks from images. 11 | - **Batch Processing**: Efficiently process multiple images using threading. 12 | - **Custom Model Support**: Flexibility to use custom Florence models. 13 | - **Easy Integration**: Simple class-based interface for integration into your projects. 14 | 15 | ## Installation 16 | 17 | First, clone the repository and navigate to the project directory: 18 | 19 | ```bash 20 | git clone https://github.com/Damarcreative/rem-wm.git 21 | cd rem-wm 22 | ``` 23 | 24 | Install the necessary dependencies: 25 | 26 | ```bash 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ## Usage 31 | 32 | ### Single Image Processing 33 | 34 | To remove a watermark from a single image, use the `WatermarkRemover` class: 35 | 36 | ```python 37 | from remwm import WatermarkRemover 38 | 39 | # Initialize the WatermarkRemover with the default Florence model 40 | remover = WatermarkRemover() 41 | 42 | # Define input and output paths 43 | input_image_path = "path/to/input/image.jpg" 44 | output_image_path = "path/to/output/image.jpg" 45 | 46 | # Process the image 47 | remover.process_images_florence_lama(input_image_path, output_image_path) 48 | ``` 49 | 50 | ### Batch Processing 51 | 52 | To process multiple images in a directory, use the `process_batch` method: 53 | 54 | ```python 55 | from remwm import WatermarkRemover 56 | 57 | # Initialize the WatermarkRemover 58 | remover = WatermarkRemover() 59 | 60 | # Define input and output directories 61 | input_dir = "path/to/input/folder" 62 | output_dir = "path/to/output/folder" 63 | 64 | # Process the batch of images 65 | remover.process_batch(input_dir, output_dir, max_workers=4) 66 | ``` 67 | 68 | ### Using a Custom Florence Model 69 | 70 | If you want to use a custom Florence model, simply provide the model ID during initialization: 71 | 72 | ```python 73 | from remwm import WatermarkRemover 74 | 75 | # Initialize with a custom Florence model 76 | remover = WatermarkRemover(model_id='facebook/custom-florence-model') 77 | 78 | # Process images as usual 79 | input_image_path = "path/to/input/image.jpg" 80 | output_image_path = "path/to/output/image.jpg" 81 | remover.process_images_florence_lama(input_image_path, output_image_path) 82 | ``` 83 | 84 | ## Contributing 85 | 86 | We welcome contributions to enhance Rem-WM! To contribute, please follow these steps: 87 | 88 | 1. Fork the repository. 89 | 2. Create a new branch (`git checkout -b feature/your-feature-name`). 90 | 3. Make your changes. 91 | 4. Commit your changes (`git commit -m 'Add some feature'`). 92 | 5. Push to the branch (`git push origin feature/your-feature-name`). 93 | 6. Open a pull request. 94 | 95 | ## License 96 | 97 | This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. 98 | 99 | ## Acknowledgements 100 | 101 | - [Microsoft Florence](https://huggingface.co/microsoft/Florence-2-large) 102 | - [Lama Cleaner](https://github.com/Sanster/IOPaint) 103 | 104 | ## Contact 105 | 106 | For any questions or inquiries, please open an issue or contact us at 107 | dev@damarcreative.my.id . 108 | 109 | --- 110 | 111 | Thank you for using Rem-WM! We hope this tool helps you effectively remove watermarks from your images. 112 | -------------------------------------------------------------------------------- /lama_cleaner/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for lama_cleaner project. 3 | 4 | Generated by 'django-admin startproject' using Django 4.1.2. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/4.1/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/4.1/ref/settings/ 11 | """ 12 | 13 | from pathlib import Path 14 | 15 | # Build paths inside the project like this: BASE_DIR / 'subdir'. 16 | BASE_DIR = Path(__file__).resolve().parent.parent 17 | 18 | 19 | # Quick-start development settings - unsuitable for production 20 | # See https://docs.djangoproject.com/en/4.1/howto/deployment/checklist/ 21 | 22 | # SECURITY WARNING: keep the secret key used in production secret! 23 | SECRET_KEY = 'django-insecure-=x2n@zasb2nkq$)frp(&h*tsozyka+jb5(&3^7@u5@ven@-sdu' 24 | 25 | # SECURITY WARNING: don't run with debug turned on in production! 26 | DEBUG = True 27 | 28 | ALLOWED_HOSTS = [] 29 | 30 | 31 | # Application definition 32 | 33 | INSTALLED_APPS = [ 34 | 'django.contrib.admin', 35 | 'django.contrib.auth', 36 | 'django.contrib.contenttypes', 37 | 'django.contrib.sessions', 38 | 'django.contrib.messages', 39 | 'django.contrib.staticfiles', 40 | 'inpainting', 41 | ] 42 | 43 | MIDDLEWARE = [ 44 | 'django.middleware.security.SecurityMiddleware', 45 | 'django.contrib.sessions.middleware.SessionMiddleware', 46 | 'django.middleware.common.CommonMiddleware', 47 | 'django.middleware.csrf.CsrfViewMiddleware', 48 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 49 | 'django.contrib.messages.middleware.MessageMiddleware', 50 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 51 | ] 52 | 53 | ROOT_URLCONF = 'lama_cleaner.urls' 54 | 55 | TEMPLATES = [ 56 | { 57 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 58 | 'DIRS': [], 59 | 'APP_DIRS': True, 60 | 'OPTIONS': { 61 | 'context_processors': [ 62 | 'django.template.context_processors.debug', 63 | 'django.template.context_processors.request', 64 | 'django.contrib.auth.context_processors.auth', 65 | 'django.contrib.messages.context_processors.messages', 66 | ], 67 | }, 68 | }, 69 | ] 70 | 71 | WSGI_APPLICATION = 'lama_cleaner.wsgi.application' 72 | 73 | 74 | # Database 75 | # https://docs.djangoproject.com/en/4.1/ref/settings/#databases 76 | 77 | DATABASES = { 78 | 'default': { 79 | 'ENGINE': 'django.db.backends.sqlite3', 80 | 'NAME': BASE_DIR / 'db.sqlite3', 81 | } 82 | } 83 | 84 | 85 | # Password validation 86 | # https://docs.djangoproject.com/en/4.1/ref/settings/#auth-password-validators 87 | 88 | AUTH_PASSWORD_VALIDATORS = [ 89 | { 90 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 91 | }, 92 | { 93 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 94 | }, 95 | { 96 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 97 | }, 98 | { 99 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 100 | }, 101 | ] 102 | 103 | 104 | # Internationalization 105 | # https://docs.djangoproject.com/en/4.1/topics/i18n/ 106 | 107 | LANGUAGE_CODE = 'en-us' 108 | 109 | TIME_ZONE = 'UTC' 110 | 111 | USE_I18N = True 112 | 113 | USE_TZ = True 114 | 115 | 116 | # Static files (CSS, JavaScript, Images) 117 | # https://docs.djangoproject.com/en/4.1/howto/static-files/ 118 | 119 | STATIC_URL = 'static/' 120 | 121 | # Default primary key field type 122 | # https://docs.djangoproject.com/en/4.1/ref/settings/#default-auto-field 123 | 124 | DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' 125 | -------------------------------------------------------------------------------- /gradio.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import torch 3 | from lama_cleaner.model_manager import ModelManager 4 | from lama_cleaner.schema import Config, HDStrategy, LDMSampler 5 | from transformers import AutoProcessor, AutoModelForCausalLM 6 | import cv2 7 | import numpy as np 8 | from PIL import Image, ImageDraw 9 | import subprocess 10 | 11 | # Install necessary packages 12 | subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) 13 | 14 | # Initialize Llama Cleaner model 15 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 16 | 17 | # Define available models 18 | available_models = [ 19 | 'microsoft/Florence-2-base', 20 | 'microsoft/Florence-2-base-ft', 21 | 'microsoft/Florence-2-large', 22 | 'microsoft/Florence-2-large-ft' 23 | ] 24 | 25 | # Load all models and processors 26 | model_dict = {} 27 | for model_id in available_models: 28 | florence_model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval() 29 | florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) 30 | model_dict[model_id] = (florence_model, florence_processor) 31 | 32 | @spaces.GPU() 33 | def process_image(image, mask, strategy, sampler, fx=1, fy=1): 34 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 35 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) 36 | 37 | if fx != 1 or fy != 1: 38 | image = cv2.resize(image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA) 39 | mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST) 40 | 41 | config = Config( 42 | ldm_steps=1, 43 | ldm_sampler=sampler, 44 | hd_strategy=strategy, 45 | hd_strategy_crop_margin=32, 46 | hd_strategy_crop_trigger_size=200, 47 | hd_strategy_resize_limit=200, 48 | ) 49 | 50 | model = ModelManager(name="lama", device=device) 51 | result = model(image, mask, config) 52 | return result 53 | 54 | def create_mask(image, prediction): 55 | mask = Image.new("RGBA", image.size, (0, 0, 0, 255)) # Black background 56 | draw = ImageDraw.Draw(mask) 57 | scale = 1 58 | for polygons in prediction['polygons']: 59 | for _polygon in polygons: 60 | _polygon = np.array(_polygon).reshape(-1, 2) 61 | if len(_polygon) < 3: 62 | continue 63 | _polygon = (_polygon * scale).reshape(-1).tolist() 64 | draw.polygon(_polygon, fill=(255, 255, 255, 255)) # Make selected area white 65 | return mask 66 | 67 | def process_images_florence_lama(image, model_choice): 68 | florence_model, florence_processor = model_dict[model_choice] 69 | 70 | # Convert image to OpenCV format 71 | image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 72 | 73 | # Run Florence to get mask 74 | text_input = 'watermark' 75 | task_prompt = '' 76 | image_pil = Image.fromarray(image_cv) # Convert array to PIL Image 77 | inputs = florence_processor(text=task_prompt + text_input, images=image_pil, return_tensors="pt").to("cuda") 78 | generated_ids = florence_model.generate( 79 | input_ids=inputs["input_ids"], 80 | pixel_values=inputs["pixel_values"], 81 | max_new_tokens=1024, 82 | early_stopping=False, 83 | do_sample=False, 84 | num_beams=3, 85 | ) 86 | generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] 87 | parsed_answer = florence_processor.post_process_generation( 88 | generated_text, 89 | task=task_prompt, 90 | image_size=(image_pil.width, image_pil.height) 91 | ) 92 | 93 | # Create mask and process image with Llama Cleaner 94 | mask_image = create_mask(image_pil, parsed_answer['']) 95 | result_image = process_image(image_cv, np.array(mask_image), HDStrategy.RESIZE, LDMSampler.ddim) 96 | 97 | # Convert result back to PIL Image 98 | result_image_pil = Image.fromarray(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)) 99 | 100 | return result_image_pil 101 | 102 | # Define Gradio interface 103 | demo = gr.Interface( 104 | fn=process_images_florence_lama, 105 | inputs=[ 106 | gr.Image(type="pil", label="Input Image"), 107 | gr.Dropdown(choices=available_models, value='microsoft/Florence-2-large', label="Choose Florence Model") 108 | ], 109 | outputs=gr.Image(type="pil", label="Output Image"), 110 | title="Watermark Remover", 111 | description="Upload images and remove selected watermarks using Florence and Lama Cleaner.\nhttps://github.com/Damarcreative/rem-wm.git" 112 | ) 113 | 114 | if __name__ == "__main__": 115 | demo.launch() 116 | -------------------------------------------------------------------------------- /remwm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lama_cleaner.model_manager import ModelManager 3 | from lama_cleaner.schema import Config, HDStrategy, LDMSampler 4 | from transformers import AutoProcessor, AutoModelForCausalLM 5 | import cv2 6 | import numpy as np 7 | from PIL import Image, ImageDraw 8 | import subprocess 9 | import os 10 | from concurrent.futures import ThreadPoolExecutor 11 | 12 | # Install necessary packages 13 | subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) 14 | 15 | class WatermarkRemover: 16 | def __init__(self, model_id='microsoft/Florence-2-large'): 17 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 18 | 19 | # Initialize Florence model 20 | self.florence_model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(self.device).eval() 21 | self.florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) 22 | 23 | # Initialize Llama Cleaner model 24 | self.model_manager = ModelManager(name="lama", device=self.device) 25 | 26 | def process_image(self, image, mask, strategy=HDStrategy.RESIZE, sampler=LDMSampler.ddim, fx=1, fy=1): 27 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 28 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) 29 | 30 | if fx != 1 or fy != 1: 31 | image = cv2.resize(image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA) 32 | mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST) 33 | 34 | config = Config( 35 | ldm_steps=1, 36 | ldm_sampler=sampler, 37 | hd_strategy=strategy, 38 | hd_strategy_crop_margin=32, 39 | hd_strategy_crop_trigger_size=200, 40 | hd_strategy_resize_limit=200, 41 | ) 42 | 43 | result = self.model_manager(image, mask, config) 44 | return result 45 | 46 | def create_mask(self, image, prediction): 47 | mask = Image.new("RGBA", image.size, (0, 0, 0, 255)) # Black background 48 | draw = ImageDraw.Draw(mask) 49 | scale = 1 50 | for polygons in prediction['polygons']: 51 | for _polygon in polygons: 52 | _polygon = np.array(_polygon).reshape(-1, 2) 53 | if len(_polygon) < 3: 54 | continue 55 | _polygon = (_polygon * scale).reshape(-1).tolist() 56 | draw.polygon(_polygon, fill=(255, 255, 255, 255)) # Make selected area white 57 | return mask 58 | 59 | def process_images_florence_lama(self, input_image_path, output_image_path): 60 | # Load input image 61 | image = Image.open(input_image_path).convert("RGB") 62 | 63 | # Convert image to OpenCV format 64 | image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 65 | 66 | # Run Florence to get mask 67 | text_input = 'watermark' # Teks untuk Florence agar mengenali watermark 68 | task_prompt = '' 69 | inputs = self.florence_processor(text=task_prompt + text_input, images=image, return_tensors="pt").to(self.device) 70 | generated_ids = self.florence_model.generate( 71 | input_ids=inputs["input_ids"], 72 | pixel_values=inputs["pixel_values"], 73 | max_new_tokens=1024, 74 | early_stopping=False, 75 | do_sample=False, 76 | num_beams=3, 77 | ) 78 | generated_text = self.florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0] 79 | parsed_answer = self.florence_processor.post_process_generation( 80 | generated_text, 81 | task=task_prompt, 82 | image_size=(image.width, image.height) 83 | ) 84 | 85 | # Create mask and process image with Llama Cleaner 86 | mask_image = self.create_mask(image, parsed_answer['']) 87 | result_image = self.process_image(image_cv, np.array(mask_image), HDStrategy.RESIZE, LDMSampler.ddim) 88 | 89 | # Convert result back to PIL Image 90 | result_image_pil = Image.fromarray(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)) 91 | 92 | # Save output image 93 | result_image_pil.save(output_image_path) 94 | 95 | def process_batch(self, input_dir, output_dir, max_workers=4): 96 | input_images = [os.path.join(input_dir, img) for img in os.listdir(input_dir) if img.endswith(('.png', '.jpg', '.jpeg'))] 97 | output_images = [os.path.join(output_dir, os.path.basename(img)) for img in input_images] 98 | 99 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 100 | executor.map(self.process_images_florence_lama, input_images, output_images) 101 | -------------------------------------------------------------------------------- /lama_cleaner/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import List, Optional 4 | 5 | from urllib.parse import urlparse 6 | import cv2 7 | import numpy as np 8 | import torch 9 | from loguru import logger 10 | from torch.hub import download_url_to_file, get_dir 11 | 12 | 13 | def get_cache_path_by_url(url): 14 | parts = urlparse(url) 15 | hub_dir = get_dir() 16 | model_dir = os.path.join(hub_dir, "checkpoints") 17 | if not os.path.isdir(model_dir): 18 | os.makedirs(os.path.join(model_dir, "hub", "checkpoints")) 19 | filename = os.path.basename(parts.path) 20 | cached_file = os.path.join(model_dir, filename) 21 | return cached_file 22 | 23 | 24 | def download_model(url): 25 | cached_file = get_cache_path_by_url(url) 26 | if not os.path.exists(cached_file): 27 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 28 | hash_prefix = None 29 | download_url_to_file(url, cached_file, hash_prefix, progress=True) 30 | return cached_file 31 | 32 | 33 | def ceil_modulo(x, mod): 34 | if x % mod == 0: 35 | return x 36 | return (x // mod + 1) * mod 37 | 38 | 39 | def load_jit_model(url_or_path, device): 40 | if os.path.exists(url_or_path): 41 | model_path = url_or_path 42 | else: 43 | model_path = download_model(url_or_path) 44 | logger.info(f"Load model from: {model_path}") 45 | try: 46 | model = torch.jit.load(model_path).to(device) 47 | except: 48 | logger.error( 49 | f"Failed to load {model_path}, delete model and restart lama-cleaner" 50 | ) 51 | exit(-1) 52 | model.eval() 53 | return model 54 | 55 | 56 | def load_model(model: torch.nn.Module, url_or_path, device): 57 | if os.path.exists(url_or_path): 58 | model_path = url_or_path 59 | else: 60 | model_path = download_model(url_or_path) 61 | 62 | try: 63 | state_dict = torch.load(model_path, map_location='cpu') 64 | model.load_state_dict(state_dict, strict=True) 65 | model.to(device) 66 | logger.info(f"Load model from: {model_path}") 67 | except: 68 | logger.error( 69 | f"Failed to load {model_path}, delete model and restart lama-cleaner" 70 | ) 71 | exit(-1) 72 | model.eval() 73 | return model 74 | 75 | 76 | def numpy_to_bytes(image_numpy: np.ndarray, ext: str) -> bytes: 77 | data = cv2.imencode( 78 | f".{ext}", 79 | image_numpy, 80 | [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], 81 | )[1] 82 | image_bytes = data.tobytes() 83 | return image_bytes 84 | 85 | 86 | def load_img(img_bytes, gray: bool = False): 87 | alpha_channel = None 88 | nparr = np.frombuffer(img_bytes, np.uint8) 89 | if gray: 90 | np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE) 91 | else: 92 | np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED) 93 | if len(np_img.shape) == 3 and np_img.shape[2] == 4: 94 | alpha_channel = np_img[:, :, -1] 95 | np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB) 96 | else: 97 | np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) 98 | 99 | return np_img, alpha_channel 100 | 101 | 102 | def norm_img(np_img): 103 | if len(np_img.shape) == 2: 104 | np_img = np_img[:, :, np.newaxis] 105 | np_img = np.transpose(np_img, (2, 0, 1)) 106 | np_img = np_img.astype("float32") / 255 107 | return np_img 108 | 109 | 110 | def resize_max_size( 111 | np_img, size_limit: int, interpolation=cv2.INTER_CUBIC 112 | ) -> np.ndarray: 113 | # Resize image's longer size to size_limit if longer size larger than size_limit 114 | h, w = np_img.shape[:2] 115 | if max(h, w) > size_limit: 116 | ratio = size_limit / max(h, w) 117 | new_w = int(w * ratio + 0.5) 118 | new_h = int(h * ratio + 0.5) 119 | return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation) 120 | else: 121 | return np_img 122 | 123 | 124 | def pad_img_to_modulo( 125 | img: np.ndarray, mod: int, square: bool = False, min_size: Optional[int] = None 126 | ): 127 | """ 128 | 129 | Args: 130 | img: [H, W, C] 131 | mod: 132 | square: 是否为正方形 133 | min_size: 134 | 135 | Returns: 136 | 137 | """ 138 | if len(img.shape) == 2: 139 | img = img[:, :, np.newaxis] 140 | height, width = img.shape[:2] 141 | out_height = ceil_modulo(height, mod) 142 | out_width = ceil_modulo(width, mod) 143 | 144 | if min_size is not None: 145 | assert min_size % mod == 0 146 | out_width = max(min_size, out_width) 147 | out_height = max(min_size, out_height) 148 | 149 | if square: 150 | max_size = max(out_height, out_width) 151 | out_height = max_size 152 | out_width = max_size 153 | 154 | return np.pad( 155 | img, 156 | ((0, out_height - height), (0, out_width - width), (0, 0)), 157 | mode="symmetric", 158 | ) 159 | 160 | 161 | def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]: 162 | """ 163 | Args: 164 | mask: (h, w, 1) 0~255 165 | 166 | Returns: 167 | 168 | """ 169 | height, width = mask.shape[:2] 170 | _, thresh = cv2.threshold(mask, 127, 255, 0) 171 | contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 172 | 173 | boxes = [] 174 | for cnt in contours: 175 | x, y, w, h = cv2.boundingRect(cnt) 176 | box = np.array([x, y, x + w, y + h]).astype(int) 177 | 178 | box[::2] = np.clip(box[::2], 0, width) 179 | box[1::2] = np.clip(box[1::2], 0, height) 180 | boxes.append(box) 181 | 182 | return boxes 183 | -------------------------------------------------------------------------------- /lama_cleaner/model/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Optional 3 | 4 | import cv2 5 | import torch 6 | from loguru import logger 7 | 8 | from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo 9 | from lama_cleaner.schema import Config, HDStrategy 10 | 11 | 12 | class InpaintModel: 13 | min_size: Optional[int] = None 14 | pad_mod = 8 15 | pad_to_square = False 16 | 17 | def __init__(self, device, **kwargs): 18 | """ 19 | 20 | Args: 21 | device: 22 | """ 23 | self.device = device 24 | self.init_model(device, **kwargs) 25 | 26 | @abc.abstractmethod 27 | def init_model(self, device, **kwargs): 28 | ... 29 | 30 | @staticmethod 31 | @abc.abstractmethod 32 | def is_downloaded() -> bool: 33 | ... 34 | 35 | @abc.abstractmethod 36 | def forward(self, image, mask, config: Config): 37 | """Input images and output images have same size 38 | images: [H, W, C] RGB 39 | masks: [H, W, 1] 255 为 masks 区域 40 | return: BGR IMAGE 41 | """ 42 | ... 43 | 44 | def _pad_forward(self, image, mask, config: Config): 45 | origin_height, origin_width = image.shape[:2] 46 | pad_image = pad_img_to_modulo( 47 | image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size 48 | ) 49 | pad_mask = pad_img_to_modulo( 50 | mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size 51 | ) 52 | 53 | logger.info(f"final forward pad size: {pad_image.shape}") 54 | 55 | result = self.forward(pad_image, pad_mask, config) 56 | result = result[0:origin_height, 0:origin_width, :] 57 | 58 | original_pixel_indices = mask < 127 59 | result[original_pixel_indices] = image[:, :, ::-1][original_pixel_indices] 60 | return result 61 | 62 | @torch.no_grad() 63 | def __call__(self, image, mask, config: Config): 64 | """ 65 | images: [H, W, C] RGB, not normalized 66 | masks: [H, W] 67 | return: BGR IMAGE 68 | """ 69 | inpaint_result = None 70 | logger.info(f"hd_strategy: {config.hd_strategy}") 71 | if config.hd_strategy == HDStrategy.CROP: 72 | if max(image.shape) > config.hd_strategy_crop_trigger_size: 73 | logger.info(f"Run crop strategy") 74 | boxes = boxes_from_mask(mask) 75 | crop_result = [] 76 | for box in boxes: 77 | crop_image, crop_box = self._run_box(image, mask, box, config) 78 | crop_result.append((crop_image, crop_box)) 79 | 80 | inpaint_result = image[:, :, ::-1] 81 | for crop_image, crop_box in crop_result: 82 | x1, y1, x2, y2 = crop_box 83 | inpaint_result[y1:y2, x1:x2, :] = crop_image 84 | 85 | elif config.hd_strategy == HDStrategy.RESIZE: 86 | if max(image.shape) > config.hd_strategy_resize_limit: 87 | origin_size = image.shape[:2] 88 | downsize_image = resize_max_size( 89 | image, size_limit=config.hd_strategy_resize_limit 90 | ) 91 | downsize_mask = resize_max_size( 92 | mask, size_limit=config.hd_strategy_resize_limit 93 | ) 94 | 95 | logger.info( 96 | f"Run resize strategy, origin size: {image.shape} forward size: {downsize_image.shape}" 97 | ) 98 | inpaint_result = self._pad_forward( 99 | downsize_image, downsize_mask, config 100 | ) 101 | 102 | # only paste masked area result 103 | inpaint_result = cv2.resize( 104 | inpaint_result, 105 | (origin_size[1], origin_size[0]), 106 | interpolation=cv2.INTER_CUBIC, 107 | ) 108 | original_pixel_indices = mask < 127 109 | inpaint_result[original_pixel_indices] = image[:, :, ::-1][ 110 | original_pixel_indices 111 | ] 112 | 113 | if inpaint_result is None: 114 | inpaint_result = self._pad_forward(image, mask, config) 115 | 116 | return inpaint_result 117 | 118 | def _crop_box(self, image, mask, box, config: Config): 119 | """ 120 | 121 | Args: 122 | image: [H, W, C] RGB 123 | mask: [H, W, 1] 124 | box: [left,top,right,bottom] 125 | 126 | Returns: 127 | BGR IMAGE, (l, r, r, b) 128 | """ 129 | box_h = box[3] - box[1] 130 | box_w = box[2] - box[0] 131 | cx = (box[0] + box[2]) // 2 132 | cy = (box[1] + box[3]) // 2 133 | img_h, img_w = image.shape[:2] 134 | 135 | w = box_w + config.hd_strategy_crop_margin * 2 136 | h = box_h + config.hd_strategy_crop_margin * 2 137 | 138 | _l = cx - w // 2 139 | _r = cx + w // 2 140 | _t = cy - h // 2 141 | _b = cy + h // 2 142 | 143 | l = max(_l, 0) 144 | r = min(_r, img_w) 145 | t = max(_t, 0) 146 | b = min(_b, img_h) 147 | 148 | # try to get more context when crop around image edge 149 | if _l < 0: 150 | r += abs(_l) 151 | if _r > img_w: 152 | l -= _r - img_w 153 | if _t < 0: 154 | b += abs(_t) 155 | if _b > img_h: 156 | t -= _b - img_h 157 | 158 | l = max(l, 0) 159 | r = min(r, img_w) 160 | t = max(t, 0) 161 | b = min(b, img_h) 162 | 163 | crop_img = image[t:b, l:r, :] 164 | crop_mask = mask[t:b, l:r] 165 | 166 | logger.info(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}") 167 | 168 | return crop_img, crop_mask, [l, t, r, b] 169 | 170 | def _run_box(self, image, mask, box, config: Config): 171 | """ 172 | 173 | Args: 174 | image: [H, W, C] RGB 175 | mask: [H, W, 1] 176 | box: [left,top,right,bottom] 177 | 178 | Returns: 179 | BGR IMAGE 180 | """ 181 | crop_img, crop_mask, [l, t, r, b] = self._crop_box(image, mask, box, config) 182 | 183 | return self._pad_forward(crop_img, crop_mask, config), [l, t, r, b] 184 | -------------------------------------------------------------------------------- /lama_cleaner/model/ddim_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like 6 | 7 | from loguru import logger 8 | 9 | 10 | class DDIMSampler(object): 11 | def __init__(self, model, schedule="linear"): 12 | super().__init__() 13 | self.model = model 14 | self.ddpm_num_timesteps = model.num_timesteps 15 | self.schedule = schedule 16 | 17 | def register_buffer(self, name, attr): 18 | setattr(self, name, attr) 19 | 20 | def make_schedule( 21 | self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True 22 | ): 23 | self.ddim_timesteps = make_ddim_timesteps( 24 | ddim_discr_method=ddim_discretize, 25 | num_ddim_timesteps=ddim_num_steps, 26 | # array([1]) 27 | num_ddpm_timesteps=self.ddpm_num_timesteps, 28 | verbose=verbose, 29 | ) 30 | alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000]) 31 | assert ( 32 | alphas_cumprod.shape[0] == self.ddpm_num_timesteps 33 | ), "alphas have to be defined for each timestep" 34 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 35 | 36 | self.register_buffer("betas", to_torch(self.model.betas)) 37 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) 38 | self.register_buffer( 39 | "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) 40 | ) 41 | 42 | # calculations for diffusion q(x_t | x_{t-1}) and others 43 | self.register_buffer( 44 | "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) 45 | ) 46 | self.register_buffer( 47 | "sqrt_one_minus_alphas_cumprod", 48 | to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), 49 | ) 50 | self.register_buffer( 51 | "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) 52 | ) 53 | self.register_buffer( 54 | "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) 55 | ) 56 | self.register_buffer( 57 | "sqrt_recipm1_alphas_cumprod", 58 | to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), 59 | ) 60 | 61 | # ddim sampling parameters 62 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( 63 | alphacums=alphas_cumprod.cpu(), 64 | ddim_timesteps=self.ddim_timesteps, 65 | eta=ddim_eta, 66 | verbose=verbose, 67 | ) 68 | self.register_buffer("ddim_sigmas", ddim_sigmas) 69 | self.register_buffer("ddim_alphas", ddim_alphas) 70 | self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) 71 | self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) 72 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 73 | (1 - self.alphas_cumprod_prev) 74 | / (1 - self.alphas_cumprod) 75 | * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) 76 | ) 77 | self.register_buffer( 78 | "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps 79 | ) 80 | 81 | @torch.no_grad() 82 | def sample(self, steps, conditioning, batch_size, shape): 83 | self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False) 84 | # sampling 85 | C, H, W = shape 86 | size = (batch_size, C, H, W) 87 | 88 | # samples: 1,3,128,128 89 | return self.ddim_sampling( 90 | conditioning, 91 | size, 92 | quantize_denoised=False, 93 | ddim_use_original_steps=False, 94 | noise_dropout=0, 95 | temperature=1.0, 96 | ) 97 | 98 | @torch.no_grad() 99 | def ddim_sampling( 100 | self, 101 | cond, 102 | shape, 103 | ddim_use_original_steps=False, 104 | quantize_denoised=False, 105 | temperature=1.0, 106 | noise_dropout=0.0, 107 | ): 108 | device = self.model.betas.device 109 | b = shape[0] 110 | img = torch.randn(shape, device=device, dtype=cond.dtype) 111 | timesteps = ( 112 | self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 113 | ) 114 | 115 | time_range = ( 116 | reversed(range(0, timesteps)) 117 | if ddim_use_original_steps 118 | else np.flip(timesteps) 119 | ) 120 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 121 | logger.info(f"Running DDIM Sampling with {total_steps} timesteps") 122 | 123 | iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) 124 | 125 | for i, step in enumerate(iterator): 126 | index = total_steps - i - 1 127 | ts = torch.full((b,), step, device=device, dtype=torch.long) 128 | 129 | outs = self.p_sample_ddim( 130 | img, 131 | cond, 132 | ts, 133 | index=index, 134 | use_original_steps=ddim_use_original_steps, 135 | quantize_denoised=quantize_denoised, 136 | temperature=temperature, 137 | noise_dropout=noise_dropout, 138 | ) 139 | img, _ = outs 140 | 141 | return img 142 | 143 | @torch.no_grad() 144 | def p_sample_ddim( 145 | self, 146 | x, 147 | c, 148 | t, 149 | index, 150 | repeat_noise=False, 151 | use_original_steps=False, 152 | quantize_denoised=False, 153 | temperature=1.0, 154 | noise_dropout=0.0, 155 | ): 156 | b, *_, device = *x.shape, x.device 157 | e_t = self.model.apply_model(x, t, c) 158 | 159 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 160 | alphas_prev = ( 161 | self.model.alphas_cumprod_prev 162 | if use_original_steps 163 | else self.ddim_alphas_prev 164 | ) 165 | sqrt_one_minus_alphas = ( 166 | self.model.sqrt_one_minus_alphas_cumprod 167 | if use_original_steps 168 | else self.ddim_sqrt_one_minus_alphas 169 | ) 170 | sigmas = ( 171 | self.model.ddim_sigmas_for_original_num_steps 172 | if use_original_steps 173 | else self.ddim_sigmas 174 | ) 175 | # select parameters corresponding to the currently considered timestep 176 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 177 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 178 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 179 | sqrt_one_minus_at = torch.full( 180 | (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device 181 | ) 182 | 183 | # current prediction for x_0 184 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 185 | if quantize_denoised: # 没用 186 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 187 | # direction pointing to x_t 188 | dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t 189 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 190 | if noise_dropout > 0.0: # 没用 191 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 192 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 193 | return x_prev, pred_x0 194 | -------------------------------------------------------------------------------- /lama_cleaner/model/sd.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import PIL.Image 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from diffusers import PNDMScheduler, DDIMScheduler 8 | from loguru import logger 9 | from transformers import FeatureExtractionMixin, ImageFeatureExtractionMixin 10 | 11 | from lama_cleaner.helper import norm_img 12 | 13 | from lama_cleaner.model.base import InpaintModel 14 | from lama_cleaner.schema import Config, SDSampler 15 | 16 | 17 | # 18 | # 19 | # def preprocess_image(image): 20 | # w, h = image.size 21 | # w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 22 | # image = image.resize((w, h), resample=PIL.Image.LANCZOS) 23 | # image = np.array(image).astype(np.float32) / 255.0 24 | # image = image[None].transpose(0, 3, 1, 2) 25 | # image = torch.from_numpy(image) 26 | # # [-1, 1] 27 | # return 2.0 * image - 1.0 28 | # 29 | # 30 | # def preprocess_mask(mask): 31 | # mask = mask.convert("L") 32 | # w, h = mask.size 33 | # w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 34 | # mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) 35 | # mask = np.array(mask).astype(np.float32) / 255.0 36 | # mask = np.tile(mask, (4, 1, 1)) 37 | # mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? 38 | # mask = 1 - mask # repaint white, keep black 39 | # mask = torch.from_numpy(mask) 40 | # return mask 41 | 42 | class DummyFeatureExtractorOutput: 43 | def __init__(self, pixel_values): 44 | self.pixel_values = pixel_values 45 | 46 | def to(self, device): 47 | return self 48 | 49 | 50 | class DummyFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): 51 | def __init__(self, **kwargs): 52 | super().__init__(**kwargs) 53 | 54 | def __call__(self, *args, **kwargs): 55 | return DummyFeatureExtractorOutput(torch.empty(0, 3)) 56 | 57 | 58 | class DummySafetyChecker: 59 | def __init__(self, *args, **kwargs): 60 | pass 61 | 62 | def __call__(self, clip_input, images): 63 | return images, False 64 | 65 | 66 | class SD(InpaintModel): 67 | pad_mod = 64 # current diffusers only support 64 https://github.com/huggingface/diffusers/pull/505 68 | min_size = 512 69 | 70 | def init_model(self, device: torch.device, **kwargs): 71 | from .sd_pipeline import StableDiffusionInpaintPipeline 72 | 73 | model_kwargs = {"local_files_only": kwargs['sd_run_local']} 74 | if kwargs['sd_disable_nsfw']: 75 | logger.info("Disable Stable Diffusion Model NSFW checker") 76 | model_kwargs.update(dict( 77 | feature_extractor=DummyFeatureExtractor(), 78 | safety_checker=DummySafetyChecker(), 79 | )) 80 | 81 | self.model = StableDiffusionInpaintPipeline.from_pretrained( 82 | self.model_id_or_path, 83 | revision="fp16" if torch.cuda.is_available() else "main", 84 | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, 85 | use_auth_token=kwargs["hf_access_token"], 86 | **model_kwargs 87 | ) 88 | # https://huggingface.co/docs/diffusers/v0.3.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing 89 | self.model.enable_attention_slicing() 90 | self.model = self.model.to(device) 91 | 92 | if kwargs['sd_cpu_textencoder']: 93 | logger.info("Run Stable Diffusion TextEncoder on CPU") 94 | self.model.text_encoder = self.model.text_encoder.to(torch.device('cpu'), non_blocking=True) 95 | self.model.text_encoder = self.model.text_encoder.to(torch.float32, non_blocking=True ) 96 | 97 | self.callbacks = kwargs.pop("callbacks", None) 98 | 99 | @torch.cuda.amp.autocast() 100 | def forward(self, image, mask, config: Config): 101 | """Input image and output image have same size 102 | image: [H, W, C] RGB 103 | mask: [H, W, 1] 255 means area to repaint 104 | return: BGR IMAGE 105 | """ 106 | 107 | # image = norm_img(image) # [0, 1] 108 | # image = image * 2 - 1 # [0, 1] -> [-1, 1] 109 | 110 | # resize to latent feature map size 111 | # h, w = mask.shape[:2] 112 | # mask = cv2.resize(mask, (h // 8, w // 8), interpolation=cv2.INTER_AREA) 113 | # mask = norm_img(mask) 114 | # 115 | # image = torch.from_numpy(image).unsqueeze(0).to(self.device) 116 | # mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) 117 | 118 | if config.sd_sampler == SDSampler.ddim: 119 | scheduler = DDIMScheduler( 120 | beta_start=0.00085, 121 | beta_end=0.012, 122 | beta_schedule="scaled_linear", 123 | clip_sample=False, 124 | set_alpha_to_one=False, 125 | ) 126 | elif config.sd_sampler == SDSampler.pndm: 127 | PNDM_kwargs = { 128 | "tensor_format": "pt", 129 | "beta_schedule": "scaled_linear", 130 | "beta_start": 0.00085, 131 | "beta_end": 0.012, 132 | "num_train_timesteps": 1000, 133 | "skip_prk_steps": True, 134 | } 135 | scheduler = PNDMScheduler(**PNDM_kwargs) 136 | else: 137 | raise ValueError(config.sd_sampler) 138 | 139 | self.model.scheduler = scheduler 140 | 141 | seed = config.sd_seed 142 | random.seed(seed) 143 | np.random.seed(seed) 144 | torch.manual_seed(seed) 145 | torch.cuda.manual_seed_all(seed) 146 | 147 | if config.sd_mask_blur != 0: 148 | k = 2 * config.sd_mask_blur + 1 149 | mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] 150 | 151 | output = self.model( 152 | prompt=config.prompt, 153 | init_image=PIL.Image.fromarray(image), 154 | mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"), 155 | strength=config.sd_strength, 156 | num_inference_steps=config.sd_steps, 157 | guidance_scale=config.sd_guidance_scale, 158 | output_type="np.array", 159 | callbacks=self.callbacks, 160 | ).images[0] 161 | 162 | output = (output * 255).round().astype("uint8") 163 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 164 | return output 165 | 166 | @torch.no_grad() 167 | def __call__(self, image, mask, config: Config): 168 | """ 169 | images: [H, W, C] RGB, not normalized 170 | masks: [H, W] 171 | return: BGR IMAGE 172 | """ 173 | img_h, img_w = image.shape[:2] 174 | 175 | # boxes = boxes_from_mask(mask) 176 | if config.use_croper: 177 | logger.info("use croper") 178 | l, t, w, h = ( 179 | config.croper_x, 180 | config.croper_y, 181 | config.croper_width, 182 | config.croper_height, 183 | ) 184 | r = l + w 185 | b = t + h 186 | 187 | l = max(l, 0) 188 | r = min(r, img_w) 189 | t = max(t, 0) 190 | b = min(b, img_h) 191 | 192 | crop_img = image[t:b, l:r, :] 193 | crop_mask = mask[t:b, l:r] 194 | 195 | crop_image = self._pad_forward(crop_img, crop_mask, config) 196 | 197 | inpaint_result = image[:, :, ::-1] 198 | inpaint_result[t:b, l:r, :] = crop_image 199 | else: 200 | inpaint_result = self._pad_forward(image, mask, config) 201 | 202 | return inpaint_result 203 | 204 | @staticmethod 205 | def is_downloaded() -> bool: 206 | # model will be downloaded when app start, and can't switch in frontend settings 207 | return True 208 | 209 | 210 | class SD14(SD): 211 | model_id_or_path = "CompVis/stable-diffusion-v1-4" 212 | 213 | 214 | class SD15(SD): 215 | model_id_or_path = "CompVis/stable-diffusion-v1-5" 216 | -------------------------------------------------------------------------------- /lama_cleaner/model/ldm.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from loguru import logger 6 | 7 | from lama_cleaner.model.base import InpaintModel 8 | from lama_cleaner.model.ddim_sampler import DDIMSampler 9 | from lama_cleaner.model.plms_sampler import PLMSSampler 10 | from lama_cleaner.schema import Config, LDMSampler 11 | 12 | torch.manual_seed(42) 13 | import torch.nn as nn 14 | from lama_cleaner.helper import ( 15 | download_model, 16 | norm_img, 17 | get_cache_path_by_url, 18 | load_jit_model, 19 | ) 20 | from lama_cleaner.model.utils import ( 21 | make_beta_schedule, 22 | timestep_embedding, 23 | ) 24 | 25 | LDM_ENCODE_MODEL_URL = os.environ.get( 26 | "LDM_ENCODE_MODEL_URL", 27 | "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_encode.pt", 28 | ) 29 | 30 | LDM_DECODE_MODEL_URL = os.environ.get( 31 | "LDM_DECODE_MODEL_URL", 32 | "https://github.com/Sanster/models/releases/download/add_ldm/cond_stage_model_decode.pt", 33 | ) 34 | 35 | LDM_DIFFUSION_MODEL_URL = os.environ.get( 36 | "LDM_DIFFUSION_MODEL_URL", 37 | "https://github.com/Sanster/models/releases/download/add_ldm/diffusion.pt", 38 | ) 39 | 40 | 41 | class DDPM(nn.Module): 42 | # classic DDPM with Gaussian diffusion, in image space 43 | def __init__( 44 | self, 45 | device, 46 | timesteps=1000, 47 | beta_schedule="linear", 48 | linear_start=0.0015, 49 | linear_end=0.0205, 50 | cosine_s=0.008, 51 | original_elbo_weight=0.0, 52 | v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta 53 | l_simple_weight=1.0, 54 | parameterization="eps", # all assuming fixed variance schedules 55 | use_positional_encodings=False, 56 | ): 57 | super().__init__() 58 | self.device = device 59 | self.parameterization = parameterization 60 | self.use_positional_encodings = use_positional_encodings 61 | 62 | self.v_posterior = v_posterior 63 | self.original_elbo_weight = original_elbo_weight 64 | self.l_simple_weight = l_simple_weight 65 | 66 | self.register_schedule( 67 | beta_schedule=beta_schedule, 68 | timesteps=timesteps, 69 | linear_start=linear_start, 70 | linear_end=linear_end, 71 | cosine_s=cosine_s, 72 | ) 73 | 74 | def register_schedule( 75 | self, 76 | given_betas=None, 77 | beta_schedule="linear", 78 | timesteps=1000, 79 | linear_start=1e-4, 80 | linear_end=2e-2, 81 | cosine_s=8e-3, 82 | ): 83 | betas = make_beta_schedule( 84 | self.device, 85 | beta_schedule, 86 | timesteps, 87 | linear_start=linear_start, 88 | linear_end=linear_end, 89 | cosine_s=cosine_s, 90 | ) 91 | alphas = 1.0 - betas 92 | alphas_cumprod = np.cumprod(alphas, axis=0) 93 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) 94 | 95 | (timesteps,) = betas.shape 96 | self.num_timesteps = int(timesteps) 97 | self.linear_start = linear_start 98 | self.linear_end = linear_end 99 | assert ( 100 | alphas_cumprod.shape[0] == self.num_timesteps 101 | ), "alphas have to be defined for each timestep" 102 | 103 | to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device) 104 | 105 | self.register_buffer("betas", to_torch(betas)) 106 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) 107 | self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) 108 | 109 | # calculations for diffusion q(x_t | x_{t-1}) and others 110 | self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) 111 | self.register_buffer( 112 | "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) 113 | ) 114 | self.register_buffer( 115 | "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) 116 | ) 117 | self.register_buffer( 118 | "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) 119 | ) 120 | self.register_buffer( 121 | "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) 122 | ) 123 | 124 | # calculations for posterior q(x_{t-1} | x_t, x_0) 125 | posterior_variance = (1 - self.v_posterior) * betas * ( 126 | 1.0 - alphas_cumprod_prev 127 | ) / (1.0 - alphas_cumprod) + self.v_posterior * betas 128 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 129 | self.register_buffer("posterior_variance", to_torch(posterior_variance)) 130 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 131 | self.register_buffer( 132 | "posterior_log_variance_clipped", 133 | to_torch(np.log(np.maximum(posterior_variance, 1e-20))), 134 | ) 135 | self.register_buffer( 136 | "posterior_mean_coef1", 137 | to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), 138 | ) 139 | self.register_buffer( 140 | "posterior_mean_coef2", 141 | to_torch( 142 | (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) 143 | ), 144 | ) 145 | 146 | if self.parameterization == "eps": 147 | lvlb_weights = self.betas**2 / ( 148 | 2 149 | * self.posterior_variance 150 | * to_torch(alphas) 151 | * (1 - self.alphas_cumprod) 152 | ) 153 | elif self.parameterization == "x0": 154 | lvlb_weights = ( 155 | 0.5 156 | * np.sqrt(torch.Tensor(alphas_cumprod)) 157 | / (2.0 * 1 - torch.Tensor(alphas_cumprod)) 158 | ) 159 | else: 160 | raise NotImplementedError("mu not supported") 161 | # TODO how to choose this term 162 | lvlb_weights[0] = lvlb_weights[1] 163 | self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) 164 | assert not torch.isnan(self.lvlb_weights).all() 165 | 166 | 167 | class LatentDiffusion(DDPM): 168 | def __init__( 169 | self, 170 | diffusion_model, 171 | device, 172 | cond_stage_key="image", 173 | cond_stage_trainable=False, 174 | concat_mode=True, 175 | scale_factor=1.0, 176 | scale_by_std=False, 177 | *args, 178 | **kwargs, 179 | ): 180 | self.num_timesteps_cond = 1 181 | self.scale_by_std = scale_by_std 182 | super().__init__(device, *args, **kwargs) 183 | self.diffusion_model = diffusion_model 184 | self.concat_mode = concat_mode 185 | self.cond_stage_trainable = cond_stage_trainable 186 | self.cond_stage_key = cond_stage_key 187 | self.num_downs = 2 188 | self.scale_factor = scale_factor 189 | 190 | def make_cond_schedule( 191 | self, 192 | ): 193 | self.cond_ids = torch.full( 194 | size=(self.num_timesteps,), 195 | fill_value=self.num_timesteps - 1, 196 | dtype=torch.long, 197 | ) 198 | ids = torch.round( 199 | torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) 200 | ).long() 201 | self.cond_ids[: self.num_timesteps_cond] = ids 202 | 203 | def register_schedule( 204 | self, 205 | given_betas=None, 206 | beta_schedule="linear", 207 | timesteps=1000, 208 | linear_start=1e-4, 209 | linear_end=2e-2, 210 | cosine_s=8e-3, 211 | ): 212 | super().register_schedule( 213 | given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s 214 | ) 215 | 216 | self.shorten_cond_schedule = self.num_timesteps_cond > 1 217 | if self.shorten_cond_schedule: 218 | self.make_cond_schedule() 219 | 220 | def apply_model(self, x_noisy, t, cond): 221 | # x_recon = self.model(x_noisy, t, cond['c_concat'][0]) # cond['c_concat'][0].shape 1,4,128,128 222 | t_emb = timestep_embedding(x_noisy.device, t, 256, repeat_only=False) 223 | x_recon = self.diffusion_model(x_noisy, t_emb, cond) 224 | return x_recon 225 | 226 | 227 | class LDM(InpaintModel): 228 | pad_mod = 32 229 | 230 | def __init__(self, device, fp16: bool = True, **kwargs): 231 | self.fp16 = fp16 232 | super().__init__(device) 233 | self.device = device 234 | 235 | def init_model(self, device, **kwargs): 236 | self.diffusion_model = load_jit_model(LDM_DIFFUSION_MODEL_URL, device) 237 | self.cond_stage_model_decode = load_jit_model(LDM_DECODE_MODEL_URL, device) 238 | self.cond_stage_model_encode = load_jit_model(LDM_ENCODE_MODEL_URL, device) 239 | if self.fp16 and "cuda" in str(device): 240 | self.diffusion_model = self.diffusion_model.half() 241 | self.cond_stage_model_decode = self.cond_stage_model_decode.half() 242 | self.cond_stage_model_encode = self.cond_stage_model_encode.half() 243 | 244 | self.model = LatentDiffusion(self.diffusion_model, device) 245 | 246 | @staticmethod 247 | def is_downloaded() -> bool: 248 | model_paths = [ 249 | get_cache_path_by_url(LDM_DIFFUSION_MODEL_URL), 250 | get_cache_path_by_url(LDM_DECODE_MODEL_URL), 251 | get_cache_path_by_url(LDM_ENCODE_MODEL_URL), 252 | ] 253 | return all([os.path.exists(it) for it in model_paths]) 254 | 255 | @torch.cuda.amp.autocast() 256 | def forward(self, image, mask, config: Config): 257 | """ 258 | image: [H, W, C] RGB 259 | mask: [H, W, 1] 260 | return: BGR IMAGE 261 | """ 262 | # image [1,3,512,512] float32 263 | # mask: [1,1,512,512] float32 264 | # masked_image: [1,3,512,512] float32 265 | if config.ldm_sampler == LDMSampler.ddim: 266 | sampler = DDIMSampler(self.model) 267 | elif config.ldm_sampler == LDMSampler.plms: 268 | sampler = PLMSSampler(self.model) 269 | else: 270 | raise ValueError() 271 | 272 | steps = config.ldm_steps 273 | image = norm_img(image) 274 | mask = norm_img(mask) 275 | 276 | mask[mask < 0.5] = 0 277 | mask[mask >= 0.5] = 1 278 | 279 | image = torch.from_numpy(image).unsqueeze(0).to(self.device) 280 | mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) 281 | masked_image = (1 - mask) * image 282 | 283 | mask = self._norm(mask) 284 | masked_image = self._norm(masked_image) 285 | 286 | c = self.cond_stage_model_encode(masked_image) 287 | torch.cuda.empty_cache() 288 | 289 | cc = torch.nn.functional.interpolate(mask, size=c.shape[-2:]) # 1,1,128,128 290 | c = torch.cat((c, cc), dim=1) # 1,4,128,128 291 | 292 | shape = (c.shape[1] - 1,) + c.shape[2:] 293 | samples_ddim = sampler.sample( 294 | steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape 295 | ) 296 | torch.cuda.empty_cache() 297 | x_samples_ddim = self.cond_stage_model_decode( 298 | samples_ddim 299 | ) # samples_ddim: 1, 3, 128, 128 float32 300 | torch.cuda.empty_cache() 301 | 302 | # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) 303 | # mask = torch.clamp((mask + 1.0) / 2.0, min=0.0, max=1.0) 304 | inpainted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 305 | 306 | # inpainted = (1 - mask) * image + mask * predicted_image 307 | inpainted_image = inpainted_image.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 308 | inpainted_image = inpainted_image.astype(np.uint8)[:, :, ::-1] 309 | return inpainted_image 310 | 311 | def _norm(self, tensor): 312 | return tensor * 2.0 - 1.0 313 | -------------------------------------------------------------------------------- /lama_cleaner/model/plms_sampler.py: -------------------------------------------------------------------------------- 1 | # From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py 2 | import torch 3 | import numpy as np 4 | from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like 5 | from tqdm import tqdm 6 | 7 | 8 | class PLMSSampler(object): 9 | def __init__(self, model, schedule="linear", **kwargs): 10 | super().__init__() 11 | self.model = model 12 | self.ddpm_num_timesteps = model.num_timesteps 13 | self.schedule = schedule 14 | 15 | def register_buffer(self, name, attr): 16 | setattr(self, name, attr) 17 | 18 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 19 | if ddim_eta != 0: 20 | raise ValueError('ddim_eta must be 0 for PLMS') 21 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 22 | num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) 23 | alphas_cumprod = self.model.alphas_cumprod 24 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 25 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 26 | 27 | self.register_buffer('betas', to_torch(self.model.betas)) 28 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 29 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 30 | 31 | # calculations for diffusion q(x_t | x_{t-1}) and others 32 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 33 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 34 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 35 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 36 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 37 | 38 | # ddim sampling parameters 39 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 40 | ddim_timesteps=self.ddim_timesteps, 41 | eta=ddim_eta, verbose=verbose) 42 | self.register_buffer('ddim_sigmas', ddim_sigmas) 43 | self.register_buffer('ddim_alphas', ddim_alphas) 44 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 45 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 46 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 47 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 48 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 49 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 50 | 51 | @torch.no_grad() 52 | def sample(self, 53 | steps, 54 | batch_size, 55 | shape, 56 | conditioning=None, 57 | callback=None, 58 | normals_sequence=None, 59 | img_callback=None, 60 | quantize_x0=False, 61 | eta=0., 62 | mask=None, 63 | x0=None, 64 | temperature=1., 65 | noise_dropout=0., 66 | score_corrector=None, 67 | corrector_kwargs=None, 68 | verbose=False, 69 | x_T=None, 70 | log_every_t=100, 71 | unconditional_guidance_scale=1., 72 | unconditional_conditioning=None, 73 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 74 | **kwargs 75 | ): 76 | if conditioning is not None: 77 | if isinstance(conditioning, dict): 78 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 79 | if cbs != batch_size: 80 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 81 | else: 82 | if conditioning.shape[0] != batch_size: 83 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 84 | 85 | self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose) 86 | # sampling 87 | C, H, W = shape 88 | size = (batch_size, C, H, W) 89 | print(f'Data shape for PLMS sampling is {size}') 90 | 91 | samples = self.plms_sampling(conditioning, size, 92 | callback=callback, 93 | img_callback=img_callback, 94 | quantize_denoised=quantize_x0, 95 | mask=mask, x0=x0, 96 | ddim_use_original_steps=False, 97 | noise_dropout=noise_dropout, 98 | temperature=temperature, 99 | score_corrector=score_corrector, 100 | corrector_kwargs=corrector_kwargs, 101 | x_T=x_T, 102 | log_every_t=log_every_t, 103 | unconditional_guidance_scale=unconditional_guidance_scale, 104 | unconditional_conditioning=unconditional_conditioning, 105 | ) 106 | return samples 107 | 108 | @torch.no_grad() 109 | def plms_sampling(self, cond, shape, 110 | x_T=None, ddim_use_original_steps=False, 111 | callback=None, timesteps=None, quantize_denoised=False, 112 | mask=None, x0=None, img_callback=None, log_every_t=100, 113 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 114 | unconditional_guidance_scale=1., unconditional_conditioning=None, ): 115 | device = self.model.betas.device 116 | b = shape[0] 117 | if x_T is None: 118 | img = torch.randn(shape, device=device) 119 | else: 120 | img = x_T 121 | 122 | if timesteps is None: 123 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 124 | elif timesteps is not None and not ddim_use_original_steps: 125 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 126 | timesteps = self.ddim_timesteps[:subset_end] 127 | 128 | time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) 129 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 130 | print(f"Running PLMS Sampling with {total_steps} timesteps") 131 | 132 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 133 | old_eps = [] 134 | 135 | for i, step in enumerate(iterator): 136 | index = total_steps - i - 1 137 | ts = torch.full((b,), step, device=device, dtype=torch.long) 138 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 139 | 140 | if mask is not None: 141 | assert x0 is not None 142 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 143 | img = img_orig * mask + (1. - mask) * img 144 | 145 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 146 | quantize_denoised=quantize_denoised, temperature=temperature, 147 | noise_dropout=noise_dropout, score_corrector=score_corrector, 148 | corrector_kwargs=corrector_kwargs, 149 | unconditional_guidance_scale=unconditional_guidance_scale, 150 | unconditional_conditioning=unconditional_conditioning, 151 | old_eps=old_eps, t_next=ts_next) 152 | img, pred_x0, e_t = outs 153 | old_eps.append(e_t) 154 | if len(old_eps) >= 4: 155 | old_eps.pop(0) 156 | if callback: callback(i) 157 | if img_callback: img_callback(pred_x0, i) 158 | 159 | return img 160 | 161 | @torch.no_grad() 162 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 163 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 164 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): 165 | b, *_, device = *x.shape, x.device 166 | 167 | def get_model_output(x, t): 168 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 169 | e_t = self.model.apply_model(x, t, c) 170 | else: 171 | x_in = torch.cat([x] * 2) 172 | t_in = torch.cat([t] * 2) 173 | c_in = torch.cat([unconditional_conditioning, c]) 174 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 175 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 176 | 177 | if score_corrector is not None: 178 | assert self.model.parameterization == "eps" 179 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 180 | 181 | return e_t 182 | 183 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 184 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 185 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 186 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 187 | 188 | def get_x_prev_and_pred_x0(e_t, index): 189 | # select parameters corresponding to the currently considered timestep 190 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 191 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 192 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 193 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) 194 | 195 | # current prediction for x_0 196 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 197 | if quantize_denoised: 198 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 199 | # direction pointing to x_t 200 | dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t 201 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 202 | if noise_dropout > 0.: 203 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 204 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 205 | return x_prev, pred_x0 206 | 207 | e_t = get_model_output(x, t) 208 | if len(old_eps) == 0: 209 | # Pseudo Improved Euler (2nd order) 210 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 211 | e_t_next = get_model_output(x_prev, t_next) 212 | e_t_prime = (e_t + e_t_next) / 2 213 | elif len(old_eps) == 1: 214 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 215 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 216 | elif len(old_eps) == 2: 217 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 218 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 219 | elif len(old_eps) >= 3: 220 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 221 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 222 | 223 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 224 | 225 | return x_prev, pred_x0, e_t 226 | -------------------------------------------------------------------------------- /lama_cleaner/model/zits.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import cv2 5 | import skimage 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from lama_cleaner.helper import get_cache_path_by_url, load_jit_model 10 | from lama_cleaner.schema import Config 11 | import numpy as np 12 | 13 | from lama_cleaner.model.base import InpaintModel 14 | 15 | ZITS_INPAINT_MODEL_URL = os.environ.get( 16 | "ZITS_INPAINT_MODEL_URL", 17 | "https://github.com/Sanster/models/releases/download/add_zits/zits-inpaint-0717.pt", 18 | ) 19 | 20 | ZITS_EDGE_LINE_MODEL_URL = os.environ.get( 21 | "ZITS_EDGE_LINE_MODEL_URL", 22 | "https://github.com/Sanster/models/releases/download/add_zits/zits-edge-line-0717.pt", 23 | ) 24 | 25 | ZITS_STRUCTURE_UPSAMPLE_MODEL_URL = os.environ.get( 26 | "ZITS_STRUCTURE_UPSAMPLE_MODEL_URL", 27 | "https://github.com/Sanster/models/releases/download/add_zits/zits-structure-upsample-0717.pt", 28 | ) 29 | 30 | ZITS_WIRE_FRAME_MODEL_URL = os.environ.get( 31 | "ZITS_WIRE_FRAME_MODEL_URL", 32 | "https://github.com/Sanster/models/releases/download/add_zits/zits-wireframe-0717.pt", 33 | ) 34 | 35 | 36 | def resize(img, height, width, center_crop=False): 37 | imgh, imgw = img.shape[0:2] 38 | 39 | if center_crop and imgh != imgw: 40 | # center crop 41 | side = np.minimum(imgh, imgw) 42 | j = (imgh - side) // 2 43 | i = (imgw - side) // 2 44 | img = img[j : j + side, i : i + side, ...] 45 | 46 | if imgh > height and imgw > width: 47 | inter = cv2.INTER_AREA 48 | else: 49 | inter = cv2.INTER_LINEAR 50 | img = cv2.resize(img, (height, width), interpolation=inter) 51 | 52 | return img 53 | 54 | 55 | def to_tensor(img, scale=True, norm=False): 56 | if img.ndim == 2: 57 | img = img[:, :, np.newaxis] 58 | c = img.shape[-1] 59 | 60 | if scale: 61 | img_t = torch.from_numpy(img).permute(2, 0, 1).float().div(255) 62 | else: 63 | img_t = torch.from_numpy(img).permute(2, 0, 1).float() 64 | 65 | if norm: 66 | mean = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1) 67 | std = torch.tensor([0.5, 0.5, 0.5]).reshape(c, 1, 1) 68 | img_t = (img_t - mean) / std 69 | return img_t 70 | 71 | 72 | def load_masked_position_encoding(mask): 73 | ones_filter = np.ones((3, 3), dtype=np.float32) 74 | d_filter1 = np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32) 75 | d_filter2 = np.array([[0, 0, 0], [1, 1, 0], [1, 1, 0]], dtype=np.float32) 76 | d_filter3 = np.array([[0, 1, 1], [0, 1, 1], [0, 0, 0]], dtype=np.float32) 77 | d_filter4 = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=np.float32) 78 | str_size = 256 79 | pos_num = 128 80 | 81 | ori_mask = mask.copy() 82 | ori_h, ori_w = ori_mask.shape[0:2] 83 | ori_mask = ori_mask / 255 84 | mask = cv2.resize(mask, (str_size, str_size), interpolation=cv2.INTER_AREA) 85 | mask[mask > 0] = 255 86 | h, w = mask.shape[0:2] 87 | mask3 = mask.copy() 88 | mask3 = 1.0 - (mask3 / 255.0) 89 | pos = np.zeros((h, w), dtype=np.int32) 90 | direct = np.zeros((h, w, 4), dtype=np.int32) 91 | i = 0 92 | while np.sum(1 - mask3) > 0: 93 | i += 1 94 | mask3_ = cv2.filter2D(mask3, -1, ones_filter) 95 | mask3_[mask3_ > 0] = 1 96 | sub_mask = mask3_ - mask3 97 | pos[sub_mask == 1] = i 98 | 99 | m = cv2.filter2D(mask3, -1, d_filter1) 100 | m[m > 0] = 1 101 | m = m - mask3 102 | direct[m == 1, 0] = 1 103 | 104 | m = cv2.filter2D(mask3, -1, d_filter2) 105 | m[m > 0] = 1 106 | m = m - mask3 107 | direct[m == 1, 1] = 1 108 | 109 | m = cv2.filter2D(mask3, -1, d_filter3) 110 | m[m > 0] = 1 111 | m = m - mask3 112 | direct[m == 1, 2] = 1 113 | 114 | m = cv2.filter2D(mask3, -1, d_filter4) 115 | m[m > 0] = 1 116 | m = m - mask3 117 | direct[m == 1, 3] = 1 118 | 119 | mask3 = mask3_ 120 | 121 | abs_pos = pos.copy() 122 | rel_pos = pos / (str_size / 2) # to 0~1 maybe larger than 1 123 | rel_pos = (rel_pos * pos_num).astype(np.int32) 124 | rel_pos = np.clip(rel_pos, 0, pos_num - 1) 125 | 126 | if ori_w != w or ori_h != h: 127 | rel_pos = cv2.resize(rel_pos, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST) 128 | rel_pos[ori_mask == 0] = 0 129 | direct = cv2.resize(direct, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST) 130 | direct[ori_mask == 0, :] = 0 131 | 132 | return rel_pos, abs_pos, direct 133 | 134 | 135 | def load_image(img, mask, device, sigma256=3.0): 136 | """ 137 | Args: 138 | img: [H, W, C] RGB 139 | mask: [H, W] 255 为 masks 区域 140 | sigma256: 141 | 142 | Returns: 143 | 144 | """ 145 | h, w, _ = img.shape 146 | imgh, imgw = img.shape[0:2] 147 | img_256 = resize(img, 256, 256) 148 | 149 | mask = (mask > 127).astype(np.uint8) * 255 150 | mask_256 = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_AREA) 151 | mask_256[mask_256 > 0] = 255 152 | 153 | mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_AREA) 154 | mask_512[mask_512 > 0] = 255 155 | 156 | # original skimage implemention 157 | # https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny 158 | # low_threshold: Lower bound for hysteresis thresholding (linking edges). If None, low_threshold is set to 10% of dtype’s max. 159 | # high_threshold: Upper bound for hysteresis thresholding (linking edges). If None, high_threshold is set to 20% of dtype’s max. 160 | gray_256 = skimage.color.rgb2gray(img_256) 161 | edge_256 = skimage.feature.canny(gray_256, sigma=sigma256, mask=None).astype(float) 162 | # cv2.imwrite("skimage_gray.jpg", (_gray_256*255).astype(np.uint8)) 163 | # cv2.imwrite("skimage_edge.jpg", (_edge_256*255).astype(np.uint8)) 164 | 165 | # gray_256 = cv2.cvtColor(img_256, cv2.COLOR_RGB2GRAY) 166 | # gray_256_blured = cv2.GaussianBlur(gray_256, ksize=(3,3), sigmaX=sigma256, sigmaY=sigma256) 167 | # edge_256 = cv2.Canny(gray_256_blured, threshold1=int(255*0.1), threshold2=int(255*0.2)) 168 | # cv2.imwrite("edge.jpg", edge_256) 169 | 170 | # line 171 | img_512 = resize(img, 512, 512) 172 | 173 | rel_pos, abs_pos, direct = load_masked_position_encoding(mask) 174 | 175 | batch = dict() 176 | batch["images"] = to_tensor(img.copy()).unsqueeze(0).to(device) 177 | batch["img_256"] = to_tensor(img_256, norm=True).unsqueeze(0).to(device) 178 | batch["masks"] = to_tensor(mask).unsqueeze(0).to(device) 179 | batch["mask_256"] = to_tensor(mask_256).unsqueeze(0).to(device) 180 | batch["mask_512"] = to_tensor(mask_512).unsqueeze(0).to(device) 181 | batch["edge_256"] = to_tensor(edge_256, scale=False).unsqueeze(0).to(device) 182 | batch["img_512"] = to_tensor(img_512).unsqueeze(0).to(device) 183 | batch["rel_pos"] = torch.LongTensor(rel_pos).unsqueeze(0).to(device) 184 | batch["abs_pos"] = torch.LongTensor(abs_pos).unsqueeze(0).to(device) 185 | batch["direct"] = torch.LongTensor(direct).unsqueeze(0).to(device) 186 | batch["h"] = imgh 187 | batch["w"] = imgw 188 | 189 | return batch 190 | 191 | 192 | def to_device(data, device): 193 | if isinstance(data, torch.Tensor): 194 | return data.to(device) 195 | if isinstance(data, dict): 196 | for key in data: 197 | if isinstance(data[key], torch.Tensor): 198 | data[key] = data[key].to(device) 199 | return data 200 | if isinstance(data, list): 201 | return [to_device(d, device) for d in data] 202 | 203 | 204 | class ZITS(InpaintModel): 205 | min_size = 256 206 | pad_mod = 32 207 | pad_to_square = True 208 | 209 | def __init__(self, device, **kwargs): 210 | """ 211 | 212 | Args: 213 | device: 214 | """ 215 | super().__init__(device) 216 | self.device = device 217 | self.sample_edge_line_iterations = 1 218 | 219 | def init_model(self, device, **kwargs): 220 | self.wireframe = load_jit_model(ZITS_WIRE_FRAME_MODEL_URL, device) 221 | self.edge_line = load_jit_model(ZITS_EDGE_LINE_MODEL_URL, device) 222 | self.structure_upsample = load_jit_model( 223 | ZITS_STRUCTURE_UPSAMPLE_MODEL_URL, device 224 | ) 225 | self.inpaint = load_jit_model(ZITS_INPAINT_MODEL_URL, device) 226 | 227 | @staticmethod 228 | def is_downloaded() -> bool: 229 | model_paths = [ 230 | get_cache_path_by_url(ZITS_WIRE_FRAME_MODEL_URL), 231 | get_cache_path_by_url(ZITS_EDGE_LINE_MODEL_URL), 232 | get_cache_path_by_url(ZITS_STRUCTURE_UPSAMPLE_MODEL_URL), 233 | get_cache_path_by_url(ZITS_INPAINT_MODEL_URL), 234 | ] 235 | return all([os.path.exists(it) for it in model_paths]) 236 | 237 | def wireframe_edge_and_line(self, items, enable: bool): 238 | # 最终向 items 中添加 edge 和 line key 239 | if not enable: 240 | items["edge"] = torch.zeros_like(items["masks"]) 241 | items["line"] = torch.zeros_like(items["masks"]) 242 | return 243 | 244 | start = time.time() 245 | try: 246 | line_256 = self.wireframe_forward( 247 | items["img_512"], 248 | h=256, 249 | w=256, 250 | masks=items["mask_512"], 251 | mask_th=0.85, 252 | ) 253 | except: 254 | line_256 = torch.zeros_like(items["mask_256"]) 255 | 256 | print(f"wireframe_forward time: {(time.time() - start) * 1000:.2f}ms") 257 | 258 | # np_line = (line[0][0].numpy() * 255).astype(np.uint8) 259 | # cv2.imwrite("line.jpg", np_line) 260 | 261 | start = time.time() 262 | edge_pred, line_pred = self.sample_edge_line_logits( 263 | context=[items["img_256"], items["edge_256"], line_256], 264 | mask=items["mask_256"].clone(), 265 | iterations=self.sample_edge_line_iterations, 266 | add_v=0.05, 267 | mul_v=4, 268 | ) 269 | print(f"sample_edge_line_logits time: {(time.time() - start) * 1000:.2f}ms") 270 | 271 | # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8) 272 | # cv2.imwrite("edge_pred.jpg", np_edge_pred) 273 | # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8) 274 | # cv2.imwrite("line_pred.jpg", np_line_pred) 275 | # exit() 276 | 277 | input_size = min(items["h"], items["w"]) 278 | if input_size != 256 and input_size > 256: 279 | while edge_pred.shape[2] < input_size: 280 | edge_pred = self.structure_upsample(edge_pred) 281 | edge_pred = torch.sigmoid((edge_pred + 2) * 2) 282 | 283 | line_pred = self.structure_upsample(line_pred) 284 | line_pred = torch.sigmoid((line_pred + 2) * 2) 285 | 286 | edge_pred = F.interpolate( 287 | edge_pred, 288 | size=(input_size, input_size), 289 | mode="bilinear", 290 | align_corners=False, 291 | ) 292 | line_pred = F.interpolate( 293 | line_pred, 294 | size=(input_size, input_size), 295 | mode="bilinear", 296 | align_corners=False, 297 | ) 298 | 299 | # np_edge_pred = (edge_pred[0][0].numpy() * 255).astype(np.uint8) 300 | # cv2.imwrite("edge_pred_upsample.jpg", np_edge_pred) 301 | # np_line_pred = (line_pred[0][0].numpy() * 255).astype(np.uint8) 302 | # cv2.imwrite("line_pred_upsample.jpg", np_line_pred) 303 | # exit() 304 | 305 | items["edge"] = edge_pred.detach() 306 | items["line"] = line_pred.detach() 307 | 308 | @torch.no_grad() 309 | def forward(self, image, mask, config: Config): 310 | """Input images and output images have same size 311 | images: [H, W, C] RGB 312 | masks: [H, W] 313 | return: BGR IMAGE 314 | """ 315 | mask = mask[:, :, 0] 316 | items = load_image(image, mask, device=self.device) 317 | 318 | self.wireframe_edge_and_line(items, config.zits_wireframe) 319 | 320 | inpainted_image = self.inpaint( 321 | items["images"], 322 | items["masks"], 323 | items["edge"], 324 | items["line"], 325 | items["rel_pos"], 326 | items["direct"], 327 | ) 328 | 329 | inpainted_image = inpainted_image * 255.0 330 | inpainted_image = ( 331 | inpainted_image.cpu().permute(0, 2, 3, 1)[0].numpy().astype(np.uint8) 332 | ) 333 | inpainted_image = inpainted_image[:, :, ::-1] 334 | 335 | # cv2.imwrite("inpainted.jpg", inpainted_image) 336 | # exit() 337 | 338 | return inpainted_image 339 | 340 | def wireframe_forward(self, images, h, w, masks, mask_th=0.925): 341 | lcnn_mean = torch.tensor([109.730, 103.832, 98.681]).reshape(1, 3, 1, 1) 342 | lcnn_std = torch.tensor([22.275, 22.124, 23.229]).reshape(1, 3, 1, 1) 343 | images = images * 255.0 344 | # the masks value of lcnn is 127.5 345 | masked_images = images * (1 - masks) + torch.ones_like(images) * masks * 127.5 346 | masked_images = (masked_images - lcnn_mean) / lcnn_std 347 | 348 | def to_int(x): 349 | return tuple(map(int, x)) 350 | 351 | lines_tensor = [] 352 | lmap = np.zeros((h, w)) 353 | 354 | output_masked = self.wireframe(masked_images) 355 | 356 | output_masked = to_device(output_masked, "cpu") 357 | if output_masked["num_proposals"] == 0: 358 | lines_masked = [] 359 | scores_masked = [] 360 | else: 361 | lines_masked = output_masked["lines_pred"].numpy() 362 | lines_masked = [ 363 | [line[1] * h, line[0] * w, line[3] * h, line[2] * w] 364 | for line in lines_masked 365 | ] 366 | scores_masked = output_masked["lines_score"].numpy() 367 | 368 | for line, score in zip(lines_masked, scores_masked): 369 | if score > mask_th: 370 | rr, cc, value = skimage.draw.line_aa( 371 | *to_int(line[0:2]), *to_int(line[2:4]) 372 | ) 373 | lmap[rr, cc] = np.maximum(lmap[rr, cc], value) 374 | 375 | lmap = np.clip(lmap * 255, 0, 255).astype(np.uint8) 376 | lines_tensor.append(to_tensor(lmap).unsqueeze(0)) 377 | 378 | lines_tensor = torch.cat(lines_tensor, dim=0) 379 | return lines_tensor.detach().to(self.device) 380 | 381 | def sample_edge_line_logits( 382 | self, context, mask=None, iterations=1, add_v=0, mul_v=4 383 | ): 384 | [img, edge, line] = context 385 | 386 | img = img * (1 - mask) 387 | edge = edge * (1 - mask) 388 | line = line * (1 - mask) 389 | 390 | for i in range(iterations): 391 | edge_logits, line_logits = self.edge_line(img, edge, line, masks=mask) 392 | 393 | edge_pred = torch.sigmoid(edge_logits) 394 | line_pred = torch.sigmoid((line_logits + add_v) * mul_v) 395 | edge = edge + edge_pred * mask 396 | edge[edge >= 0.25] = 1 397 | edge[edge < 0.25] = 0 398 | line = line + line_pred * mask 399 | 400 | b, _, h, w = edge_pred.shape 401 | edge_pred = edge_pred.reshape(b, -1, 1) 402 | line_pred = line_pred.reshape(b, -1, 1) 403 | mask = mask.reshape(b, -1) 404 | 405 | edge_probs = torch.cat([1 - edge_pred, edge_pred], dim=-1) 406 | line_probs = torch.cat([1 - line_pred, line_pred], dim=-1) 407 | edge_probs[:, :, 1] += 0.5 408 | line_probs[:, :, 1] += 0.5 409 | edge_max_probs = edge_probs.max(dim=-1)[0] + (1 - mask) * (-100) 410 | line_max_probs = line_probs.max(dim=-1)[0] + (1 - mask) * (-100) 411 | 412 | indices = torch.sort( 413 | edge_max_probs + line_max_probs, dim=-1, descending=True 414 | )[1] 415 | 416 | for ii in range(b): 417 | keep = int((i + 1) / iterations * torch.sum(mask[ii, ...])) 418 | 419 | assert torch.sum(mask[ii][indices[ii, :keep]]) == keep, "Error!!!" 420 | mask[ii][indices[ii, :keep]] = 0 421 | 422 | mask = mask.reshape(b, 1, h, w) 423 | edge = edge * (1 - mask) 424 | line = line * (1 - mask) 425 | 426 | edge, line = edge.to(torch.float32), line.to(torch.float32) 427 | return edge, line 428 | -------------------------------------------------------------------------------- /lama_cleaner/model/sd_pipeline.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import List, Optional, Union, Callable 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import PIL 8 | from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, PNDMScheduler 9 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput 10 | from diffusers.utils import logging 11 | from tqdm.auto import tqdm 12 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 13 | 14 | logger = logging.get_logger(__name__) 15 | 16 | 17 | def preprocess_image(image): 18 | w, h = image.size 19 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 20 | image = image.resize((w, h), resample=PIL.Image.LANCZOS) 21 | image = np.array(image).astype(np.float32) / 255.0 22 | image = image[None].transpose(0, 3, 1, 2) 23 | image = torch.from_numpy(image) 24 | return 2.0 * image - 1.0 25 | 26 | 27 | def preprocess_mask(mask): 28 | mask = mask.convert("L") 29 | w, h = mask.size 30 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 31 | mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) 32 | mask = np.array(mask).astype(np.float32) / 255.0 33 | mask = np.tile(mask, (4, 1, 1)) 34 | mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? 35 | mask = 1 - mask # repaint white, keep black 36 | mask = torch.from_numpy(mask) 37 | return mask 38 | 39 | 40 | class StableDiffusionInpaintPipeline(DiffusionPipeline): 41 | r""" 42 | Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. 43 | 44 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 45 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 46 | 47 | Args: 48 | vae ([`AutoencoderKL`]): 49 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 50 | text_encoder ([`CLIPTextModel`]): 51 | Frozen text-encoder. Stable Diffusion uses the text portion of 52 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 53 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 54 | tokenizer (`CLIPTokenizer`): 55 | Tokenizer of class 56 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 57 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 58 | scheduler ([`SchedulerMixin`]): 59 | A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of 60 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 61 | safety_checker ([`StableDiffusionSafetyChecker`]): 62 | Classification module that estimates whether generated images could be considered offsensive or harmful. 63 | Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. 64 | feature_extractor ([`CLIPFeatureExtractor`]): 65 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 66 | """ 67 | 68 | def __init__( 69 | self, 70 | vae: AutoencoderKL, 71 | text_encoder: CLIPTextModel, 72 | tokenizer: CLIPTokenizer, 73 | unet: UNet2DConditionModel, 74 | scheduler: Union[DDIMScheduler, PNDMScheduler], 75 | safety_checker: StableDiffusionSafetyChecker, 76 | feature_extractor: CLIPFeatureExtractor, 77 | ): 78 | super().__init__() 79 | scheduler = scheduler.set_format("pt") 80 | logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") 81 | self.register_modules( 82 | vae=vae, 83 | text_encoder=text_encoder, 84 | tokenizer=tokenizer, 85 | unet=unet, 86 | scheduler=scheduler, 87 | safety_checker=safety_checker, 88 | feature_extractor=feature_extractor, 89 | ) 90 | 91 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 92 | r""" 93 | Enable sliced attention computation. 94 | 95 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 96 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 97 | 98 | Args: 99 | slice_size (`str` or `int`, *optional*, defaults to `"auto"`): 100 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 101 | a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, 102 | `attention_head_dim` must be a multiple of `slice_size`. 103 | """ 104 | if slice_size == "auto": 105 | # half the attention head size is usually a good trade-off between 106 | # speed and memory 107 | slice_size = self.unet.config.attention_head_dim // 2 108 | self.unet.set_attention_slice(slice_size) 109 | 110 | def disable_attention_slicing(self): 111 | r""" 112 | Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go 113 | back to computing attention in one step. 114 | """ 115 | # set slice_size = `None` to disable `set_attention_slice` 116 | self.enable_attention_slice(None) 117 | 118 | @torch.no_grad() 119 | def __call__( 120 | self, 121 | prompt: Union[str, List[str]], 122 | init_image: Union[torch.FloatTensor, PIL.Image.Image], 123 | mask_image: Union[torch.FloatTensor, PIL.Image.Image], 124 | strength: float = 0.8, 125 | num_inference_steps: Optional[int] = 50, 126 | guidance_scale: Optional[float] = 7.5, 127 | eta: Optional[float] = 0.0, 128 | generator: Optional[torch.Generator] = None, 129 | output_type: Optional[str] = "pil", 130 | return_dict: bool = True, 131 | callbacks: List[Callable[[int], None]] = None 132 | ): 133 | r""" 134 | Function invoked when calling the pipeline for generation. 135 | 136 | Args: 137 | prompt (`str` or `List[str]`): 138 | The prompt or prompts to guide the image generation. 139 | init_image (`torch.FloatTensor` or `PIL.Image.Image`): 140 | `Image`, or tensor representing an image batch, that will be used as the starting point for the 141 | process. This is the image whose masked region will be inpainted. 142 | mask_image (`torch.FloatTensor` or `PIL.Image.Image`): 143 | `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be 144 | replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be 145 | converted to a single channel (luminance) before use. 146 | strength (`float`, *optional*, defaults to 0.8): 147 | Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` 148 | is 1, the denoising process will be run on the masked area for the full number of iterations specified 149 | in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more 150 | noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. 151 | num_inference_steps (`int`, *optional*, defaults to 50): 152 | The reference number of denoising steps. More denoising steps usually lead to a higher quality image at 153 | the expense of slower inference. This parameter will be modulated by `strength`, as explained above. 154 | guidance_scale (`float`, *optional*, defaults to 7.5): 155 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 156 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 157 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 158 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 159 | usually at the expense of lower image quality. 160 | eta (`float`, *optional*, defaults to 0.0): 161 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 162 | [`schedulers.DDIMScheduler`], will be ignored for others. 163 | generator (`torch.Generator`, *optional*): 164 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 165 | deterministic. 166 | output_type (`str`, *optional*, defaults to `"pil"`): 167 | The output format of the generate image. Choose between 168 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. 169 | return_dict (`bool`, *optional*, defaults to `True`): 170 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 171 | plain tuple. 172 | 173 | Returns: 174 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 175 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 176 | When returning a tuple, the first element is a list with the generated images, and the second element is a 177 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 178 | (nsfw) content, according to the `safety_checker`. 179 | """ 180 | if isinstance(prompt, str): 181 | batch_size = 1 182 | elif isinstance(prompt, list): 183 | batch_size = len(prompt) 184 | else: 185 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 186 | 187 | if strength < 0 or strength > 1: 188 | raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") 189 | 190 | # set timesteps 191 | accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) 192 | extra_set_kwargs = {} 193 | offset = 0 194 | if accepts_offset: 195 | offset = 1 196 | extra_set_kwargs["offset"] = 1 197 | 198 | self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 199 | 200 | # preprocess image 201 | init_image = preprocess_image(init_image).to(self.device) 202 | 203 | # encode the init image into latents and scale the latents 204 | init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist 205 | init_latents = init_latent_dist.sample(generator=generator) 206 | 207 | init_latents = 0.18215 * init_latents 208 | 209 | # Expand init_latents for batch_size 210 | init_latents = torch.cat([init_latents] * batch_size) 211 | init_latents_orig = init_latents 212 | 213 | # preprocess mask 214 | mask = preprocess_mask(mask_image).to(self.device) 215 | mask = torch.cat([mask] * batch_size) 216 | 217 | # check sizes 218 | if not mask.shape == init_latents.shape: 219 | raise ValueError("The mask and init_image should be the same size!") 220 | 221 | # get the original timestep using init_timestep 222 | init_timestep = int(num_inference_steps * strength) + offset 223 | init_timestep = min(init_timestep, num_inference_steps) 224 | timesteps = self.scheduler.timesteps[-init_timestep] 225 | timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) 226 | 227 | # add noise to latents using the timesteps 228 | noise = torch.randn(init_latents.shape, generator=generator, device=self.device) 229 | init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) 230 | 231 | # get prompt text embeddings 232 | text_input = self.tokenizer( 233 | prompt, 234 | padding="max_length", 235 | max_length=self.tokenizer.model_max_length, 236 | truncation=True, 237 | return_tensors="pt", 238 | ) 239 | text_encoder_device = self.text_encoder.device 240 | 241 | text_embeddings = self.text_encoder(text_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True) 242 | 243 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 244 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 245 | # corresponds to doing no classifier free guidance. 246 | do_classifier_free_guidance = guidance_scale > 1.0 247 | # get unconditional embeddings for classifier free guidance 248 | if do_classifier_free_guidance: 249 | max_length = text_input.input_ids.shape[-1] 250 | uncond_input = self.tokenizer( 251 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 252 | ) 253 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(text_encoder_device, non_blocking=True))[0].to(self.device, non_blocking=True) 254 | 255 | # For classifier free guidance, we need to do two forward passes. 256 | # Here we concatenate the unconditional and text embeddings into a single batch 257 | # to avoid doing two forward passes 258 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 259 | 260 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 261 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 262 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 263 | # and should be between [0, 1] 264 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 265 | extra_step_kwargs = {} 266 | if accepts_eta: 267 | extra_step_kwargs["eta"] = eta 268 | 269 | latents = init_latents 270 | t_start = max(num_inference_steps - init_timestep + offset, 0) 271 | for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): 272 | # expand the latents if we are doing classifier free guidance 273 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 274 | # predict the noise residual 275 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 276 | 277 | # perform guidance 278 | if do_classifier_free_guidance: 279 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 280 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 281 | 282 | # compute the previous noisy sample x_t -> x_t-1 283 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 284 | 285 | # masking 286 | init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) 287 | latents = (init_latents_proper * mask) + (latents * (1 - mask)) 288 | 289 | if callbacks is not None: 290 | for callback in callbacks: 291 | callback(i) 292 | 293 | # scale and decode the image latents with vae 294 | latents = 1 / 0.18215 * latents 295 | image = self.vae.decode(latents).sample 296 | 297 | image = (image / 2 + 0.5).clamp(0, 1) 298 | image = image.cpu().permute(0, 2, 3, 1).numpy() 299 | 300 | # run safety checker 301 | safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) 302 | image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) 303 | 304 | if output_type == "pil": 305 | image = self.numpy_to_pil(image) 306 | 307 | if not return_dict: 308 | return (image, has_nsfw_concept) 309 | 310 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 311 | -------------------------------------------------------------------------------- /lama_cleaner/model/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any 3 | 4 | import torch 5 | import numpy as np 6 | import collections 7 | from itertools import repeat 8 | 9 | from torch import conv2d, conv_transpose2d 10 | 11 | 12 | def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 13 | if schedule == "linear": 14 | betas = ( 15 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 16 | ) 17 | 18 | elif schedule == "cosine": 19 | timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s).to(device) 20 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 21 | alphas = torch.cos(alphas).pow(2).to(device) 22 | alphas = alphas / alphas[0] 23 | betas = 1 - alphas[1:] / alphas[:-1] 24 | betas = np.clip(betas, a_min=0, a_max=0.999) 25 | 26 | elif schedule == "sqrt_linear": 27 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 28 | elif schedule == "sqrt": 29 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 30 | else: 31 | raise ValueError(f"schedule '{schedule}' unknown.") 32 | return betas.numpy() 33 | 34 | 35 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 36 | # select alphas for computing the variance schedule 37 | alphas = alphacums[ddim_timesteps] 38 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 39 | 40 | # according the the formula provided in https://arxiv.org/abs/2010.02502 41 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 42 | if verbose: 43 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 44 | print(f'For the chosen value of eta, which is {eta}, ' 45 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 46 | return sigmas, alphas, alphas_prev 47 | 48 | 49 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 50 | if ddim_discr_method == 'uniform': 51 | c = num_ddpm_timesteps // num_ddim_timesteps 52 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 53 | elif ddim_discr_method == 'quad': 54 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 55 | else: 56 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 57 | 58 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 59 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 60 | steps_out = ddim_timesteps + 1 61 | if verbose: 62 | print(f'Selected timesteps for ddim sampler: {steps_out}') 63 | return steps_out 64 | 65 | 66 | def noise_like(shape, device, repeat=False): 67 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 68 | noise = lambda: torch.randn(shape, device=device) 69 | return repeat_noise() if repeat else noise() 70 | 71 | 72 | def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=False): 73 | """ 74 | Create sinusoidal timestep embeddings. 75 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 76 | These may be fractional. 77 | :param dim: the dimension of the output. 78 | :param max_period: controls the minimum frequency of the embeddings. 79 | :return: an [N x dim] Tensor of positional embeddings. 80 | """ 81 | half = dim // 2 82 | freqs = torch.exp( 83 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 84 | ).to(device=device) 85 | 86 | args = timesteps[:, None].float() * freqs[None] 87 | 88 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 89 | if dim % 2: 90 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 91 | return embedding 92 | 93 | 94 | ###### MAT and FcF ####### 95 | 96 | 97 | def normalize_2nd_moment(x, dim=1, eps=1e-8): 98 | return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() 99 | 100 | 101 | class EasyDict(dict): 102 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 103 | 104 | def __getattr__(self, name: str) -> Any: 105 | try: 106 | return self[name] 107 | except KeyError: 108 | raise AttributeError(name) 109 | 110 | def __setattr__(self, name: str, value: Any) -> None: 111 | self[name] = value 112 | 113 | def __delattr__(self, name: str) -> None: 114 | del self[name] 115 | 116 | 117 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 118 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 119 | """ 120 | assert isinstance(x, torch.Tensor) 121 | assert clamp is None or clamp >= 0 122 | spec = activation_funcs[act] 123 | alpha = float(alpha if alpha is not None else spec.def_alpha) 124 | gain = float(gain if gain is not None else spec.def_gain) 125 | clamp = float(clamp if clamp is not None else -1) 126 | 127 | # Add bias. 128 | if b is not None: 129 | assert isinstance(b, torch.Tensor) and b.ndim == 1 130 | assert 0 <= dim < x.ndim 131 | assert b.shape[0] == x.shape[dim] 132 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 133 | 134 | # Evaluate activation function. 135 | alpha = float(alpha) 136 | x = spec.func(x, alpha=alpha) 137 | 138 | # Scale by gain. 139 | gain = float(gain) 140 | if gain != 1: 141 | x = x * gain 142 | 143 | # Clamp. 144 | if clamp >= 0: 145 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 146 | return x 147 | 148 | 149 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'): 150 | r"""Fused bias and activation function. 151 | 152 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 153 | and scales the result by `gain`. Each of the steps is optional. In most cases, 154 | the fused op is considerably more efficient than performing the same calculation 155 | using standard PyTorch ops. It supports first and second order gradients, 156 | but not third order gradients. 157 | 158 | Args: 159 | x: Input activation tensor. Can be of any shape. 160 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 161 | as `x`. The shape must be known, and it must match the dimension of `x` 162 | corresponding to `dim`. 163 | dim: The dimension in `x` corresponding to the elements of `b`. 164 | The value of `dim` is ignored if `b` is not specified. 165 | act: Name of the activation function to evaluate, or `"linear"` to disable. 166 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 167 | See `activation_funcs` for a full list. `None` is not allowed. 168 | alpha: Shape parameter for the activation function, or `None` to use the default. 169 | gain: Scaling factor for the output tensor, or `None` to use default. 170 | See `activation_funcs` for the default scaling of each activation function. 171 | If unsure, consider specifying 1. 172 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 173 | the clamping (default). 174 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 175 | 176 | Returns: 177 | Tensor of the same shape and datatype as `x`. 178 | """ 179 | assert isinstance(x, torch.Tensor) 180 | assert impl in ['ref', 'cuda'] 181 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 182 | 183 | 184 | def _get_filter_size(f): 185 | if f is None: 186 | return 1, 1 187 | 188 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 189 | fw = f.shape[-1] 190 | fh = f.shape[0] 191 | 192 | fw = int(fw) 193 | fh = int(fh) 194 | assert fw >= 1 and fh >= 1 195 | return fw, fh 196 | 197 | 198 | def _get_weight_shape(w): 199 | shape = [int(sz) for sz in w.shape] 200 | return shape 201 | 202 | 203 | def _parse_scaling(scaling): 204 | if isinstance(scaling, int): 205 | scaling = [scaling, scaling] 206 | assert isinstance(scaling, (list, tuple)) 207 | assert all(isinstance(x, int) for x in scaling) 208 | sx, sy = scaling 209 | assert sx >= 1 and sy >= 1 210 | return sx, sy 211 | 212 | 213 | def _parse_padding(padding): 214 | if isinstance(padding, int): 215 | padding = [padding, padding] 216 | assert isinstance(padding, (list, tuple)) 217 | assert all(isinstance(x, int) for x in padding) 218 | if len(padding) == 2: 219 | padx, pady = padding 220 | padding = [padx, padx, pady, pady] 221 | padx0, padx1, pady0, pady1 = padding 222 | return padx0, padx1, pady0, pady1 223 | 224 | 225 | def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): 226 | r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. 227 | 228 | Args: 229 | f: Torch tensor, numpy array, or python list of the shape 230 | `[filter_height, filter_width]` (non-separable), 231 | `[filter_taps]` (separable), 232 | `[]` (impulse), or 233 | `None` (identity). 234 | device: Result device (default: cpu). 235 | normalize: Normalize the filter so that it retains the magnitude 236 | for constant input signal (DC)? (default: True). 237 | flip_filter: Flip the filter? (default: False). 238 | gain: Overall scaling factor for signal magnitude (default: 1). 239 | separable: Return a separable filter? (default: select automatically). 240 | 241 | Returns: 242 | Float32 tensor of the shape 243 | `[filter_height, filter_width]` (non-separable) or 244 | `[filter_taps]` (separable). 245 | """ 246 | # Validate. 247 | if f is None: 248 | f = 1 249 | f = torch.as_tensor(f, dtype=torch.float32) 250 | assert f.ndim in [0, 1, 2] 251 | assert f.numel() > 0 252 | if f.ndim == 0: 253 | f = f[np.newaxis] 254 | 255 | # Separable? 256 | if separable is None: 257 | separable = (f.ndim == 1 and f.numel() >= 8) 258 | if f.ndim == 1 and not separable: 259 | f = f.ger(f) 260 | assert f.ndim == (1 if separable else 2) 261 | 262 | # Apply normalize, flip, gain, and device. 263 | if normalize: 264 | f /= f.sum() 265 | if flip_filter: 266 | f = f.flip(list(range(f.ndim))) 267 | f = f * (gain ** (f.ndim / 2)) 268 | f = f.to(device=device) 269 | return f 270 | 271 | 272 | def _ntuple(n): 273 | def parse(x): 274 | if isinstance(x, collections.abc.Iterable): 275 | return x 276 | return tuple(repeat(x, n)) 277 | 278 | return parse 279 | 280 | 281 | to_2tuple = _ntuple(2) 282 | 283 | activation_funcs = { 284 | 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 285 | 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, 286 | ref='y', has_2nd_grad=False), 287 | 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, 288 | def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 289 | 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', 290 | has_2nd_grad=True), 291 | 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', 292 | has_2nd_grad=True), 293 | 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', 294 | has_2nd_grad=True), 295 | 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', 296 | has_2nd_grad=True), 297 | 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, 298 | ref='y', has_2nd_grad=True), 299 | 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', 300 | has_2nd_grad=True), 301 | } 302 | 303 | 304 | def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): 305 | r"""Pad, upsample, filter, and downsample a batch of 2D images. 306 | 307 | Performs the following sequence of operations for each channel: 308 | 309 | 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). 310 | 311 | 2. Pad the image with the specified number of zeros on each side (`padding`). 312 | Negative padding corresponds to cropping the image. 313 | 314 | 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it 315 | so that the footprint of all output pixels lies within the input image. 316 | 317 | 4. Downsample the image by keeping every Nth pixel (`down`). 318 | 319 | This sequence of operations bears close resemblance to scipy.signal.upfirdn(). 320 | The fused op is considerably more efficient than performing the same calculation 321 | using standard PyTorch ops. It supports gradients of arbitrary order. 322 | 323 | Args: 324 | x: Float32/float64/float16 input tensor of the shape 325 | `[batch_size, num_channels, in_height, in_width]`. 326 | f: Float32 FIR filter of the shape 327 | `[filter_height, filter_width]` (non-separable), 328 | `[filter_taps]` (separable), or 329 | `None` (identity). 330 | up: Integer upsampling factor. Can be a single int or a list/tuple 331 | `[x, y]` (default: 1). 332 | down: Integer downsampling factor. Can be a single int or a list/tuple 333 | `[x, y]` (default: 1). 334 | padding: Padding with respect to the upsampled image. Can be a single number 335 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 336 | (default: 0). 337 | flip_filter: False = convolution, True = correlation (default: False). 338 | gain: Overall scaling factor for signal magnitude (default: 1). 339 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 340 | 341 | Returns: 342 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 343 | """ 344 | # assert isinstance(x, torch.Tensor) 345 | # assert impl in ['ref', 'cuda'] 346 | return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) 347 | 348 | 349 | def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): 350 | """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. 351 | """ 352 | # Validate arguments. 353 | assert isinstance(x, torch.Tensor) and x.ndim == 4 354 | if f is None: 355 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 356 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 357 | assert f.dtype == torch.float32 and not f.requires_grad 358 | batch_size, num_channels, in_height, in_width = x.shape 359 | # upx, upy = _parse_scaling(up) 360 | # downx, downy = _parse_scaling(down) 361 | 362 | upx, upy = up, up 363 | downx, downy = down, down 364 | 365 | # padx0, padx1, pady0, pady1 = _parse_padding(padding) 366 | padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3] 367 | 368 | # Upsample by inserting zeros. 369 | x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) 370 | x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) 371 | x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) 372 | 373 | # Pad or crop. 374 | x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) 375 | x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)] 376 | 377 | # Setup filter. 378 | f = f * (gain ** (f.ndim / 2)) 379 | f = f.to(x.dtype) 380 | if not flip_filter: 381 | f = f.flip(list(range(f.ndim))) 382 | 383 | # Convolve with the filter. 384 | f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) 385 | if f.ndim == 4: 386 | x = conv2d(input=x, weight=f, groups=num_channels) 387 | else: 388 | x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) 389 | x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) 390 | 391 | # Downsample by throwing away pixels. 392 | x = x[:, :, ::downy, ::downx] 393 | return x 394 | 395 | 396 | def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): 397 | r"""Downsample a batch of 2D images using the given 2D FIR filter. 398 | 399 | By default, the result is padded so that its shape is a fraction of the input. 400 | User-specified padding is applied on top of that, with negative values 401 | indicating cropping. Pixels outside the image are assumed to be zero. 402 | 403 | Args: 404 | x: Float32/float64/float16 input tensor of the shape 405 | `[batch_size, num_channels, in_height, in_width]`. 406 | f: Float32 FIR filter of the shape 407 | `[filter_height, filter_width]` (non-separable), 408 | `[filter_taps]` (separable), or 409 | `None` (identity). 410 | down: Integer downsampling factor. Can be a single int or a list/tuple 411 | `[x, y]` (default: 1). 412 | padding: Padding with respect to the input. Can be a single number or a 413 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 414 | (default: 0). 415 | flip_filter: False = convolution, True = correlation (default: False). 416 | gain: Overall scaling factor for signal magnitude (default: 1). 417 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 418 | 419 | Returns: 420 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 421 | """ 422 | downx, downy = _parse_scaling(down) 423 | # padx0, padx1, pady0, pady1 = _parse_padding(padding) 424 | padx0, padx1, pady0, pady1 = padding, padding, padding, padding 425 | 426 | fw, fh = _get_filter_size(f) 427 | p = [ 428 | padx0 + (fw - downx + 1) // 2, 429 | padx1 + (fw - downx) // 2, 430 | pady0 + (fh - downy + 1) // 2, 431 | pady1 + (fh - downy) // 2, 432 | ] 433 | return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) 434 | 435 | 436 | def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): 437 | r"""Upsample a batch of 2D images using the given 2D FIR filter. 438 | 439 | By default, the result is padded so that its shape is a multiple of the input. 440 | User-specified padding is applied on top of that, with negative values 441 | indicating cropping. Pixels outside the image are assumed to be zero. 442 | 443 | Args: 444 | x: Float32/float64/float16 input tensor of the shape 445 | `[batch_size, num_channels, in_height, in_width]`. 446 | f: Float32 FIR filter of the shape 447 | `[filter_height, filter_width]` (non-separable), 448 | `[filter_taps]` (separable), or 449 | `None` (identity). 450 | up: Integer upsampling factor. Can be a single int or a list/tuple 451 | `[x, y]` (default: 1). 452 | padding: Padding with respect to the output. Can be a single number or a 453 | list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 454 | (default: 0). 455 | flip_filter: False = convolution, True = correlation (default: False). 456 | gain: Overall scaling factor for signal magnitude (default: 1). 457 | impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). 458 | 459 | Returns: 460 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 461 | """ 462 | upx, upy = _parse_scaling(up) 463 | # upx, upy = up, up 464 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 465 | # padx0, padx1, pady0, pady1 = padding, padding, padding, padding 466 | fw, fh = _get_filter_size(f) 467 | p = [ 468 | padx0 + (fw + upx - 1) // 2, 469 | padx1 + (fw - upx) // 2, 470 | pady0 + (fh + upy - 1) // 2, 471 | pady1 + (fh - upy) // 2, 472 | ] 473 | return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl) 474 | 475 | 476 | class MinibatchStdLayer(torch.nn.Module): 477 | def __init__(self, group_size, num_channels=1): 478 | super().__init__() 479 | self.group_size = group_size 480 | self.num_channels = num_channels 481 | 482 | def forward(self, x): 483 | N, C, H, W = x.shape 484 | G = torch.min(torch.as_tensor(self.group_size), 485 | torch.as_tensor(N)) if self.group_size is not None else N 486 | F = self.num_channels 487 | c = C // F 488 | 489 | y = x.reshape(G, -1, F, c, H, 490 | W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. 491 | y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. 492 | y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. 493 | y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. 494 | y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels. 495 | y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions. 496 | y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels. 497 | x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels. 498 | return x 499 | 500 | 501 | class FullyConnectedLayer(torch.nn.Module): 502 | def __init__(self, 503 | in_features, # Number of input features. 504 | out_features, # Number of output features. 505 | bias=True, # Apply additive bias before the activation function? 506 | activation='linear', # Activation function: 'relu', 'lrelu', etc. 507 | lr_multiplier=1, # Learning rate multiplier. 508 | bias_init=0, # Initial value for the additive bias. 509 | ): 510 | super().__init__() 511 | self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) 512 | self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None 513 | self.activation = activation 514 | 515 | self.weight_gain = lr_multiplier / np.sqrt(in_features) 516 | self.bias_gain = lr_multiplier 517 | 518 | def forward(self, x): 519 | w = self.weight * self.weight_gain 520 | b = self.bias 521 | if b is not None and self.bias_gain != 1: 522 | b = b * self.bias_gain 523 | 524 | if self.activation == 'linear' and b is not None: 525 | # out = torch.addmm(b.unsqueeze(0), x, w.t()) 526 | x = x.matmul(w.t()) 527 | out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)]) 528 | else: 529 | x = x.matmul(w.t()) 530 | out = bias_act(x, b, act=self.activation, dim=x.ndim - 1) 531 | return out 532 | 533 | 534 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 535 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 536 | """ 537 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 538 | 539 | # Flip weight if requested. 540 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 541 | w = w.flip([2, 3]) 542 | 543 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 544 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 545 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 546 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 547 | if out_channels <= 4 and groups == 1: 548 | in_shape = x.shape 549 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 550 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 551 | else: 552 | x = x.to(memory_format=torch.contiguous_format) 553 | w = w.to(memory_format=torch.contiguous_format) 554 | x = conv2d(x, w, groups=groups) 555 | return x.to(memory_format=torch.channels_last) 556 | 557 | # Otherwise => execute using conv2d_gradfix. 558 | op = conv_transpose2d if transpose else conv2d 559 | return op(x, w, stride=stride, padding=padding, groups=groups) 560 | 561 | 562 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 563 | r"""2D convolution with optional up/downsampling. 564 | 565 | Padding is performed only once at the beginning, not between the operations. 566 | 567 | Args: 568 | x: Input tensor of shape 569 | `[batch_size, in_channels, in_height, in_width]`. 570 | w: Weight tensor of shape 571 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 572 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 573 | calling setup_filter(). None = identity (default). 574 | up: Integer upsampling factor (default: 1). 575 | down: Integer downsampling factor (default: 1). 576 | padding: Padding with respect to the upsampled image. Can be a single number 577 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 578 | (default: 0). 579 | groups: Split input channels into N groups (default: 1). 580 | flip_weight: False = convolution, True = correlation (default: True). 581 | flip_filter: False = convolution, True = correlation (default: False). 582 | 583 | Returns: 584 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 585 | """ 586 | # Validate arguments. 587 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 588 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 589 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 590 | assert isinstance(up, int) and (up >= 1) 591 | assert isinstance(down, int) and (down >= 1) 592 | # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}" 593 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 594 | fw, fh = _get_filter_size(f) 595 | # px0, px1, py0, py1 = _parse_padding(padding) 596 | px0, px1, py0, py1 = padding, padding, padding, padding 597 | 598 | # Adjust padding to account for up/downsampling. 599 | if up > 1: 600 | px0 += (fw + up - 1) // 2 601 | px1 += (fw - up) // 2 602 | py0 += (fh + up - 1) // 2 603 | py1 += (fh - up) // 2 604 | if down > 1: 605 | px0 += (fw - down + 1) // 2 606 | px1 += (fw - down) // 2 607 | py0 += (fh - down + 1) // 2 608 | py1 += (fh - down) // 2 609 | 610 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 611 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 612 | x = upfirdn2d(x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter) 613 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 614 | return x 615 | 616 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 617 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 618 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 619 | x = upfirdn2d(x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter) 620 | return x 621 | 622 | # Fast path: downsampling only => use strided convolution. 623 | if down > 1 and up == 1: 624 | x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter) 625 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 626 | return x 627 | 628 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 629 | if up > 1: 630 | if groups == 1: 631 | w = w.transpose(0, 1) 632 | else: 633 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 634 | w = w.transpose(1, 2) 635 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 636 | px0 -= kw - 1 637 | px1 -= kw - up 638 | py0 -= kh - 1 639 | py1 -= kh - up 640 | pxt = max(min(-px0, -px1), 0) 641 | pyt = max(min(-py0, -py1), 0) 642 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt, pxt], groups=groups, transpose=True, 643 | flip_weight=(not flip_weight)) 644 | x = upfirdn2d(x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], gain=up ** 2, 645 | flip_filter=flip_filter) 646 | if down > 1: 647 | x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 648 | return x 649 | 650 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 651 | if up == 1 and down == 1: 652 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 653 | return _conv2d_wrapper(x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight) 654 | 655 | # Fallback: Generic reference implementation. 656 | x = upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], gain=up ** 2, 657 | flip_filter=flip_filter) 658 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 659 | if down > 1: 660 | x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 661 | return x 662 | 663 | 664 | class Conv2dLayer(torch.nn.Module): 665 | def __init__(self, 666 | in_channels, # Number of input channels. 667 | out_channels, # Number of output channels. 668 | kernel_size, # Width and height of the convolution kernel. 669 | bias=True, # Apply additive bias before the activation function? 670 | activation='linear', # Activation function: 'relu', 'lrelu', etc. 671 | up=1, # Integer upsampling factor. 672 | down=1, # Integer downsampling factor. 673 | resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. 674 | conv_clamp=None, # Clamp the output to +-X, None = disable clamping. 675 | channels_last=False, # Expect the input to have memory_format=channels_last? 676 | trainable=True, # Update the weights of this layer during training? 677 | ): 678 | super().__init__() 679 | self.activation = activation 680 | self.up = up 681 | self.down = down 682 | self.register_buffer('resample_filter', setup_filter(resample_filter)) 683 | self.conv_clamp = conv_clamp 684 | self.padding = kernel_size // 2 685 | self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) 686 | self.act_gain = activation_funcs[activation].def_gain 687 | 688 | memory_format = torch.channels_last if channels_last else torch.contiguous_format 689 | weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) 690 | bias = torch.zeros([out_channels]) if bias else None 691 | if trainable: 692 | self.weight = torch.nn.Parameter(weight) 693 | self.bias = torch.nn.Parameter(bias) if bias is not None else None 694 | else: 695 | self.register_buffer('weight', weight) 696 | if bias is not None: 697 | self.register_buffer('bias', bias) 698 | else: 699 | self.bias = None 700 | 701 | def forward(self, x, gain=1): 702 | w = self.weight * self.weight_gain 703 | x = conv2d_resample(x=x, w=w, f=self.resample_filter, up=self.up, down=self.down, 704 | padding=self.padding) 705 | 706 | act_gain = self.act_gain * gain 707 | act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None 708 | out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp) 709 | return out 710 | -------------------------------------------------------------------------------- /lama_cleaner/model/fcf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import cv2 5 | import torch 6 | import numpy as np 7 | import torch.fft as fft 8 | 9 | from lama_cleaner.schema import Config 10 | 11 | from lama_cleaner.helper import load_model, get_cache_path_by_url, norm_img, boxes_from_mask, resize_max_size 12 | from lama_cleaner.model.base import InpaintModel 13 | from torch import conv2d, nn 14 | import torch.nn.functional as F 15 | 16 | from lama_cleaner.model.utils import setup_filter, _parse_scaling, _parse_padding, Conv2dLayer, FullyConnectedLayer, \ 17 | MinibatchStdLayer, activation_funcs, conv2d_resample, bias_act, upsample2d, normalize_2nd_moment, downsample2d 18 | 19 | 20 | def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): 21 | assert isinstance(x, torch.Tensor) 22 | return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) 23 | 24 | 25 | def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): 26 | """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. 27 | """ 28 | # Validate arguments. 29 | assert isinstance(x, torch.Tensor) and x.ndim == 4 30 | if f is None: 31 | f = torch.ones([1, 1], dtype=torch.float32, device=x.device) 32 | assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] 33 | assert f.dtype == torch.float32 and not f.requires_grad 34 | batch_size, num_channels, in_height, in_width = x.shape 35 | upx, upy = _parse_scaling(up) 36 | downx, downy = _parse_scaling(down) 37 | padx0, padx1, pady0, pady1 = _parse_padding(padding) 38 | 39 | # Upsample by inserting zeros. 40 | x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) 41 | x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) 42 | x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) 43 | 44 | # Pad or crop. 45 | x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) 46 | x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)] 47 | 48 | # Setup filter. 49 | f = f * (gain ** (f.ndim / 2)) 50 | f = f.to(x.dtype) 51 | if not flip_filter: 52 | f = f.flip(list(range(f.ndim))) 53 | 54 | # Convolve with the filter. 55 | f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) 56 | if f.ndim == 4: 57 | x = conv2d(input=x, weight=f, groups=num_channels) 58 | else: 59 | x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) 60 | x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) 61 | 62 | # Downsample by throwing away pixels. 63 | x = x[:, :, ::downy, ::downx] 64 | return x 65 | 66 | 67 | class EncoderEpilogue(torch.nn.Module): 68 | def __init__(self, 69 | in_channels, # Number of input channels. 70 | cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label. 71 | z_dim, # Output Latent (Z) dimensionality. 72 | resolution, # Resolution of this block. 73 | img_channels, # Number of input color channels. 74 | architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. 75 | mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch. 76 | mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable. 77 | activation='lrelu', # Activation function: 'relu', 'lrelu', etc. 78 | conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. 79 | ): 80 | assert architecture in ['orig', 'skip', 'resnet'] 81 | super().__init__() 82 | self.in_channels = in_channels 83 | self.cmap_dim = cmap_dim 84 | self.resolution = resolution 85 | self.img_channels = img_channels 86 | self.architecture = architecture 87 | 88 | if architecture == 'skip': 89 | self.fromrgb = Conv2dLayer(self.img_channels, in_channels, kernel_size=1, activation=activation) 90 | self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, 91 | num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None 92 | self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, 93 | conv_clamp=conv_clamp) 94 | self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), z_dim, activation=activation) 95 | self.dropout = torch.nn.Dropout(p=0.5) 96 | 97 | def forward(self, x, cmap, force_fp32=False): 98 | _ = force_fp32 # unused 99 | dtype = torch.float32 100 | memory_format = torch.contiguous_format 101 | 102 | # FromRGB. 103 | x = x.to(dtype=dtype, memory_format=memory_format) 104 | 105 | # Main layers. 106 | if self.mbstd is not None: 107 | x = self.mbstd(x) 108 | const_e = self.conv(x) 109 | x = self.fc(const_e.flatten(1)) 110 | x = self.dropout(x) 111 | 112 | # Conditioning. 113 | if self.cmap_dim > 0: 114 | x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) 115 | 116 | assert x.dtype == dtype 117 | return x, const_e 118 | 119 | 120 | class EncoderBlock(torch.nn.Module): 121 | def __init__(self, 122 | in_channels, # Number of input channels, 0 = first block. 123 | tmp_channels, # Number of intermediate channels. 124 | out_channels, # Number of output channels. 125 | resolution, # Resolution of this block. 126 | img_channels, # Number of input color channels. 127 | first_layer_idx, # Index of the first layer. 128 | architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. 129 | activation='lrelu', # Activation function: 'relu', 'lrelu', etc. 130 | resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. 131 | conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. 132 | use_fp16=False, # Use FP16 for this block? 133 | fp16_channels_last=False, # Use channels-last memory format with FP16? 134 | freeze_layers=0, # Freeze-D: Number of layers to freeze. 135 | ): 136 | assert in_channels in [0, tmp_channels] 137 | assert architecture in ['orig', 'skip', 'resnet'] 138 | super().__init__() 139 | self.in_channels = in_channels 140 | self.resolution = resolution 141 | self.img_channels = img_channels + 1 142 | self.first_layer_idx = first_layer_idx 143 | self.architecture = architecture 144 | self.use_fp16 = use_fp16 145 | self.channels_last = (use_fp16 and fp16_channels_last) 146 | self.register_buffer('resample_filter', setup_filter(resample_filter)) 147 | 148 | self.num_layers = 0 149 | 150 | def trainable_gen(): 151 | while True: 152 | layer_idx = self.first_layer_idx + self.num_layers 153 | trainable = (layer_idx >= freeze_layers) 154 | self.num_layers += 1 155 | yield trainable 156 | 157 | trainable_iter = trainable_gen() 158 | 159 | if in_channels == 0: 160 | self.fromrgb = Conv2dLayer(self.img_channels, tmp_channels, kernel_size=1, activation=activation, 161 | trainable=next(trainable_iter), conv_clamp=conv_clamp, 162 | channels_last=self.channels_last) 163 | 164 | self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation, 165 | trainable=next(trainable_iter), conv_clamp=conv_clamp, 166 | channels_last=self.channels_last) 167 | 168 | self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2, 169 | trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, 170 | channels_last=self.channels_last) 171 | 172 | if architecture == 'resnet': 173 | self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2, 174 | trainable=next(trainable_iter), resample_filter=resample_filter, 175 | channels_last=self.channels_last) 176 | 177 | def forward(self, x, img, force_fp32=False): 178 | # dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 179 | dtype = torch.float32 180 | memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format 181 | 182 | # Input. 183 | if x is not None: 184 | x = x.to(dtype=dtype, memory_format=memory_format) 185 | 186 | # FromRGB. 187 | if self.in_channels == 0: 188 | img = img.to(dtype=dtype, memory_format=memory_format) 189 | y = self.fromrgb(img) 190 | x = x + y if x is not None else y 191 | img = downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None 192 | 193 | # Main layers. 194 | if self.architecture == 'resnet': 195 | y = self.skip(x, gain=np.sqrt(0.5)) 196 | x = self.conv0(x) 197 | feat = x.clone() 198 | x = self.conv1(x, gain=np.sqrt(0.5)) 199 | x = y.add_(x) 200 | else: 201 | x = self.conv0(x) 202 | feat = x.clone() 203 | x = self.conv1(x) 204 | 205 | assert x.dtype == dtype 206 | return x, img, feat 207 | 208 | 209 | class EncoderNetwork(torch.nn.Module): 210 | def __init__(self, 211 | c_dim, # Conditioning label (C) dimensionality. 212 | z_dim, # Input latent (Z) dimensionality. 213 | img_resolution, # Input resolution. 214 | img_channels, # Number of input color channels. 215 | architecture='orig', # Architecture: 'orig', 'skip', 'resnet'. 216 | channel_base=16384, # Overall multiplier for the number of channels. 217 | channel_max=512, # Maximum number of channels in any layer. 218 | num_fp16_res=0, # Use FP16 for the N highest resolutions. 219 | conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. 220 | cmap_dim=None, # Dimensionality of mapped conditioning label, None = default. 221 | block_kwargs={}, # Arguments for DiscriminatorBlock. 222 | mapping_kwargs={}, # Arguments for MappingNetwork. 223 | epilogue_kwargs={}, # Arguments for EncoderEpilogue. 224 | ): 225 | super().__init__() 226 | self.c_dim = c_dim 227 | self.z_dim = z_dim 228 | self.img_resolution = img_resolution 229 | self.img_resolution_log2 = int(np.log2(img_resolution)) 230 | self.img_channels = img_channels 231 | self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] 232 | channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} 233 | fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) 234 | 235 | if cmap_dim is None: 236 | cmap_dim = channels_dict[4] 237 | if c_dim == 0: 238 | cmap_dim = 0 239 | 240 | common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) 241 | cur_layer_idx = 0 242 | for res in self.block_resolutions: 243 | in_channels = channels_dict[res] if res < img_resolution else 0 244 | tmp_channels = channels_dict[res] 245 | out_channels = channels_dict[res // 2] 246 | use_fp16 = (res >= fp16_resolution) 247 | use_fp16 = False 248 | block = EncoderBlock(in_channels, tmp_channels, out_channels, resolution=res, 249 | first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) 250 | setattr(self, f'b{res}', block) 251 | cur_layer_idx += block.num_layers 252 | if c_dim > 0: 253 | self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, 254 | **mapping_kwargs) 255 | self.b4 = EncoderEpilogue(channels_dict[4], cmap_dim=cmap_dim, z_dim=z_dim * 2, resolution=4, **epilogue_kwargs, 256 | **common_kwargs) 257 | 258 | def forward(self, img, c, **block_kwargs): 259 | x = None 260 | feats = {} 261 | for res in self.block_resolutions: 262 | block = getattr(self, f'b{res}') 263 | x, img, feat = block(x, img, **block_kwargs) 264 | feats[res] = feat 265 | 266 | cmap = None 267 | if self.c_dim > 0: 268 | cmap = self.mapping(None, c) 269 | x, const_e = self.b4(x, cmap) 270 | feats[4] = const_e 271 | 272 | B, _ = x.shape 273 | z = torch.zeros((B, self.z_dim), requires_grad=False, dtype=x.dtype, 274 | device=x.device) ## Noise for Co-Modulation 275 | return x, z, feats 276 | 277 | 278 | def fma(a, b, c): # => a * b + c 279 | return _FusedMultiplyAdd.apply(a, b, c) 280 | 281 | 282 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 283 | @staticmethod 284 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 285 | out = torch.addcmul(c, a, b) 286 | ctx.save_for_backward(a, b) 287 | ctx.c_shape = c.shape 288 | return out 289 | 290 | @staticmethod 291 | def backward(ctx, dout): # pylint: disable=arguments-differ 292 | a, b = ctx.saved_tensors 293 | c_shape = ctx.c_shape 294 | da = None 295 | db = None 296 | dc = None 297 | 298 | if ctx.needs_input_grad[0]: 299 | da = _unbroadcast(dout * b, a.shape) 300 | 301 | if ctx.needs_input_grad[1]: 302 | db = _unbroadcast(dout * a, b.shape) 303 | 304 | if ctx.needs_input_grad[2]: 305 | dc = _unbroadcast(dout, c_shape) 306 | 307 | return da, db, dc 308 | 309 | 310 | def _unbroadcast(x, shape): 311 | extra_dims = x.ndim - len(shape) 312 | assert extra_dims >= 0 313 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 314 | if len(dim): 315 | x = x.sum(dim=dim, keepdim=True) 316 | if extra_dims: 317 | x = x.reshape(-1, *x.shape[extra_dims + 1:]) 318 | assert x.shape == shape 319 | return x 320 | 321 | 322 | def modulated_conv2d( 323 | x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. 324 | weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. 325 | styles, # Modulation coefficients of shape [batch_size, in_channels]. 326 | noise=None, # Optional noise tensor to add to the output activations. 327 | up=1, # Integer upsampling factor. 328 | down=1, # Integer downsampling factor. 329 | padding=0, # Padding with respect to the upsampled image. 330 | resample_filter=None, 331 | # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). 332 | demodulate=True, # Apply weight demodulation? 333 | flip_weight=True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). 334 | fused_modconv=True, # Perform modulation, convolution, and demodulation as a single fused operation? 335 | ): 336 | batch_size = x.shape[0] 337 | out_channels, in_channels, kh, kw = weight.shape 338 | 339 | # Pre-normalize inputs to avoid FP16 overflow. 340 | if x.dtype == torch.float16 and demodulate: 341 | weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1, 2, 3], 342 | keepdim=True)) # max_Ikk 343 | styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I 344 | 345 | # Calculate per-sample weights and demodulation coefficients. 346 | w = None 347 | dcoefs = None 348 | if demodulate or fused_modconv: 349 | w = weight.unsqueeze(0) # [NOIkk] 350 | w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] 351 | if demodulate: 352 | dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt() # [NO] 353 | if demodulate and fused_modconv: 354 | w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] 355 | # Execute by scaling the activations before and after the convolution. 356 | if not fused_modconv: 357 | x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) 358 | x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, 359 | padding=padding, flip_weight=flip_weight) 360 | if demodulate and noise is not None: 361 | x = fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) 362 | elif demodulate: 363 | x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1) 364 | elif noise is not None: 365 | x = x.add_(noise.to(x.dtype)) 366 | return x 367 | 368 | # Execute as one fused op using grouped convolution. 369 | batch_size = int(batch_size) 370 | x = x.reshape(1, -1, *x.shape[2:]) 371 | w = w.reshape(-1, in_channels, kh, kw) 372 | x = conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, 373 | groups=batch_size, flip_weight=flip_weight) 374 | x = x.reshape(batch_size, -1, *x.shape[2:]) 375 | if noise is not None: 376 | x = x.add_(noise) 377 | return x 378 | 379 | 380 | class SynthesisLayer(torch.nn.Module): 381 | def __init__(self, 382 | in_channels, # Number of input channels. 383 | out_channels, # Number of output channels. 384 | w_dim, # Intermediate latent (W) dimensionality. 385 | resolution, # Resolution of this layer. 386 | kernel_size=3, # Convolution kernel size. 387 | up=1, # Integer upsampling factor. 388 | use_noise=True, # Enable noise input? 389 | activation='lrelu', # Activation function: 'relu', 'lrelu', etc. 390 | resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. 391 | conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. 392 | channels_last=False, # Use channels_last format for the weights? 393 | ): 394 | super().__init__() 395 | self.resolution = resolution 396 | self.up = up 397 | self.use_noise = use_noise 398 | self.activation = activation 399 | self.conv_clamp = conv_clamp 400 | self.register_buffer('resample_filter', setup_filter(resample_filter)) 401 | self.padding = kernel_size // 2 402 | self.act_gain = activation_funcs[activation].def_gain 403 | 404 | self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) 405 | memory_format = torch.channels_last if channels_last else torch.contiguous_format 406 | self.weight = torch.nn.Parameter( 407 | torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) 408 | if use_noise: 409 | self.register_buffer('noise_const', torch.randn([resolution, resolution])) 410 | self.noise_strength = torch.nn.Parameter(torch.zeros([])) 411 | self.bias = torch.nn.Parameter(torch.zeros([out_channels])) 412 | 413 | def forward(self, x, w, noise_mode='none', fused_modconv=True, gain=1): 414 | assert noise_mode in ['random', 'const', 'none'] 415 | in_resolution = self.resolution // self.up 416 | styles = self.affine(w) 417 | 418 | noise = None 419 | if self.use_noise and noise_mode == 'random': 420 | noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], 421 | device=x.device) * self.noise_strength 422 | if self.use_noise and noise_mode == 'const': 423 | noise = self.noise_const * self.noise_strength 424 | 425 | flip_weight = (self.up == 1) # slightly faster 426 | x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up, 427 | padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, 428 | fused_modconv=fused_modconv) 429 | 430 | act_gain = self.act_gain * gain 431 | act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None 432 | x = F.leaky_relu(x, negative_slope=0.2, inplace=False) 433 | if act_gain != 1: 434 | x = x * act_gain 435 | if act_clamp is not None: 436 | x = x.clamp(-act_clamp, act_clamp) 437 | return x 438 | 439 | 440 | class ToRGBLayer(torch.nn.Module): 441 | def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False): 442 | super().__init__() 443 | self.conv_clamp = conv_clamp 444 | self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) 445 | memory_format = torch.channels_last if channels_last else torch.contiguous_format 446 | self.weight = torch.nn.Parameter( 447 | torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)) 448 | self.bias = torch.nn.Parameter(torch.zeros([out_channels])) 449 | self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) 450 | 451 | def forward(self, x, w, fused_modconv=True): 452 | styles = self.affine(w) * self.weight_gain 453 | x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) 454 | x = bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp) 455 | return x 456 | 457 | 458 | class SynthesisForeword(torch.nn.Module): 459 | def __init__(self, 460 | z_dim, # Output Latent (Z) dimensionality. 461 | resolution, # Resolution of this block. 462 | in_channels, 463 | img_channels, # Number of input color channels. 464 | architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. 465 | activation='lrelu', # Activation function: 'relu', 'lrelu', etc. 466 | 467 | ): 468 | super().__init__() 469 | self.in_channels = in_channels 470 | self.z_dim = z_dim 471 | self.resolution = resolution 472 | self.img_channels = img_channels 473 | self.architecture = architecture 474 | 475 | self.fc = FullyConnectedLayer(self.z_dim, (self.z_dim // 2) * 4 * 4, activation=activation) 476 | self.conv = SynthesisLayer(self.in_channels, self.in_channels, w_dim=(z_dim // 2) * 3, resolution=4) 477 | 478 | if architecture == 'skip': 479 | self.torgb = ToRGBLayer(self.in_channels, self.img_channels, kernel_size=1, w_dim=(z_dim // 2) * 3) 480 | 481 | def forward(self, x, ws, feats, img, force_fp32=False): 482 | _ = force_fp32 # unused 483 | dtype = torch.float32 484 | memory_format = torch.contiguous_format 485 | 486 | x_global = x.clone() 487 | # ToRGB. 488 | x = self.fc(x) 489 | x = x.view(-1, self.z_dim // 2, 4, 4) 490 | x = x.to(dtype=dtype, memory_format=memory_format) 491 | 492 | # Main layers. 493 | x_skip = feats[4].clone() 494 | x = x + x_skip 495 | 496 | mod_vector = [] 497 | mod_vector.append(ws[:, 0]) 498 | mod_vector.append(x_global.clone()) 499 | mod_vector = torch.cat(mod_vector, dim=1) 500 | 501 | x = self.conv(x, mod_vector) 502 | 503 | mod_vector = [] 504 | mod_vector.append(ws[:, 2 * 2 - 3]) 505 | mod_vector.append(x_global.clone()) 506 | mod_vector = torch.cat(mod_vector, dim=1) 507 | 508 | if self.architecture == 'skip': 509 | img = self.torgb(x, mod_vector) 510 | img = img.to(dtype=torch.float32, memory_format=torch.contiguous_format) 511 | 512 | assert x.dtype == dtype 513 | return x, img 514 | 515 | 516 | class SELayer(nn.Module): 517 | def __init__(self, channel, reduction=16): 518 | super(SELayer, self).__init__() 519 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 520 | self.fc = nn.Sequential( 521 | nn.Linear(channel, channel // reduction, bias=False), 522 | nn.ReLU(inplace=False), 523 | nn.Linear(channel // reduction, channel, bias=False), 524 | nn.Sigmoid() 525 | ) 526 | 527 | def forward(self, x): 528 | b, c, _, _ = x.size() 529 | y = self.avg_pool(x).view(b, c) 530 | y = self.fc(y).view(b, c, 1, 1) 531 | res = x * y.expand_as(x) 532 | return res 533 | 534 | 535 | class FourierUnit(nn.Module): 536 | 537 | def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', 538 | spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'): 539 | # bn_layer not used 540 | super(FourierUnit, self).__init__() 541 | self.groups = groups 542 | 543 | self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), 544 | out_channels=out_channels * 2, 545 | kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) 546 | self.relu = torch.nn.ReLU(inplace=False) 547 | 548 | # squeeze and excitation block 549 | self.use_se = use_se 550 | if use_se: 551 | if se_kwargs is None: 552 | se_kwargs = {} 553 | self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) 554 | 555 | self.spatial_scale_factor = spatial_scale_factor 556 | self.spatial_scale_mode = spatial_scale_mode 557 | self.spectral_pos_encoding = spectral_pos_encoding 558 | self.ffc3d = ffc3d 559 | self.fft_norm = fft_norm 560 | 561 | def forward(self, x): 562 | batch = x.shape[0] 563 | 564 | if self.spatial_scale_factor is not None: 565 | orig_size = x.shape[-2:] 566 | x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, 567 | align_corners=False) 568 | 569 | r_size = x.size() 570 | # (batch, c, h, w/2+1, 2) 571 | fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) 572 | ffted = fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) 573 | ffted = torch.stack((ffted.real, ffted.imag), dim=-1) 574 | ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) 575 | ffted = ffted.view((batch, -1,) + ffted.size()[3:]) 576 | 577 | if self.spectral_pos_encoding: 578 | height, width = ffted.shape[-2:] 579 | coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted) 580 | coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted) 581 | ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) 582 | 583 | if self.use_se: 584 | ffted = self.se(ffted) 585 | 586 | ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) 587 | ffted = self.relu(ffted) 588 | 589 | ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( 590 | 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) 591 | ffted = torch.complex(ffted[..., 0], ffted[..., 1]) 592 | 593 | ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] 594 | output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) 595 | 596 | if self.spatial_scale_factor is not None: 597 | output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False) 598 | 599 | return output 600 | 601 | 602 | class SpectralTransform(nn.Module): 603 | 604 | def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs): 605 | # bn_layer not used 606 | super(SpectralTransform, self).__init__() 607 | self.enable_lfu = enable_lfu 608 | if stride == 2: 609 | self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) 610 | else: 611 | self.downsample = nn.Identity() 612 | 613 | self.stride = stride 614 | self.conv1 = nn.Sequential( 615 | nn.Conv2d(in_channels, out_channels // 616 | 2, kernel_size=1, groups=groups, bias=False), 617 | # nn.BatchNorm2d(out_channels // 2), 618 | nn.ReLU(inplace=True) 619 | ) 620 | self.fu = FourierUnit( 621 | out_channels // 2, out_channels // 2, groups, **fu_kwargs) 622 | if self.enable_lfu: 623 | self.lfu = FourierUnit( 624 | out_channels // 2, out_channels // 2, groups) 625 | self.conv2 = torch.nn.Conv2d( 626 | out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False) 627 | 628 | def forward(self, x): 629 | 630 | x = self.downsample(x) 631 | x = self.conv1(x) 632 | output = self.fu(x) 633 | 634 | if self.enable_lfu: 635 | n, c, h, w = x.shape 636 | split_no = 2 637 | split_s = h // split_no 638 | xs = torch.cat(torch.split( 639 | x[:, :c // 4], split_s, dim=-2), dim=1).contiguous() 640 | xs = torch.cat(torch.split(xs, split_s, dim=-1), 641 | dim=1).contiguous() 642 | xs = self.lfu(xs) 643 | xs = xs.repeat(1, 1, split_no, split_no).contiguous() 644 | else: 645 | xs = 0 646 | 647 | output = self.conv2(x + output + xs) 648 | 649 | return output 650 | 651 | 652 | class FFC(nn.Module): 653 | 654 | def __init__(self, in_channels, out_channels, kernel_size, 655 | ratio_gin, ratio_gout, stride=1, padding=0, 656 | dilation=1, groups=1, bias=False, enable_lfu=True, 657 | padding_type='reflect', gated=False, **spectral_kwargs): 658 | super(FFC, self).__init__() 659 | 660 | assert stride == 1 or stride == 2, "Stride should be 1 or 2." 661 | self.stride = stride 662 | 663 | in_cg = int(in_channels * ratio_gin) 664 | in_cl = in_channels - in_cg 665 | out_cg = int(out_channels * ratio_gout) 666 | out_cl = out_channels - out_cg 667 | # groups_g = 1 if groups == 1 else int(groups * ratio_gout) 668 | # groups_l = 1 if groups == 1 else groups - groups_g 669 | 670 | self.ratio_gin = ratio_gin 671 | self.ratio_gout = ratio_gout 672 | self.global_in_num = in_cg 673 | 674 | module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d 675 | self.convl2l = module(in_cl, out_cl, kernel_size, 676 | stride, padding, dilation, groups, bias, padding_mode=padding_type) 677 | module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d 678 | self.convl2g = module(in_cl, out_cg, kernel_size, 679 | stride, padding, dilation, groups, bias, padding_mode=padding_type) 680 | module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d 681 | self.convg2l = module(in_cg, out_cl, kernel_size, 682 | stride, padding, dilation, groups, bias, padding_mode=padding_type) 683 | module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform 684 | self.convg2g = module( 685 | in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs) 686 | 687 | self.gated = gated 688 | module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d 689 | self.gate = module(in_channels, 2, 1) 690 | 691 | def forward(self, x, fname=None): 692 | x_l, x_g = x if type(x) is tuple else (x, 0) 693 | out_xl, out_xg = 0, 0 694 | 695 | if self.gated: 696 | total_input_parts = [x_l] 697 | if torch.is_tensor(x_g): 698 | total_input_parts.append(x_g) 699 | total_input = torch.cat(total_input_parts, dim=1) 700 | 701 | gates = torch.sigmoid(self.gate(total_input)) 702 | g2l_gate, l2g_gate = gates.chunk(2, dim=1) 703 | else: 704 | g2l_gate, l2g_gate = 1, 1 705 | 706 | spec_x = self.convg2g(x_g) 707 | 708 | if self.ratio_gout != 1: 709 | out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate 710 | if self.ratio_gout != 0: 711 | out_xg = self.convl2g(x_l) * l2g_gate + spec_x 712 | 713 | return out_xl, out_xg 714 | 715 | 716 | class FFC_BN_ACT(nn.Module): 717 | 718 | def __init__(self, in_channels, out_channels, 719 | kernel_size, ratio_gin, ratio_gout, 720 | stride=1, padding=0, dilation=1, groups=1, bias=False, 721 | norm_layer=nn.SyncBatchNorm, activation_layer=nn.Identity, 722 | padding_type='reflect', 723 | enable_lfu=True, **kwargs): 724 | super(FFC_BN_ACT, self).__init__() 725 | self.ffc = FFC(in_channels, out_channels, kernel_size, 726 | ratio_gin, ratio_gout, stride, padding, dilation, 727 | groups, bias, enable_lfu, padding_type=padding_type, **kwargs) 728 | lnorm = nn.Identity if ratio_gout == 1 else norm_layer 729 | gnorm = nn.Identity if ratio_gout == 0 else norm_layer 730 | global_channels = int(out_channels * ratio_gout) 731 | # self.bn_l = lnorm(out_channels - global_channels) 732 | # self.bn_g = gnorm(global_channels) 733 | 734 | lact = nn.Identity if ratio_gout == 1 else activation_layer 735 | gact = nn.Identity if ratio_gout == 0 else activation_layer 736 | self.act_l = lact(inplace=True) 737 | self.act_g = gact(inplace=True) 738 | 739 | def forward(self, x, fname=None): 740 | x_l, x_g = self.ffc(x, fname=fname, ) 741 | x_l = self.act_l(x_l) 742 | x_g = self.act_g(x_g) 743 | return x_l, x_g 744 | 745 | 746 | class FFCResnetBlock(nn.Module): 747 | def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1, 748 | spatial_transform_kwargs=None, inline=False, ratio_gin=0.75, ratio_gout=0.75): 749 | super().__init__() 750 | self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, 751 | norm_layer=norm_layer, 752 | activation_layer=activation_layer, 753 | padding_type=padding_type, 754 | ratio_gin=ratio_gin, ratio_gout=ratio_gout) 755 | self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, 756 | norm_layer=norm_layer, 757 | activation_layer=activation_layer, 758 | padding_type=padding_type, 759 | ratio_gin=ratio_gin, ratio_gout=ratio_gout) 760 | self.inline = inline 761 | 762 | def forward(self, x, fname=None): 763 | if self.inline: 764 | x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:] 765 | else: 766 | x_l, x_g = x if type(x) is tuple else (x, 0) 767 | 768 | id_l, id_g = x_l, x_g 769 | 770 | x_l, x_g = self.conv1((x_l, x_g), fname=fname) 771 | x_l, x_g = self.conv2((x_l, x_g), fname=fname) 772 | 773 | x_l, x_g = id_l + x_l, id_g + x_g 774 | out = x_l, x_g 775 | if self.inline: 776 | out = torch.cat(out, dim=1) 777 | return out 778 | 779 | 780 | class ConcatTupleLayer(nn.Module): 781 | def forward(self, x): 782 | assert isinstance(x, tuple) 783 | x_l, x_g = x 784 | assert torch.is_tensor(x_l) or torch.is_tensor(x_g) 785 | if not torch.is_tensor(x_g): 786 | return x_l 787 | return torch.cat(x, dim=1) 788 | 789 | 790 | class FFCBlock(torch.nn.Module): 791 | def __init__(self, 792 | dim, # Number of output/input channels. 793 | kernel_size, # Width and height of the convolution kernel. 794 | padding, 795 | ratio_gin=0.75, 796 | ratio_gout=0.75, 797 | activation='linear', # Activation function: 'relu', 'lrelu', etc. 798 | ): 799 | super().__init__() 800 | if activation == 'linear': 801 | self.activation = nn.Identity 802 | else: 803 | self.activation = nn.ReLU 804 | self.padding = padding 805 | self.kernel_size = kernel_size 806 | self.ffc_block = FFCResnetBlock(dim=dim, 807 | padding_type='reflect', 808 | norm_layer=nn.SyncBatchNorm, 809 | activation_layer=self.activation, 810 | dilation=1, 811 | ratio_gin=ratio_gin, 812 | ratio_gout=ratio_gout) 813 | 814 | self.concat_layer = ConcatTupleLayer() 815 | 816 | def forward(self, gen_ft, mask, fname=None): 817 | x = gen_ft.float() 818 | 819 | x_l, x_g = x[:, :-self.ffc_block.conv1.ffc.global_in_num], x[:, -self.ffc_block.conv1.ffc.global_in_num:] 820 | id_l, id_g = x_l, x_g 821 | 822 | x_l, x_g = self.ffc_block((x_l, x_g), fname=fname) 823 | x_l, x_g = id_l + x_l, id_g + x_g 824 | x = self.concat_layer((x_l, x_g)) 825 | 826 | return x + gen_ft.float() 827 | 828 | 829 | class FFCSkipLayer(torch.nn.Module): 830 | def __init__(self, 831 | dim, # Number of input/output channels. 832 | kernel_size=3, # Convolution kernel size. 833 | ratio_gin=0.75, 834 | ratio_gout=0.75, 835 | ): 836 | super().__init__() 837 | self.padding = kernel_size // 2 838 | 839 | self.ffc_act = FFCBlock(dim=dim, kernel_size=kernel_size, activation=nn.ReLU, 840 | padding=self.padding, ratio_gin=ratio_gin, ratio_gout=ratio_gout) 841 | 842 | def forward(self, gen_ft, mask, fname=None): 843 | x = self.ffc_act(gen_ft, mask, fname=fname) 844 | return x 845 | 846 | 847 | class SynthesisBlock(torch.nn.Module): 848 | def __init__(self, 849 | in_channels, # Number of input channels, 0 = first block. 850 | out_channels, # Number of output channels. 851 | w_dim, # Intermediate latent (W) dimensionality. 852 | resolution, # Resolution of this block. 853 | img_channels, # Number of output color channels. 854 | is_last, # Is this the last block? 855 | architecture='skip', # Architecture: 'orig', 'skip', 'resnet'. 856 | resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. 857 | conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. 858 | use_fp16=False, # Use FP16 for this block? 859 | fp16_channels_last=False, # Use channels-last memory format with FP16? 860 | **layer_kwargs, # Arguments for SynthesisLayer. 861 | ): 862 | assert architecture in ['orig', 'skip', 'resnet'] 863 | super().__init__() 864 | self.in_channels = in_channels 865 | self.w_dim = w_dim 866 | self.resolution = resolution 867 | self.img_channels = img_channels 868 | self.is_last = is_last 869 | self.architecture = architecture 870 | self.use_fp16 = use_fp16 871 | self.channels_last = (use_fp16 and fp16_channels_last) 872 | self.register_buffer('resample_filter', setup_filter(resample_filter)) 873 | self.num_conv = 0 874 | self.num_torgb = 0 875 | self.res_ffc = {4: 0, 8: 0, 16: 0, 32: 1, 64: 1, 128: 1, 256: 1, 512: 1} 876 | 877 | if in_channels != 0 and resolution >= 8: 878 | self.ffc_skip = nn.ModuleList() 879 | for _ in range(self.res_ffc[resolution]): 880 | self.ffc_skip.append(FFCSkipLayer(dim=out_channels)) 881 | 882 | if in_channels == 0: 883 | self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution])) 884 | 885 | if in_channels != 0: 886 | self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim * 3, resolution=resolution, up=2, 887 | resample_filter=resample_filter, conv_clamp=conv_clamp, 888 | channels_last=self.channels_last, **layer_kwargs) 889 | self.num_conv += 1 890 | 891 | self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim * 3, resolution=resolution, 892 | conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs) 893 | self.num_conv += 1 894 | 895 | if is_last or architecture == 'skip': 896 | self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim * 3, 897 | conv_clamp=conv_clamp, channels_last=self.channels_last) 898 | self.num_torgb += 1 899 | 900 | if in_channels != 0 and architecture == 'resnet': 901 | self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2, 902 | resample_filter=resample_filter, channels_last=self.channels_last) 903 | 904 | def forward(self, x, mask, feats, img, ws, fname=None, force_fp32=False, fused_modconv=None, **layer_kwargs): 905 | dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32 906 | dtype = torch.float32 907 | memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format 908 | if fused_modconv is None: 909 | fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1) 910 | 911 | x = x.to(dtype=dtype, memory_format=memory_format) 912 | x_skip = feats[self.resolution].clone().to(dtype=dtype, memory_format=memory_format) 913 | 914 | # Main layers. 915 | if self.in_channels == 0: 916 | x = self.conv1(x, ws[1], fused_modconv=fused_modconv, **layer_kwargs) 917 | elif self.architecture == 'resnet': 918 | y = self.skip(x, gain=np.sqrt(0.5)) 919 | x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs) 920 | if len(self.ffc_skip) > 0: 921 | mask = F.interpolate(mask, size=x_skip.shape[2:], ) 922 | z = x + x_skip 923 | for fres in self.ffc_skip: 924 | z = fres(z, mask) 925 | x = x + z 926 | else: 927 | x = x + x_skip 928 | x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs) 929 | x = y.add_(x) 930 | else: 931 | x = self.conv0(x, ws[0].clone(), fused_modconv=fused_modconv, **layer_kwargs) 932 | if len(self.ffc_skip) > 0: 933 | mask = F.interpolate(mask, size=x_skip.shape[2:], ) 934 | z = x + x_skip 935 | for fres in self.ffc_skip: 936 | z = fres(z, mask) 937 | x = x + z 938 | else: 939 | x = x + x_skip 940 | x = self.conv1(x, ws[1].clone(), fused_modconv=fused_modconv, **layer_kwargs) 941 | # ToRGB. 942 | if img is not None: 943 | img = upsample2d(img, self.resample_filter) 944 | if self.is_last or self.architecture == 'skip': 945 | y = self.torgb(x, ws[2].clone(), fused_modconv=fused_modconv) 946 | y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) 947 | img = img.add_(y) if img is not None else y 948 | 949 | x = x.to(dtype=dtype) 950 | assert x.dtype == dtype 951 | assert img is None or img.dtype == torch.float32 952 | return x, img 953 | 954 | 955 | class SynthesisNetwork(torch.nn.Module): 956 | def __init__(self, 957 | w_dim, # Intermediate latent (W) dimensionality. 958 | z_dim, # Output Latent (Z) dimensionality. 959 | img_resolution, # Output image resolution. 960 | img_channels, # Number of color channels. 961 | channel_base=16384, # Overall multiplier for the number of channels. 962 | channel_max=512, # Maximum number of channels in any layer. 963 | num_fp16_res=0, # Use FP16 for the N highest resolutions. 964 | **block_kwargs, # Arguments for SynthesisBlock. 965 | ): 966 | assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0 967 | super().__init__() 968 | self.w_dim = w_dim 969 | self.img_resolution = img_resolution 970 | self.img_resolution_log2 = int(np.log2(img_resolution)) 971 | self.img_channels = img_channels 972 | self.block_resolutions = [2 ** i for i in range(3, self.img_resolution_log2 + 1)] 973 | channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions} 974 | fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) 975 | 976 | self.foreword = SynthesisForeword(img_channels=img_channels, in_channels=min(channel_base // 4, channel_max), 977 | z_dim=z_dim * 2, resolution=4) 978 | 979 | self.num_ws = self.img_resolution_log2 * 2 - 2 980 | for res in self.block_resolutions: 981 | if res // 2 in channels_dict.keys(): 982 | in_channels = channels_dict[res // 2] if res > 4 else 0 983 | else: 984 | in_channels = min(channel_base // (res // 2), channel_max) 985 | out_channels = channels_dict[res] 986 | use_fp16 = (res >= fp16_resolution) 987 | use_fp16 = False 988 | is_last = (res == self.img_resolution) 989 | block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res, 990 | img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs) 991 | setattr(self, f'b{res}', block) 992 | 993 | def forward(self, x_global, mask, feats, ws, fname=None, **block_kwargs): 994 | 995 | img = None 996 | 997 | x, img = self.foreword(x_global, ws, feats, img) 998 | 999 | for res in self.block_resolutions: 1000 | block = getattr(self, f'b{res}') 1001 | mod_vector0 = [] 1002 | mod_vector0.append(ws[:, int(np.log2(res)) * 2 - 5]) 1003 | mod_vector0.append(x_global.clone()) 1004 | mod_vector0 = torch.cat(mod_vector0, dim=1) 1005 | 1006 | mod_vector1 = [] 1007 | mod_vector1.append(ws[:, int(np.log2(res)) * 2 - 4]) 1008 | mod_vector1.append(x_global.clone()) 1009 | mod_vector1 = torch.cat(mod_vector1, dim=1) 1010 | 1011 | mod_vector_rgb = [] 1012 | mod_vector_rgb.append(ws[:, int(np.log2(res)) * 2 - 3]) 1013 | mod_vector_rgb.append(x_global.clone()) 1014 | mod_vector_rgb = torch.cat(mod_vector_rgb, dim=1) 1015 | x, img = block(x, mask, feats, img, (mod_vector0, mod_vector1, mod_vector_rgb), fname=fname, **block_kwargs) 1016 | return img 1017 | 1018 | 1019 | class MappingNetwork(torch.nn.Module): 1020 | def __init__(self, 1021 | z_dim, # Input latent (Z) dimensionality, 0 = no latent. 1022 | c_dim, # Conditioning label (C) dimensionality, 0 = no label. 1023 | w_dim, # Intermediate latent (W) dimensionality. 1024 | num_ws, # Number of intermediate latents to output, None = do not broadcast. 1025 | num_layers=8, # Number of mapping layers. 1026 | embed_features=None, # Label embedding dimensionality, None = same as w_dim. 1027 | layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. 1028 | activation='lrelu', # Activation function: 'relu', 'lrelu', etc. 1029 | lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. 1030 | w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. 1031 | ): 1032 | super().__init__() 1033 | self.z_dim = z_dim 1034 | self.c_dim = c_dim 1035 | self.w_dim = w_dim 1036 | self.num_ws = num_ws 1037 | self.num_layers = num_layers 1038 | self.w_avg_beta = w_avg_beta 1039 | 1040 | if embed_features is None: 1041 | embed_features = w_dim 1042 | if c_dim == 0: 1043 | embed_features = 0 1044 | if layer_features is None: 1045 | layer_features = w_dim 1046 | features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] 1047 | 1048 | if c_dim > 0: 1049 | self.embed = FullyConnectedLayer(c_dim, embed_features) 1050 | for idx in range(num_layers): 1051 | in_features = features_list[idx] 1052 | out_features = features_list[idx + 1] 1053 | layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) 1054 | setattr(self, f'fc{idx}', layer) 1055 | 1056 | if num_ws is not None and w_avg_beta is not None: 1057 | self.register_buffer('w_avg', torch.zeros([w_dim])) 1058 | 1059 | def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): 1060 | # Embed, normalize, and concat inputs. 1061 | x = None 1062 | with torch.autograd.profiler.record_function('input'): 1063 | if self.z_dim > 0: 1064 | x = normalize_2nd_moment(z.to(torch.float32)) 1065 | if self.c_dim > 0: 1066 | y = normalize_2nd_moment(self.embed(c.to(torch.float32))) 1067 | x = torch.cat([x, y], dim=1) if x is not None else y 1068 | 1069 | # Main layers. 1070 | for idx in range(self.num_layers): 1071 | layer = getattr(self, f'fc{idx}') 1072 | x = layer(x) 1073 | 1074 | # Update moving average of W. 1075 | if self.w_avg_beta is not None and self.training and not skip_w_avg_update: 1076 | with torch.autograd.profiler.record_function('update_w_avg'): 1077 | self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) 1078 | 1079 | # Broadcast. 1080 | if self.num_ws is not None: 1081 | with torch.autograd.profiler.record_function('broadcast'): 1082 | x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) 1083 | 1084 | # Apply truncation. 1085 | if truncation_psi != 1: 1086 | with torch.autograd.profiler.record_function('truncate'): 1087 | assert self.w_avg_beta is not None 1088 | if self.num_ws is None or truncation_cutoff is None: 1089 | x = self.w_avg.lerp(x, truncation_psi) 1090 | else: 1091 | x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) 1092 | return x 1093 | 1094 | 1095 | class Generator(torch.nn.Module): 1096 | def __init__(self, 1097 | z_dim, # Input latent (Z) dimensionality. 1098 | c_dim, # Conditioning label (C) dimensionality. 1099 | w_dim, # Intermediate latent (W) dimensionality. 1100 | img_resolution, # Output resolution. 1101 | img_channels, # Number of output color channels. 1102 | encoder_kwargs={}, # Arguments for EncoderNetwork. 1103 | mapping_kwargs={}, # Arguments for MappingNetwork. 1104 | synthesis_kwargs={}, # Arguments for SynthesisNetwork. 1105 | ): 1106 | super().__init__() 1107 | self.z_dim = z_dim 1108 | self.c_dim = c_dim 1109 | self.w_dim = w_dim 1110 | self.img_resolution = img_resolution 1111 | self.img_channels = img_channels 1112 | self.encoder = EncoderNetwork(c_dim=c_dim, z_dim=z_dim, img_resolution=img_resolution, 1113 | img_channels=img_channels, **encoder_kwargs) 1114 | self.synthesis = SynthesisNetwork(z_dim=z_dim, w_dim=w_dim, img_resolution=img_resolution, 1115 | img_channels=img_channels, **synthesis_kwargs) 1116 | self.num_ws = self.synthesis.num_ws 1117 | self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) 1118 | 1119 | def forward(self, img, c, fname=None, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs): 1120 | mask = img[:, -1].unsqueeze(1) 1121 | x_global, z, feats = self.encoder(img, c) 1122 | ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) 1123 | img = self.synthesis(x_global, mask, feats, ws, fname=fname, **synthesis_kwargs) 1124 | return img 1125 | 1126 | 1127 | FCF_MODEL_URL = os.environ.get( 1128 | "FCF_MODEL_URL", 1129 | "https://github.com/Sanster/models/releases/download/add_fcf/places_512_G.pth", 1130 | ) 1131 | 1132 | 1133 | class FcF(InpaintModel): 1134 | min_size = 512 1135 | pad_mod = 512 1136 | pad_to_square = True 1137 | 1138 | def init_model(self, device, **kwargs): 1139 | seed = 0 1140 | random.seed(seed) 1141 | np.random.seed(seed) 1142 | torch.manual_seed(seed) 1143 | torch.cuda.manual_seed_all(seed) 1144 | torch.backends.cudnn.deterministic = True 1145 | torch.backends.cudnn.benchmark = False 1146 | 1147 | kwargs = {'channel_base': 1 * 32768, 'channel_max': 512, 'num_fp16_res': 4, 'conv_clamp': 256} 1148 | G = Generator(z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3, 1149 | synthesis_kwargs=kwargs, encoder_kwargs=kwargs, mapping_kwargs={'num_layers': 2}) 1150 | self.model = load_model(G, FCF_MODEL_URL, device) 1151 | self.label = torch.zeros([1, self.model.c_dim], device=device) 1152 | 1153 | @staticmethod 1154 | def is_downloaded() -> bool: 1155 | return os.path.exists(get_cache_path_by_url(FCF_MODEL_URL)) 1156 | 1157 | @torch.no_grad() 1158 | def __call__(self, image, mask, config: Config): 1159 | """ 1160 | images: [H, W, C] RGB, not normalized 1161 | masks: [H, W] 1162 | return: BGR IMAGE 1163 | """ 1164 | if image.shape[0] == 512 and image.shape[1] == 512: 1165 | return self._pad_forward(image, mask, config) 1166 | 1167 | boxes = boxes_from_mask(mask) 1168 | crop_result = [] 1169 | config.hd_strategy_crop_margin = 128 1170 | for box in boxes: 1171 | crop_image, crop_mask, crop_box = self._crop_box(image, mask, box, config) 1172 | origin_size = crop_image.shape[:2] 1173 | resize_image = resize_max_size(crop_image, size_limit=512) 1174 | resize_mask = resize_max_size(crop_mask, size_limit=512) 1175 | inpaint_result = self._pad_forward(resize_image, resize_mask, config) 1176 | 1177 | # only paste masked area result 1178 | inpaint_result = cv2.resize(inpaint_result, (origin_size[1], origin_size[0]), interpolation=cv2.INTER_CUBIC) 1179 | 1180 | original_pixel_indices = crop_mask < 127 1181 | inpaint_result[original_pixel_indices] = crop_image[:, :, ::-1][original_pixel_indices] 1182 | 1183 | crop_result.append((inpaint_result, crop_box)) 1184 | 1185 | inpaint_result = image[:, :, ::-1] 1186 | for crop_image, crop_box in crop_result: 1187 | x1, y1, x2, y2 = crop_box 1188 | inpaint_result[y1:y2, x1:x2, :] = crop_image 1189 | 1190 | return inpaint_result 1191 | 1192 | def forward(self, image, mask, config: Config): 1193 | """Input images and output images have same size 1194 | images: [H, W, C] RGB 1195 | masks: [H, W] mask area == 255 1196 | return: BGR IMAGE 1197 | """ 1198 | 1199 | image = norm_img(image) # [0, 1] 1200 | image = image * 2 - 1 # [0, 1] -> [-1, 1] 1201 | mask = (mask > 120) * 255 1202 | mask = norm_img(mask) 1203 | 1204 | image = torch.from_numpy(image).unsqueeze(0).to(self.device) 1205 | mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) 1206 | 1207 | erased_img = image * (1 - mask) 1208 | input_image = torch.cat([0.5 - mask, erased_img], dim=1) 1209 | 1210 | output = self.model(input_image, self.label, truncation_psi=0.1, noise_mode='none') 1211 | output = (output.permute(0, 2, 3, 1) * 127.5 + 127.5).round().clamp(0, 255).to(torch.uint8) 1212 | output = output[0].cpu().numpy() 1213 | cur_res = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 1214 | return cur_res 1215 | --------------------------------------------------------------------------------