├── LICENSE ├── README.md ├── __assets__ └── github │ ├── method_arch.png │ └── teaser.png ├── config ├── ddpm │ ├── v1.yaml │ └── v2-upsample.yaml ├── encoders │ ├── clip.yaml │ └── openclip.yaml ├── unet │ ├── inpainting │ │ ├── v1.yaml │ │ └── v2.yaml │ ├── upsample │ │ └── v2.yaml │ ├── v │ │ └── v2.yaml │ ├── v1.yaml │ └── v2.yaml ├── vae-upsample.yaml └── vae.yaml ├── data ├── masks │ ├── 0_rgb.png │ ├── 1_rgb.png │ └── 2_rgb.png ├── metadata │ ├── 0.json │ ├── 1.json │ └── 2.json └── outputs │ ├── 0_rgb.png │ ├── 1_rgb.png │ └── 2_rgb.png ├── requirements.txt ├── src ├── smplfusion │ ├── __init__.py │ ├── animation.py │ ├── api │ │ └── __init__.py │ ├── common.py │ ├── config.py │ ├── ddim.py │ ├── libimage │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── iimage.cpython-310.pyc │ │ │ ├── iimage.cpython-39.pyc │ │ │ ├── iimage_gallery.cpython-310.pyc │ │ │ ├── iimage_gallery.cpython-39.pyc │ │ │ ├── utils.cpython-310.pyc │ │ │ └── utils.cpython-39.pyc │ │ ├── iimage.py │ │ ├── iimage_gallery.py │ │ └── utils.py │ ├── libpath.py │ ├── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── unet.cpython-39.pyc │ │ │ ├── util.cpython-39.pyc │ │ │ └── vae.cpython-39.pyc │ │ ├── encoders │ │ │ ├── __pycache__ │ │ │ │ └── clip_embedder.cpython-39.pyc │ │ │ ├── clip_embedder.py │ │ │ ├── clip_image_embedder.py │ │ │ ├── modules.py │ │ │ ├── open_clip_embedder.py │ │ │ ├── open_clip_image_embedder.py │ │ │ └── t5_mebedder.py │ │ ├── unet.py │ │ ├── util.py │ │ └── vae.py │ ├── modelzoo │ │ ├── __init__.py │ │ ├── dreamshaper8.py │ │ └── dreamshaper8_inpainting.py │ ├── modules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── autoencoder.cpython-39.pyc │ │ │ ├── distributions.cpython-39.pyc │ │ │ ├── ema.cpython-39.pyc │ │ │ └── util.cpython-39.pyc │ │ ├── attention │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ ├── basic_transformer_block.cpython-39.pyc │ │ │ │ ├── cross_attention.cpython-39.pyc │ │ │ │ ├── feed_forward.cpython-39.pyc │ │ │ │ ├── memory_efficient_cross_attention.cpython-39.pyc │ │ │ │ └── spatial_transformer.cpython-39.pyc │ │ │ ├── basic_transformer_block.py │ │ │ ├── cross_attention.py │ │ │ ├── feed_forward.py │ │ │ ├── memory_efficient_cross_attention.py │ │ │ └── spatial_transformer.py │ │ ├── autoencoder.py │ │ ├── distributions.py │ │ ├── ema.py │ │ ├── partial_conv2d.py │ │ └── util.py │ ├── patches │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ └── router.cpython-39.pyc │ │ ├── attentionpatch │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ ├── attstore.cpython-39.pyc │ │ │ │ ├── boosted_maskscore.cpython-39.pyc │ │ │ │ ├── default.cpython-39.pyc │ │ │ │ ├── ediff.cpython-39.pyc │ │ │ │ ├── inpaint.cpython-39.pyc │ │ │ │ ├── introvert.cpython-39.pyc │ │ │ │ ├── shuffled.cpython-39.pyc │ │ │ │ └── zeropaint.cpython-39.pyc │ │ │ ├── attstore.py │ │ │ ├── boosted_maskscore copy.py │ │ │ ├── boosted_maskscore.py │ │ │ ├── default.py │ │ │ ├── inpaint.py │ │ │ ├── introvert.py │ │ │ ├── maskscore.py │ │ │ ├── other.py │ │ │ ├── shuffled.py │ │ │ └── zeropaint.py │ │ ├── router.py │ │ └── transformerpatch │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── default.cpython-39.pyc │ │ │ ├── guided.cpython-39.pyc │ │ │ ├── introvert.cpython-39.pyc │ │ │ └── weighting_versions.cpython-39.pyc │ │ │ ├── default.py │ │ │ ├── guided.py │ │ │ ├── introvert.py │ │ │ └── weighting_versions.py │ ├── scheduler.py │ ├── share.py │ ├── util.py │ └── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── input_image.cpython-39.pyc │ │ ├── input_mask.cpython-39.pyc │ │ ├── input_shape.cpython-39.pyc │ │ └── layer_mask.cpython-39.pyc │ │ ├── input_image.py │ │ ├── input_mask.py │ │ ├── input_shape.py │ │ └── layer_mask.py └── zeropainter │ ├── __init__.py │ ├── convert_diffusers.py │ ├── dreamshaper.py │ ├── generation.py │ ├── inpainting.py │ ├── models.py │ ├── segmentation.py │ ├── zero_painter_dataset.py │ └── zero_painter_pipline.py └── zero_painter.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Picsart AI Research (PAIR) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Zero-Painter: Training-Free Layout Control for Text-to-Image Synthesis [CVPR 2024] 2 | 3 | This repository is the official implementation of [Zero-Painter](https://arxiv.org/abs/2406.04032). 4 | 5 | 6 | **[Zero-Painter: Training-Free Layout Control for Text-to-Image Synthesis](https://arxiv.org/abs/2406.04032)** 7 |
8 | Marianna Ohanyan*, 9 | Hayk Manukyan*, 10 | Zhangyang Wang, 11 | Shant Navasardyan, 12 | [Humphrey Shi](https://www.humphreyshi.com) 13 |
14 | 15 | [Arxiv](https://arxiv.org/abs/2406.04032) 16 | 17 |

18 | 19 |
20 | 21 | We present Zero-Painter , a novel training-free framework for layout-conditional text-to-image synthesis that facilitates the creation of detailed and controlled imagery from textual prompts. Our method utilizes object masks and individual descriptions, coupled with a global text prompt, to generate images with high fidelity. Zero-Painter employs a two-stage process involving our novel Prompt-Adjusted Cross-Attention (PACA) and Region-Grouped Cross-Attention (ReGCA) blocks, ensuring precise alignment of generated objects with textual prompts and mask shapes. Our extensive experiments demonstrate that Zero-Painter surpasses current state-of-the-art methods in preserving textual details and adhering to mask shapes. 22 | 23 | 24 |

25 | 26 | ## 🔥 News 27 | - [2024.06.6] ZeroPainter paper and code is released. 28 | - [2024.02.27] Paper is accepted to CVPR 2024. 29 | 30 | 31 | ## ⚒️ Installation 32 | 33 | 38 | Install with `pip`: 39 | ```bash 40 | pip3 install -r requirements.txt 41 | ``` 42 | 43 | ## 💃 Inference: Generate images with Zero-Painter 44 | 45 | 1. Download [models](https://huggingface.co/PAIR/Zero-Painter) and put them in the `models` folder. 46 | 2. You can use the following script to perform inference on the given mask and prompts pair: 47 | ``` 48 | python zero_painter.py \ 49 | --mask-path data/masks/1_rgb.png \ 50 | --metadata data/metadata/1.json \ 51 | --output-dir data/outputs/ 52 | ``` 53 | 54 | `meatadata` sould be in the following format 55 | ``` 56 | [{ 57 | "prompt": "Brown gift box beside red candle.", 58 | "color_context_dict": { 59 | "(244, 54, 32)": "Brown gift box", 60 | "(54, 245, 32)": "red candle" 61 | } 62 | }] 63 | ``` 64 | 73 | 74 | ## Method 75 | 76 | 77 | 78 | --- 79 | 80 | ## 🎓 Citation 81 | If you use our work in your research, please cite our publication: 82 | ``` 83 | @article{Zeropainter, 84 | title={Zero-Painter: Training-Free Layout Control for Text-to-Image Synthesis}, 85 | url={http://arxiv.org/abs/2406.04032}, 86 | publisher={arXiv}, 87 | author={Ohanyan, Marianna and Manukyan, Hayk and Wang, Zhangyang and Navasardyan, Shant and Shi, Humphrey}, 88 | year={2024}} 89 | 90 | ``` -------------------------------------------------------------------------------- /__assets__/github/method_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/__assets__/github/method_arch.png -------------------------------------------------------------------------------- /__assets__/github/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/__assets__/github/teaser.png -------------------------------------------------------------------------------- /config/ddpm/v1.yaml: -------------------------------------------------------------------------------- 1 | linear_start: 0.00085 2 | linear_end: 0.0120 3 | num_timesteps_cond: 1 4 | log_every_t: 200 5 | timesteps: 1000 6 | first_stage_key: "jpg" 7 | cond_stage_key: "txt" 8 | image_size: 64 9 | channels: 4 10 | cond_stage_trainable: false 11 | conditioning_key: crossattn 12 | monitor: val/loss_simple_ema 13 | scale_factor: 0.18215 14 | use_ema: False # we set this to false because this is an inference only config -------------------------------------------------------------------------------- /config/ddpm/v2-upsample.yaml: -------------------------------------------------------------------------------- 1 | parameterization: "v" 2 | low_scale_key: "lr" 3 | linear_start: 0.0001 4 | linear_end: 0.02 5 | num_timesteps_cond: 1 6 | log_every_t: 200 7 | timesteps: 1000 8 | first_stage_key: "jpg" 9 | cond_stage_key: "txt" 10 | image_size: 128 11 | channels: 4 12 | cond_stage_trainable: false 13 | conditioning_key: "hybrid-adm" 14 | monitor: val/loss_simple_ema 15 | scale_factor: 0.08333 16 | use_ema: False 17 | 18 | low_scale_config: 19 | target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation 20 | params: 21 | noise_schedule_config: # image space 22 | linear_start: 0.0001 23 | linear_end: 0.02 24 | max_noise_level: 350 25 | -------------------------------------------------------------------------------- /config/encoders/clip.yaml: -------------------------------------------------------------------------------- 1 | __class__: smplfusion.models.encoders.clip_embedder.FrozenCLIPEmbedder -------------------------------------------------------------------------------- /config/encoders/openclip.yaml: -------------------------------------------------------------------------------- 1 | __class__: smplfusion.models.encoders.open_clip_embedder.FrozenOpenCLIPEmbedder 2 | __init__: 3 | freeze: True 4 | layer: "penultimate" -------------------------------------------------------------------------------- /config/unet/inpainting/v1.yaml: -------------------------------------------------------------------------------- 1 | __class__: smplfusion.models.unet.UNetModel 2 | __init__: 3 | image_size: 32 # unused 4 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask 5 | out_channels: 4 6 | model_channels: 320 7 | attention_resolutions: [ 4, 2, 1 ] 8 | num_res_blocks: 2 9 | channel_mult: [ 1, 2, 4, 4 ] 10 | num_heads: 8 11 | use_spatial_transformer: True 12 | transformer_depth: 1 13 | context_dim: 768 14 | use_checkpoint: False 15 | legacy: False -------------------------------------------------------------------------------- /config/unet/inpainting/v2.yaml: -------------------------------------------------------------------------------- 1 | __class__: smplfusion.models.unet.UNetModel 2 | __init__: 3 | use_checkpoint: False 4 | image_size: 32 # unused 5 | in_channels: 9 6 | out_channels: 4 7 | model_channels: 320 8 | attention_resolutions: [ 4, 2, 1 ] 9 | num_res_blocks: 2 10 | channel_mult: [ 1, 2, 4, 4 ] 11 | num_head_channels: 64 # need to fix for flash-attn 12 | use_spatial_transformer: True 13 | use_linear_in_transformer: True 14 | transformer_depth: 1 15 | context_dim: 1024 16 | legacy: False -------------------------------------------------------------------------------- /config/unet/upsample/v2.yaml: -------------------------------------------------------------------------------- 1 | __class__: smplfusion.models.unet.UNetModel 2 | __init__: 3 | use_checkpoint: False 4 | num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) 5 | image_size: 128 6 | in_channels: 7 7 | out_channels: 4 8 | model_channels: 256 9 | attention_resolutions: [ 2,4,8] 10 | num_res_blocks: 2 11 | channel_mult: [ 1, 2, 2, 4] 12 | disable_self_attentions: [True, True, True, False] 13 | disable_middle_self_attn: False 14 | num_heads: 8 15 | use_spatial_transformer: True 16 | transformer_depth: 1 17 | context_dim: 1024 18 | legacy: False 19 | use_linear_in_transformer: True -------------------------------------------------------------------------------- /config/unet/v/v2.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/config/unet/v/v2.yaml -------------------------------------------------------------------------------- /config/unet/v1.yaml: -------------------------------------------------------------------------------- 1 | __class__: smplfusion.models.unet.UNetModel 2 | __init__: 3 | image_size: 32 # unused 4 | in_channels: 4 5 | out_channels: 4 6 | model_channels: 320 7 | attention_resolutions: [ 4, 2, 1 ] 8 | num_res_blocks: 2 9 | channel_mult: [ 1, 2, 4, 4 ] 10 | num_heads: 8 11 | use_spatial_transformer: True 12 | transformer_depth: 1 13 | context_dim: 768 14 | use_checkpoint: False 15 | legacy: False -------------------------------------------------------------------------------- /config/unet/v2.yaml: -------------------------------------------------------------------------------- 1 | __class__: smplfusion.models.unet.UNetModel 2 | __init__: 3 | use_checkpoint: False 4 | use_fp16: True 5 | image_size: 32 # unused 6 | in_channels: 4 7 | out_channels: 4 8 | model_channels: 320 9 | attention_resolutions: [ 4, 2, 1 ] 10 | num_res_blocks: 2 11 | channel_mult: [ 1, 2, 4, 4 ] 12 | num_head_channels: 64 13 | use_spatial_transformer: True 14 | use_linear_in_transformer: True 15 | transformer_depth: 1 16 | context_dim: 1024 17 | legacy: False -------------------------------------------------------------------------------- /config/vae-upsample.yaml: -------------------------------------------------------------------------------- 1 | __class__: smplfusion.models.vae.AutoencoderKL 2 | __init__: 3 | embed_dim: 4 4 | ddconfig: 5 | double_z: True 6 | z_channels: 4 7 | resolution: 256 8 | in_channels: 3 9 | out_ch: 3 10 | ch: 128 11 | ch_mult: [ 1,2,4 ] 12 | num_res_blocks: 2 13 | attn_resolutions: [ ] 14 | dropout: 0.0 15 | lossconfig: 16 | target: torch.nn.Identity -------------------------------------------------------------------------------- /config/vae.yaml: -------------------------------------------------------------------------------- 1 | __class__: smplfusion.models.vae.AutoencoderKL 2 | __init__: 3 | embed_dim: 4 4 | monitor: val/rec_loss 5 | ddconfig: 6 | double_z: true 7 | z_channels: 4 8 | resolution: 256 9 | in_channels: 3 10 | out_ch: 3 11 | ch: 128 12 | ch_mult: [1,2,4,4] 13 | num_res_blocks: 2 14 | attn_resolutions: [] 15 | dropout: 0.0 16 | lossconfig: 17 | target: torch.nn.Identity -------------------------------------------------------------------------------- /data/masks/0_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/data/masks/0_rgb.png -------------------------------------------------------------------------------- /data/masks/1_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/data/masks/1_rgb.png -------------------------------------------------------------------------------- /data/masks/2_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/data/masks/2_rgb.png -------------------------------------------------------------------------------- /data/metadata/0.json: -------------------------------------------------------------------------------- 1 | [{ 2 | "prompt": "Brown gift box beside red candle.", 3 | "color_context_dict": { 4 | "(244, 54, 32)": "Brown gift box", 5 | "(54, 245, 32)": "red candle" 6 | } 7 | }] 8 | -------------------------------------------------------------------------------- /data/metadata/1.json: -------------------------------------------------------------------------------- 1 | [ { 2 | "prompt": "Brown tabby cat on white stairs photo", 3 | "color_context_dict": { 4 | "(244, 54, 32)": "Brown tabby cat", 5 | "(54, 245, 32)": "blue vase", 6 | "(4,54,200)": "red apple" 7 | } 8 | }] -------------------------------------------------------------------------------- /data/metadata/2.json: -------------------------------------------------------------------------------- 1 | [{ 2 | "prompt": "Green succulent on white and pink pot photo", 3 | "color_context_dict": { 4 | "(213, 24, 207)": "white and pink pot", 5 | "(78, 213, 24)": "Green succulent " 6 | } 7 | }] -------------------------------------------------------------------------------- /data/outputs/0_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/data/outputs/0_rgb.png -------------------------------------------------------------------------------- /data/outputs/1_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/data/outputs/1_rgb.png -------------------------------------------------------------------------------- /data/outputs/2_rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/data/outputs/2_rgb.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | torchvision 3 | einops 4 | pytorch_lightning 5 | xformers 6 | open_clip_torch 7 | ipywidgets 8 | transformers==4.40.2 9 | segment-anything 10 | imageio 11 | scipy 12 | opencv-python 13 | matplotlib 14 | omegaconf -------------------------------------------------------------------------------- /src/smplfusion/__init__.py: -------------------------------------------------------------------------------- 1 | # TODO: Remove farancia.Path depency! 2 | from .libpath import Path 3 | from .libimage import IImage 4 | 5 | print (Path.cwd(__file__)) 6 | config = Path.cwd(__file__).config 7 | models = Path.cwd(__file__).config 8 | options = Path.cwd(__file__).config -------------------------------------------------------------------------------- /src/smplfusion/animation.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib import animation 3 | from IPython.display import HTML, Image, display 4 | 5 | class Animation: 6 | JS = 0 7 | HTML = 1 8 | ANIMATION_MODE = HTML 9 | def __init__(self, frames, fps = 30): 10 | """_summary_ 11 | 12 | Args: 13 | frames (np.ndarray): _description_ 14 | """ 15 | self.frames = frames 16 | self.fps = fps 17 | self.anim_obj = None 18 | self.anim_str = None 19 | def render(self): 20 | size = (self.frames.shape[2],self.frames.shape[1]) 21 | self.fig = plt.figure(figsize = size, dpi = 1) 22 | plt.axis('off') 23 | img = plt.imshow(self.frames[0], cmap = 'gray') 24 | self.fig.subplots_adjust(0,0,1,1) 25 | self.anim_obj = animation.FuncAnimation( 26 | self.fig, 27 | lambda i: img.set_data(self.frames[i,:,:,:]), 28 | frames=self.frames.shape[0], 29 | interval = 1000 / self.fps 30 | ) 31 | plt.close() 32 | if Animation.ANIMATION_MODE == Animation.HTML: 33 | self.anim_str = self.anim_obj.to_html5_video() 34 | elif Animation.ANIMATION_MODE == Animation.JS: 35 | self.anim_str = self.anim_obj.to_jshtml() 36 | return self.anim_obj 37 | def _repr_html_(self): 38 | if self.anim_obj is None: self.render() 39 | return self.anim_str -------------------------------------------------------------------------------- /src/smplfusion/api/__init__.py: -------------------------------------------------------------------------------- 1 | # API goes here! -------------------------------------------------------------------------------- /src/smplfusion/common.py: -------------------------------------------------------------------------------- 1 | from . import share, scheduler 2 | from .ddim import DDIM 3 | from .patches import router, attentionpatch, transformerpatch 4 | from . import options 5 | from pytorch_lightning import seed_everything 6 | import open_clip 7 | 8 | def count_tokens(prompt): 9 | tokens = open_clip.tokenize(prompt)[0] 10 | return (tokens > 0).sum() 11 | 12 | def tokenize(prompt): 13 | tokens = open_clip.tokenize(prompt)[0] 14 | return [open_clip.tokenizer._tokenizer.decoder[x.item()] for x in tokens] 15 | 16 | def get_token_idx(prompt, prefix, positive_prompt): 17 | prompt = prefix.format(prompt) 18 | return list(range(1 + prefix.split(' ').index('{}'), tokenize(prompt).index(''))) + [tokenize(prompt + positive_prompt).index('')] 19 | 20 | def load_model_v2_inpainting(folder): 21 | model_config = options.ddpm.v1_yaml 22 | 23 | unet = options.unet.inpainting.v2_yaml.eval().cuda() 24 | unet.load_state_dict(folder.unet_ckpt) 25 | unet = unet.requires_grad_(False) 26 | 27 | vae = options.vae_yaml.eval().cuda() 28 | vae.load_state_dict(folder.vae_ckpt) 29 | vae = vae.requires_grad_(False) 30 | 31 | encoder = options.encoders.openclip_yaml.eval().cuda() 32 | encoder.load_state_dict(folder.encoder_ckpt) 33 | encoder = encoder.requires_grad_(False) 34 | 35 | return model_config, unet, vae, encoder 36 | 37 | def load_model_v15_inpainting(folder): 38 | model_config = options.ddpm.v1_yaml 39 | 40 | unet = options.unet.inpainting.v1_yaml.eval().cuda() 41 | unet.load_state_dict(folder.unet_ckpt) 42 | unet = unet.requires_grad_(False) 43 | 44 | vae = options.vae_yaml.eval().cuda() 45 | vae.load_state_dict(folder.vae_ckpt) 46 | vae = vae.requires_grad_(False) 47 | 48 | encoder = options.encoders.clip_yaml.eval().cuda() 49 | encoder.load_state_dict(folder.encoder_ckpt) 50 | encoder = encoder.requires_grad_(False) 51 | 52 | return model_config, unet, vae, encoder 53 | 54 | def load_model_v2(folder): 55 | model_config = options.ddpm.v1_yaml 56 | unet = options.unet.v2_yaml.eval().cuda() 57 | unet.load_state_dict(folder.unet_ckpt) 58 | unet = unet.requires_grad_(False) 59 | vae = options.vae_yaml.eval().cuda() 60 | vae.load_state_dict(folder.vae_ckpt) 61 | vae = vae.requires_grad_(False) 62 | encoder = options.encoders.openclip_yaml.eval().cuda() 63 | encoder.load_state_dict(folder.encoder_ckpt) 64 | encoder = encoder.requires_grad_(False) 65 | return model_config, unet, vae, encoder 66 | 67 | def load_model_v1(folder): 68 | model_config = options.ddpm.v1_yaml 69 | unet = options.unet.v1_yaml.eval().cuda() 70 | unet.load_state_dict(folder.unet_ckpt) 71 | unet = unet.requires_grad_(False) 72 | vae = options.vae_yaml.eval().cuda() 73 | vae.load_state_dict(folder.vae_ckpt) 74 | vae = vae.requires_grad_(False) 75 | encoder = options.encoders.clip_yaml.eval().cuda() 76 | encoder.load_state_dict(folder.encoder_ckpt) 77 | encoder = encoder.requires_grad_(False) 78 | return model_config, unet, vae, encoder 79 | 80 | def load_unet_v2(folder): 81 | unet = options.unet.v2_yaml.eval().cuda() 82 | unet.load_state_dict(folder.unet_ckpt) 83 | unet = unet.requires_grad_(False) 84 | return unet -------------------------------------------------------------------------------- /src/smplfusion/config.py: -------------------------------------------------------------------------------- 1 | IMG_THUMBSIZE = None -------------------------------------------------------------------------------- /src/smplfusion/ddim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm.notebook import tqdm 3 | from . import scheduler 4 | from . import share 5 | 6 | from .libimage import IImage 7 | 8 | class DDIM: 9 | def __init__(self, config, vae, encoder, unet): 10 | self.vae = vae 11 | self.encoder = encoder 12 | self.unet = unet 13 | self.config = config 14 | self.schedule = scheduler.linear(1000, config.linear_start, config.linear_end) 15 | 16 | def __call__( 17 | self, prompt = '', dt = 50, shape = (1,4,64,64), seed = None, negative_prompt = '', unet_condition = None, 18 | context = None, verbose = True): 19 | if seed is not None: torch.manual_seed(seed) 20 | if unet_condition is not None: 21 | zT = torch.randn((1,4) + unet_condition.shape[2:]).cuda() 22 | else: 23 | zT = torch.randn(shape).cuda() 24 | 25 | with torch.autocast('cuda'), torch.no_grad(): 26 | if context is None: context = self.encoder.encode([negative_prompt, prompt]) 27 | 28 | zt = zT 29 | pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt) 30 | for timestep in share.DDIMIterator(pbar): 31 | _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1) 32 | eps_uncond, eps = self.unet( 33 | torch.cat([_zt, _zt]), 34 | timesteps = torch.tensor([timestep, timestep]).cuda(), 35 | context = context 36 | ).chunk(2) 37 | 38 | eps = (eps_uncond + 7.5 * (eps - eps_uncond)) 39 | 40 | z0 = (zt - self.schedule.sqrt_one_minus_alphas[timestep] * eps) / self.schedule.sqrt_alphas[timestep] 41 | zt = self.schedule.sqrt_alphas[timestep - dt] * z0 + self.schedule.sqrt_one_minus_alphas[timestep - dt] * eps 42 | return IImage(self.vae.decode(z0 / self.config.scale_factor)) 43 | 44 | def encode(self, image): 45 | return self.vae.encode(image.padx(64).torch().cuda()).mean * self.config.scale_factor 46 | def decode(self, latent): 47 | return IImage(self.vae.decode(latent / self.config.scale_factor)) 48 | 49 | def get_inpainting_condition(self, image, mask): 50 | latent_size = [x//8 for x in image.size] 51 | condition_x0 = self.vae.encode(image.torch().cuda() * ~mask.torch(0).bool().cuda()).mean * self.config.scale_factor 52 | condition_mask = mask.resize(latent_size[::-1]).cuda().torch(0).bool().float() 53 | 54 | # condition_x0 += 0.01 * condition_mask * torch.randn_like(condition_mask) 55 | return torch.cat([condition_mask, condition_x0], 1) 56 | 57 | inpainting_condition = get_inpainting_condition 58 | 59 | -------------------------------------------------------------------------------- /src/smplfusion/libimage/__init__.py: -------------------------------------------------------------------------------- 1 | from .iimage import IImage 2 | from .iimage_gallery import ImageGallery 3 | from .utils import bytes2html 4 | 5 | import math 6 | import numpy as np 7 | import warnings 8 | 9 | # ========= STATIC FUNCTIONS ============= 10 | def find_max_h(images): 11 | return max([x.size[1] for x in images]) 12 | def find_max_w(images): 13 | return max([x.size[0] for x in images]) 14 | def find_max_size(images): 15 | return find_max_w(images), find_max_h(images) 16 | 17 | 18 | def stack(images, axis = 0): 19 | return IImage(np.concatenate([x.data for x in images], axis)) 20 | def tstack(images): 21 | w,h = find_max_size(images) 22 | images = [x.pad2wh(w,h) for x in images] 23 | return IImage(np.concatenate([x.data for x in images], 0)) 24 | def hstack(images): 25 | h = find_max_h(images) 26 | images = [x.pad2wh(h = h) for x in images] 27 | return IImage(np.concatenate([x.data for x in images], 2)) 28 | def vstack(images): 29 | w = find_max_w(images) 30 | images = [x.pad2wh(w = w) for x in images] 31 | return IImage(np.concatenate([x.data for x in images], 1)) 32 | 33 | def grid(images, nrows = None, ncols = None): 34 | combined = stack(images) 35 | if nrows is not None: 36 | ncols = math.ceil(combined.data.shape[0] / nrows) 37 | elif ncols is not None: 38 | nrows = math.ceil(combined.data.shape[0] / ncols) 39 | else: 40 | warnings.warn("No dimensions specified, creating a grid with 5 columns (default)") 41 | ncols = 5 42 | nrows = math.ceil(combined.data.shape[0] / ncols) 43 | 44 | pad = nrows * ncols - combined.data.shape[0] 45 | data = np.pad(combined.data, ((0,pad),(0,0),(0,0),(0,0))) 46 | rows = [np.concatenate(x,1,dtype=np.uint8) for x in np.array_split(data, nrows)] 47 | return IImage(np.concatenate(rows, 0, dtype = np.uint8)[None]) -------------------------------------------------------------------------------- /src/smplfusion/libimage/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/smplfusion/libimage/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/libimage/__pycache__/iimage.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/iimage.cpython-310.pyc -------------------------------------------------------------------------------- /src/smplfusion/libimage/__pycache__/iimage.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/iimage.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/libimage/__pycache__/iimage_gallery.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/iimage_gallery.cpython-310.pyc -------------------------------------------------------------------------------- /src/smplfusion/libimage/__pycache__/iimage_gallery.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/iimage_gallery.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/libimage/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/smplfusion/libimage/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/libimage/iimage_gallery.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .utils import bytes2html 3 | 4 | class ImageGallery: 5 | def __init__(self, images, captions = None, size='auto', max_rows = -1, root_path = '/', caption_font_size = '2em'): 6 | self.size = size 7 | self.images = images 8 | self.display_str = None 9 | self.captions = captions if captions is not None else [''] * len(self.images) 10 | self.max_rows = max_rows 11 | self.root_path = root_path 12 | self.caption_font_size = caption_font_size 13 | 14 | def generate_display(self): 15 | """Shows a set of images in a gallery that flexes with the width of the notebook. 16 | 17 | Parameters 18 | ---------- 19 | images: list of str or bytes 20 | URLs or bytes of images to display 21 | 22 | row_height: str 23 | CSS height value to assign to all images. Set to 'auto' by default to show images 24 | with their native dimensions. Set to a value e.g. '250px' to make all rows 25 | in the gallery equal height. 26 | """ 27 | figures = [] 28 | row_figures = 0 29 | for image, caption in zip(self.images, self.captions): 30 | if isinstance(image, str): 31 | with open(image,'rb') as f: 32 | link = os.path.relpath(image, self.root_path) 33 | # src = bytes2html(f.read(), width = self.size) 34 | src = bytes2html(f.read()) 35 | src = f'{src}' 36 | else: 37 | if image.display_str is None: image.generate_display() 38 | if image.is_video(): 39 | src = image.display_str 40 | else: 41 | # src = _src_from_data(image.display_str) 42 | src = image.to_html(width = "100%", root_path = self.root_path) 43 | if caption != '': 44 | caption = f'
{caption}
' 45 | figures.append(f''' 46 |
47 | {caption} 48 | {src} 49 |
50 | ''') 51 | row_figures += 1 52 | if row_figures == self.max_rows: 53 | row_figures = 0 54 | figures.append('
') 55 | self.display_str = f''' 56 |
57 | {''.join(figures)} 58 |
59 | ''' 60 | 61 | def _repr_html_(self): 62 | if self.display_str is None: self.generate_display() 63 | return self.display_str 64 | 65 | def save(self, path): 66 | if self.display_str is None: self.generate_display() 67 | with open(path, 'w') as f: 68 | f.write(self.display_str) -------------------------------------------------------------------------------- /src/smplfusion/libimage/utils.py: -------------------------------------------------------------------------------- 1 | from IPython.display import Image as IpyImage 2 | 3 | def bytes2html(data, width='auto'): 4 | img_obj = IpyImage(data=data, format='JPG') 5 | for bundle in img_obj._repr_mimebundle_(): 6 | for mimetype, b64value in bundle.items(): 7 | if mimetype.startswith('image/'): 8 | return f'' -------------------------------------------------------------------------------- /src/smplfusion/libpath.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import * 3 | import sys 4 | 5 | import pickle 6 | import inspect 7 | import json 8 | import importlib.util 9 | import importlib 10 | import random 11 | import csv 12 | 13 | # Requirements 14 | import yaml 15 | from omegaconf import OmegaConf 16 | import torch 17 | import numpy as np 18 | from PIL import Image 19 | from IPython.display import Markdown # TODO: Remove this requirement 20 | 21 | from .libimage import IImage 22 | # from .experiment import SmartFolder 23 | 24 | def instantiate_from_config(config): 25 | if not "target" in config: 26 | raise KeyError("Expected key `target` to instantiate.") 27 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 28 | 29 | def load_obj(objyaml): 30 | if "__init__" in objyaml: 31 | return get_obj_from_str(objyaml['__class__'])(**objyaml["__init__"]) 32 | else: 33 | return get_obj_from_str(objyaml['__class__'])() 34 | 35 | def get_obj_from_str(string): 36 | module, cls = string.rsplit(".", 1) 37 | try: 38 | return getattr(importlib.import_module(module, package=None), cls) 39 | except: 40 | return getattr(importlib.import_module('lib.' + module, package=None), cls) 41 | 42 | 43 | def path2var(path): 44 | return path.replace(" ", "_").replace(".", "_").replace("-", "_").replace(":", "_") 45 | 46 | 47 | def var2path(path): 48 | return path.replace("ß", " ").replace("ä", ".").replace("ö", "-") 49 | 50 | 51 | def save(obj, path): 52 | pass 53 | 54 | class FileType: 55 | pass 56 | 57 | def cvt(x: str): 58 | try: return int(x) 59 | except: 60 | try: return float(x) 61 | except: return x 62 | 63 | class Path: 64 | FOLDER = FileType() 65 | 66 | def cwd(path): 67 | return Path(os.path.dirname(os.path.abspath(path))) 68 | def __init__(self, path): 69 | self.path = path 70 | if os.path.isdir(self.path): 71 | self.reload() 72 | else: 73 | self.extension = self.path.split(".")[-1] 74 | self.subpaths = {} 75 | self.filename = os.path.basename(self.path) 76 | 77 | def reload(self): 78 | if os.path.isdir(self.path): 79 | self.subpaths = {path2var(x): x for x in os.listdir(self.path)} 80 | self.keyorder = sorted(list(self.subpaths.keys())) 81 | self.extension = None 82 | 83 | def read(self, is_preview=False): 84 | if self.extension is None or is_preview: 85 | # if self.isdir and os.path.exists(f'{self}/__smart__'): 86 | # return SmartFolder.open(str(self)) 87 | return self 88 | elif self.extension == "yaml": 89 | yaml = OmegaConf.load(self.path) 90 | if "__class__" in yaml: 91 | return load_obj(yaml) 92 | if "model" in yaml: 93 | return instantiate_from_config(yaml.model) 94 | return yaml 95 | elif self.extension == "yamlobj": 96 | return load_obj(OmegaConf.load(self.path)) 97 | elif self.extension in ["jpg", "jpeg", "png"]: 98 | return IImage.open(self.path) 99 | elif self.extension in ["ckpt", "pt"]: 100 | return torch.load(self.path) 101 | elif self.extension in ["pkl", "bin"]: 102 | with open(self.path, "rb") as f: 103 | return pickle.load(f) 104 | elif self.extension in ["txt"]: 105 | with open(self.path) as f: 106 | return f.read() 107 | elif self.extension in ["lst", "list"]: 108 | with open(self.path) as f: 109 | return [cvt(x) for x in f.read().split("\n")] 110 | elif self.extension in ["md"]: 111 | with open(self.path) as f: 112 | return Markdown(f.read()) 113 | elif self.extension in ["json"]: 114 | with open(self.path) as f: 115 | return json.load(f) 116 | elif self.extension in ["npy", "npz"]: 117 | return np.load(self.path) 118 | elif self.extension in ['csv', 'table']: 119 | with open(self.path) as f: 120 | return [[cvt(x) for x in row] for row in csv.reader(f, delimiter=' ')] 121 | elif self.extension in ["py"]: 122 | spec = importlib.util.spec_from_file_location(f"autoload.module", self.path) 123 | module = importlib.util.module_from_spec(spec) 124 | spec.loader.exec_module(module) 125 | return module 126 | else: 127 | return self 128 | 129 | def exists(self): 130 | return os.path.exists(str(self)) 131 | 132 | def sample(self): 133 | return getattr(self, random.choice(list(self.subpaths.keys()))) 134 | 135 | def __dir__(self): 136 | self.reload() 137 | return list(self.subpaths.keys()) 138 | 139 | def __len__(self): 140 | return len(self.subpaths) 141 | 142 | def ls(self): 143 | self.reload() 144 | return list(self.subpaths.values()) 145 | # return [str(self + x) for x in self.subpaths.values()] 146 | # return dir(self) 147 | 148 | @property 149 | def parent(self): return Path(os.path.dirname(self.path)) 150 | @property 151 | def isdir(self): return os.path.isdir(os.path.abspath(self.path)) 152 | 153 | def __getattr__(self, __name: str): 154 | # print ("Func", inspect.stack()[1].function, "name", __name) 155 | is_preview = inspect.stack()[1].function == "getattr_paths" 156 | 157 | self.reload() 158 | if __name in self.subpaths and self.subpaths[__name] in os.listdir(self.path): 159 | return Path(join(self.path, self.subpaths[__name])).read(is_preview) 160 | elif __name in ['__wrapped__']: 161 | raise AttributeError() 162 | else: 163 | return Path(join(self.path, __name)) 164 | 165 | def __setattr__(self, __name: str, __value: any): 166 | if __value == Path.FOLDER: 167 | os.makedirs(f'{self}/{__name}', exist_ok=True) 168 | self.reload() 169 | else: 170 | return super().__setattr__(__name, __value) 171 | 172 | def __hasattr__(self, __name: str): 173 | self.reload() 174 | return self.subpaths is not None and __name in self.subpaths 175 | 176 | # def __getitem__(self, __name): 177 | # if __name in self.subpaths and self.subpaths[__name] in os.listdir(self.path): 178 | # return Path(join(self.path, self.subpaths[__name])).read() 179 | 180 | def __add__(self, other): 181 | assert other is not str 182 | return Path(join(self.path, other)) 183 | 184 | def __call__(self, idx=None): 185 | if idx is None: 186 | return str(self) 187 | if isinstance(idx, str): 188 | return Path(join(self.path, idx)) 189 | else: 190 | return Path(join(self.path, self.subpaths[self.keyorder[idx]])) 191 | 192 | def __getitem__(self, idx): 193 | if isinstance(idx, str): 194 | return Path(join(self.path, idx)).read() 195 | else: 196 | return Path(join(self.path, self.subpaths[self.keyorder[idx]])).read() 197 | 198 | def __str__(self): 199 | return os.path.abspath(self.path) 200 | 201 | def __repr__(self): 202 | return f"Path reference to: {os.path.abspath(self.path)}" 203 | 204 | def new_child(self, ext = ''): 205 | idx = 0 206 | while os.path.exists(join(str(self), f"file_{idx}{ext}")): 207 | idx += 1 208 | return join(str(self), f"file_{idx}{ext}") -------------------------------------------------------------------------------- /src/smplfusion/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/models/__init__.py -------------------------------------------------------------------------------- /src/smplfusion/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/models/__pycache__/unet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/models/__pycache__/unet.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/models/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/models/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/models/__pycache__/vae.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/models/__pycache__/vae.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/models/encoders/__pycache__/clip_embedder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/models/encoders/__pycache__/clip_embedder.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/models/encoders/clip_embedder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from transformers import CLIPTokenizer, CLIPTextModel 3 | 4 | class FrozenCLIPEmbedder(nn.Module): 5 | """Uses the CLIP transformer encoder for text (from huggingface)""" 6 | LAYERS = [ 7 | "last", 8 | "pooled", 9 | "hidden" 10 | ] 11 | 12 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, 13 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 14 | super().__init__() 15 | assert layer in self.LAYERS 16 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 17 | self.transformer = CLIPTextModel.from_pretrained(version) 18 | self.device = device 19 | self.max_length = max_length 20 | if freeze: 21 | self.freeze() 22 | self.layer = layer 23 | self.layer_idx = layer_idx 24 | if layer == "hidden": 25 | assert layer_idx is not None 26 | assert 0 <= abs(layer_idx) <= 12 27 | 28 | def freeze(self): 29 | self.transformer = self.transformer.eval() 30 | # self.train = disabled_train 31 | for param in self.parameters(): 32 | param.requires_grad = False 33 | 34 | def forward(self, text): 35 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 36 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 37 | tokens = batch_encoding["input_ids"].to(self.device) 38 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") 39 | if self.layer == "last": 40 | z = outputs.last_hidden_state 41 | elif self.layer == "pooled": 42 | z = outputs.pooler_output[:, None, :] 43 | else: 44 | z = outputs.hidden_states[self.layer_idx] 45 | return z 46 | 47 | def encode(self, text): 48 | return self(text) -------------------------------------------------------------------------------- /src/smplfusion/models/encoders/clip_image_embedder.py: -------------------------------------------------------------------------------- 1 | class ClipImageEmbedder(nn.Module): 2 | def __init__( 3 | self, 4 | model, 5 | jit=False, 6 | device='cuda' if torch.cuda.is_available() else 'cpu', 7 | antialias=True, 8 | ucg_rate=0. 9 | ): 10 | super().__init__() 11 | from clip import load as load_clip 12 | self.model, _ = load_clip(name=model, device=device, jit=jit) 13 | 14 | self.antialias = antialias 15 | 16 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 17 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 18 | self.ucg_rate = ucg_rate 19 | 20 | def preprocess(self, x): 21 | # normalize to [0,1] 22 | x = kornia.geometry.resize(x, (224, 224), 23 | interpolation='bicubic', align_corners=True, 24 | antialias=self.antialias) 25 | x = (x + 1.) / 2. 26 | # re-normalize according to clip 27 | x = kornia.enhance.normalize(x, self.mean, self.std) 28 | return x 29 | 30 | def forward(self, x, no_dropout=False): 31 | # x is assumed to be in range [-1,1] 32 | out = self.model.encode_image(self.preprocess(x)) 33 | out = out.to(x.dtype) 34 | if self.ucg_rate > 0. and not no_dropout: 35 | out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out 36 | return out -------------------------------------------------------------------------------- /src/smplfusion/models/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import kornia 4 | from torch.utils.checkpoint import checkpoint 5 | 6 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel 7 | 8 | import open_clip 9 | from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation 10 | from ldm.modules.diffusionmodules.openaimodel import Timestep 11 | from ...util import default, count_params, autocast 12 | 13 | 14 | class ClassEmbedder(nn.Module): 15 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): 16 | super().__init__() 17 | self.key = key 18 | self.embedding = nn.Embedding(n_classes, embed_dim) 19 | self.n_classes = n_classes 20 | self.ucg_rate = ucg_rate 21 | 22 | def forward(self, batch, key=None, disable_dropout=False): 23 | if key is None: 24 | key = self.key 25 | # this is for use in crossattn 26 | c = batch[key][:, None] 27 | if self.ucg_rate > 0. and not disable_dropout: 28 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) 29 | c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) 30 | c = c.long() 31 | c = self.embedding(c) 32 | return c 33 | 34 | def get_unconditional_conditioning(self, bs, device="cuda"): 35 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) 36 | uc = torch.ones((bs,), device=device) * uc_class 37 | uc = {self.key: uc} 38 | return uc 39 | 40 | 41 | def disabled_train(self, mode=True): 42 | """Overwrite model.train with this function to make sure train/eval mode 43 | does not change anymore.""" 44 | return self 45 | 46 | 47 | class FrozenCLIPT5Encoder(AbstractEncoder): 48 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", 49 | clip_max_length=77, t5_max_length=77): 50 | super().__init__() 51 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) 52 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) 53 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " 54 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.") 55 | 56 | def encode(self, text): 57 | return self(text) 58 | 59 | def forward(self, text): 60 | clip_z = self.clip_encoder.encode(text) 61 | t5_z = self.t5_encoder.encode(text) 62 | return [clip_z, t5_z] 63 | 64 | class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): 65 | def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs): 66 | super().__init__(*args, **kwargs) 67 | if clip_stats_path is None: 68 | clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim) 69 | else: 70 | clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu") 71 | self.register_buffer("data_mean", clip_mean[None, :], persistent=False) 72 | self.register_buffer("data_std", clip_std[None, :], persistent=False) 73 | self.time_embed = Timestep(timestep_dim) 74 | 75 | def scale(self, x): 76 | # re-normalize to centered mean and unit variance 77 | x = (x - self.data_mean) * 1. / self.data_std 78 | return x 79 | 80 | def unscale(self, x): 81 | # back to original data stats 82 | x = (x * self.data_std) + self.data_mean 83 | return x 84 | 85 | def forward(self, x, noise_level=None): 86 | if noise_level is None: 87 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() 88 | else: 89 | assert isinstance(noise_level, torch.Tensor) 90 | x = self.scale(x) 91 | z = self.q_sample(x, noise_level) 92 | z = self.unscale(z) 93 | noise_level = self.time_embed(noise_level) 94 | return z, noise_level 95 | -------------------------------------------------------------------------------- /src/smplfusion/models/encoders/open_clip_embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.checkpoint import checkpoint 4 | 5 | import open_clip 6 | 7 | class FrozenOpenCLIPEmbedder(nn.Module): 8 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 9 | freeze=True, layer="last"): 10 | super().__init__() 11 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 12 | del model.visual 13 | self.model = model 14 | 15 | self.device = device 16 | self.max_length = max_length 17 | if freeze: self.freeze() 18 | self.layer = layer 19 | if self.layer == "last": 20 | self.layer_idx = 0 21 | elif self.layer == "penultimate": 22 | self.layer_idx = 1 23 | else: 24 | raise NotImplementedError() 25 | 26 | def freeze(self): 27 | self.model = self.model.eval() 28 | for param in self.parameters(): 29 | param.requires_grad = False 30 | 31 | def forward(self, text): 32 | tokens = open_clip.tokenize(text) 33 | z = self.encode_with_transformer(tokens.to(self.device)) 34 | return z 35 | 36 | def encode_with_transformer(self, text): 37 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 38 | x = x + self.model.positional_embedding 39 | x = x.permute(1, 0, 2) # NLD -> LND 40 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 41 | x = x.permute(1, 0, 2) # LND -> NLD 42 | x = self.model.ln_final(x) 43 | return x 44 | 45 | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): 46 | for i, r in enumerate(self.model.transformer.resblocks): 47 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 48 | break 49 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 50 | x = checkpoint(r, x, attn_mask) 51 | else: 52 | x = r(x, attn_mask=attn_mask) 53 | return x 54 | 55 | def encode(self, text): 56 | return self(text) -------------------------------------------------------------------------------- /src/smplfusion/models/encoders/open_clip_image_embedder.py: -------------------------------------------------------------------------------- 1 | class FrozenOpenCLIPImageEmbedder(AbstractEncoder): 2 | """ 3 | Uses the OpenCLIP vision transformer encoder for images 4 | """ 5 | 6 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 7 | freeze=True, layer="pooled", antialias=True, ucg_rate=0.): 8 | super().__init__() 9 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), 10 | pretrained=version, ) 11 | del model.transformer 12 | self.model = model 13 | 14 | self.device = device 15 | self.max_length = max_length 16 | if freeze: 17 | self.freeze() 18 | self.layer = layer 19 | if self.layer == "penultimate": 20 | raise NotImplementedError() 21 | self.layer_idx = 1 22 | 23 | self.antialias = antialias 24 | 25 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 26 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 27 | self.ucg_rate = ucg_rate 28 | 29 | def preprocess(self, x): 30 | x = kornia.geometry.resize(x, (224, 224), 31 | interpolation='bicubic', align_corners=True, 32 | antialias=self.antialias) 33 | x = (x + 1.) / 2. 34 | x = kornia.enhance.normalize(x, self.mean, self.std) 35 | return x 36 | 37 | def freeze(self): 38 | self.model = self.model.eval() 39 | for param in self.parameters(): 40 | param.requires_grad = False 41 | 42 | @autocast 43 | def forward(self, image, no_dropout=False): 44 | z = self.encode_with_vision_transformer(image) 45 | if self.ucg_rate > 0. and not no_dropout: 46 | z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z 47 | return z 48 | 49 | def encode_with_vision_transformer(self, img): 50 | img = self.preprocess(img) 51 | x = self.model.visual(img) 52 | return x 53 | 54 | def encode(self, text): 55 | return self(text) -------------------------------------------------------------------------------- /src/smplfusion/models/encoders/t5_mebedder.py: -------------------------------------------------------------------------------- 1 | class FrozenT5Embedder(AbstractEncoder): 2 | """Uses the T5 transformer encoder for text""" 3 | 4 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, 5 | freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 6 | super().__init__() 7 | self.tokenizer = T5Tokenizer.from_pretrained(version) 8 | self.transformer = T5EncoderModel.from_pretrained(version) 9 | self.device = device 10 | self.max_length = max_length # TODO: typical value? 11 | if freeze: 12 | self.freeze() 13 | 14 | def freeze(self): 15 | self.transformer = self.transformer.eval() 16 | # self.train = disabled_train 17 | for param in self.parameters(): 18 | param.requires_grad = False 19 | 20 | def forward(self, text): 21 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 22 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 23 | tokens = batch_encoding["input_ids"].to(self.device) 24 | outputs = self.transformer(input_ids=tokens) 25 | 26 | z = outputs.last_hidden_state 27 | return z 28 | 29 | def encode(self, text): 30 | return self(text) -------------------------------------------------------------------------------- /src/smplfusion/models/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytorch_lightning as pl 4 | from contextlib import contextmanager 5 | 6 | from ..modules.autoencoder import Encoder, Decoder 7 | from ..modules.distributions import DiagonalGaussianDistribution 8 | 9 | from ..util import instantiate_from_config 10 | from ..modules.ema import LitEma 11 | 12 | class AutoencoderKL(pl.LightningModule): 13 | def __init__(self, 14 | ddconfig, 15 | lossconfig, 16 | embed_dim, 17 | ckpt_path=None, 18 | ignore_keys=[], 19 | image_key="image", 20 | colorize_nlabels=None, 21 | monitor=None, 22 | ema_decay=None, 23 | learn_logvar=False 24 | ): 25 | super().__init__() 26 | self.learn_logvar = learn_logvar 27 | self.image_key = image_key 28 | self.encoder = Encoder(**ddconfig) 29 | self.decoder = Decoder(**ddconfig) 30 | self.loss = instantiate_from_config(lossconfig) 31 | assert ddconfig["double_z"] 32 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 33 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 34 | self.embed_dim = embed_dim 35 | if colorize_nlabels is not None: 36 | assert type(colorize_nlabels)==int 37 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 38 | if monitor is not None: 39 | self.monitor = monitor 40 | 41 | self.use_ema = ema_decay is not None 42 | if self.use_ema: 43 | self.ema_decay = ema_decay 44 | assert 0. < ema_decay < 1. 45 | self.model_ema = LitEma(self, decay=ema_decay) 46 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 47 | 48 | if ckpt_path is not None: 49 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 50 | 51 | def init_from_ckpt(self, path, ignore_keys=list()): 52 | sd = torch.load(path, map_location="cpu")["state_dict"] 53 | keys = list(sd.keys()) 54 | for k in keys: 55 | for ik in ignore_keys: 56 | if k.startswith(ik): 57 | print("Deleting key {} from state_dict.".format(k)) 58 | del sd[k] 59 | self.load_state_dict(sd, strict=False) 60 | print(f"Restored from {path}") 61 | 62 | @contextmanager 63 | def ema_scope(self, context=None): 64 | if self.use_ema: 65 | self.model_ema.store(self.parameters()) 66 | self.model_ema.copy_to(self) 67 | if context is not None: 68 | print(f"{context}: Switched to EMA weights") 69 | try: 70 | yield None 71 | finally: 72 | if self.use_ema: 73 | self.model_ema.restore(self.parameters()) 74 | if context is not None: 75 | print(f"{context}: Restored training weights") 76 | 77 | def on_train_batch_end(self, *args, **kwargs): 78 | if self.use_ema: 79 | self.model_ema(self) 80 | 81 | def encode(self, x): 82 | h = self.encoder(x) 83 | moments = self.quant_conv(h) 84 | posterior = DiagonalGaussianDistribution(moments) 85 | return posterior 86 | 87 | def decode(self, z): 88 | z = self.post_quant_conv(z) 89 | dec = self.decoder(z) 90 | return dec 91 | 92 | def forward(self, input, sample_posterior=True): 93 | posterior = self.encode(input) 94 | if sample_posterior: 95 | z = posterior.sample() 96 | else: 97 | z = posterior.mode() 98 | dec = self.decode(z) 99 | return dec, posterior 100 | 101 | def get_input(self, batch, k): 102 | x = batch[k] 103 | if len(x.shape) == 3: 104 | x = x[..., None] 105 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() 106 | return x 107 | 108 | def training_step(self, batch, batch_idx, optimizer_idx): 109 | inputs = self.get_input(batch, self.image_key) 110 | reconstructions, posterior = self(inputs) 111 | 112 | if optimizer_idx == 0: 113 | # train encoder+decoder+logvar 114 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 115 | last_layer=self.get_last_layer(), split="train") 116 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 117 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) 118 | return aeloss 119 | 120 | if optimizer_idx == 1: 121 | # train the discriminator 122 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, 123 | last_layer=self.get_last_layer(), split="train") 124 | 125 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 126 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) 127 | return discloss 128 | 129 | def validation_step(self, batch, batch_idx): 130 | log_dict = self._validation_step(batch, batch_idx) 131 | with self.ema_scope(): 132 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 133 | return log_dict 134 | 135 | def _validation_step(self, batch, batch_idx, postfix=""): 136 | inputs = self.get_input(batch, self.image_key) 137 | reconstructions, posterior = self(inputs) 138 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, 139 | last_layer=self.get_last_layer(), split="val"+postfix) 140 | 141 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, 142 | last_layer=self.get_last_layer(), split="val"+postfix) 143 | 144 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 145 | self.log_dict(log_dict_ae) 146 | self.log_dict(log_dict_disc) 147 | return self.log_dict 148 | 149 | def configure_optimizers(self): 150 | lr = self.learning_rate 151 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( 152 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) 153 | if self.learn_logvar: 154 | print(f"{self.__class__.__name__}: Learning logvar") 155 | ae_params_list.append(self.loss.logvar) 156 | opt_ae = torch.optim.Adam(ae_params_list, 157 | lr=lr, betas=(0.5, 0.9)) 158 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 159 | lr=lr, betas=(0.5, 0.9)) 160 | return [opt_ae, opt_disc], [] 161 | 162 | def get_last_layer(self): 163 | return self.decoder.conv_out.weight 164 | 165 | @torch.no_grad() 166 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): 167 | log = dict() 168 | x = self.get_input(batch, self.image_key) 169 | x = x.to(self.device) 170 | if not only_inputs: 171 | xrec, posterior = self(x) 172 | if x.shape[1] > 3: 173 | # colorize with random projection 174 | assert xrec.shape[1] > 3 175 | x = self.to_rgb(x) 176 | xrec = self.to_rgb(xrec) 177 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 178 | log["reconstructions"] = xrec 179 | if log_ema or self.use_ema: 180 | with self.ema_scope(): 181 | xrec_ema, posterior_ema = self(x) 182 | if x.shape[1] > 3: 183 | # colorize with random projection 184 | assert xrec_ema.shape[1] > 3 185 | xrec_ema = self.to_rgb(xrec_ema) 186 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) 187 | log["reconstructions_ema"] = xrec_ema 188 | log["inputs"] = x 189 | return log 190 | 191 | def to_rgb(self, x): 192 | assert self.image_key == "segmentation" 193 | if not hasattr(self, "colorize"): 194 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 195 | x = F.conv2d(x, weight=self.colorize) 196 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 197 | return x -------------------------------------------------------------------------------- /src/smplfusion/modelzoo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modelzoo/__init__.py -------------------------------------------------------------------------------- /src/smplfusion/modelzoo/dreamshaper8.py: -------------------------------------------------------------------------------- 1 | print ("Loading model: Dreamshaper V8") 2 | 3 | from os.path import dirname 4 | import importlib 5 | from omegaconf import OmegaConf 6 | import torch 7 | import safetensors 8 | import safetensors.torch 9 | import open_clip 10 | 11 | PROJECT_DIR = dirname(dirname(dirname(dirname(__file__)))) 12 | LIB_DIR = dirname(dirname(__file__)) 13 | print (PROJECT_DIR) 14 | 15 | CONFIG_FOLDER = f'{LIB_DIR}/config/' 16 | ASSETS_FOLDER = f'{PROJECT_DIR}/assets/' 17 | MODEL_FOLDER = f'{ASSETS_FOLDER}/models/' 18 | 19 | def get_obj_from_str(string): 20 | module, cls = string.rsplit(".", 1) 21 | try: 22 | return getattr(importlib.import_module(module, package=None), cls) 23 | except: 24 | return getattr(importlib.import_module('lib.' + module, package=None), cls) 25 | def load_obj(path): 26 | objyaml = OmegaConf.load(path) 27 | return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {})) 28 | 29 | state_dict = safetensors.torch.load_file(f'{MODEL_FOLDER}/dreamshaper/dreamshaper_8.safetensors') 30 | 31 | config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml') 32 | unet = load_obj(f'{CONFIG_FOLDER}/unet/v1.yaml').eval().cuda() 33 | vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda() 34 | encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda() 35 | 36 | extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x} 37 | unet_state = extract(state_dict, 'model.diffusion_model') 38 | encoder_state = extract(state_dict, 'cond_stage_model') 39 | vae_state = extract(state_dict, 'first_stage_model') 40 | 41 | unet.load_state_dict(unet_state, strict=False); 42 | encoder.load_state_dict(encoder_state, strict=False); 43 | vae.load_state_dict(vae_state, strict=False); 44 | 45 | unet = unet.requires_grad_(False) 46 | encoder = encoder.requires_grad_(False) 47 | vae = vae.requires_grad_(False) -------------------------------------------------------------------------------- /src/smplfusion/modelzoo/dreamshaper8_inpainting.py: -------------------------------------------------------------------------------- 1 | print ("Loading model: Dreamshaper Inpainting V8") 2 | 3 | from os.path import dirname 4 | import importlib 5 | from omegaconf import OmegaConf 6 | import torch 7 | import safetensors 8 | import safetensors.torch 9 | import open_clip 10 | 11 | PROJECT_DIR = dirname(dirname(dirname(dirname(__file__)))) 12 | LIB_DIR = dirname(dirname(__file__)) 13 | print (PROJECT_DIR) 14 | 15 | CONFIG_FOLDER = f'{LIB_DIR}/config/' 16 | ASSETS_FOLDER = f'{PROJECT_DIR}/assets/' 17 | MODEL_FOLDER = f'{ASSETS_FOLDER}/models/' 18 | 19 | def get_obj_from_str(string): 20 | module, cls = string.rsplit(".", 1) 21 | try: 22 | return getattr(importlib.import_module(module, package=None), cls) 23 | except: 24 | return getattr(importlib.import_module('lib.' + module, package=None), cls) 25 | def load_obj(path): 26 | objyaml = OmegaConf.load(path) 27 | return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {})) 28 | 29 | state_dict = safetensors.torch.load_file(f'{MODEL_FOLDER}/dreamshaper/dreamshaper_8Inpainting.safetensors') 30 | 31 | config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml') 32 | unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda() 33 | vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda() 34 | encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda() 35 | 36 | extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x} 37 | unet_state = extract(state_dict, 'model.diffusion_model') 38 | encoder_state = extract(state_dict, 'cond_stage_model') 39 | vae_state = extract(state_dict, 'first_stage_model') 40 | 41 | unet.load_state_dict(unet_state); 42 | encoder.load_state_dict(encoder_state); 43 | vae.load_state_dict(vae_state); 44 | 45 | unet = unet.requires_grad_(False) 46 | encoder = encoder.requires_grad_(False) 47 | vae = vae.requires_grad_(False) -------------------------------------------------------------------------------- /src/smplfusion/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/__init__.py -------------------------------------------------------------------------------- /src/smplfusion/modules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/modules/__pycache__/autoencoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/__pycache__/autoencoder.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/modules/__pycache__/distributions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/__pycache__/distributions.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/modules/__pycache__/ema.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/__pycache__/ema.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/modules/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/modules/attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__init__.py -------------------------------------------------------------------------------- /src/smplfusion/modules/attention/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/modules/attention/__pycache__/basic_transformer_block.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__pycache__/basic_transformer_block.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/modules/attention/__pycache__/cross_attention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__pycache__/cross_attention.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/modules/attention/__pycache__/feed_forward.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__pycache__/feed_forward.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/modules/attention/__pycache__/memory_efficient_cross_attention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__pycache__/memory_efficient_cross_attention.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/modules/attention/__pycache__/spatial_transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__pycache__/spatial_transformer.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/modules/attention/basic_transformer_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .feed_forward import FeedForward 4 | 5 | try: 6 | from .cross_attention import PatchedCrossAttention as CrossAttention 7 | except: 8 | try: 9 | from .memory_efficient_cross_attention import MemoryEfficientCrossAttention as CrossAttention 10 | except: 11 | from .cross_attention import CrossAttention 12 | from ..util import checkpoint 13 | from ...patches import router 14 | 15 | class BasicTransformerBlock(nn.Module): 16 | def __init__( 17 | self,dim,n_heads,d_head,dropout=0.0,context_dim=None, 18 | gated_ff=True,checkpoint=True,disable_self_attn=False, 19 | ): 20 | super().__init__() 21 | self.disable_self_attn = disable_self_attn 22 | # is a self-attention if not self.disable_self_attn 23 | self.attn1 = CrossAttention(query_dim=dim,heads=n_heads,dim_head=d_head,dropout=dropout,context_dim=context_dim if self.disable_self_attn else None) 24 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 25 | # is self-attn if context is none 26 | self.attn2 = CrossAttention(query_dim=dim,context_dim=context_dim,heads=n_heads,dim_head=d_head,dropout=dropout) 27 | self.norm1 = nn.LayerNorm(dim) 28 | self.norm2 = nn.LayerNorm(dim) 29 | self.norm3 = nn.LayerNorm(dim) 30 | self.checkpoint = checkpoint 31 | 32 | def forward(self, x, context=None): 33 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 34 | 35 | def _forward(self, x, context=None): 36 | x = x + self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) 37 | x = x + self.attn2(self.norm2(x), context=context) 38 | x = x + self.ff(self.norm3(x)) 39 | return x 40 | 41 | class PatchedBasicTransformerBlock(nn.Module): 42 | def __init__( 43 | self,dim,n_heads,d_head,dropout=0.0,context_dim=None, 44 | gated_ff=True,checkpoint=True,disable_self_attn=False, 45 | ): 46 | super().__init__() 47 | self.disable_self_attn = disable_self_attn 48 | # is a self-attention if not self.disable_self_attn 49 | self.attn1 = CrossAttention(query_dim=dim,heads=n_heads,dim_head=d_head,dropout=dropout,context_dim=context_dim if self.disable_self_attn else None) 50 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 51 | # is self-attn if context is none 52 | self.attn2 = CrossAttention(query_dim=dim,context_dim=context_dim,heads=n_heads,dim_head=d_head,dropout=dropout) 53 | self.norm1 = nn.LayerNorm(dim) 54 | self.norm2 = nn.LayerNorm(dim) 55 | self.norm3 = nn.LayerNorm(dim) 56 | self.checkpoint = checkpoint 57 | 58 | def forward(self, x, context=None): 59 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 60 | 61 | def _forward(self, x, context=None): 62 | return router.basic_transformer_forward(self, x, context) 63 | -------------------------------------------------------------------------------- /src/smplfusion/modules/attention/cross_attention.py: -------------------------------------------------------------------------------- 1 | # CrossAttn precision handling 2 | import os 3 | 4 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from torch import einsum 10 | from einops import rearrange, repeat 11 | import torch 12 | from torch import nn 13 | from typing import Optional, Any 14 | from ...patches import router 15 | 16 | class CrossAttention(nn.Module): 17 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 18 | super().__init__() 19 | inner_dim = dim_head * heads 20 | context_dim = context_dim or query_dim 21 | 22 | self.scale = dim_head**-0.5 23 | self.heads = heads 24 | 25 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 26 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 27 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 28 | 29 | self.to_out = nn.Sequential( 30 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 31 | ) 32 | 33 | def forward(self, x, context=None, mask=None): 34 | h = self.heads 35 | 36 | q = self.to_q(x) 37 | context = x if context is None else context 38 | k = self.to_k(context) 39 | v = self.to_v(context) 40 | 41 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 42 | 43 | # force cast to fp32 to avoid overflowing 44 | if _ATTN_PRECISION == "fp32": 45 | with torch.autocast(enabled=False, device_type="cuda"): 46 | q, k = q.float(), k.float() 47 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 48 | else: 49 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 50 | 51 | del q, k 52 | 53 | if mask is not None: 54 | mask = rearrange(mask, "b ... -> b (...)") 55 | max_neg_value = -torch.finfo(sim.dtype).max 56 | mask = repeat(mask, "b j -> (b h) () j", h=h) 57 | sim.masked_fill_(~mask, max_neg_value) 58 | 59 | # attention, what we cannot get enough of 60 | sim = sim.softmax(dim=-1) 61 | 62 | out = einsum("b i j, b j d -> b i d", sim, v) 63 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 64 | return self.to_out(out) 65 | 66 | class PatchedCrossAttention(nn.Module): 67 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 68 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 69 | super().__init__() 70 | inner_dim = dim_head * heads 71 | context_dim = context_dim or query_dim 72 | 73 | self.heads = heads 74 | self.dim_head = dim_head 75 | self.scale = dim_head**-0.5 76 | 77 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 78 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 79 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 80 | 81 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 82 | self.attention_op: Optional[Any] = None 83 | 84 | def forward(self, x, context=None, mask=None): 85 | return router.attention_forward(self, x, context, mask) -------------------------------------------------------------------------------- /src/smplfusion/modules/attention/feed_forward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GEGLU(nn.Module): 7 | def __init__(self, dim_in, dim_out): 8 | super().__init__() 9 | self.proj = nn.Linear(dim_in, dim_out * 2) 10 | 11 | def forward(self, x): 12 | x, gate = self.proj(x).chunk(2, dim=-1) 13 | return x * F.gelu(gate) 14 | 15 | 16 | class FeedForward(nn.Module): 17 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 18 | super().__init__() 19 | inner_dim = int(dim * mult) 20 | dim_out = dim_out or dim 21 | project_in = nn.Sequential( 22 | nn.Linear(dim, inner_dim), 23 | nn.GELU() 24 | ) if not glu else GEGLU(dim, inner_dim) 25 | 26 | self.net = nn.Sequential( 27 | project_in, 28 | nn.Dropout(dropout), 29 | nn.Linear(inner_dim, dim_out) 30 | ) 31 | 32 | def forward(self, x): 33 | return self.net(x) -------------------------------------------------------------------------------- /src/smplfusion/modules/attention/memory_efficient_cross_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Optional, Any 4 | 5 | try: 6 | import xformers 7 | import xformers.ops 8 | XFORMERS_IS_AVAILBLE = True 9 | except: 10 | XFORMERS_IS_AVAILBLE = False 11 | 12 | class MemoryEfficientCrossAttention(nn.Module): 13 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 14 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 15 | super().__init__() 16 | # print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using {heads} heads.") 17 | inner_dim = dim_head * heads 18 | context_dim = context_dim or query_dim 19 | 20 | self.heads = heads 21 | self.dim_head = dim_head 22 | 23 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 24 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 25 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 26 | 27 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 28 | self.attention_op: Optional[Any] = None 29 | 30 | def forward(self, x, context=None, mask=None): 31 | q = self.to_q(x) 32 | context = x if context is None else context 33 | k = self.to_k(context) 34 | v = self.to_v(context) 35 | 36 | b, _, _ = q.shape 37 | q, k, v = map( 38 | lambda t: t.unsqueeze(3) 39 | .reshape(b, t.shape[1], self.heads, self.dim_head) 40 | .permute(0, 2, 1, 3) 41 | .reshape(b * self.heads, t.shape[1], self.dim_head) 42 | .contiguous(), 43 | (q, k, v), 44 | ) 45 | 46 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) 47 | 48 | if mask is not None: 49 | raise NotImplementedError 50 | out = ( 51 | out.unsqueeze(0) 52 | .reshape(b, self.heads, out.shape[1], self.dim_head) 53 | .permute(0, 2, 1, 3) 54 | .reshape(b, out.shape[1], self.heads * self.dim_head) 55 | ) 56 | return self.to_out(out) -------------------------------------------------------------------------------- /src/smplfusion/modules/attention/spatial_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | from torch import einsum 6 | from einops import rearrange, repeat 7 | from .basic_transformer_block import PatchedBasicTransformerBlock as BasicTransformerBlock 8 | 9 | def Normalize(in_channels): 10 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 11 | 12 | def init_(tensor): 13 | dim = tensor.shape[-1] 14 | std = 1 / math.sqrt(dim) 15 | tensor.uniform_(-std, std) 16 | return tensor 17 | 18 | 19 | def zero_module(module): 20 | for p in module.parameters(): 21 | p.detach().zero_() 22 | return module 23 | 24 | 25 | class SpatialTransformer(nn.Module): 26 | """ 27 | Transformer block for image-like data. 28 | First, project the input (aka embedding) 29 | and reshape to b, t, d. 30 | Then apply standard transformer action. 31 | Finally, reshape to image 32 | NEW: use_linear for more efficiency instead of the 1x1 convs 33 | """ 34 | def __init__(self, in_channels, n_heads, d_head, 35 | depth=1, dropout=0., context_dim=None, 36 | disable_self_attn=False, use_linear=False, 37 | use_checkpoint=True): 38 | super().__init__() 39 | if context_dim is not None and not isinstance(context_dim, list): 40 | context_dim = [context_dim] 41 | self.in_channels = in_channels 42 | inner_dim = n_heads * d_head 43 | self.norm = Normalize(in_channels) 44 | if not use_linear: 45 | self.proj_in = nn.Conv2d(in_channels,inner_dim,kernel_size=1,stride=1,padding=0) 46 | else: 47 | self.proj_in = nn.Linear(in_channels, inner_dim) 48 | 49 | self.transformer_blocks = nn.ModuleList( 50 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], 51 | disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) 52 | for d in range(depth)] 53 | ) 54 | if not use_linear: 55 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 56 | in_channels, 57 | kernel_size=1, 58 | stride=1, 59 | padding=0)) 60 | else: 61 | self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) 62 | self.use_linear = use_linear 63 | 64 | def forward(self, x, context=None): 65 | # note: if no context is given, cross-attention defaults to self-attention 66 | if not isinstance(context, list): 67 | context = [context] 68 | b, c, h, w = x.shape 69 | x_in = x 70 | x = self.norm(x) 71 | if not self.use_linear: 72 | x = self.proj_in(x) 73 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 74 | if self.use_linear: 75 | x = self.proj_in(x) 76 | for i, block in enumerate(self.transformer_blocks): 77 | x = block(x, context=context[i]) 78 | if self.use_linear: 79 | x = self.proj_out(x) 80 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 81 | if not self.use_linear: 82 | x = self.proj_out(x) 83 | return x + x_in 84 | 85 | -------------------------------------------------------------------------------- /src/smplfusion/modules/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class DiagonalGaussianDistribution(object): 6 | def __init__(self, parameters, deterministic=False): 7 | self.parameters = parameters 8 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 9 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 10 | self.deterministic = deterministic 11 | self.std = torch.exp(0.5 * self.logvar) 12 | self.var = torch.exp(self.logvar) 13 | if self.deterministic: 14 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 15 | 16 | def sample(self): 17 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 18 | return x 19 | 20 | def kl(self, other=None): 21 | if self.deterministic: 22 | return torch.Tensor([0.]) 23 | else: 24 | if other is None: 25 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 26 | + self.var - 1.0 - self.logvar, 27 | dim=[1, 2, 3]) 28 | else: 29 | return 0.5 * torch.sum( 30 | torch.pow(self.mean - other.mean, 2) / other.var 31 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 32 | dim=[1, 2, 3]) 33 | 34 | def nll(self, sample, dims=[1,2,3]): 35 | if self.deterministic: 36 | return torch.Tensor([0.]) 37 | logtwopi = np.log(2.0 * np.pi) 38 | return 0.5 * torch.sum( 39 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 40 | dim=dims) 41 | 42 | def mode(self): 43 | return self.mean 44 | 45 | 46 | def normal_kl(mean1, logvar1, mean2, logvar2): 47 | """ 48 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 49 | Compute the KL divergence between two gaussians. 50 | Shapes are automatically broadcasted, so batches can be compared to 51 | scalars, among other use cases. 52 | """ 53 | tensor = None 54 | for obj in (mean1, logvar1, mean2, logvar2): 55 | if isinstance(obj, torch.Tensor): 56 | tensor = obj 57 | break 58 | assert tensor is not None, "at least one argument must be a Tensor" 59 | 60 | # Force variances to be Tensors. Broadcasting helps convert scalars to 61 | # Tensors, but it does not work for torch.exp(). 62 | logvar1, logvar2 = [ 63 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 64 | for x in (logvar1, logvar2) 65 | ] 66 | 67 | return 0.5 * ( 68 | -1.0 69 | + logvar2 70 | - logvar1 71 | + torch.exp(logvar1 - logvar2) 72 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 73 | ) 74 | -------------------------------------------------------------------------------- /src/smplfusion/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) 81 | -------------------------------------------------------------------------------- /src/smplfusion/modules/partial_conv2d.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # BSD 3-Clause License 3 | # 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Author & Contact: Guilin Liu (guilinl@nvidia.com) 7 | ############################################################################### 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn, cuda 12 | 13 | from .. import share 14 | 15 | partial_res = [8, 16, 32, 64] 16 | 17 | # class PartialConv2d(nn.Conv2d): 18 | # def __init__(self, *args, **kwargs): 19 | 20 | # # whether the mask is multi-channel or not 21 | # if 'multi_channel' in kwargs: 22 | # self.multi_channel = kwargs['multi_channel'] 23 | # kwargs.pop('multi_channel') 24 | # else: 25 | # self.multi_channel = False 26 | 27 | # if 'return_mask' in kwargs: 28 | # self.return_mask = kwargs['return_mask'] 29 | # kwargs.pop('return_mask') 30 | # else: 31 | # self.return_mask = False 32 | 33 | # super(PartialConv2d, self).__init__(*args, **kwargs) 34 | 35 | # if self.multi_channel: 36 | # self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1]) 37 | # else: 38 | # self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]) 39 | 40 | # self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3] 41 | 42 | # self.last_size = (None, None, None, None) 43 | # self.update_mask = None 44 | # self.mask_ratio = None 45 | 46 | # def forward(self, input, mask_in=None): 47 | # assert len(input.shape) == 4 48 | # if mask_in is not None or self.last_size != tuple(input.shape): 49 | # self.last_size = tuple(input.shape) 50 | 51 | # with torch.no_grad(): 52 | # if self.weight_maskUpdater.type() != input.type(): 53 | # self.weight_maskUpdater = self.weight_maskUpdater.to(input) 54 | 55 | # if mask_in is None: 56 | # # if mask is not provided, create a mask 57 | # if self.multi_channel: 58 | # mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input) 59 | # else: 60 | # mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input) 61 | # else: 62 | # mask = mask_in 63 | 64 | # self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1) 65 | 66 | # # for mixed precision training, change 1e-8 to 1e-6 67 | # self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8) 68 | # # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8) 69 | # self.update_mask = torch.clamp(self.update_mask, 0, 1) 70 | # self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask) 71 | 72 | 73 | # raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input) 74 | 75 | # if self.bias is not None: 76 | # bias_view = self.bias.view(1, self.out_channels, 1, 1) 77 | # output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view 78 | # output = torch.mul(output, self.update_mask) 79 | # else: 80 | # output = torch.mul(raw_out, self.mask_ratio) 81 | 82 | 83 | # if self.return_mask: 84 | # return output, self.update_mask 85 | # else: 86 | # return output 87 | 88 | 89 | class PartialConv2d(nn.Conv2d): 90 | """ 91 | NOTE: You need to use share.set_mask(original_mask) before inferencing with PartialConv2d. 92 | """ 93 | 94 | def __init__(self, *args, **kwargs): 95 | super(PartialConv2d, self).__init__(*args, **kwargs) 96 | 97 | self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]) 98 | self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3] 99 | 100 | def forward(self, input): 101 | if share.input_mask is None: 102 | raise Exception('Please set share.set_mask(original_mask) before inferencing with PartialConv2d.') 103 | 104 | bias_view = torch.zeros((1, self.out_channels, 1, 1), dtype=input.dtype, device=input.device) 105 | if self.bias is not None: 106 | bias_view = self.bias.view(1, self.out_channels, 1, 1) 107 | 108 | # Get the resized mask for the current resolution 109 | res = max(input.shape[2:]) 110 | 111 | if res == max(share.input_mask.shape64): 112 | mask = share.input_mask.val64 113 | mask_down = share.input_mask.val32 114 | elif res == max(share.input_mask.shape32): 115 | mask = share.input_mask.val32 116 | mask_down = share.input_mask.val16 117 | elif res == max(share.input_mask.shape16): 118 | mask = share.input_mask.val16 119 | mask_down = share.input_mask.val8 120 | elif res == max(share.input_mask.shape8): 121 | mask = share.input_mask.val8 122 | 123 | mask = mask.to(input.device) 124 | 125 | # Separately perform the convolution operation on masked and known regions 126 | masked_input = torch.mul(input, mask) 127 | known_input = torch.mul(input, 1-mask) 128 | 129 | input_out =super(PartialConv2d, self).forward(input) 130 | masked_out = super(PartialConv2d, self).forward(masked_input) 131 | known_out = super(PartialConv2d, self).forward(known_input) 132 | 133 | # Calculate the rescaling weights for known and unknown regions 134 | 135 | # ############ Weighting strategy No 1 ################# 136 | 137 | # pixel_counts = F.conv2d( 138 | # F.pad(mask, [*self.padding, *self.padding], mode='reflect'), 139 | # self.weight_maskUpdater.to(mask.device), 140 | # bias=None, 141 | # stride=self.stride, 142 | # padding=(0, 0), 143 | # dilation=self.dilation, 144 | # groups=1 145 | # ) 146 | 147 | # mask_ratio_unknown = self.slide_winsize/ (pixel_counts) 148 | # mask_ratio_known = self.slide_winsize / (self.slide_winsize - pixel_counts) 149 | 150 | # ################## End of No 1 ######################### 151 | 152 | # ################ Weighting strategy No 2 ############### 153 | 154 | # ones_input = torch.ones_like(input) 155 | # masks_input = mask.repeat(1, ones_input.shape[1], 1, 1) 156 | 157 | # sum_overall = F.conv2d( 158 | # F.pad(ones_input, [*self.padding, *self.padding], mode='reflect'), 159 | # torch.abs(self.weight), 160 | # bias=None, 161 | # stride=self.stride, 162 | # padding=(0, 0), 163 | # dilation=self.dilation, 164 | # groups=1 165 | # ) 166 | 167 | # sum_masked = F.conv2d( 168 | # F.pad(masks_input, [*self.padding, *self.padding], mode='reflect'), 169 | # torch.abs(self.weight), 170 | # bias=None, 171 | # stride=self.stride, 172 | # padding=(0, 0), 173 | # dilation=self.dilation, 174 | # groups=1 175 | # ) 176 | 177 | # mask_ratio_unknown = sum_overall / (sum_masked) 178 | # mask_ratio_known = sum_overall / (sum_overall - sum_masked) 179 | 180 | # ################## End of No 2 ######################### 181 | 182 | ################ Weighting strategy No 3 ############### 183 | if res not in partial_res: 184 | return input_out 185 | 186 | input_norm = torch.norm(input_out - bias_view, dim=1, keepdim=True) 187 | known_norm = torch.norm(known_out - bias_view, dim=1, keepdim=True) 188 | masked_norm = torch.norm(masked_out - bias_view, dim=1, keepdim=True) 189 | 190 | mask_ratio_unknown = input_norm / masked_norm 191 | mask_ratio_known = input_norm / known_norm 192 | 193 | ################## End of No 3 ######################### 194 | 195 | # Replace nan and inf with 0.0 196 | mask_ratio_unknown = torch.nan_to_num(mask_ratio_unknown, nan=0.0, posinf=0.0, neginf=0.0) 197 | mask_ratio_known = torch.nan_to_num(mask_ratio_known, nan=0.0, posinf=0.0, neginf=0.0) 198 | 199 | ###################### DEBUG ############################ 200 | # if res == 8: 201 | # print(mask_ratio_unknown[0][0], mask_ratio_unknown.shape) 202 | # print(mask_ratio_known[0][0], mask_ratio_known.shape) 203 | 204 | ################### END OF DEBUG ######################## 205 | 206 | # If set to true, doesn't rescale the convolution outputs 207 | ignore_mask_ratio = False 208 | if ignore_mask_ratio: 209 | mask_ratio_known = 1.0 210 | mask_ratio_unknown = 1.0 211 | 212 | known_out = known_out - bias_view 213 | masked_out = masked_out - bias_view 214 | 215 | if max(self.stride) > 1: 216 | mask_down = mask_down.to(input.device) 217 | out = masked_out * mask_down * mask_ratio_unknown + known_out * (1-mask_down) * mask_ratio_known 218 | else: 219 | out = masked_out * mask * mask_ratio_unknown + known_out * (1-mask) * mask_ratio_known 220 | 221 | out = out + bias_view 222 | 223 | return out -------------------------------------------------------------------------------- /src/smplfusion/patches/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/__init__.py -------------------------------------------------------------------------------- /src/smplfusion/patches/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/__pycache__/router.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/__pycache__/router.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/__init__.py: -------------------------------------------------------------------------------- 1 | from . import zeropaint 2 | from . import default 3 | from . import attstore 4 | from . import inpaint 5 | from . import shuffled 6 | from . import introvert 7 | from . import boosted_maskscore 8 | -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/__pycache__/attstore.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/attstore.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/__pycache__/boosted_maskscore.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/boosted_maskscore.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/__pycache__/default.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/default.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/__pycache__/ediff.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/ediff.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/__pycache__/inpaint.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/inpaint.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/__pycache__/introvert.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/introvert.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/__pycache__/shuffled.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/shuffled.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/__pycache__/zeropaint.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/zeropaint.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/attstore.py: -------------------------------------------------------------------------------- 1 | # CrossAttn precision handling 2 | import os 3 | import torch 4 | from torch import nn 5 | 6 | from torch import einsum 7 | from einops import rearrange, repeat 8 | from ... import share 9 | 10 | att_res = [16 * 16] 11 | force_idx = 1 12 | 13 | def forward(self, x, context=None, mask=None): 14 | h = self.heads 15 | 16 | q = self.to_q(x) 17 | att_type = "self" if context is None else "cross" 18 | context = x if context is None else context 19 | k = self.to_k(context) 20 | v = self.to_v(context) 21 | 22 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 23 | 24 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 25 | 26 | if att_type == "cross" and q.shape[1] in att_res: 27 | share.sim.append(sim) 28 | 29 | # attention, what we cannot get enough of 30 | sim = sim.softmax(dim=-1) 31 | 32 | out = einsum("b i j, b j d -> b i d", sim, v) 33 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 34 | return self.to_out(out) 35 | 36 | def forward_force(self, x, context=None, mask=None): 37 | h = self.heads 38 | 39 | q = self.to_q(x) 40 | att_type = "self" if context is None else "cross" 41 | context = x if context is None else context 42 | k = self.to_k(context) 43 | v = self.to_v(context) 44 | 45 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 46 | 47 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 48 | 49 | if att_type == "cross": 50 | context_dim = context.shape[1] # Number of tokens, might not be 77 51 | 52 | # if q.shape[1] == share.input_mask.res64: 53 | # _sim = 100 * torch.eye(context_dim)[force_idx].half().cuda()[None].repeat(sim.shape[0], sim.shape[1], 1) 54 | # _sim += 100 * torch.eye(context_dim)[0].half().cuda()[None].repeat(sim.shape[0], sim.shape[1], 1) 55 | # sim[:,share.input_mask.val64.reshape(-1) > 0,:] = _sim[:,share.input_mask.val64.reshape(-1) > 0,:] 56 | # if q.shape[1] == share.input_mask.res32: 57 | # _sim = 100 * torch.eye(context_dim)[force_idx].half().cuda()[None].repeat(sim.shape[0], sim.shape[1], 1) 58 | # _sim += 100 * torch.eye(context_dim)[0].half().cuda()[None].repeat(sim.shape[0], sim.shape[1], 1) 59 | # sim[:,share.input_mask.val32.reshape(-1) > 0,:] = _sim[:,share.input_mask.val32.reshape(-1) > 0,:] 60 | if q.shape[1] == share.input_mask.res16: 61 | # print (sim.shape) 62 | _sim = 100 * torch.eye(context_dim)[force_idx].half().cuda()[None].repeat(sim.shape[0]//2, sim.shape[1], 1) 63 | _sim += 100 * torch.eye(context_dim)[0].half().cuda()[None].repeat(sim.shape[0]//2, sim.shape[1], 1) 64 | # sim[sim.shape[0]//2:,share.input_mask.val16.reshape(-1) > 0,:] = _sim[:,share.input_mask.val16.reshape(-1) > 0,:] 65 | # sim[sim.shape[0]//2:, share.input_mask.val16.reshape(-1) > 0, 0] *= 0.8 66 | share.sim.append(sim[sim.shape[0]//2:]) 67 | 68 | # attention, what we cannot get enough of 69 | sim = sim.softmax(dim=-1) 70 | 71 | out = einsum("b i j, b j d -> b i d", sim, v) 72 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 73 | if q.shape[1] == share.input_mask.res16: 74 | if att_type == "cross": 75 | share.cross_out.append(out.detach()) 76 | if att_type == "self": 77 | share.self_out.append(out.detach()) 78 | 79 | return self.to_out(out) 80 | -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/boosted_maskscore copy.py: -------------------------------------------------------------------------------- 1 | from ... import share 2 | import torch 3 | 4 | # import xformers 5 | # import xformers.ops 6 | 7 | import torchvision.transforms.functional as TF 8 | 9 | from torch import nn, einsum 10 | from inspect import isfunction 11 | from einops import rearrange, repeat 12 | 13 | qkv_reduce_dims = [-1] 14 | increase_indices = [1,2] 15 | 16 | w8 = 0. 17 | w16 = 0. 18 | w32 = 0. 19 | w64 = 0. 20 | 21 | def forward_and_save(self, x, context=None, mask=None): 22 | att_type = "self" if context is None else "cross" 23 | 24 | h = self.heads 25 | q = self.to_q(x) 26 | is_cross = context is not None 27 | context = x if context is None else context 28 | k = self.to_k(context) 29 | v = self.to_v(context) 30 | 31 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 32 | 33 | scale = self.dim_head**-0.5 34 | sim = einsum("b i d, b j d -> b i j", q, k) * scale 35 | 36 | if is_cross: 37 | N = sim.shape[0] // 2 38 | 39 | delta_uncond = torch.zeros_like(sim[:N]) 40 | 41 | if q.shape[1] == share.input_shape.res8: 42 | W = w8 + torch.zeros_like(sim[:N]) 43 | if q.shape[1] == share.input_shape.res16: 44 | W = w16 + torch.zeros_like(sim[:N]) 45 | if q.shape[1] == share.input_shape.res32: 46 | W = w32 + torch.zeros_like(sim[:N]) 47 | if q.shape[1] == share.input_shape.res64: 48 | W = w64 + torch.zeros_like(sim[:N]) 49 | 50 | tokens = torch.eye(77)[increase_indices].sum(0).cuda() 51 | 52 | sim[N:].argmax() 53 | 54 | max_sim = sim[N:].detach().amax(dim=qkv_reduce_dims, keepdim = True) # (16x1x1) 55 | sigma = share.schedule.sqrt_noise_signal_ratio[share.timestep] 56 | mask = share.input_mask.get_res(q, 'cuda').reshape(1,-1,1) 57 | 58 | delta_cond = W * max_sim * mask * tokens 59 | sim += torch.cat([delta_uncond, delta_cond]) 60 | 61 | if hasattr(share, '_crossattn_similarity') and x.shape[1] == share.input_shape.res16 and att_type == 'cross': 62 | share._crossattn_similarity.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts 63 | if hasattr(share, '_crossattn_similarity_res8') and x.shape[1] == share.input_shape.res8 and att_type == 'cross': 64 | share._crossattn_similarity_res8.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts 65 | if hasattr(share, '_crossattn_similarity_res16') and x.shape[1] == share.input_shape.res16 and att_type == 'cross': 66 | share._crossattn_similarity_res16.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts 67 | if hasattr(share, '_crossattn_similarity_res32') and x.shape[1] == share.input_shape.res32 and att_type == 'cross': 68 | share._crossattn_similarity_res32.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts 69 | if hasattr(share, '_crossattn_similarity_res64') and x.shape[1] == share.input_shape.res64 and att_type == 'cross': 70 | share._crossattn_similarity_res64.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts 71 | 72 | sim = sim.softmax(dim=-1) 73 | out = einsum("b i j, b j d -> b i d", sim, v) 74 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 75 | return self.to_out(out) -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/boosted_maskscore.py: -------------------------------------------------------------------------------- 1 | from ... import share 2 | import torch 3 | 4 | # import xformers 5 | # import xformers.ops 6 | 7 | import torchvision.transforms.functional as TF 8 | 9 | from torch import nn, einsum 10 | from inspect import isfunction 11 | from einops import rearrange, repeat 12 | import numpy as np 13 | 14 | qkv_reduce_dims = [-1] 15 | increase_indices = [1,2] 16 | topk_heads = 0.5 17 | 18 | w8 = 0. 19 | w16 = 0. 20 | w32 = 0. 21 | w64 = 0. 22 | 23 | def forward(self, x, context=None, mask=None): 24 | att_type = "self" if context is None else "cross" 25 | 26 | h = self.heads 27 | q = self.to_q(x) 28 | is_cross = context is not None 29 | context = x if context is None else context 30 | k = self.to_k(context) 31 | v = self.to_v(context) 32 | 33 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 34 | 35 | scale = self.dim_head**-0.5 36 | sim = einsum("b i d, b j d -> b i j", q, k) * scale 37 | 38 | if is_cross: 39 | if share.timestep < 0: 40 | # Choose the weight 41 | if q.shape[1] == share.input_shape.res8: w = w8 42 | if q.shape[1] == share.input_shape.res16: w = w16 43 | if q.shape[1] == share.input_shape.res32: w = w32 44 | if q.shape[1] == share.input_shape.res64: w = w64 45 | mask = share.input_mask.get_res(q, 'cuda').reshape(-1) 46 | 47 | sim = rearrange(sim, "(b h) n d -> b h n d", h=h) # (2x10x4096x77) 48 | sim[1,...,0] -= 0.2 * sim[1, ..., 0] * mask 49 | 50 | # # Get coefficients 51 | # sigma = (1 - share.schedule.alphas[share.timestep]) 52 | # sim_max = sim.detach().amax(dim=qkv_reduce_dims, keepdim = True) # (2x10x1x1) 53 | 54 | # # Use to modify only the conditional part 55 | # batch_one_hot = torch.tensor([0,1.])[:, None, None, None].cuda() # (2x1x1x1) 56 | 57 | # for token_idx in increase_indices: 58 | # # Get the index of the head to be modified 59 | # n_topk = int(sim.shape[1] * topk_heads) 60 | # head_indices = (sim[1,...,1]).amax(dim = 1).topk(n_topk).indices.cpu() 61 | # head_one_hot = torch.eye(h)[head_indices] 62 | # # head_one_hot[n_topk // 2:] *= 0.5 63 | # head_one_hot = head_one_hot.sum(0).cuda()[None,:,None,None] # (1x10x1x1) 64 | 65 | # # head_one_hot = (1 / (1 + torch.arange(sim.shape[1]))) ** 0.5 66 | # # head_one_hot = head_one_hot.cuda()[None,:,None,None] 67 | 68 | # # Get the one hot token index 69 | # token_one_hot = torch.eye(77)[token_idx].cuda()[None,None,None,:] # (1x1x1x77) 70 | 71 | # sim += w * sim_max * mask * token_one_hot * head_one_hot * batch_one_hot 72 | 73 | sim = rearrange(sim, "b h n d -> (b h) n d", h=h) # (2x10x4096x77) 74 | if not is_cross and q.shape[1] == share.input_shape.res16: 75 | # shape of sim: 20 x 4096 x 4096 76 | if share.timestep < 0: 77 | sim_max = sim.detach().amax(dim=-1, keepdim = True) # (20x1x1) 78 | 79 | mask = share.input_mask.get_res(q, 'cuda').reshape(-1) 80 | delta = (mask[:,None] @ mask[None,:])[None] 81 | gamma = ((1 - mask[:,None]) @ (mask[None,:]))[None] 82 | 83 | sim = sim * (1 + 1.0 * share.schedule.sqrt_one_minus_alphas[share.timestep] * delta) 84 | sim = sim * (1 - 1.0 * share.schedule.sqrt_one_minus_alphas[share.timestep] * gamma) 85 | # sim = sim + 1.0 * share.schedule.sqrt_noise_signal_ratio[share.timestep] * sim_max * delta 86 | 87 | # Chunk into 2 parts to differentiate the unconditional and conditional parts 88 | if hasattr(share, '_crossattn_similarity') and x.shape[1] == share.input_shape.res16 and att_type == 'cross': 89 | share._crossattn_similarity.append(torch.stack(share.reshape(sim).chunk(2))) 90 | if hasattr(share, '_crossattn_similarity_res8') and x.shape[1] == share.input_shape.res8 and att_type == 'cross': 91 | share._crossattn_similarity_res8.append(torch.stack(share.reshape(sim).chunk(2))) 92 | if hasattr(share, '_crossattn_similarity_res16') and x.shape[1] == share.input_shape.res16 and att_type == 'cross': 93 | share._crossattn_similarity_res16.append(torch.stack(share.reshape(sim).chunk(2))) 94 | if hasattr(share, '_crossattn_similarity_res32') and x.shape[1] == share.input_shape.res32 and att_type == 'cross': 95 | share._crossattn_similarity_res32.append(torch.stack(share.reshape(sim).chunk(2))) 96 | if hasattr(share, '_crossattn_similarity_res64') and x.shape[1] == share.input_shape.res64 and att_type == 'cross': 97 | share._crossattn_similarity_res64.append(torch.stack(share.reshape(sim).chunk(2))) 98 | 99 | sim = sim.softmax(dim=-1) 100 | out = einsum("b i j, b j d -> b i d", sim, v) 101 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 102 | return self.to_out(out) 103 | 104 | forward_and_save = forward -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/default.py: -------------------------------------------------------------------------------- 1 | from ... import share 2 | 3 | try: 4 | import xformers 5 | import xformers.ops 6 | XFORMERS_IS_AVAILBLE = True 7 | except: 8 | XFORMERS_IS_AVAILBLE = False 9 | print("No module 'xformers'. Proceeding without it.") 10 | 11 | 12 | import torch 13 | from torch import nn, einsum 14 | import torchvision.transforms.functional as TF 15 | from einops import rearrange, repeat 16 | 17 | _ATTN_PRECISION = None 18 | 19 | def forward_sd2(self, x, context=None, mask=None): 20 | h = self.heads 21 | q = self.to_q(x) 22 | context = x if context is None else context 23 | k = self.to_k(context) 24 | v = self.to_v(context) 25 | 26 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 27 | 28 | if _ATTN_PRECISION =="fp32": # force cast to fp32 to avoid overflowing 29 | with torch.autocast(enabled=False, device_type = 'cuda'): 30 | q, k = q.float(), k.float() 31 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 32 | else: 33 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 34 | del q, k 35 | 36 | if mask is not None: 37 | mask = rearrange(mask, 'b ... -> b (...)') 38 | max_neg_value = -torch.finfo(sim.dtype).max 39 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 40 | sim.masked_fill_(~mask, max_neg_value) 41 | 42 | # attention, what we cannot get enough of 43 | sim = sim.softmax(dim=-1) 44 | 45 | out = einsum('b i j, b j d -> b i d', sim, v) 46 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 47 | return self.to_out(out) 48 | 49 | def forward_xformers(self, x, context=None, mask=None): 50 | q = self.to_q(x) 51 | context = x if context is None else context 52 | k = self.to_k(context) 53 | v = self.to_v(context) 54 | 55 | b, _, _ = q.shape 56 | q, k, v = map( 57 | lambda t: t.unsqueeze(3) 58 | .reshape(b, t.shape[1], self.heads, self.dim_head) 59 | .permute(0, 2, 1, 3) 60 | .reshape(b * self.heads, t.shape[1], self.dim_head) 61 | .contiguous(), 62 | (q, k, v), 63 | ) 64 | 65 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) 66 | 67 | if mask is not None: 68 | raise NotImplementedError 69 | out = ( 70 | out.unsqueeze(0) 71 | .reshape(b, self.heads, out.shape[1], self.dim_head) 72 | .permute(0, 2, 1, 3) 73 | .reshape(b, out.shape[1], self.heads * self.dim_head) 74 | ) 75 | return self.to_out(out) 76 | 77 | if XFORMERS_IS_AVAILBLE: 78 | forward = forward_xformers 79 | else: 80 | forward = forward_sd2 81 | 82 | def forward_and_save(self, x, context=None, mask=None): 83 | att_type = "self" if context is None else "cross" 84 | 85 | h = self.heads 86 | q = self.to_q(x) 87 | context = x if context is None else context 88 | k = self.to_k(context) 89 | v = self.to_v(context) 90 | 91 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 92 | 93 | # with torch.autocast(enabled=False, device_type = 'cuda'): 94 | # q, k = q.float(), k.float() 95 | # sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 96 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 97 | # del q,k 98 | 99 | if hasattr(share, '_crossattn_similarity_res8') and x.shape[1] == share.input_shape.res8 and att_type == 'cross': 100 | share._crossattn_similarity_res8.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts 101 | if hasattr(share, '_crossattn_similarity_res16') and x.shape[1] == share.input_shape.res16 and att_type == 'cross': 102 | share._crossattn_similarity_res16.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts 103 | if hasattr(share, '_crossattn_similarity_res32') and x.shape[1] == share.input_shape.res32 and att_type == 'cross': 104 | share._crossattn_similarity_res32.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts 105 | if hasattr(share, '_crossattn_similarity_res64') and x.shape[1] == share.input_shape.res64 and att_type == 'cross': 106 | share._crossattn_similarity_res64.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts 107 | 108 | # attention, what we cannot get enough of 109 | sim = sim.softmax(dim=-1) 110 | out = einsum("b i j, b j d -> b i d", sim, v) 111 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 112 | return self.to_out(out) -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/inpaint.py: -------------------------------------------------------------------------------- 1 | from ... import share 2 | import torch 3 | 4 | # import xformers 5 | # import xformers.ops 6 | 7 | import torchvision.transforms.functional as TF 8 | 9 | from torch import nn, einsum 10 | from inspect import isfunction 11 | from einops import rearrange, repeat 12 | 13 | qkv_reduce_dims = [-1] 14 | increase_indices = [1,2] 15 | 16 | def forward(self, x, context=None, mask=None): 17 | h = self.heads 18 | q = self.to_q(x) 19 | is_cross = context is not None 20 | context = x if context is None else context 21 | k = self.to_k(context) 22 | v = self.to_v(context) 23 | 24 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 25 | 26 | scale = self.dim_head**-0.5 27 | sim = einsum("b i d, b j d -> b i j", q, k) * scale 28 | 29 | if is_cross: 30 | if q.shape[1] in [share.input_shape.res16]: 31 | # For simplicity, assume token 1 is target, token 2 is (1 word label) 32 | # sim: (16x4096x77); mask: (64x64) -> mask.reshape(-1): (4096) 33 | N = sim.shape[0]//2 34 | sim[N:] += ( 35 | share.w 36 | * share.noise_signal_ratio[share.timestep] 37 | * share.input_mask.get_res(q, 'cuda').reshape(1,-1,1)# (1,4096,1) 38 | * sim[N:].amax(dim=qkv_reduce_dims, keepdim = True) # (16x1x1) 39 | ) * (torch.eye(77)[increase_indices].sum(0).cuda()) # (1,1,77) 40 | 41 | del q, k 42 | sim = sim.softmax(dim=-1) 43 | 44 | out = einsum("b i j, b j d -> b i d", sim, v) 45 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 46 | return self.to_out(out) -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/introvert.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numbers 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision 8 | from torch import nn, einsum 9 | from einops import rearrange, repeat 10 | 11 | from ... import share 12 | from ...libimage import IImage 13 | 14 | 15 | # Default for version 4 16 | introvert_res = [16, 32, 64] 17 | introvert_on = True 18 | token_idx = [1,2] 19 | 20 | # Visualization purpose 21 | viz_image = None 22 | viz_mask = None 23 | 24 | video_frames_selfattn = [] 25 | video_frames_crossattn = [] 26 | visualize_resolution = 16 27 | visualize_selfattn = False 28 | visualize_crossattn = False 29 | 30 | class GaussianSmoothing(nn.Module): 31 | """ 32 | Apply gaussian smoothing on a 33 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 34 | in the input using a depthwise convolution. 35 | Arguments: 36 | channels (int, sequence): Number of channels of the input tensors. Output will 37 | have this number of channels as well. 38 | kernel_size (int, sequence): Size of the gaussian kernel. 39 | sigma (float, sequence): Standard deviation of the gaussian kernel. 40 | dim (int, optional): The number of dimensions of the data. 41 | Default value is 2 (spatial). 42 | """ 43 | def __init__(self, channels, kernel_size, sigma, dim=2): 44 | super(GaussianSmoothing, self).__init__() 45 | if isinstance(kernel_size, numbers.Number): 46 | kernel_size = [kernel_size] * dim 47 | if isinstance(sigma, numbers.Number): 48 | sigma = [sigma] * dim 49 | 50 | # The gaussian kernel is the product of the 51 | # gaussian function of each dimension. 52 | kernel = 1 53 | meshgrids = torch.meshgrid( 54 | [ 55 | torch.arange(size, dtype=torch.float32) 56 | for size in kernel_size 57 | ] 58 | ) 59 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 60 | mean = (size - 1) / 2 61 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 62 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2) 63 | 64 | # Make sure sum of values in gaussian kernel equals 1. 65 | kernel = kernel / torch.sum(kernel) 66 | 67 | # Reshape to depthwise convolutional weight 68 | kernel = kernel.view(1, 1, *kernel.size()) 69 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 70 | 71 | self.register_buffer('weight', kernel) 72 | self.groups = channels 73 | 74 | if dim == 1: 75 | self.conv = F.conv1d 76 | elif dim == 2: 77 | self.conv = F.conv2d 78 | elif dim == 3: 79 | self.conv = F.conv3d 80 | else: 81 | raise RuntimeError( 82 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 83 | ) 84 | 85 | def forward(self, input): 86 | """ 87 | Apply gaussian filter to input. 88 | Arguments: 89 | input (torch.Tensor): Input to apply gaussian filter on. 90 | Returns: 91 | filtered (torch.Tensor): Filtered output. 92 | """ 93 | return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups, padding='same') 94 | 95 | 96 | def forward(self, x, context=None, mask=None): 97 | is_cross = context is not None 98 | att_type = "self" if context is None else "cross" 99 | 100 | h = self.heads 101 | 102 | q = self.to_q(x) 103 | context = x if context is None else context 104 | k = self.to_k(context) 105 | v = self.to_v(context) 106 | 107 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 108 | 109 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 110 | sim_before = sim 111 | del q, k 112 | 113 | if mask is not None: 114 | mask = rearrange(mask, 'b ... -> b (...)') 115 | max_neg_value = -torch.finfo(sim.dtype).max 116 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 117 | sim.masked_fill_(~mask, max_neg_value) 118 | 119 | if hasattr(share, '_crossattn_similarity_res8') and x.shape[1] == share.input_shape.res8 and att_type == 'cross': 120 | share._crossattn_similarity_res8.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts 121 | if hasattr(share, '_crossattn_similarity_res16') and x.shape[1] == share.input_shape.res16 and att_type == 'cross': 122 | share._crossattn_similarity_res16.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts 123 | if hasattr(share, '_crossattn_similarity_res32') and x.shape[1] == share.input_shape.res32 and att_type == 'cross': 124 | share._crossattn_similarity_res32.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts 125 | if hasattr(share, '_crossattn_similarity_res64') and x.shape[1] == share.input_shape.res64 and att_type == 'cross': 126 | share._crossattn_similarity_res64.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts 127 | 128 | sim = sim.softmax(dim=-1) 129 | out = einsum('b i j, b j d -> b i d', sim, v) 130 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 131 | 132 | if is_cross: 133 | return self.to_out(out) 134 | 135 | return self.to_out(out), v, sim_before 136 | 137 | 138 | def introvert_rescale(y, self_v, self_sim, cross_sim, self_h, to_out): 139 | mask = share.introvert_mask.get_res(self_v) 140 | shape = share.introvert_mask.get_shape(self_v) 141 | res = share.introvert_mask.get_res_val(self_v) 142 | # print (res, shape) 143 | 144 | ################# Introvert Attention V4 ################ 145 | # Use this with 50% of DDIM steps 146 | 147 | # TODO: LOOK INTO THIS. WHY WITHOUT BINARY WORKS BETTER???? 148 | # mask = (mask > 0.5).to(torch.float32) 149 | m = mask.to(self_v.device) 150 | # mask_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).cuda() 151 | # m = mask_smoothing(m) # Smoothing on binary mask also works 152 | m = rearrange(m, 'b c h w -> b (h w) c').contiguous() 153 | mo = m 154 | m = torch.matmul(m, m.permute(0, 2, 1)) + (1-m) 155 | 156 | cross_sim = cross_sim[:, token_idx].sum(dim=1) 157 | cross_sim = cross_sim.reshape(shape) 158 | # TODO: comment this out if it is not neccessary 159 | heatmap_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).cuda() 160 | cross_sim = heatmap_smoothing(cross_sim.unsqueeze(0))[0] 161 | cross_sim = cross_sim.reshape(-1) 162 | cross_sim = ((cross_sim - torch.median(cross_sim.ravel())) / torch.max(cross_sim.ravel())).clip(0, 1) 163 | 164 | # If introvert attention is off, return original SA result 165 | if introvert_on and res in introvert_res: 166 | w = (1 - m) * cross_sim.reshape(1, 1, -1) + m 167 | # On 64 resolution make scaling with constant, as cross_sim doesn't contain semantic meaning 168 | if res == 64: w = m 169 | self_sim = self_sim * w 170 | self_sim_viz = self_sim # Keep for viz purposes 171 | self_sim = self_sim.softmax(dim=-1) 172 | out = einsum('b i j, b j d -> b i d', self_sim, self_v) 173 | out = rearrange(out, '(b h) n d -> b n (h d)', h=self_h) 174 | out = to_out(out) 175 | else: 176 | self_sim_viz = self_sim # Keep for viz purposes 177 | out = y 178 | ################## END OF Introvert Attention V4 ########################### 179 | 180 | 181 | ################# VISUALIZE CROSS ATTENTION ############################### 182 | if visualize_crossattn and res == visualize_resolution: 183 | cross_vis = cross_sim.reshape(shape) 184 | up = (64 // res) * 8 185 | if viz_image is not None: 186 | heatmap = IImage(cross_vis, vmin=0).heatmap((cross_vis.shape[0]*up, cross_vis.shape[1]*up)) 187 | video_frames_crossattn.append(viz_image + heatmap) 188 | else: 189 | heatmap = IImage(cross_vis, vmin=0).heatmap((cross_vis.shape[0]*up, cross_vis.shape[1]*up)) 190 | video_frames_crossattn.append(heatmap) 191 | ############################### 192 | 193 | ################# VISUALIZE SELF ATTENTION ############################### 194 | if visualize_selfattn and res == visualize_resolution: 195 | selected = [] 196 | up = (64 // res) * 8 197 | for i in range(mo.shape[1]): 198 | if mo[0, i, 0]: 199 | selected.append(self_sim_viz[:, i, :]) 200 | selected = torch.stack(selected, dim=1) 201 | selected_vis = selected.mean(0).mean(0) 202 | 203 | selected_vis = selected_vis.reshape(shape) 204 | 205 | if viz_image is not None: 206 | heatmap = IImage(selected_vis, vmin=0, vmax=1).heatmap((selected_vis.shape[0]*up, selected_vis.shape[1]*up)) 207 | video_frames_selfattn.append(viz_image+heatmap) 208 | else: 209 | heatmap = IImage(selected_vis, vmin=0).heatmap((selected_vis.shape[0]*up, selected_vis.shape[1]*up)) 210 | video_frames_selfattn.append(heatmap) 211 | ############################### 212 | 213 | return out 214 | 215 | -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/maskscore.py: -------------------------------------------------------------------------------- 1 | from ... import share 2 | import torch 3 | 4 | # import xformers 5 | # import xformers.ops 6 | 7 | import torchvision.transforms.functional as TF 8 | 9 | from torch import nn, einsum 10 | from inspect import isfunction 11 | from einops import rearrange, repeat 12 | 13 | qkv_reduce_dims = [-1] 14 | increase_indices = [1,2] 15 | 16 | def forward(self, x, context=None, mask=None): 17 | h = self.heads 18 | q = self.to_q(x) 19 | is_cross = context is not None 20 | context = x if context is None else context 21 | k = self.to_k(context) 22 | v = self.to_v(context) 23 | 24 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 25 | 26 | scale = self.dim_head**-0.5 27 | sim = einsum("b i d, b j d -> b i j", q, k) * scale 28 | 29 | if is_cross: 30 | if q.shape[1] in [share.input_shape.res16]: 31 | # For simplicity, assume token 1 is target, token 2 is (1 word label) 32 | # sim: (16x4096x77); mask: (64x64) -> mask.reshape(-1): (4096) 33 | N = sim.shape[0]//2 34 | sim[N:] += ( 35 | share.w 36 | * share.noise_signal_ratio[share.timestep] 37 | * share.input_mask.get_res(q, 'cuda').reshape(1,-1,1)# (1,4096,1) 38 | * sim[N:].amax(dim=qkv_reduce_dims, keepdim = True) # (16x1x1) 39 | ) * (torch.eye(77)[increase_indices].sum(0).cuda()) # (1,1,77) 40 | 41 | del q, k 42 | sim = sim.softmax(dim=-1) 43 | 44 | out = einsum("b i j, b j d -> b i d", sim, v) 45 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 46 | return self.to_out(out) -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/other.py: -------------------------------------------------------------------------------- 1 | # import xformers 2 | # import xformers.ops 3 | 4 | def forward(self, x, context=None, mask=None): 5 | q = self.to_q(x) 6 | context = x if context is None else context 7 | k = self.to_k(context) 8 | v = self.to_v(context) 9 | 10 | b, _, _ = q.shape 11 | q, k, v = map( 12 | lambda t: t.unsqueeze(3) 13 | .reshape(b, t.shape[1], self.heads, self.dim_head) 14 | .permute(0, 2, 1, 3) 15 | .reshape(b * self.heads, t.shape[1], self.dim_head) 16 | .contiguous(), 17 | (q, k, v), 18 | ) 19 | 20 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) 21 | 22 | if mask is not None: 23 | raise NotImplementedError 24 | out = ( 25 | out.unsqueeze(0) 26 | .reshape(b, self.heads, out.shape[1], self.dim_head) 27 | .permute(0, 2, 1, 3) 28 | .reshape(b, out.shape[1], self.heads * self.dim_head) 29 | ) 30 | return self.to_out(1 - out) -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/shuffled.py: -------------------------------------------------------------------------------- 1 | from ... import share 2 | 3 | # import xformers 4 | # import xformers.ops 5 | 6 | import torch 7 | from torch import nn, einsum 8 | import torchvision.transforms.functional as TF 9 | from einops import rearrange, repeat 10 | 11 | layer_mask = share.LayerMask() 12 | 13 | def forward(self, x, context=None, mask=None): 14 | att_type = "self" if context is None else "cross" 15 | 16 | h = self.heads 17 | q = self.to_q(x) 18 | context = x if context is None else context 19 | k = self.to_k(context) 20 | v = self.to_v(context) 21 | 22 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 23 | 24 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 25 | 26 | # attention, what we cannot get enough of 27 | sim = sim.softmax(dim=-1) 28 | out = einsum("b i j, b j d -> b i d", sim, v) 29 | 30 | if att_type == 'self' and q.shape[1] in [share.input_shape.res16]: 31 | out = out[:,torch.randperm(out.shape[1]),:] 32 | 33 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 34 | return self.to_out(out) -------------------------------------------------------------------------------- /src/smplfusion/patches/attentionpatch/zeropaint.py: -------------------------------------------------------------------------------- 1 | from ... import share 2 | import torch 3 | 4 | # import xformers 5 | # import xformers.ops 6 | 7 | import torchvision.transforms.functional as TF 8 | 9 | from torch import nn, einsum 10 | from inspect import isfunction 11 | from einops import rearrange, repeat 12 | 13 | qkv_reduce_dims = [-1, -2] 14 | 15 | def forward(self, x, context=None, mask=None): 16 | h = self.heads 17 | q = self.to_q(x) 18 | is_cross = context is not None 19 | context = x if context is None else context 20 | k = self.to_k(context) 21 | v = self.to_v(context) 22 | 23 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 24 | scale = self.dim_head**-0.5 25 | sim = einsum("b i d, b j d -> b i j", q, k) 26 | steps_40 = [999, 974, 949, 924, 899, 874, 849, 824, 799, 774, 749, 724, 699, 674, 649, 624, 599, 574, 549, 524, 499, 474, 449, 424, 399, 374, 349, 324, 299, 274, 249, 224, 199, 174, 149, 124, 99, 74, 49, 24] 27 | test_ind = steps_40.index(share.timestep) 28 | if hasattr(share, 'list_of_masks') and is_cross: 29 | if q.shape[1] in [share.input_shape.res16, share.input_shape.res32] or True: 30 | for masked_object in share.list_of_masks: 31 | # sim: (16x4096x77); mask: (64x64) -> mask.reshape(-1): (4096) 32 | zp_condition = ( 33 | masked_object['w'] 34 | # * share.zp_sigmas[share.timestep] 35 | * share.zp_sigmas[test_ind] 36 | * masked_object['mask'].get_res(q, 'cuda').reshape(1,-1,1)# (1,4096,1) 37 | * sim.amax(dim=qkv_reduce_dims, keepdim = True) # (16x1x1) 38 | ) * (torch.eye(77)[masked_object['token_idx']].sum(0).cuda()) # (1,1,77) 39 | 40 | zp_zeros = torch.zeros_like(zp_condition).cuda() 41 | final_condition = torch.concat([zp_zeros[:8,:,:],zp_condition[8:,:,:]]) 42 | 43 | sim = sim+final_condition 44 | # print(sim.shape,"wwwww") 45 | del q, k 46 | 47 | sim = sim*scale 48 | sim = sim.softmax(dim=-1) 49 | 50 | out = einsum("b i j, b j d -> b i d", sim, v) 51 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 52 | return self.to_out(out) -------------------------------------------------------------------------------- /src/smplfusion/patches/router.py: -------------------------------------------------------------------------------- 1 | from . import attentionpatch 2 | from . import transformerpatch 3 | 4 | VERBOSE = False 5 | attention_forward = attentionpatch.default.forward 6 | basic_transformer_forward = transformerpatch.default.forward 7 | 8 | def reset(): 9 | global attention_forward, basic_transformer_forward 10 | attention_forward = attentionpatch.default.forward 11 | basic_transformer_forward = transformerpatch.default.forward 12 | if VERBOSE: print ("Resetting Diffusion Model") 13 | 14 | print ("RELOADING ROUTER") 15 | print (attention_forward) -------------------------------------------------------------------------------- /src/smplfusion/patches/transformerpatch/__init__.py: -------------------------------------------------------------------------------- 1 | from . import default 2 | from . import guided 3 | from . import weighting_versions 4 | from . import introvert 5 | -------------------------------------------------------------------------------- /src/smplfusion/patches/transformerpatch/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/transformerpatch/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/transformerpatch/__pycache__/default.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/transformerpatch/__pycache__/default.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/transformerpatch/__pycache__/guided.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/transformerpatch/__pycache__/guided.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/transformerpatch/__pycache__/introvert.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/transformerpatch/__pycache__/introvert.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/transformerpatch/__pycache__/weighting_versions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/transformerpatch/__pycache__/weighting_versions.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/patches/transformerpatch/default.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ... import share 3 | 4 | def set_wsa(value): 5 | global wsa 6 | wsa = torch.tensor([1., value])[:,None,None].cuda() 7 | def set_wca(value): 8 | global wca 9 | wca = torch.tensor([1., value])[:,None,None].cuda() 10 | def set_wff(value): 11 | global wff 12 | wff = torch.tensor([1., value])[:,None,None].cuda() 13 | 14 | # set_wca(1.) 15 | # set_wsa(1.) 16 | # set_wff(1.) 17 | 18 | wca = 1.0 19 | wsa = 1.0 20 | wff = 1.0 21 | 22 | def forward(self, x, context=None): 23 | x = x + wsa * self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) # Self Attn. 24 | x = x + wca * self.attn2(self.norm2(x), context=context) # Cross Attn. 25 | x = x + wff * self.ff(self.norm3(x)) 26 | return x 27 | 28 | def forward_and_save(self, x, context=None): 29 | val = [x] 30 | val.append(wsa * self.attn1(self.norm1(x), context=context if self.disable_self_attn else None)) # Self Attn. 31 | x = x + val[-1] 32 | val.append(wca * self.attn2(self.norm2(x), context=context)) # Cross Attn. 33 | x = x + val[-1] 34 | val.append(wff * self.ff(self.norm3(x))) 35 | x = x + val[-1] 36 | 37 | # Save outputs 38 | if hasattr(share, '_basic_transformer_input') and x.shape[1] == share.input_shape.res16: 39 | share._basic_transformer_input.append(share.reshape(val[0])) 40 | if hasattr(share, '_basic_transformer_selfattn') and x.shape[1] == share.input_shape.res16: 41 | share._basic_transformer_selfattn.append(share.reshape(val[1])) 42 | if hasattr(share, '_basic_transformer_crossattn') and x.shape[1] == share.input_shape.res16: 43 | share._basic_transformer_crossattn.append(share.reshape(val[2])) 44 | if hasattr(share, '_basic_transformer_ff') and x.shape[1] == share.input_shape.res16: 45 | share._basic_transformer_ff.append(share.reshape(val[3])) 46 | return x 47 | -------------------------------------------------------------------------------- /src/smplfusion/patches/transformerpatch/guided.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ... import share 3 | 4 | def set_wsa(value): 5 | global wsa 6 | wsa = torch.tensor([1., value])[:,None,None].cuda() 7 | def set_wca(value): 8 | global wca 9 | wca = torch.tensor([1., value])[:,None,None].cuda() 10 | def set_wff(value): 11 | global wff 12 | wff = torch.tensor([1., value])[:,None,None].cuda() 13 | 14 | set_wca(1.) 15 | set_wsa(1.) 16 | set_wff(1.) 17 | 18 | guidance_scale = 7.5 19 | 20 | def forward(self, x, context=None): 21 | # print (x.shape) 22 | mask = share.input_mask.get_res(x).reshape(-1,1).cuda() 23 | 24 | _out_sa = wsa * self.attn1(self.norm1(x), None) # Self Attn. 25 | 26 | # _out_ca_uncond = wca * self.attn2(self.norm2(x), context=context) # Cross Attn. "Unconditional" 27 | _out_ca = wca * self.attn2(self.norm2(x + _out_sa), context=context) # Cross Attn. 28 | 29 | if share.timestep < 0: 30 | # x = x + (1 - mask) * _out_sa + mask * (guidance_scale * _out_ca_uncond + (1 - guidance_scale) * (_out_sa + _out_ca)) 31 | x = x + (1 - mask) * (_out_sa + _out_ca) + mask * (guidance_scale * _out_ca_uncond + (1 - guidance_scale) * (_out_sa + _out_ca)) 32 | else: 33 | x = x + _out_sa + _out_ca 34 | 35 | _out_ff = wff * self.ff(self.norm3(x)) 36 | x = x + _out_ff 37 | 38 | if False: 39 | save([_out_ca, _out_sa, _out_ff]) 40 | 41 | return x 42 | 43 | forward_and_save = forward 44 | 45 | def save(val):# Save outputs 46 | if hasattr(share, '_basic_transformer_selfattn'): 47 | share._basic_transformer_selfattn.append(share.reshape(val[0])) 48 | if hasattr(share, '_basic_transformer_crossattn'): 49 | share._basic_transformer_crossattn.append(share.reshape(val[1])) 50 | if hasattr(share, '_basic_transformer_ff'): 51 | share._basic_transformer_ff.append(share.reshape(val[1])) -------------------------------------------------------------------------------- /src/smplfusion/patches/transformerpatch/introvert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange, repeat 4 | from ... import share 5 | from ..attentionpatch import introvert 6 | 7 | w_ff = 1. 8 | w_sa = 1. 9 | w_ca = 1. 10 | 11 | use_grad = True 12 | 13 | 14 | def forward(self, x, context=None): 15 | # with torch.no_grad(): 16 | if use_grad: 17 | y, self_v, self_sim = self.attn1(self.norm1(x), None) # Self Attn. 18 | 19 | x_uncond, x_cond = x.chunk(2) 20 | context_uncond, context_cond = context.chunk(2) 21 | 22 | y_uncond, y_cond = y.chunk(2) 23 | self_sim_uncond, self_sim_cond = self_sim.chunk(2) 24 | self_v_uncond, self_v_cond = self_v.chunk(2) 25 | 26 | # Calculate CA similarities with conditional context 27 | cross_h = self.attn2.heads 28 | cross_q = self.attn2.to_q(self.norm2(x_cond+y_cond)) 29 | cross_k = self.attn2.to_k(context_cond) 30 | cross_v = self.attn2.to_v(context_cond) 31 | 32 | cross_q, cross_k, cross_v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=cross_h), (cross_q, cross_k, cross_v)) 33 | 34 | with torch.autocast(enabled=False, device_type = 'cuda'): 35 | cross_q, cross_k = cross_q.float(), cross_k.float() 36 | cross_sim = einsum('b i d, b j d -> b i j', cross_q, cross_k) * self.attn2.scale 37 | 38 | del cross_q, cross_k 39 | cross_sim = cross_sim.softmax(dim=-1) # Up to this point cross_sim is regular cross_sim in CA layer 40 | 41 | cross_sim = cross_sim.mean(dim=0) # Calculate mean across heads 42 | 43 | # Introvert Attention rescale heppens here 44 | y_cond = introvert.introvert_rescale( 45 | y_cond, self_v_cond, self_sim_cond, cross_sim, self.attn1.heads, self.attn1.to_out) # Rescale cond 46 | y_uncond = introvert.introvert_rescale( 47 | y_uncond, self_v_uncond, self_sim_uncond, cross_sim, self.attn1.heads, self.attn1.to_out) # Rescale uncond 48 | 49 | y = torch.cat([y_uncond, y_cond], dim=0) 50 | 51 | x = x + w_sa * y 52 | x = x + w_ca * self.attn2(self.norm2(x), context=context) # Cross Attn. 53 | x = x + w_ff * self.ff(self.norm3(x)) 54 | return x -------------------------------------------------------------------------------- /src/smplfusion/patches/transformerpatch/weighting_versions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ... import share 3 | 4 | def set_wsa(value): 5 | global wsa 6 | wsa = torch.tensor([1., value])[:,None,None].cuda() 7 | def set_wca(value): 8 | global wca 9 | wca = torch.tensor([1., value])[:,None,None].cuda() 10 | def set_wff(value): 11 | global wff 12 | wff = torch.tensor([1., value])[:,None,None].cuda() 13 | 14 | set_wca(1.) 15 | set_wsa(1.) 16 | set_wff(1.) 17 | 18 | def forward(self, x, context=None): 19 | x = x + wca * self.attn2(self.norm2(x), context=context) # Cross Attn. 20 | x = x + wsa * self.attn1(self.norm1(x), None) # Self Attn. 21 | x = x + wff * self.ff(self.norm3(x)) 22 | return x 23 | 24 | def forward_and_save(self, x, context=None): 25 | val = [x] 26 | val.append(wsa * self.attn1(self.norm1(x), None)) # Self Attn. 27 | x = x + val[-1] 28 | val.append(wca * self.attn2(self.norm2(x), context=context)) # Cross Attn. 29 | x = x + val[-1] 30 | val.append(wff * self.ff(self.norm3(x))) 31 | x = x + val[-1] 32 | 33 | # Save outputs 34 | if hasattr(share, '_basic_transformer_input') and x.shape[1] == share.input_shape.res16: 35 | share._basic_transformer_input.append(share.reshape(val[0])) 36 | if hasattr(share, '_basic_transformer_selfattn') and x.shape[1] == share.input_shape.res16: 37 | share._basic_transformer_selfattn.append(share.reshape(val[1])) 38 | if hasattr(share, '_basic_transformer_crossattn') and x.shape[1] == share.input_shape.res16: 39 | share._basic_transformer_crossattn.append(share.reshape(val[2])) 40 | if hasattr(share, '_basic_transformer_ff') and x.shape[1] == share.input_shape.res16: 41 | share._basic_transformer_ff.append(share.reshape(val[3])) 42 | return x 43 | 44 | 45 | # ========= MODIFICATIONS ============= # 46 | 47 | 48 | def forward_and_save3(self, x, context=None): 49 | val = [] 50 | val.append(self.attn1(self.norm1(x), None)) # Self Attn. 51 | modify_res = [share.input_mask.res16, share.input_mask.res32, share.input_mask.res64] 52 | 53 | if x.shape[1] in modify_res: 54 | val[-1] = val[-1] + (wsa - 1) * share.input_mask.get_res(x).reshape(-1)[:,None].cuda() * val[-1] 55 | x = x + val[-1] 56 | 57 | val.append(self.attn2(self.norm2(x), context=context)) # Cross Attn. 58 | 59 | if x.shape[1] in modify_res: 60 | val[-1] = val[-1] + (wca - 1) * share.input_mask.get_res(x).reshape(-1)[:,None].cuda() * val[-1] 61 | x = x + val[-1] 62 | 63 | val.append(self.ff(self.norm3(x))) 64 | 65 | if x.shape[1] in modify_res: 66 | val[-1] = val[-1] + (wff - 1) * share.input_mask.get_res(x).reshape(-1)[:,None].cuda() * val[-1] 67 | x = x + val[-1] 68 | if hasattr(share, 'out_basic_transformer_block') and x.shape[1] == share.input_mask.res16: 69 | share.out_basic_transformer_block.append(val) 70 | return x 71 | 72 | def forward_and_save2(self, x, context=None): 73 | val = [] 74 | val.append(self.attn1(self.norm1(x), None)) # Self Attn. 75 | x = x + wsa * val[-1] 76 | val.append(self.attn2(self.norm2(x), context=context)) # Cross Attn. 77 | print (val[-1].shape, share.input_mask.val16.shape) 78 | x = x + val[-1] + wca * share.input_mask.val16 * val[-1] 79 | val.append(self.ff(self.norm3(x))) 80 | x = x + wff * val[-1] 81 | if hasattr(share, 'out_basic_transformer_block') and x.shape[1] == share.input_mask.res16: 82 | share.out_basic_transformer_block.append(val) 83 | return x 84 | 85 | def forward_and_reweight(self, x, context=None): 86 | modify_res = [share.input_mask.res16, share.input_mask.res32, share.input_mask.res64] 87 | 88 | _attn1 = self.attn1(self.norm1(x), None) # Self Attn. 89 | _attn2 = self.attn2(self.norm2(x + _attn1), context=context) # Cross Attn. 90 | _ff = self.ff(self.norm3(x + _attn1 + _attn2)) 91 | 92 | if x.shape[1] in modify_res: 93 | lm1,lm2,lm3 = wsa - 1, wca - 1, wff - 1 94 | _attn1 *= (1 + lm1 * share.input_mask.get_res(x).reshape(-1)[:,None].cuda()) 95 | _attn2 *= (1 + lm2 * share.input_mask.get_res(x).reshape(-1)[:,None].cuda()) 96 | _ff *= (1 + lm3 * share.input_mask.get_res(x).reshape(-1)[:,None].cuda()) 97 | 98 | if hasattr(share, 'out_basic_transformer_block') and x.shape[1] == share.input_mask.res16: 99 | share.out_basic_transformer_block.append([_attn1, _attn2, _ff]) 100 | 101 | return x + _attn1 + _attn2 + _ff 102 | # return x + (w_sa * _attn1 + w_ca * _attn2 + w_ff * _ff) 103 | # return x + (w_sa * _attn1 + w_ca * _attn2 + w_ff * _ff) / ((w_sa + w_ca + w_ff) / 3) 104 | 105 | def forward_and_reweight2(self, x, context=None): 106 | _attn1 = self.attn1(self.norm1(x), None) # Self Attn. 107 | _attn2 = self.attn2(self.norm2(x + _attn1), context=context) # Cross Attn. 108 | _ff = self.ff(self.norm3(x + _attn1 + _attn2)) 109 | 110 | _attn1 += (wsa - 1) / ((wsa + wca + wff) / 3) * share.input_mask.get_res(x).reshape(-1)[:,None].cuda() * _attn1 111 | _attn2 += (wca - 1) / ((wsa + wca + wff) / 3) * share.input_mask.get_res(x).reshape(-1)[:,None].cuda() * _attn2 112 | _ff += (wff - 1) / ((wsa + wca + wff) / 3) * share.input_mask.get_res(x).reshape(-1)[:,None].cuda() * _ff 113 | 114 | 115 | if hasattr(share, 'out_basic_transformer_block') and x.shape[1] == share.input_mask.res16: 116 | share.out_basic_transformer_block.append([_attn1, _attn2, _ff]) 117 | # share.out_basic_transformer_block.append([ 118 | # w_sa / ((w_sa + w_ca + w_ff) / 3) * _attn1, 119 | # w_ca / ((w_sa + w_ca + w_ff) / 3) * _attn2, 120 | # w_ff / ((w_sa + w_ca + w_ff) / 3) * _ff 121 | # ]) 122 | 123 | return x + _attn1 + _attn2 + _ff 124 | # return x + (w_sa * _attn1 + w_ca * _attn2 + w_ff * _ff) 125 | # return x + (w_sa * _attn1 + w_ca * _attn2 + w_ff * _ff) / ((w_sa + w_ca + w_ff) / 3) -------------------------------------------------------------------------------- /src/smplfusion/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def linear(n_timestep = 1000, start = 1e-4, end = 2e-2): 4 | return Schedule(torch.linspace(start ** 0.5, end ** 0.5, n_timestep, dtype = torch.float64) ** 2) 5 | 6 | class Schedule: 7 | def __init__(self, betas): 8 | self.betas = betas 9 | self._alphas = 1 - betas 10 | self.alphas = torch.cumprod(self._alphas, 0) 11 | self.one_minus_alphas = 1 - self.alphas 12 | self.sqrt_alphas = torch.sqrt(self.alphas) 13 | self.sqrt_one_minus_alphas = torch.sqrt(1 - self.alphas) 14 | self.sqrt_noise_signal_ratio = self.sqrt_one_minus_alphas / self.sqrt_alphas 15 | self.noise_signal_ratio = (1 - self.alphas) / self.alphas 16 | 17 | def range(self, dt): 18 | return range(len(self.betas)-1, 0, -dt) 19 | 20 | def sigma(self, t, dt): 21 | return torch.sqrt((1 - self._alphas[t - dt]) / (1 - self._alphas[t]) * (1 - self._alphas[t] / self._alphas[t - dt])) # Like I did initially 22 | # return torch.sqrt((1 - self.alphas[t - dt]) / (1 - self.alphas[t]) * (1 - self.alphas[t] / self.alphas[t - dt])) # Like in diffusers 23 | 24 | def sigma_(self, t, dt): 25 | return torch.sqrt((1 - self.alphas[t - dt]) / (1 - self.alphas[t]) * (1 - self.alphas[t] / self.alphas[t - dt])) # Like in diffusers -------------------------------------------------------------------------------- /src/smplfusion/share.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms.functional as TF 2 | import torch 3 | import sys 4 | from .utils import * 5 | 6 | input_mask = None 7 | input_shape = None 8 | timestep = None 9 | timestep_index = None 10 | 11 | class Seed: 12 | def __getitem__(self, idx): 13 | if isinstance(idx, slice): 14 | idx = list(range(*idx.indices(idx.stop))) 15 | if isinstance(idx, list) or isinstance(idx, tuple): 16 | return [self[_idx] for _idx in idx] 17 | return 12345 ** idx % 54321 18 | seed = Seed() 19 | 20 | lock = {} 21 | def get_lock(value): 22 | global lock 23 | if value not in lock: 24 | lock[value] = True 25 | if lock[value]: 26 | lock[value] = False 27 | return True 28 | return False 29 | 30 | 31 | class DDIMIterator: 32 | def __init__(self, iterator): 33 | self.iterator = iterator 34 | def __iter__(self): 35 | self.iterator = iter(self.iterator) 36 | global timestep_index 37 | timestep_index = 0 38 | return self 39 | def __next__(self): 40 | global timestep, timestep_index, lock 41 | for x in lock: lock[x] = True 42 | timestep = next(self.iterator) 43 | timestep_index += 1 44 | return timestep 45 | 46 | def reshape(x): 47 | return input_shape.reshape(x) 48 | 49 | def set_shape(image_or_shape): 50 | global input_shape 51 | if hasattr(image_or_shape, 'size'): 52 | input_shape = InputShape(image_or_shape.size) 53 | if isinstance(image_or_shape, torch.Tensor): 54 | input_shape = InputShape(image_or_shape.shape[-2:][::-1]) 55 | elif isinstance(image_or_shape, list) or isinstance(image_or_shape, tuple): 56 | input_shape = InputShape(image_or_shape) 57 | 58 | self = sys.modules[__name__] 59 | 60 | def set_mask(mask): 61 | global input_mask, mask64, mask32, mask16, mask8, introvert_mask 62 | input_mask = InputMask(mask) 63 | introvert_mask = InputMask(mask) 64 | 65 | mask64 = input_mask.val64[0,0] 66 | mask32 = input_mask.val32[0,0] 67 | mask16 = input_mask.val16[0,0] 68 | mask8 = input_mask.val8[0,0] 69 | set_shape(mask) 70 | 71 | def exists(name): 72 | return hasattr(self, name) and getattr(self, name) is not None -------------------------------------------------------------------------------- /src/smplfusion/util.py: -------------------------------------------------------------------------------- 1 | # TODO: remove everything below as long as it doesn't break anything! 2 | 3 | import importlib 4 | 5 | import torch 6 | from torch import optim 7 | import numpy as np 8 | 9 | from inspect import isfunction 10 | from PIL import Image, ImageDraw, ImageFont 11 | 12 | 13 | def autocast(f): 14 | def do_autocast(*args, **kwargs): 15 | with torch.cuda.amp.autocast(enabled=True, 16 | dtype=torch.get_autocast_gpu_dtype(), 17 | cache_enabled=torch.is_autocast_cache_enabled()): 18 | return f(*args, **kwargs) 19 | 20 | return do_autocast 21 | 22 | 23 | def log_txt_as_img(wh, xc, size=10): 24 | # wh a tuple of (width, height) 25 | # xc a list of captions to plot 26 | b = len(xc) 27 | txts = list() 28 | for bi in range(b): 29 | txt = Image.new("RGB", wh, color="white") 30 | draw = ImageDraw.Draw(txt) 31 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 32 | nc = int(40 * (wh[0] / 256)) 33 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 34 | 35 | try: 36 | draw.text((0, 0), lines, fill="black", font=font) 37 | except UnicodeEncodeError: 38 | print("Cant encode string for logging. Skipping.") 39 | 40 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 41 | txts.append(txt) 42 | txts = np.stack(txts) 43 | txts = torch.tensor(txts) 44 | return txts 45 | 46 | 47 | def ismap(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] > 3) 51 | 52 | 53 | def isimage(x): 54 | if not isinstance(x,torch.Tensor): 55 | return False 56 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 57 | 58 | 59 | def exists(x): 60 | return x is not None 61 | 62 | 63 | def default(val, d): 64 | if exists(val): 65 | return val 66 | return d() if isfunction(d) else d 67 | 68 | 69 | def mean_flat(tensor): 70 | """ 71 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 72 | Take the mean over all non-batch dimensions. 73 | """ 74 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 75 | 76 | 77 | def count_params(model, verbose=False): 78 | total_params = sum(p.numel() for p in model.parameters()) 79 | if verbose: 80 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 81 | return total_params 82 | 83 | 84 | def instantiate_from_config(config): 85 | if not "target" in config: 86 | if config == '__is_first_stage__': 87 | return None 88 | elif config == "__is_unconditional__": 89 | return None 90 | raise KeyError("Expected key `target` to instantiate.") 91 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 92 | 93 | 94 | def get_obj_from_str(string, reload=False): 95 | module, cls = string.rsplit(".", 1) 96 | if reload: 97 | module_imp = importlib.import_module(module) 98 | importlib.reload(module_imp) 99 | return getattr(importlib.import_module(module, package=None), cls) 100 | 101 | 102 | class AdamWwithEMAandWings(optim.Optimizer): 103 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 104 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 105 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 106 | ema_power=1., param_names=()): 107 | """AdamW that saves EMA versions of the parameters.""" 108 | if not 0.0 <= lr: 109 | raise ValueError("Invalid learning rate: {}".format(lr)) 110 | if not 0.0 <= eps: 111 | raise ValueError("Invalid epsilon value: {}".format(eps)) 112 | if not 0.0 <= betas[0] < 1.0: 113 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 114 | if not 0.0 <= betas[1] < 1.0: 115 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 116 | if not 0.0 <= weight_decay: 117 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 118 | if not 0.0 <= ema_decay <= 1.0: 119 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 120 | defaults = dict(lr=lr, betas=betas, eps=eps, 121 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 122 | ema_power=ema_power, param_names=param_names) 123 | super().__init__(params, defaults) 124 | 125 | def __setstate__(self, state): 126 | super().__setstate__(state) 127 | for group in self.param_groups: 128 | group.setdefault('amsgrad', False) 129 | 130 | @torch.no_grad() 131 | def step(self, closure=None): 132 | """Performs a single optimization step. 133 | Args: 134 | closure (callable, optional): A closure that reevaluates the model 135 | and returns the loss. 136 | """ 137 | loss = None 138 | if closure is not None: 139 | with torch.enable_grad(): 140 | loss = closure() 141 | 142 | for group in self.param_groups: 143 | params_with_grad = [] 144 | grads = [] 145 | exp_avgs = [] 146 | exp_avg_sqs = [] 147 | ema_params_with_grad = [] 148 | state_sums = [] 149 | max_exp_avg_sqs = [] 150 | state_steps = [] 151 | amsgrad = group['amsgrad'] 152 | beta1, beta2 = group['betas'] 153 | ema_decay = group['ema_decay'] 154 | ema_power = group['ema_power'] 155 | 156 | for p in group['params']: 157 | if p.grad is None: 158 | continue 159 | params_with_grad.append(p) 160 | if p.grad.is_sparse: 161 | raise RuntimeError('AdamW does not support sparse gradients') 162 | grads.append(p.grad) 163 | 164 | state = self.state[p] 165 | 166 | # State initialization 167 | if len(state) == 0: 168 | state['step'] = 0 169 | # Exponential moving average of gradient values 170 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 171 | # Exponential moving average of squared gradient values 172 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 173 | if amsgrad: 174 | # Maintains max of all exp. moving avg. of sq. grad. values 175 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 176 | # Exponential moving average of parameter values 177 | state['param_exp_avg'] = p.detach().float().clone() 178 | 179 | exp_avgs.append(state['exp_avg']) 180 | exp_avg_sqs.append(state['exp_avg_sq']) 181 | ema_params_with_grad.append(state['param_exp_avg']) 182 | 183 | if amsgrad: 184 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 185 | 186 | # update the steps for each param group update 187 | state['step'] += 1 188 | # record the step after step update 189 | state_steps.append(state['step']) 190 | 191 | optim._functional.adamw(params_with_grad, 192 | grads, 193 | exp_avgs, 194 | exp_avg_sqs, 195 | max_exp_avg_sqs, 196 | state_steps, 197 | amsgrad=amsgrad, 198 | beta1=beta1, 199 | beta2=beta2, 200 | lr=group['lr'], 201 | weight_decay=group['weight_decay'], 202 | eps=group['eps'], 203 | maximize=False) 204 | 205 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 206 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 207 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 208 | 209 | return loss -------------------------------------------------------------------------------- /src/smplfusion/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .input_image import InputImage 2 | from .input_mask import InputMask 3 | from .input_shape import InputShape 4 | from .layer_mask import LayerMask -------------------------------------------------------------------------------- /src/smplfusion/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/utils/__pycache__/input_image.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/utils/__pycache__/input_image.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/utils/__pycache__/input_mask.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/utils/__pycache__/input_mask.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/utils/__pycache__/input_shape.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/utils/__pycache__/input_shape.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/utils/__pycache__/layer_mask.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/utils/__pycache__/layer_mask.cpython-39.pyc -------------------------------------------------------------------------------- /src/smplfusion/utils/input_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..libimage import IImage 3 | 4 | class InputImage: 5 | def to(self, device): return InputImage(self.image, device = device) 6 | def cuda(self): return InputImage(self.image, device = 'cuda') 7 | def cpu(self): return InputImage(self.image, device = 'cpu') 8 | 9 | def __init__(self, input_image): 10 | ''' 11 | args: 12 | input_image: (b,c,h,w) tensor 13 | ''' 14 | if hasattr(input_image, 'is_iimage'): 15 | self.image = input_image 16 | self.val512 = self.full = input_image.torch(0) 17 | elif isinstance(input_image, torch.Tensor): 18 | self.val512 = self.full = input_image.clone() 19 | self.image = IImage(input_image,0) 20 | 21 | self.h,self.w = h,w = self.val512.shape[-2:] 22 | self.shape = [self.h, self.w] 23 | self.shape64 = [self.h // 8, self.w // 8] 24 | self.shape32 = [self.h // 16, self.w // 16] 25 | self.shape16 = [self.h // 32, self.w // 32] 26 | self.shape8 = [self.h // 64, self.w // 64] 27 | 28 | self.res = self.h * self.w 29 | self.res64 = self.res // 64 30 | self.res32 = self.res // 64 // 4 31 | self.res16 = self.res // 64 // 16 32 | self.res8 = self.res // 64 // 64 33 | 34 | self.img = self.image 35 | self.img512 = self.image 36 | self.img64 = self.image.resize((h//8,w//8)) 37 | self.img32 = self.image.resize((h//16,w//16)) 38 | self.img16 = self.image.resize((h//32,w//32)) 39 | self.img8 = self.image.resize((h//64,w//64)) 40 | 41 | self.val64 = self.img64.torch() 42 | self.val32 = self.img32.torch() 43 | self.val16 = self.img16.torch() 44 | self.val8 = self.img8.torch() 45 | 46 | def get_res(self, q, device = 'cpu'): 47 | if q.shape[1] == self.res64: return self.val64.to(device) 48 | if q.shape[1] == self.res32: return self.val32.to(device) 49 | if q.shape[1] == self.res16: return self.val16.to(device) 50 | if q.shape[1] == self.res8: return self.val8.to(device) 51 | 52 | def get_shape(self, q, device = 'cpu'): 53 | if q.shape[1] == self.res64: return self.shape64 54 | if q.shape[1] == self.res32: return self.shape32 55 | if q.shape[1] == self.res16: return self.shape16 56 | if q.shape[1] == self.res8: return self.shape8 57 | 58 | def get_res_val(self, q, device = 'cpu'): 59 | if q.shape[1] == self.res64: return 64 60 | if q.shape[1] == self.res32: return 32 61 | if q.shape[1] == self.res16: return 16 62 | if q.shape[1] == self.res8: return 8 63 | 64 | def _repr_png_(self): 65 | return self.img16._repr_png_() -------------------------------------------------------------------------------- /src/smplfusion/utils/input_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..libimage import IImage 3 | 4 | class InputMask: 5 | def to(self, device): return InputMask(self.image, device = device) 6 | def cuda(self): return InputMask(self.image, device = 'cuda') 7 | def cpu(self): return InputMask(self.image, device = 'cpu') 8 | 9 | def __init__(self, input_image, device = 'cpu'): 10 | ''' 11 | args: 12 | input_image: (b,c,h,w) tensor 13 | ''' 14 | if hasattr(input_image, 'is_iimage'): 15 | self.image = input_image 16 | self.val512 = self.full = (input_image.torch(0) > 0.5).float() 17 | elif isinstance(input_image, torch.Tensor): 18 | self.val512 = self.full = input_image.clone() 19 | self.image = IImage(input_image,0) 20 | 21 | self.h,self.w = h,w = self.val512.shape[-2:] 22 | self.shape = [self.h, self.w] 23 | self.shape64 = [self.h // 8, self.w // 8] 24 | self.shape32 = [self.h // 16, self.w // 16] 25 | self.shape16 = [self.h // 32, self.w // 32] 26 | self.shape8 = [self.h // 64, self.w // 64] 27 | 28 | self.res = self.h * self.w 29 | self.res64 = self.res // 64 30 | self.res32 = self.res // 64 // 4 31 | self.res16 = self.res // 64 // 16 32 | self.res8 = self.res // 64 // 64 33 | 34 | self.img = self.image 35 | self.img512 = self.image 36 | self.img64 = self.image.resize((h//8,w//8)) 37 | self.img32 = self.image.resize((h//16,w//16)) 38 | self.img16 = self.image.resize((h//32,w//32)) 39 | self.img8 = self.image.resize((h//64,w//64)) 40 | 41 | self.val64 = (self.img64.torch(0) > 0.5).float() 42 | self.val32 = (self.img32.torch(0) > 0.5).float() 43 | self.val16 = (self.img16.torch(0) > 0.5).float() 44 | self.val8 = ( self.img8.torch(0) > 0.5).float() 45 | 46 | 47 | def get_res(self, q, device = 'cpu'): 48 | if q.shape[1] == self.res64: return self.val64.to(device) 49 | if q.shape[1] == self.res32: return self.val32.to(device) 50 | if q.shape[1] == self.res16: return self.val16.to(device) 51 | if q.shape[1] == self.res8: return self.val8.to(device) 52 | 53 | def _repr_png_(self): 54 | return self.img16._repr_png_() 55 | 56 | def get_res(self, q, device = 'cpu'): 57 | if q.shape[1] == self.res64: return self.val64.to(device) 58 | if q.shape[1] == self.res32: return self.val32.to(device) 59 | if q.shape[1] == self.res16: return self.val16.to(device) 60 | if q.shape[1] == self.res8: return self.val8.to(device) 61 | 62 | def get_shape(self, q, device = 'cpu'): 63 | if q.shape[1] == self.res64: return self.shape64 64 | if q.shape[1] == self.res32: return self.shape32 65 | if q.shape[1] == self.res16: return self.shape16 66 | if q.shape[1] == self.res8: return self.shape8 67 | 68 | def get_res_val(self, q, device = 'cpu'): 69 | if q.shape[1] == self.res64: return 64 70 | if q.shape[1] == self.res32: return 32 71 | if q.shape[1] == self.res16: return 16 72 | if q.shape[1] == self.res8: return 8 73 | 74 | 75 | # class InputMask2: 76 | # def to(self, device): return InputMask2(self.image, device = device) 77 | # def cuda(self): return InputMask2(self.image, device = 'cuda') 78 | # def cpu(self): return InputMask2(self.image, device = 'cpu') 79 | 80 | # def __init__(self, input_image, device = 'cpu'): 81 | # ''' 82 | # args: 83 | # input_image: (b,c,h,w) tensor 84 | # ''' 85 | # if hasattr(input_image, 'is_iimage'): 86 | # self.image = input_image 87 | # self.val512 = self.full = input_image.torch(0).bool().float() 88 | # elif isinstance(input_image, torch.Tensor): 89 | # self.val512 = self.full = input_image.clone() 90 | # self.image = IImage(input_image,0) 91 | 92 | # self.h,self.w = h,w = self.val512.shape[-2:] 93 | # self.shape = [self.h, self.w] 94 | # self.shape64 = [self.h // 8, self.w // 8] 95 | # self.shape32 = [self.h // 16, self.w // 16] 96 | # self.shape16 = [self.h // 32, self.w // 32] 97 | # self.shape8 = [self.h // 64, self.w // 64] 98 | 99 | # self.res = self.h * self.w 100 | # self.res64 = self.res // 64 101 | # self.res32 = self.res // 64 // 4 102 | # self.res16 = self.res // 64 // 16 103 | # self.res8 = self.res // 64 // 64 104 | 105 | # self.img = self.image 106 | # self.img512 = self.image 107 | # self.img64 = self.image.resize((h//8,w//8)) 108 | # self.img32 = self.image.resize((h//16,w//16)) 109 | # self.img16 = self.image.resize((h//32,w//32)).dilate(1) 110 | # self.img8 = self.image.resize((h//64,w//64)).dilate(1) 111 | 112 | # self.val64 = self.img64.torch(0).bool().float() 113 | # self.val32 = self.img32.torch(0).bool().float() 114 | # self.val16 = self.img16.torch(0).bool().float() 115 | # self.val8 = self.img8.torch(0).bool().float() 116 | 117 | 118 | # def get_res(self, q, device = 'cpu'): 119 | # if q.shape[1] == self.res64: return self.val64.to(device) 120 | # if q.shape[1] == self.res32: return self.val32.to(device) 121 | # if q.shape[1] == self.res16: return self.val16.to(device) 122 | # if q.shape[1] == self.res8: return self.val8.to(device) 123 | 124 | # def _repr_png_(self): 125 | # return self.img16._repr_png_() 126 | 127 | # def get_res(self, q, device = 'cpu'): 128 | # if q.shape[1] == self.res64: return self.val64.to(device) 129 | # if q.shape[1] == self.res32: return self.val32.to(device) 130 | # if q.shape[1] == self.res16: return self.val16.to(device) 131 | # if q.shape[1] == self.res8: return self.val8.to(device) 132 | 133 | # def get_shape(self, q, device = 'cpu'): 134 | # if q.shape[1] == self.res64: return self.shape64 135 | # if q.shape[1] == self.res32: return self.shape32 136 | # if q.shape[1] == self.res16: return self.shape16 137 | # if q.shape[1] == self.res8: return self.shape8 138 | 139 | # def get_res_val(self, q, device = 'cpu'): 140 | # if q.shape[1] == self.res64: return 64 141 | # if q.shape[1] == self.res32: return 32 142 | # if q.shape[1] == self.res16: return 16 143 | # if q.shape[1] == self.res8: return 8 -------------------------------------------------------------------------------- /src/smplfusion/utils/input_shape.py: -------------------------------------------------------------------------------- 1 | class InputShape: 2 | def __init__(self, image_size): 3 | self.h,self.w = image_size[::-1] 4 | self.res = self.h * self.w 5 | self.res64 = self.res // 64 6 | self.res32 = self.res // 64 // 4 7 | self.res16 = self.res // 64 // 16 8 | self.res8 = self.res // 64 // 64 9 | self.shape = [self.h, self.w] 10 | self.shape64 = [self.h // 8, self.w // 8] 11 | self.shape32 = [self.h // 16, self.w // 16] 12 | self.shape16 = [self.h // 32, self.w // 32] 13 | self.shape8 = [self.h // 64, self.w // 64] 14 | 15 | def reshape(self, x): 16 | assert len(x.shape) == 3 17 | if x.shape[1] == self.res64: return x.reshape([x.shape[0]] + self.shape64 + [x.shape[-1]]) 18 | if x.shape[1] == self.res32: return x.reshape([x.shape[0]] + self.shape32 + [x.shape[-1]]) 19 | if x.shape[1] == self.res16: return x.reshape([x.shape[0]] + self.shape16 + [x.shape[-1]]) 20 | if x.shape[1] == self.res8: return x.reshape([x.shape[0]] + self.shape8 + [x.shape[-1]]) 21 | raise Exception("Unknown shape") 22 | 23 | def get_res(self, q, device = 'cpu'): 24 | if q.shape[1] == self.res64: return 64 25 | if q.shape[1] == self.res32: return 32 26 | if q.shape[1] == self.res16: return 16 27 | if q.shape[1] == self.res8: return 8 -------------------------------------------------------------------------------- /src/smplfusion/utils/layer_mask.py: -------------------------------------------------------------------------------- 1 | class LayerMask: 2 | def __init__(self, self64 = False, self32 = False, self16 = False, self8 = False, cross64 = False, cross32 = False, cross16 = False, cross8 = False): 3 | self.self64 = self64 4 | self.self32 = self32 5 | self.self16 = self16 6 | self.self8 = self8 7 | self.cross64 = cross64 8 | self.cross32 = cross32 9 | self.cross16 = cross16 10 | self.cross8 = cross8 11 | -------------------------------------------------------------------------------- /src/zeropainter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/zeropainter/__init__.py -------------------------------------------------------------------------------- /src/zeropainter/convert_diffusers.py: -------------------------------------------------------------------------------- 1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. 2 | # *Only* converts the UNet, VAE, and Text Encoder. 3 | # Does not convert optimizer state or any other thing. 4 | # Written by jachiam https://github.com/jachiam 5 | 6 | import argparse 7 | import os.path as osp 8 | 9 | import torch 10 | 11 | 12 | # =================# 13 | # UNet Conversion # 14 | # =================# 15 | 16 | unet_conversion_map = [ 17 | # (stable-diffusion, HF Diffusers) 18 | ("time_embed.0.weight", "time_embedding.linear_1.weight"), 19 | ("time_embed.0.bias", "time_embedding.linear_1.bias"), 20 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), 21 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), 22 | ("input_blocks.0.0.weight", "conv_in.weight"), 23 | ("input_blocks.0.0.bias", "conv_in.bias"), 24 | ("out.0.weight", "conv_norm_out.weight"), 25 | ("out.0.bias", "conv_norm_out.bias"), 26 | ("out.2.weight", "conv_out.weight"), 27 | ("out.2.bias", "conv_out.bias"), 28 | ] 29 | 30 | unet_conversion_map_resnet = [ 31 | # (stable-diffusion, HF Diffusers) 32 | ("in_layers.0", "norm1"), 33 | ("in_layers.2", "conv1"), 34 | ("out_layers.0", "norm2"), 35 | ("out_layers.3", "conv2"), 36 | ("emb_layers.1", "time_emb_proj"), 37 | ("skip_connection", "conv_shortcut"), 38 | ] 39 | 40 | unet_conversion_map_layer = [] 41 | # hardcoded number of downblocks and resnets/attentions... 42 | # would need smarter logic for other networks. 43 | for i in range(4): 44 | # loop over downblocks/upblocks 45 | 46 | for j in range(2): 47 | # loop over resnets/attentions for downblocks 48 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." 49 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." 50 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) 51 | 52 | if i < 3: 53 | # no attention layers in down_blocks.3 54 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." 55 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." 56 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) 57 | 58 | for j in range(3): 59 | # loop over resnets/attentions for upblocks 60 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." 61 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0." 62 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) 63 | 64 | if i > 0: 65 | # no attention layers in up_blocks.0 66 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." 67 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." 68 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) 69 | 70 | if i < 3: 71 | # no downsample in down_blocks.3 72 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." 73 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." 74 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) 75 | 76 | # no upsample in up_blocks.3 77 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 78 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." 79 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) 80 | 81 | hf_mid_atn_prefix = "mid_block.attentions.0." 82 | sd_mid_atn_prefix = "middle_block.1." 83 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) 84 | 85 | for j in range(2): 86 | hf_mid_res_prefix = f"mid_block.resnets.{j}." 87 | sd_mid_res_prefix = f"middle_block.{2*j}." 88 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) 89 | 90 | 91 | def convert_unet_state_dict(unet_state_dict): 92 | # buyer beware: this is a *brittle* function, 93 | # and correct output requires that all of these pieces interact in 94 | # the exact order in which I have arranged them. 95 | mapping = {k: k for k in unet_state_dict.keys()} 96 | for sd_name, hf_name in unet_conversion_map: 97 | mapping[hf_name] = sd_name 98 | for k, v in mapping.items(): 99 | if "resnets" in k: 100 | for sd_part, hf_part in unet_conversion_map_resnet: 101 | v = v.replace(hf_part, sd_part) 102 | mapping[k] = v 103 | for k, v in mapping.items(): 104 | for sd_part, hf_part in unet_conversion_map_layer: 105 | v = v.replace(hf_part, sd_part) 106 | mapping[k] = v 107 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} 108 | return new_state_dict 109 | 110 | 111 | # ================# 112 | # VAE Conversion # 113 | # ================# 114 | 115 | vae_conversion_map = [ 116 | # (stable-diffusion, HF Diffusers) 117 | ("nin_shortcut", "conv_shortcut"), 118 | ("norm_out", "conv_norm_out"), 119 | ("mid.attn_1.", "mid_block.attentions.0."), 120 | ] 121 | 122 | for i in range(4): 123 | # down_blocks have two resnets 124 | for j in range(2): 125 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." 126 | sd_down_prefix = f"encoder.down.{i}.block.{j}." 127 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) 128 | 129 | if i < 3: 130 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." 131 | sd_downsample_prefix = f"down.{i}.downsample." 132 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) 133 | 134 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." 135 | sd_upsample_prefix = f"up.{3-i}.upsample." 136 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) 137 | 138 | # up_blocks have three resnets 139 | # also, up blocks in hf are numbered in reverse from sd 140 | for j in range(3): 141 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." 142 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}." 143 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) 144 | 145 | # this part accounts for mid blocks in both the encoder and the decoder 146 | for i in range(2): 147 | hf_mid_res_prefix = f"mid_block.resnets.{i}." 148 | sd_mid_res_prefix = f"mid.block_{i+1}." 149 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) 150 | 151 | 152 | vae_conversion_map_attn = [ 153 | # (stable-diffusion, HF Diffusers) 154 | ("norm.", "group_norm."), 155 | ("q.", "query."), 156 | ("k.", "key."), 157 | ("v.", "value."), 158 | ("proj_out.", "proj_attn."), 159 | ] 160 | 161 | 162 | def reshape_weight_for_sd(w): 163 | # convert HF linear weights to SD conv2d weights 164 | return w.reshape(*w.shape, 1, 1) 165 | 166 | 167 | def convert_vae_state_dict(vae_state_dict): 168 | mapping = {k: k for k in vae_state_dict.keys()} 169 | for k, v in mapping.items(): 170 | for sd_part, hf_part in vae_conversion_map: 171 | v = v.replace(hf_part, sd_part) 172 | mapping[k] = v 173 | for k, v in mapping.items(): 174 | if "attentions" in k: 175 | for sd_part, hf_part in vae_conversion_map_attn: 176 | v = v.replace(hf_part, sd_part) 177 | mapping[k] = v 178 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} 179 | weights_to_convert = ["q", "k", "v", "proj_out"] 180 | for k, v in new_state_dict.items(): 181 | for weight_name in weights_to_convert: 182 | if f"mid.attn_1.{weight_name}.weight" in k: 183 | print(f"Reshaping {k} for SD format") 184 | new_state_dict[k] = reshape_weight_for_sd(v) 185 | return new_state_dict 186 | 187 | 188 | # =========================# 189 | # Text Encoder Conversion # 190 | # =========================# 191 | # pretty much a no-op 192 | 193 | 194 | def convert_text_enc_state_dict(text_enc_dict): 195 | return text_enc_dict 196 | 197 | 198 | if __name__ == "__main__": 199 | parser = argparse.ArgumentParser() 200 | 201 | parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") 202 | parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") 203 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.") 204 | 205 | args = parser.parse_args() 206 | 207 | assert args.model_path is not None, "Must provide a model path!" 208 | 209 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!" 210 | 211 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") 212 | vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") 213 | text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") 214 | 215 | # Convert the UNet model 216 | unet_state_dict = torch.load(unet_path, map_location='cpu') 217 | unet_state_dict = convert_unet_state_dict(unet_state_dict) 218 | unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} 219 | 220 | # Convert the VAE model 221 | vae_state_dict = torch.load(vae_path, map_location='cpu') 222 | vae_state_dict = convert_vae_state_dict(vae_state_dict) 223 | vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} 224 | 225 | # Convert the text encoder model 226 | text_enc_dict = torch.load(text_enc_path, map_location='cpu') 227 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict) 228 | text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} 229 | 230 | # Put together new checkpoint 231 | state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} 232 | if args.half: 233 | state_dict = {k:v.half() for k,v in state_dict.items()} 234 | state_dict = {"state_dict": state_dict} 235 | torch.save(state_dict, args.checkpoint_path) -------------------------------------------------------------------------------- /src/zeropainter/dreamshaper.py: -------------------------------------------------------------------------------- 1 | import src.smplfusion 2 | from src.smplfusion import scheduler 3 | from src.smplfusion.common import * 4 | from types import SimpleNamespace 5 | 6 | import importlib 7 | from omegaconf import OmegaConf 8 | import torch 9 | import safetensors 10 | import safetensors.torch 11 | 12 | from os.path import dirname 13 | 14 | print("Loading model: Dreamshaper Inpainting V8") 15 | PROJECT_DIR = dirname(dirname(dirname(dirname(__file__)))) 16 | print(PROJECT_DIR) 17 | CONFIG_FOLDER = f"{PROJECT_DIR}/lib/smplfusion/config/" 18 | MODEL_FOLDER = f"{PROJECT_DIR}/assets/models/" 19 | ASSETS_FOLDER = f"{PROJECT_DIR}/assets/" 20 | 21 | 22 | def get_inpainting_condition(model, image, mask): 23 | latent_size = [x // 8 for x in image.size] 24 | condition_x0 = ( 25 | model.vae.encode(image.torch().cuda() * ~mask.torch(0).bool().cuda()).mean 26 | * model.config.scale_factor 27 | ) 28 | condition_mask = mask.resize(latent_size[::-1]).cuda().torch(0).bool().float() 29 | return torch.cat([condition_mask, condition_x0], 1) 30 | 31 | 32 | def get_obj_from_str(string): 33 | module, cls = string.rsplit(".", 1) 34 | return getattr(importlib.import_module(module, package=None), cls) 35 | 36 | def load_obj(path): 37 | objyaml = OmegaConf.load(path) 38 | return get_obj_from_str(objyaml["__class__"])(**objyaml.get("__init__", {})) 39 | 40 | 41 | def get_t2i_model(): 42 | model_t2i = SimpleNamespace() 43 | state_dict = safetensors.torch.load_file( 44 | f"{MODEL_FOLDER}/dreamshaper/dreamshaper_8.safetensors", device="cuda" 45 | ) 46 | 47 | model_t2i.config = OmegaConf.load(f"{CONFIG_FOLDER}/ddpm/v1.yaml") 48 | model_t2i.unet = load_obj(f"{CONFIG_FOLDER}/unet/v1.yaml").eval().cuda() 49 | model_t2i.vae = load_obj(f"{CONFIG_FOLDER}/vae.yaml").eval().cuda() 50 | model_t2i.encoder = load_obj(f"{CONFIG_FOLDER}/encoders/clip.yaml").eval().cuda() 51 | 52 | extract = lambda state_dict, model: { 53 | x[len(model) + 1 :]: y for x, y in state_dict.items() if model in x 54 | } 55 | unet_state = extract(state_dict, "model.diffusion_model") 56 | encoder_state = extract(state_dict, "cond_stage_model") 57 | vae_state = extract(state_dict, "first_stage_model") 58 | 59 | model_t2i.unet.load_state_dict(unet_state, strict=False) 60 | model_t2i.encoder.load_state_dict(encoder_state, strict=False) 61 | model_t2i.vae.load_state_dict(vae_state, strict=False) 62 | model_t2i.unet = model_t2i.unet.requires_grad_(False) 63 | model_t2i.encoder = model_t2i.encoder.requires_grad_(False) 64 | model_t2i.vae = model_t2i.vae.requires_grad_(False) 65 | 66 | model_t2i.schedule = scheduler.linear( 67 | model_t2i.config.timesteps, 68 | model_t2i.config.linear_start, 69 | model_t2i.config.linear_end, 70 | ) 71 | return model_t2i, unet_state 72 | 73 | 74 | def get_inpainting_model(): 75 | model_inp = SimpleNamespace() 76 | state_dict = safetensors.torch.load_file( 77 | f"{MODEL_FOLDER}/dreamshaper/dreamshaper_8Inpainting.safetensors", device="cuda" 78 | ) 79 | 80 | model_inp.config = OmegaConf.load(f"{CONFIG_FOLDER}/ddpm/v1.yaml") 81 | model_inp.unet = load_obj(f"{CONFIG_FOLDER}/unet/inpainting/v1.yaml").eval().cuda() 82 | model_inp.vae = load_obj(f"{CONFIG_FOLDER}/vae.yaml").eval().cuda() 83 | model_inp.encoder = load_obj(f"{CONFIG_FOLDER}/encoders/clip.yaml").eval().cuda() 84 | 85 | extract = lambda state_dict, model: { 86 | x[len(model) + 1 :]: y for x, y in state_dict.items() if model in x 87 | } 88 | unet_state = extract(state_dict, "model.diffusion_model") 89 | encoder_state = extract(state_dict, "cond_stage_model") 90 | vae_state = extract(state_dict, "first_stage_model") 91 | 92 | model_inp.unet.load_state_dict(unet_state, strict=False) 93 | model_inp.encoder.load_state_dict(encoder_state, strict=False) 94 | model_inp.vae.load_state_dict(vae_state, strict=False) 95 | model_inp.unet = model_inp.unet.requires_grad_(False) 96 | model_inp.encoder = model_inp.encoder.requires_grad_(False) 97 | model_inp.vae = model_inp.vae.requires_grad_(False) 98 | 99 | model_inp.schedule = scheduler.linear( 100 | model_inp.config.timesteps, 101 | model_inp.config.linear_start, 102 | model_inp.config.linear_end, 103 | ) 104 | return model_inp, unet_state 105 | -------------------------------------------------------------------------------- /src/zeropainter/inpainting.py: -------------------------------------------------------------------------------- 1 | from src.smplfusion.common import * 2 | import torch 3 | from src.smplfusion import share 4 | from src.smplfusion.patches import attentionpatch 5 | from tqdm import tqdm 6 | from pytorch_lightning import seed_everything 7 | from src.smplfusion.patches import router 8 | from src.smplfusion import IImage 9 | 10 | # negative_prompt = "worst quality, ugly, gross, disfigured, deformed, dehydrated, extra limbs, fused body parts, mutilated, malformed, mutated, bad anatomy, bad proportions, low quality, cropped, low resolution, out of frame, poorly drawn, text, watermark, letters, jpeg artifacts" 11 | # positive_prompt = ", realistic, HD, Full HD, 4K, high quality, high resolution, masterpiece, trending on artstation, realistic lighting" 12 | negative_prompt = '' 13 | positive_prompt = '' 14 | VERBOSE = True 15 | 16 | class AttnForward: 17 | def __init__(self, masks, object_context, object_uc_context): 18 | self.masks = masks 19 | self.object_context = object_context 20 | self.object_uc_context = object_uc_context 21 | 22 | def __call__(data, self, x, context=None, mask=None): 23 | att_type = "self" if context is None else "cross" 24 | batch_size = x.shape[0] 25 | 26 | if att_type == 'cross' and x.shape[1] in [share.input_shape.res16]: # For cross attention 27 | out = torch.zeros_like(x) 28 | for i in range(len(data.masks)): 29 | if data.masks[i].sum()>0: 30 | if batch_size == 1: 31 | out[:,data.masks[i]] = attentionpatch.default.forward( 32 | self, 33 | x[:,data.masks[i] > 0], 34 | data.object_context[i][None] 35 | ).float() 36 | elif batch_size == 2: 37 | out[:,data.masks[i]] = attentionpatch.default.forward( 38 | self, 39 | x[:,data.masks[i] > 0], 40 | torch.stack([data.object_uc_context[i],data.object_context[i]]) 41 | ).float() 42 | else: 43 | raise NotImplementedError("Batch Size > 1 not yet supported!") 44 | return out 45 | else: 46 | return attentionpatch.default.forward(self, x, context, mask) 47 | 48 | def gen_filled_image(model, prompt, image, mask, zp_masks, seed, T = 899, dt = 20, model_t2i = None, guidance_scale = 7.5, use_lcm_multistep = False): 49 | masks = [x.modified_mask.val16.flatten()>0 for x in zp_masks] 50 | masks.append(torch.stack(masks).sum(0) == 0) 51 | 52 | context = model.encoder.encode(['', f"realistic photo of a {prompt}" + positive_prompt]) 53 | object_context = model.encoder.encode([x.local_prompt for x in zp_masks] + [f"realistic photo of a {prompt}" + positive_prompt]) 54 | object_uc_context = model.encoder.encode([negative_prompt] * len(zp_masks) + [', '.join([x.local_prompt for x in zp_masks])]) 55 | 56 | seed_everything(seed) 57 | 58 | eps = torch.randn((1,4,64,64)).cuda() 59 | # zT = schedule.sqrt_alphas[799] * condition_mask + schedule.sqrt_one_minus_alphas[799] * eps 60 | # zT = (condition_mask < 0) * zT + (condition_mask > 0) * eps 61 | # zT = schedule.sqrt_one_minus_alphas[899] * eps 62 | # zT = schedule.sqrt_alphas[899] * IImage(255*orig_masks_all).resize(64).alpha().torch().cuda() + schedule.sqrt_one_minus_alphas[899] * eps 63 | condition_x0 = model.vae.encode(image.torch().cuda() * ~mask.torch(0).bool().cuda()).mean * model.config.scale_factor 64 | condition_xT = model.schedule.sqrt_alphas[T] * condition_x0 + model.schedule.sqrt_one_minus_alphas[T] * eps 65 | condition_mask = mask.resize(64).cuda().torch(0).bool().float() 66 | zT = (condition_mask == 0) * condition_xT + (condition_mask > 0) * eps 67 | # zT = eps 68 | 69 | router.attention_forward = AttnForward(masks, object_context, object_uc_context) 70 | # router.attention_forward = attentionpatch.default.forward 71 | 72 | with torch.autocast('cuda'), torch.no_grad(): 73 | zt = zT 74 | timesteps = list(range(T, 0, -dt)) 75 | pbar = tqdm(timesteps) if VERBOSE else timesteps 76 | for index,t in enumerate(pbar): 77 | if index == 0: 78 | current_mask = ~(~mask).dilate(2) 79 | condition_x0 = model.vae.encode(image.torch().cuda() * ~mask.cuda().torch(0).bool().cuda()).mean * model.config.scale_factor 80 | condition_mask = current_mask.resize(64).cuda().torch(0).bool().float() 81 | condition = torch.cat([condition_mask, condition_x0], 1) 82 | if index == 5: 83 | current_mask = ~(~mask).dilate(0) 84 | condition_x0 = model.vae.encode(image.torch().cuda() * ~mask.cuda().torch(0).bool().cuda()).mean * model.config.scale_factor 85 | condition_mask = current_mask.resize(64).cuda().torch(0).bool().float() 86 | condition = torch.cat([condition_mask, condition_x0], 1) 87 | if index == len(timesteps) - 5: 88 | if model_t2i is not None: 89 | model = model_t2i 90 | condition = None 91 | else: 92 | current_mask = mask.dilate(512) 93 | condition_x0 = model.vae.encode(image.torch().cuda() * ~current_mask.cuda().torch(0).bool().cuda()).mean * model.config.scale_factor 94 | condition_mask = current_mask.resize(64).cuda().torch(0).bool().float() 95 | condition = torch.cat([condition_mask, condition_x0], 1) 96 | 97 | _zt = zt if condition is None else torch.cat([zt, condition], 1) 98 | 99 | if use_lcm_multistep: 100 | eps = model.unet( 101 | _zt, 102 | timesteps = torch.tensor([t]).cuda(), 103 | context = context[1][None] 104 | ) 105 | else: 106 | eps_uncond, eps = model.unet( 107 | torch.cat([_zt, _zt]), 108 | timesteps = torch.tensor([t, t]).cuda(), 109 | context = context 110 | ).chunk(2) 111 | eps = (eps_uncond + guidance_scale * (eps - eps_uncond)) 112 | 113 | z0 = (zt - model.schedule.sqrt_one_minus_alphas[t] * eps) / model.schedule.sqrt_alphas[t] 114 | zt = model.schedule.sqrt_alphas[t - dt] * z0 + model.schedule.sqrt_one_minus_alphas[t - dt] * eps 115 | 116 | out = IImage(model.vae.decode(z0 / model.config.scale_factor)) 117 | return out 118 | # # out.save('../../assets/paper_data/visuals_for_paper/'+str(seed_num)+'_'+name) 119 | 120 | # np_out = np.array(out.data[0]) 121 | # print(np_out.max(), orig_masks_all.max()) 122 | # blending_result_helper = 0.4*(np_out/255)+0.6*orig_masks_all 123 | # res = np.hstack([np_out,blending_result_helper*255]) 124 | # n = name.split('.')[0] 125 | # cv2.imwrite('../../assets/paper_data/visuals_for_paper/'+n+'_'+str(my_seed_gen)+'__'+str(seed_num)+'_.png',res[:,:,::-1]) 126 | # (IImage(255 * blending_result_helper) | IImage(np_out)) -------------------------------------------------------------------------------- /src/zeropainter/models.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | from . import convert_diffusers 3 | import safetensors.torch 4 | 5 | import torch 6 | from src.smplfusion import scheduler 7 | from os.path import dirname 8 | 9 | from omegaconf import OmegaConf 10 | import importlib 11 | 12 | def load_obj(objyaml): 13 | if "__init__" in objyaml: 14 | return get_obj_from_str(objyaml["__class__"])(**objyaml["__init__"]) 15 | else: 16 | return get_obj_from_str(objyaml["__class__"])() 17 | 18 | 19 | def get_obj_from_str(string): 20 | module, cls = string.rsplit(".", 1) 21 | try: 22 | return getattr(importlib.import_module(module, package=None), cls) 23 | except Exception as e: 24 | return getattr(importlib.import_module("src." + module, package=None), cls) 25 | 26 | 27 | def get_inpainting_condition(model, image, mask): 28 | latent_size = [x // 8 for x in image.size] 29 | condition_x0 = ( 30 | model.vae.encode(image.torch().cuda() * ~mask.torch(0).bool().cuda()).mean 31 | * model.config.scale_factor 32 | ) 33 | condition_mask = mask.resize(latent_size[::-1]).cuda().torch(0).bool().float() 34 | return torch.cat([condition_mask, condition_x0], 1) 35 | 36 | 37 | def get_t2i_model(config_folder, model_folder): 38 | model_t2i = SimpleNamespace() 39 | 40 | model_t2i.config = OmegaConf.load( 41 | f"{config_folder}/ddpm/v1.yaml" 42 | ) # smplfusion.options.ddpm.v1_yaml 43 | model_t2i.schedule = scheduler.linear( 44 | model_t2i.config.timesteps, 45 | model_t2i.config.linear_start, 46 | model_t2i.config.linear_end, 47 | ) 48 | 49 | model_t2i.unet = load_obj( 50 | OmegaConf.load(f"{config_folder}/unet/v1.yaml") 51 | ).cuda() # smplfusion.options.unet.v1_yaml.cuda() 52 | model_t2i.vae = load_obj( 53 | OmegaConf.load(f"{config_folder}/vae.yaml") 54 | ).cuda() # smplfusion.options.vae_yaml.cuda() 55 | model_t2i.encoder = load_obj( 56 | OmegaConf.load(f"{config_folder}/encoders/clip.yaml") 57 | ).cuda() # smplfusion.options.encoders.clip_yaml.cuda() 58 | 59 | unet_state_dict = torch.load( 60 | f"{model_folder}/unet.ckpt" 61 | ) # assets.models.sd_1_5_inpainting.unet_ckpt 62 | vae_state_dict = torch.load( 63 | f"{model_folder}/vae.ckpt" 64 | ) # assets.models.sd_1_5_inpainting.vae_ckpt 65 | encoder_state_dict = torch.load( 66 | f"{model_folder}/encoder.ckpt" 67 | ) # assets.models.sd_1_5_inpainting.encoder_ckpt 68 | 69 | model_t2i.unet.load_state_dict(unet_state_dict) 70 | model_t2i.vae.load_state_dict(vae_state_dict) 71 | model_t2i.encoder.load_state_dict(encoder_state_dict, strict=False) 72 | 73 | model_t2i.unet = model_t2i.unet.requires_grad_(False).eval() 74 | model_t2i.vae = model_t2i.vae.requires_grad_(False).eval() 75 | model_t2i.encoder = model_t2i.encoder.requires_grad_(False).eval() 76 | return model_t2i, unet_state_dict 77 | 78 | 79 | def get_inpainting_model(config_folder, model_folder): 80 | model_inp = SimpleNamespace() 81 | model_inp.config = OmegaConf.load( 82 | f"{config_folder}/ddpm/v1.yaml" 83 | ) # smplfusion.options.ddpm.v1_yaml 84 | model_inp.schedule = scheduler.linear( 85 | model_inp.config.timesteps, 86 | model_inp.config.linear_start, 87 | model_inp.config.linear_end, 88 | ) 89 | 90 | model_inp.unet = load_obj( 91 | OmegaConf.load(f"{config_folder}/unet/inpainting/v1.yaml") 92 | ).cuda() # smplfusion.options.unet.inpainting.v1_yaml.cuda() 93 | model_inp.vae = load_obj( 94 | OmegaConf.load(f"{config_folder}/vae.yaml") 95 | ).cuda() # smplfusion.options.vae_yaml.cuda() 96 | model_inp.encoder = load_obj( 97 | OmegaConf.load(f"{config_folder}/encoders/clip.yaml") 98 | ).cuda() # smplfusion.options.encoders.clip_yaml.cuda() 99 | 100 | unet_state_dict = torch.load( 101 | f"{model_folder}/unet.ckpt" 102 | ) # assets.models.sd_1_5_inpainting.unet_ckpt 103 | vae_state_dict = torch.load( 104 | f"{model_folder}/vae.ckpt" 105 | ) # assets.models.sd_1_5_inpainting.vae_ckpt 106 | encoder_state_dict = torch.load( 107 | f"{model_folder}/encoder.ckpt" 108 | ) # assets.models.sd_1_5_inpainting.encoder_ckpt 109 | 110 | model_inp.unet.load_state_dict(unet_state_dict) 111 | model_inp.vae.load_state_dict(vae_state_dict) 112 | model_inp.encoder.load_state_dict(encoder_state_dict, strict=False) 113 | 114 | model_inp.unet = model_inp.unet.requires_grad_(False).eval() 115 | model_inp.vae = model_inp.vae.requires_grad_(False).eval() 116 | model_inp.encoder = model_inp.encoder.requires_grad_(False).eval() 117 | 118 | return model_inp, unet_state_dict 119 | 120 | 121 | def get_lora(lora_path): 122 | _lora_state_dict = safetensors.torch.load_file(lora_path) 123 | 124 | # Dictionary for converting 125 | unet_conversion_dict = {} 126 | unet_conversion_dict.update(convert_diffusers.unet_conversion_map_layer) 127 | unet_conversion_dict.update(convert_diffusers.unet_conversion_map_resnet) 128 | unet_conversion_dict.update(convert_diffusers.unet_conversion_map) 129 | unet_conversion_dict = { 130 | y.replace(".", "_"): x for x, y in unet_conversion_dict.items() 131 | } 132 | unet_conversion_dict["lora_unet_"] = "" 133 | 134 | lora_state_dict = {} 135 | for key in _lora_state_dict: 136 | key_converted = key 137 | for x, y in unet_conversion_dict.items(): 138 | key_converted = key_converted.replace(x, y) 139 | lora_state_dict[key_converted] = _lora_state_dict[key] 140 | 141 | return lora_state_dict 142 | -------------------------------------------------------------------------------- /src/zeropainter/segmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os.path import dirname 3 | 4 | 5 | from segment_anything import SamPredictor, sam_model_registry 6 | 7 | def get_bounding_box(mask): 8 | all_y,all_x = np.where(mask == 1) 9 | all_y,all_x = sorted(all_y),sorted(all_x) 10 | x1,y1,x2,y2 = all_x[0],all_y[0],all_x[-1],all_y[-1] 11 | x,y,w,h = x1,y1,x2-x1,y2-y1 12 | return x,y,w,h 13 | 14 | def get_segmentation_model(path_to_check): 15 | # path_to_check = "sam_vit_h_4b8939.pth" 16 | sam = sam_model_registry["vit_h"](checkpoint=path_to_check) 17 | predictor = SamPredictor(sam) 18 | predictor.model.cuda(); 19 | return predictor 20 | 21 | def get_segmentation(predictor, im, mask): 22 | im_np = im.data[0] 23 | 24 | # SAM prediction 25 | # mask = obj['data_512'][i]['mask'] 26 | x,y,w,h = get_bounding_box(mask) 27 | predictor.set_image(im_np) 28 | input_box = np.array([x,y,x+w,y+h]) 29 | masks, scores, other = predictor.predict(box=input_box) 30 | 31 | _masks = np.concatenate([masks, ~masks]) 32 | _scores = np.concatenate([scores, scores]) 33 | _scores = np.array([y for x,y in zip(_masks, _scores) if ((1 - mask) * x).sum() <= (mask * x).sum()]) 34 | _masks = np.array([x for x in _masks if ((1 - mask) * x).sum() <= (mask * x).sum()]) 35 | if len(_masks) > 0: 36 | masks,scores = _masks,_scores 37 | 38 | pred_seg_mask = masks[scores.argmax()] 39 | pred_seg_mask[pred_seg_mask > 0] = 1 40 | pred_seg_mask = np.stack([pred_seg_mask] * 3, axis=-1) * 1 41 | 42 | return pred_seg_mask * mask[...,None] 43 | -------------------------------------------------------------------------------- /src/zeropainter/zero_painter_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import cv2 4 | import numpy as np 5 | from src.smplfusion import share 6 | 7 | class ZeroPainterMask: 8 | def __init__(self, color, local_prompt, img_grey,rgb,image_w,image_h): 9 | self.img_grey = img_grey 10 | self.color = eval(color) 11 | self.local_prompt = local_prompt 12 | self.bbox = None 13 | self.bbox_64 = None 14 | self.area = None 15 | self.mask = None 16 | self.mask_64 = None 17 | self.token_idx = None 18 | self.modified_mask = None 19 | self.inverse_mask = None 20 | self.modified_indexses_of_prompt = None 21 | self.sot_index = [0] 22 | self.w_positive = 1 23 | self.w_negative = 1 24 | self.image_w = image_w 25 | self.image_h = image_h 26 | self.rgb = rgb 27 | self.modify_info() 28 | 29 | def get_bounding_box(self, mask): 30 | mask = mask*255.0 31 | all_y, all_x = np.where(mask == 255.0) 32 | all_y, all_x = sorted(all_y), sorted(all_x) 33 | x1, y1, x2, y2 = all_x[0], all_y[0], all_x[-1], all_y[-1] 34 | x, y, w, h = x1, y1, x2 - x1, y2 - y1 35 | return x, y, w, h 36 | 37 | 38 | def modify_info(self): 39 | # Define the color to be searched for 40 | if self.rgb: 41 | color_to_search = np.array([self.color[0],self.color[1],self.color[2]]) 42 | mask_1d = np.all(self.img_grey == color_to_search, axis=2)*1.0 43 | else: 44 | mask_1d = (self.img_grey == self.color) * 1.0 45 | mask_1d_64 = cv2.resize(mask_1d, (64, 64), interpolation=cv2.INTER_NEAREST) 46 | self.bbox = self.get_bounding_box(mask_1d) 47 | self.bbox_64 = self.get_bounding_box(mask_1d_64) 48 | self.area = self.bbox[2] * self.bbox[3] 49 | self.mask = mask_1d.copy() 50 | self.mask_64 = mask_1d_64 // 255 51 | 52 | splited_prompt = self.local_prompt.split() 53 | self.token_idx = np.arange(1, len(splited_prompt) + 1, 1, dtype=int) 54 | 55 | #dif part 56 | mask_1d = mask_1d[None,None,:,:] 57 | mask_1d = torch.from_numpy(mask_1d) 58 | 59 | self.modified_indexses_of_prompt = np.arange(5, len(splited_prompt)+1) 60 | self.modified_mask = share.InputMask(mask_1d) 61 | self.inverse_mask = share.InputMask(1-mask_1d) 62 | 63 | return 64 | 65 | class ZeroPainterSample: 66 | #Example of sample 67 | #{'prompt': 'Brown gift box beside red candle.', 68 | # 'color_context_dict': {'1': 'Brown gift box', '2': 'red candle'}} 69 | def __init__(self, item, img_grey, image_h=512,image_w=512): 70 | 71 | self.item = item 72 | self.global_prompt = self.item['prompt'] 73 | self.image_w = image_w 74 | self.image_h = image_h 75 | self.img_grey = img_grey 76 | self.masks = self.load_masks() 77 | 78 | 79 | def load_masks(self): 80 | data_samples = [] 81 | 82 | self.img_grey = cv2.resize(self.img_grey, (self.image_w, self.image_h),interpolation=cv2.INTER_NEAREST) 83 | for color, local_prompt in self.item['color_context_dict'].items(): 84 | data_samples.append(ZeroPainterMask(color,local_prompt,self.img_grey,self.item['rgb'],self.image_w,self.image_h)) 85 | 86 | return data_samples 87 | 88 | 89 | class ZeroPainterDataset: 90 | def __init__(self, root_path_img, json_path,rgb=True): 91 | self.root_path_img = root_path_img 92 | self.json_path = json_path 93 | self.rgb = rgb 94 | 95 | if isinstance(json_path, dict) or isinstance(json_path, list): 96 | self.json_data = json_path 97 | else: 98 | with open(self.json_path, 'r') as file: 99 | self.json_data = json.load(file) 100 | def __len__(self): 101 | return len(self.json_data) 102 | 103 | def __getitem__(self, index): 104 | item = self.json_data[index] 105 | item['rgb'] = self.rgb 106 | if isinstance(self.root_path_img, str): 107 | self.img_path = self.root_path_img 108 | if self.rgb: 109 | self.img_grey = cv2.imread(self.img_path) 110 | else: 111 | self.img_grey = cv2.imread(self.img_path,0) 112 | else: 113 | self.img_grey = np.array(self.root_path_img) 114 | 115 | return ZeroPainterSample(item, self.img_grey) 116 | 117 | 118 | -------------------------------------------------------------------------------- /src/zeropainter/zero_painter_pipline.py: -------------------------------------------------------------------------------- 1 | from . import segmentation, generation, inpainting 2 | 3 | import torch 4 | import numpy as np 5 | from src.smplfusion import libimage 6 | from src.smplfusion import IImage 7 | 8 | 9 | class ZeroPainter: 10 | def __init__(self, model_t2i, model_inp, model_sam): 11 | self.model_t2i = model_t2i 12 | self.model_inp = model_inp 13 | self.model_sam = model_sam 14 | 15 | 16 | def gen_sample( 17 | self, 18 | sample, 19 | object_seed, 20 | image_seed, 21 | num_ddim_steps_t2i, 22 | num_ddim_steps_inp, 23 | cfg_scale_t2i=7.5, 24 | cfg_scale_inp=7.5, 25 | use_lcm_multistep_t2i=False, 26 | use_lcm_multistep_inp=False, 27 | ): 28 | 29 | 30 | gen_obj_list = [] 31 | gen_mask_list = [] 32 | real_mask_list = [] 33 | 34 | if isinstance(object_seed, int): 35 | object_seed = [object_seed] * len(sample.masks) 36 | 37 | for i in range(len(sample.masks)): 38 | eps_list, z0_list, zt_list = generation.gen_single_object( 39 | self.model_t2i, 40 | sample.masks[i], 41 | object_seed[i], 42 | dt=1000 // num_ddim_steps_t2i, 43 | guidance_scale=cfg_scale_t2i, 44 | use_lcm_multistep=use_lcm_multistep_t2i, 45 | ) 46 | mask = sample.masks[i].mask 47 | 48 | gen_image = IImage( 49 | self.model_t2i.vae.decode(z0_list[-1] / self.model_t2i.config.scale_factor) 50 | ) 51 | gen_mask = segmentation.get_segmentation(self.model_sam, gen_image, mask) 52 | gen_object = gen_image * IImage(255 * gen_mask) 53 | 54 | gen_obj_list.append(gen_object) 55 | gen_mask_list.append(gen_mask) 56 | real_mask_list.append(mask) 57 | 58 | gen_image = IImage(libimage.stack(gen_obj_list).data.sum(0)) 59 | gen_mask = IImage(255 * np.sum(gen_mask_list, 0)) 60 | real_mask = IImage(255 * np.sum(real_mask_list, 0)) 61 | 62 | output = inpainting.gen_filled_image( 63 | self.model_inp, 64 | sample.global_prompt, 65 | gen_image, 66 | (~gen_mask.alpha()).dilate(3), 67 | sample.masks, 68 | image_seed, 69 | dt=1000 // num_ddim_steps_inp, 70 | guidance_scale=cfg_scale_inp, 71 | use_lcm_multistep=use_lcm_multistep_inp, 72 | ) 73 | 74 | return output 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /zero_painter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from src.zeropainter.zero_painter_pipline import ZeroPainter 3 | from src.zeropainter import models, dreamshaper,segmentation 4 | from src.zeropainter import zero_painter_dataset 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser() 8 | 9 | parser.add_argument('--mask-path', default='data/masks/1_rgb.png',help='Mask path.') 10 | parser.add_argument('--metadata', default='data/metadata/1.json', type=str, help='Text prompt.') 11 | parser.add_argument('--output-dir', default='data/outputs/',help='Output dir.') 12 | 13 | #load models 14 | parser.add_argument('--config-folder-for-models', type=str, default='config', help='Path to configs') 15 | parser.add_argument('--model-folder-inpiting', type=str, default='models/sd-1-5-inpainting', help='Path to load inpainting model') 16 | parser.add_argument('--model-folder-generation', type=str, default='models/sd-1-4', help='Path to load generation model') 17 | parser.add_argument('--segment-anything-model', type=str, default='models/sam_vit_h_4b8939.pth', help='Path to load segmentation model') 18 | 19 | return parser.parse_args() 20 | 21 | 22 | def main(): 23 | args = get_args() 24 | 25 | model_inp,_ = models.get_inpainting_model(args.config_folder_for_models,args.model_folder_inpiting) 26 | model_t2i,_ = models.get_t2i_model(args.config_folder_for_models,args.model_folder_generation) 27 | model_sam = segmentation.get_segmentation_model(args.segment_anything_model) 28 | zero_painter_model = ZeroPainter(model_t2i, model_inp, model_sam) 29 | 30 | data = zero_painter_dataset.ZeroPainterDataset(args.mask_path, args.metadata) 31 | name = args.mask_path.split('/')[-1] 32 | result = zero_painter_model.gen_sample(data[0], 42, 42,30,30) 33 | result.save(args.output_dir+name) 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | --------------------------------------------------------------------------------