├── LICENSE
├── README.md
├── __assets__
└── github
│ ├── method_arch.png
│ └── teaser.png
├── config
├── ddpm
│ ├── v1.yaml
│ └── v2-upsample.yaml
├── encoders
│ ├── clip.yaml
│ └── openclip.yaml
├── unet
│ ├── inpainting
│ │ ├── v1.yaml
│ │ └── v2.yaml
│ ├── upsample
│ │ └── v2.yaml
│ ├── v
│ │ └── v2.yaml
│ ├── v1.yaml
│ └── v2.yaml
├── vae-upsample.yaml
└── vae.yaml
├── data
├── masks
│ ├── 0_rgb.png
│ ├── 1_rgb.png
│ └── 2_rgb.png
├── metadata
│ ├── 0.json
│ ├── 1.json
│ └── 2.json
└── outputs
│ ├── 0_rgb.png
│ ├── 1_rgb.png
│ └── 2_rgb.png
├── requirements.txt
├── src
├── smplfusion
│ ├── __init__.py
│ ├── animation.py
│ ├── api
│ │ └── __init__.py
│ ├── common.py
│ ├── config.py
│ ├── ddim.py
│ ├── libimage
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ ├── iimage.cpython-310.pyc
│ │ │ ├── iimage.cpython-39.pyc
│ │ │ ├── iimage_gallery.cpython-310.pyc
│ │ │ ├── iimage_gallery.cpython-39.pyc
│ │ │ ├── utils.cpython-310.pyc
│ │ │ └── utils.cpython-39.pyc
│ │ ├── iimage.py
│ │ ├── iimage_gallery.py
│ │ └── utils.py
│ ├── libpath.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ ├── unet.cpython-39.pyc
│ │ │ ├── util.cpython-39.pyc
│ │ │ └── vae.cpython-39.pyc
│ │ ├── encoders
│ │ │ ├── __pycache__
│ │ │ │ └── clip_embedder.cpython-39.pyc
│ │ │ ├── clip_embedder.py
│ │ │ ├── clip_image_embedder.py
│ │ │ ├── modules.py
│ │ │ ├── open_clip_embedder.py
│ │ │ ├── open_clip_image_embedder.py
│ │ │ └── t5_mebedder.py
│ │ ├── unet.py
│ │ ├── util.py
│ │ └── vae.py
│ ├── modelzoo
│ │ ├── __init__.py
│ │ ├── dreamshaper8.py
│ │ └── dreamshaper8_inpainting.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ ├── autoencoder.cpython-39.pyc
│ │ │ ├── distributions.cpython-39.pyc
│ │ │ ├── ema.cpython-39.pyc
│ │ │ └── util.cpython-39.pyc
│ │ ├── attention
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-39.pyc
│ │ │ │ ├── basic_transformer_block.cpython-39.pyc
│ │ │ │ ├── cross_attention.cpython-39.pyc
│ │ │ │ ├── feed_forward.cpython-39.pyc
│ │ │ │ ├── memory_efficient_cross_attention.cpython-39.pyc
│ │ │ │ └── spatial_transformer.cpython-39.pyc
│ │ │ ├── basic_transformer_block.py
│ │ │ ├── cross_attention.py
│ │ │ ├── feed_forward.py
│ │ │ ├── memory_efficient_cross_attention.py
│ │ │ └── spatial_transformer.py
│ │ ├── autoencoder.py
│ │ ├── distributions.py
│ │ ├── ema.py
│ │ ├── partial_conv2d.py
│ │ └── util.py
│ ├── patches
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ └── router.cpython-39.pyc
│ │ ├── attentionpatch
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-39.pyc
│ │ │ │ ├── attstore.cpython-39.pyc
│ │ │ │ ├── boosted_maskscore.cpython-39.pyc
│ │ │ │ ├── default.cpython-39.pyc
│ │ │ │ ├── ediff.cpython-39.pyc
│ │ │ │ ├── inpaint.cpython-39.pyc
│ │ │ │ ├── introvert.cpython-39.pyc
│ │ │ │ ├── shuffled.cpython-39.pyc
│ │ │ │ └── zeropaint.cpython-39.pyc
│ │ │ ├── attstore.py
│ │ │ ├── boosted_maskscore copy.py
│ │ │ ├── boosted_maskscore.py
│ │ │ ├── default.py
│ │ │ ├── inpaint.py
│ │ │ ├── introvert.py
│ │ │ ├── maskscore.py
│ │ │ ├── other.py
│ │ │ ├── shuffled.py
│ │ │ └── zeropaint.py
│ │ ├── router.py
│ │ └── transformerpatch
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ ├── default.cpython-39.pyc
│ │ │ ├── guided.cpython-39.pyc
│ │ │ ├── introvert.cpython-39.pyc
│ │ │ └── weighting_versions.cpython-39.pyc
│ │ │ ├── default.py
│ │ │ ├── guided.py
│ │ │ ├── introvert.py
│ │ │ └── weighting_versions.py
│ ├── scheduler.py
│ ├── share.py
│ ├── util.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-39.pyc
│ │ ├── input_image.cpython-39.pyc
│ │ ├── input_mask.cpython-39.pyc
│ │ ├── input_shape.cpython-39.pyc
│ │ └── layer_mask.cpython-39.pyc
│ │ ├── input_image.py
│ │ ├── input_mask.py
│ │ ├── input_shape.py
│ │ └── layer_mask.py
└── zeropainter
│ ├── __init__.py
│ ├── convert_diffusers.py
│ ├── dreamshaper.py
│ ├── generation.py
│ ├── inpainting.py
│ ├── models.py
│ ├── segmentation.py
│ ├── zero_painter_dataset.py
│ └── zero_painter_pipline.py
└── zero_painter.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Picsart AI Research (PAIR)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Zero-Painter: Training-Free Layout Control for Text-to-Image Synthesis [CVPR 2024]
2 |
3 | This repository is the official implementation of [Zero-Painter](https://arxiv.org/abs/2406.04032).
4 |
5 |
6 | **[Zero-Painter: Training-Free Layout Control for Text-to-Image Synthesis](https://arxiv.org/abs/2406.04032)**
7 |
8 | Marianna Ohanyan*,
9 | Hayk Manukyan*,
10 | Zhangyang Wang,
11 | Shant Navasardyan,
12 | [Humphrey Shi](https://www.humphreyshi.com)
13 |
14 |
15 | [Arxiv](https://arxiv.org/abs/2406.04032)
16 |
17 |
18 |
19 |
20 |
21 | We present Zero-Painter , a novel training-free framework for layout-conditional text-to-image synthesis that facilitates the creation of detailed and controlled imagery from textual prompts. Our method utilizes object masks and individual descriptions, coupled with a global text prompt, to generate images with high fidelity. Zero-Painter employs a two-stage process involving our novel Prompt-Adjusted Cross-Attention (PACA) and Region-Grouped Cross-Attention (ReGCA) blocks, ensuring precise alignment of generated objects with textual prompts and mask shapes. Our extensive experiments demonstrate that Zero-Painter surpasses current state-of-the-art methods in preserving textual details and adhering to mask shapes.
22 |
23 |
24 |
25 |
26 | ## 🔥 News
27 | - [2024.06.6] ZeroPainter paper and code is released.
28 | - [2024.02.27] Paper is accepted to CVPR 2024.
29 |
30 |
31 | ## ⚒️ Installation
32 |
33 |
38 | Install with `pip`:
39 | ```bash
40 | pip3 install -r requirements.txt
41 | ```
42 |
43 | ## 💃 Inference: Generate images with Zero-Painter
44 |
45 | 1. Download [models](https://huggingface.co/PAIR/Zero-Painter) and put them in the `models` folder.
46 | 2. You can use the following script to perform inference on the given mask and prompts pair:
47 | ```
48 | python zero_painter.py \
49 | --mask-path data/masks/1_rgb.png \
50 | --metadata data/metadata/1.json \
51 | --output-dir data/outputs/
52 | ```
53 |
54 | `meatadata` sould be in the following format
55 | ```
56 | [{
57 | "prompt": "Brown gift box beside red candle.",
58 | "color_context_dict": {
59 | "(244, 54, 32)": "Brown gift box",
60 | "(54, 245, 32)": "red candle"
61 | }
62 | }]
63 | ```
64 |
73 |
74 | ## Method
75 |
76 |
77 |
78 | ---
79 |
80 | ## 🎓 Citation
81 | If you use our work in your research, please cite our publication:
82 | ```
83 | @article{Zeropainter,
84 | title={Zero-Painter: Training-Free Layout Control for Text-to-Image Synthesis},
85 | url={http://arxiv.org/abs/2406.04032},
86 | publisher={arXiv},
87 | author={Ohanyan, Marianna and Manukyan, Hayk and Wang, Zhangyang and Navasardyan, Shant and Shi, Humphrey},
88 | year={2024}}
89 |
90 | ```
--------------------------------------------------------------------------------
/__assets__/github/method_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/__assets__/github/method_arch.png
--------------------------------------------------------------------------------
/__assets__/github/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/__assets__/github/teaser.png
--------------------------------------------------------------------------------
/config/ddpm/v1.yaml:
--------------------------------------------------------------------------------
1 | linear_start: 0.00085
2 | linear_end: 0.0120
3 | num_timesteps_cond: 1
4 | log_every_t: 200
5 | timesteps: 1000
6 | first_stage_key: "jpg"
7 | cond_stage_key: "txt"
8 | image_size: 64
9 | channels: 4
10 | cond_stage_trainable: false
11 | conditioning_key: crossattn
12 | monitor: val/loss_simple_ema
13 | scale_factor: 0.18215
14 | use_ema: False # we set this to false because this is an inference only config
--------------------------------------------------------------------------------
/config/ddpm/v2-upsample.yaml:
--------------------------------------------------------------------------------
1 | parameterization: "v"
2 | low_scale_key: "lr"
3 | linear_start: 0.0001
4 | linear_end: 0.02
5 | num_timesteps_cond: 1
6 | log_every_t: 200
7 | timesteps: 1000
8 | first_stage_key: "jpg"
9 | cond_stage_key: "txt"
10 | image_size: 128
11 | channels: 4
12 | cond_stage_trainable: false
13 | conditioning_key: "hybrid-adm"
14 | monitor: val/loss_simple_ema
15 | scale_factor: 0.08333
16 | use_ema: False
17 |
18 | low_scale_config:
19 | target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
20 | params:
21 | noise_schedule_config: # image space
22 | linear_start: 0.0001
23 | linear_end: 0.02
24 | max_noise_level: 350
25 |
--------------------------------------------------------------------------------
/config/encoders/clip.yaml:
--------------------------------------------------------------------------------
1 | __class__: smplfusion.models.encoders.clip_embedder.FrozenCLIPEmbedder
--------------------------------------------------------------------------------
/config/encoders/openclip.yaml:
--------------------------------------------------------------------------------
1 | __class__: smplfusion.models.encoders.open_clip_embedder.FrozenOpenCLIPEmbedder
2 | __init__:
3 | freeze: True
4 | layer: "penultimate"
--------------------------------------------------------------------------------
/config/unet/inpainting/v1.yaml:
--------------------------------------------------------------------------------
1 | __class__: smplfusion.models.unet.UNetModel
2 | __init__:
3 | image_size: 32 # unused
4 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask
5 | out_channels: 4
6 | model_channels: 320
7 | attention_resolutions: [ 4, 2, 1 ]
8 | num_res_blocks: 2
9 | channel_mult: [ 1, 2, 4, 4 ]
10 | num_heads: 8
11 | use_spatial_transformer: True
12 | transformer_depth: 1
13 | context_dim: 768
14 | use_checkpoint: False
15 | legacy: False
--------------------------------------------------------------------------------
/config/unet/inpainting/v2.yaml:
--------------------------------------------------------------------------------
1 | __class__: smplfusion.models.unet.UNetModel
2 | __init__:
3 | use_checkpoint: False
4 | image_size: 32 # unused
5 | in_channels: 9
6 | out_channels: 4
7 | model_channels: 320
8 | attention_resolutions: [ 4, 2, 1 ]
9 | num_res_blocks: 2
10 | channel_mult: [ 1, 2, 4, 4 ]
11 | num_head_channels: 64 # need to fix for flash-attn
12 | use_spatial_transformer: True
13 | use_linear_in_transformer: True
14 | transformer_depth: 1
15 | context_dim: 1024
16 | legacy: False
--------------------------------------------------------------------------------
/config/unet/upsample/v2.yaml:
--------------------------------------------------------------------------------
1 | __class__: smplfusion.models.unet.UNetModel
2 | __init__:
3 | use_checkpoint: False
4 | num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
5 | image_size: 128
6 | in_channels: 7
7 | out_channels: 4
8 | model_channels: 256
9 | attention_resolutions: [ 2,4,8]
10 | num_res_blocks: 2
11 | channel_mult: [ 1, 2, 2, 4]
12 | disable_self_attentions: [True, True, True, False]
13 | disable_middle_self_attn: False
14 | num_heads: 8
15 | use_spatial_transformer: True
16 | transformer_depth: 1
17 | context_dim: 1024
18 | legacy: False
19 | use_linear_in_transformer: True
--------------------------------------------------------------------------------
/config/unet/v/v2.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/config/unet/v/v2.yaml
--------------------------------------------------------------------------------
/config/unet/v1.yaml:
--------------------------------------------------------------------------------
1 | __class__: smplfusion.models.unet.UNetModel
2 | __init__:
3 | image_size: 32 # unused
4 | in_channels: 4
5 | out_channels: 4
6 | model_channels: 320
7 | attention_resolutions: [ 4, 2, 1 ]
8 | num_res_blocks: 2
9 | channel_mult: [ 1, 2, 4, 4 ]
10 | num_heads: 8
11 | use_spatial_transformer: True
12 | transformer_depth: 1
13 | context_dim: 768
14 | use_checkpoint: False
15 | legacy: False
--------------------------------------------------------------------------------
/config/unet/v2.yaml:
--------------------------------------------------------------------------------
1 | __class__: smplfusion.models.unet.UNetModel
2 | __init__:
3 | use_checkpoint: False
4 | use_fp16: True
5 | image_size: 32 # unused
6 | in_channels: 4
7 | out_channels: 4
8 | model_channels: 320
9 | attention_resolutions: [ 4, 2, 1 ]
10 | num_res_blocks: 2
11 | channel_mult: [ 1, 2, 4, 4 ]
12 | num_head_channels: 64
13 | use_spatial_transformer: True
14 | use_linear_in_transformer: True
15 | transformer_depth: 1
16 | context_dim: 1024
17 | legacy: False
--------------------------------------------------------------------------------
/config/vae-upsample.yaml:
--------------------------------------------------------------------------------
1 | __class__: smplfusion.models.vae.AutoencoderKL
2 | __init__:
3 | embed_dim: 4
4 | ddconfig:
5 | double_z: True
6 | z_channels: 4
7 | resolution: 256
8 | in_channels: 3
9 | out_ch: 3
10 | ch: 128
11 | ch_mult: [ 1,2,4 ]
12 | num_res_blocks: 2
13 | attn_resolutions: [ ]
14 | dropout: 0.0
15 | lossconfig:
16 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/config/vae.yaml:
--------------------------------------------------------------------------------
1 | __class__: smplfusion.models.vae.AutoencoderKL
2 | __init__:
3 | embed_dim: 4
4 | monitor: val/rec_loss
5 | ddconfig:
6 | double_z: true
7 | z_channels: 4
8 | resolution: 256
9 | in_channels: 3
10 | out_ch: 3
11 | ch: 128
12 | ch_mult: [1,2,4,4]
13 | num_res_blocks: 2
14 | attn_resolutions: []
15 | dropout: 0.0
16 | lossconfig:
17 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/data/masks/0_rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/data/masks/0_rgb.png
--------------------------------------------------------------------------------
/data/masks/1_rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/data/masks/1_rgb.png
--------------------------------------------------------------------------------
/data/masks/2_rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/data/masks/2_rgb.png
--------------------------------------------------------------------------------
/data/metadata/0.json:
--------------------------------------------------------------------------------
1 | [{
2 | "prompt": "Brown gift box beside red candle.",
3 | "color_context_dict": {
4 | "(244, 54, 32)": "Brown gift box",
5 | "(54, 245, 32)": "red candle"
6 | }
7 | }]
8 |
--------------------------------------------------------------------------------
/data/metadata/1.json:
--------------------------------------------------------------------------------
1 | [ {
2 | "prompt": "Brown tabby cat on white stairs photo",
3 | "color_context_dict": {
4 | "(244, 54, 32)": "Brown tabby cat",
5 | "(54, 245, 32)": "blue vase",
6 | "(4,54,200)": "red apple"
7 | }
8 | }]
--------------------------------------------------------------------------------
/data/metadata/2.json:
--------------------------------------------------------------------------------
1 | [{
2 | "prompt": "Green succulent on white and pink pot photo",
3 | "color_context_dict": {
4 | "(213, 24, 207)": "white and pink pot",
5 | "(78, 213, 24)": "Green succulent "
6 | }
7 | }]
--------------------------------------------------------------------------------
/data/outputs/0_rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/data/outputs/0_rgb.png
--------------------------------------------------------------------------------
/data/outputs/1_rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/data/outputs/1_rgb.png
--------------------------------------------------------------------------------
/data/outputs/2_rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/data/outputs/2_rgb.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.1
2 | torchvision
3 | einops
4 | pytorch_lightning
5 | xformers
6 | open_clip_torch
7 | ipywidgets
8 | transformers==4.40.2
9 | segment-anything
10 | imageio
11 | scipy
12 | opencv-python
13 | matplotlib
14 | omegaconf
--------------------------------------------------------------------------------
/src/smplfusion/__init__.py:
--------------------------------------------------------------------------------
1 | # TODO: Remove farancia.Path depency!
2 | from .libpath import Path
3 | from .libimage import IImage
4 |
5 | print (Path.cwd(__file__))
6 | config = Path.cwd(__file__).config
7 | models = Path.cwd(__file__).config
8 | options = Path.cwd(__file__).config
--------------------------------------------------------------------------------
/src/smplfusion/animation.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from matplotlib import animation
3 | from IPython.display import HTML, Image, display
4 |
5 | class Animation:
6 | JS = 0
7 | HTML = 1
8 | ANIMATION_MODE = HTML
9 | def __init__(self, frames, fps = 30):
10 | """_summary_
11 |
12 | Args:
13 | frames (np.ndarray): _description_
14 | """
15 | self.frames = frames
16 | self.fps = fps
17 | self.anim_obj = None
18 | self.anim_str = None
19 | def render(self):
20 | size = (self.frames.shape[2],self.frames.shape[1])
21 | self.fig = plt.figure(figsize = size, dpi = 1)
22 | plt.axis('off')
23 | img = plt.imshow(self.frames[0], cmap = 'gray')
24 | self.fig.subplots_adjust(0,0,1,1)
25 | self.anim_obj = animation.FuncAnimation(
26 | self.fig,
27 | lambda i: img.set_data(self.frames[i,:,:,:]),
28 | frames=self.frames.shape[0],
29 | interval = 1000 / self.fps
30 | )
31 | plt.close()
32 | if Animation.ANIMATION_MODE == Animation.HTML:
33 | self.anim_str = self.anim_obj.to_html5_video()
34 | elif Animation.ANIMATION_MODE == Animation.JS:
35 | self.anim_str = self.anim_obj.to_jshtml()
36 | return self.anim_obj
37 | def _repr_html_(self):
38 | if self.anim_obj is None: self.render()
39 | return self.anim_str
--------------------------------------------------------------------------------
/src/smplfusion/api/__init__.py:
--------------------------------------------------------------------------------
1 | # API goes here!
--------------------------------------------------------------------------------
/src/smplfusion/common.py:
--------------------------------------------------------------------------------
1 | from . import share, scheduler
2 | from .ddim import DDIM
3 | from .patches import router, attentionpatch, transformerpatch
4 | from . import options
5 | from pytorch_lightning import seed_everything
6 | import open_clip
7 |
8 | def count_tokens(prompt):
9 | tokens = open_clip.tokenize(prompt)[0]
10 | return (tokens > 0).sum()
11 |
12 | def tokenize(prompt):
13 | tokens = open_clip.tokenize(prompt)[0]
14 | return [open_clip.tokenizer._tokenizer.decoder[x.item()] for x in tokens]
15 |
16 | def get_token_idx(prompt, prefix, positive_prompt):
17 | prompt = prefix.format(prompt)
18 | return list(range(1 + prefix.split(' ').index('{}'), tokenize(prompt).index(''))) + [tokenize(prompt + positive_prompt).index('')]
19 |
20 | def load_model_v2_inpainting(folder):
21 | model_config = options.ddpm.v1_yaml
22 |
23 | unet = options.unet.inpainting.v2_yaml.eval().cuda()
24 | unet.load_state_dict(folder.unet_ckpt)
25 | unet = unet.requires_grad_(False)
26 |
27 | vae = options.vae_yaml.eval().cuda()
28 | vae.load_state_dict(folder.vae_ckpt)
29 | vae = vae.requires_grad_(False)
30 |
31 | encoder = options.encoders.openclip_yaml.eval().cuda()
32 | encoder.load_state_dict(folder.encoder_ckpt)
33 | encoder = encoder.requires_grad_(False)
34 |
35 | return model_config, unet, vae, encoder
36 |
37 | def load_model_v15_inpainting(folder):
38 | model_config = options.ddpm.v1_yaml
39 |
40 | unet = options.unet.inpainting.v1_yaml.eval().cuda()
41 | unet.load_state_dict(folder.unet_ckpt)
42 | unet = unet.requires_grad_(False)
43 |
44 | vae = options.vae_yaml.eval().cuda()
45 | vae.load_state_dict(folder.vae_ckpt)
46 | vae = vae.requires_grad_(False)
47 |
48 | encoder = options.encoders.clip_yaml.eval().cuda()
49 | encoder.load_state_dict(folder.encoder_ckpt)
50 | encoder = encoder.requires_grad_(False)
51 |
52 | return model_config, unet, vae, encoder
53 |
54 | def load_model_v2(folder):
55 | model_config = options.ddpm.v1_yaml
56 | unet = options.unet.v2_yaml.eval().cuda()
57 | unet.load_state_dict(folder.unet_ckpt)
58 | unet = unet.requires_grad_(False)
59 | vae = options.vae_yaml.eval().cuda()
60 | vae.load_state_dict(folder.vae_ckpt)
61 | vae = vae.requires_grad_(False)
62 | encoder = options.encoders.openclip_yaml.eval().cuda()
63 | encoder.load_state_dict(folder.encoder_ckpt)
64 | encoder = encoder.requires_grad_(False)
65 | return model_config, unet, vae, encoder
66 |
67 | def load_model_v1(folder):
68 | model_config = options.ddpm.v1_yaml
69 | unet = options.unet.v1_yaml.eval().cuda()
70 | unet.load_state_dict(folder.unet_ckpt)
71 | unet = unet.requires_grad_(False)
72 | vae = options.vae_yaml.eval().cuda()
73 | vae.load_state_dict(folder.vae_ckpt)
74 | vae = vae.requires_grad_(False)
75 | encoder = options.encoders.clip_yaml.eval().cuda()
76 | encoder.load_state_dict(folder.encoder_ckpt)
77 | encoder = encoder.requires_grad_(False)
78 | return model_config, unet, vae, encoder
79 |
80 | def load_unet_v2(folder):
81 | unet = options.unet.v2_yaml.eval().cuda()
82 | unet.load_state_dict(folder.unet_ckpt)
83 | unet = unet.requires_grad_(False)
84 | return unet
--------------------------------------------------------------------------------
/src/smplfusion/config.py:
--------------------------------------------------------------------------------
1 | IMG_THUMBSIZE = None
--------------------------------------------------------------------------------
/src/smplfusion/ddim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm.notebook import tqdm
3 | from . import scheduler
4 | from . import share
5 |
6 | from .libimage import IImage
7 |
8 | class DDIM:
9 | def __init__(self, config, vae, encoder, unet):
10 | self.vae = vae
11 | self.encoder = encoder
12 | self.unet = unet
13 | self.config = config
14 | self.schedule = scheduler.linear(1000, config.linear_start, config.linear_end)
15 |
16 | def __call__(
17 | self, prompt = '', dt = 50, shape = (1,4,64,64), seed = None, negative_prompt = '', unet_condition = None,
18 | context = None, verbose = True):
19 | if seed is not None: torch.manual_seed(seed)
20 | if unet_condition is not None:
21 | zT = torch.randn((1,4) + unet_condition.shape[2:]).cuda()
22 | else:
23 | zT = torch.randn(shape).cuda()
24 |
25 | with torch.autocast('cuda'), torch.no_grad():
26 | if context is None: context = self.encoder.encode([negative_prompt, prompt])
27 |
28 | zt = zT
29 | pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt)
30 | for timestep in share.DDIMIterator(pbar):
31 | _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
32 | eps_uncond, eps = self.unet(
33 | torch.cat([_zt, _zt]),
34 | timesteps = torch.tensor([timestep, timestep]).cuda(),
35 | context = context
36 | ).chunk(2)
37 |
38 | eps = (eps_uncond + 7.5 * (eps - eps_uncond))
39 |
40 | z0 = (zt - self.schedule.sqrt_one_minus_alphas[timestep] * eps) / self.schedule.sqrt_alphas[timestep]
41 | zt = self.schedule.sqrt_alphas[timestep - dt] * z0 + self.schedule.sqrt_one_minus_alphas[timestep - dt] * eps
42 | return IImage(self.vae.decode(z0 / self.config.scale_factor))
43 |
44 | def encode(self, image):
45 | return self.vae.encode(image.padx(64).torch().cuda()).mean * self.config.scale_factor
46 | def decode(self, latent):
47 | return IImage(self.vae.decode(latent / self.config.scale_factor))
48 |
49 | def get_inpainting_condition(self, image, mask):
50 | latent_size = [x//8 for x in image.size]
51 | condition_x0 = self.vae.encode(image.torch().cuda() * ~mask.torch(0).bool().cuda()).mean * self.config.scale_factor
52 | condition_mask = mask.resize(latent_size[::-1]).cuda().torch(0).bool().float()
53 |
54 | # condition_x0 += 0.01 * condition_mask * torch.randn_like(condition_mask)
55 | return torch.cat([condition_mask, condition_x0], 1)
56 |
57 | inpainting_condition = get_inpainting_condition
58 |
59 |
--------------------------------------------------------------------------------
/src/smplfusion/libimage/__init__.py:
--------------------------------------------------------------------------------
1 | from .iimage import IImage
2 | from .iimage_gallery import ImageGallery
3 | from .utils import bytes2html
4 |
5 | import math
6 | import numpy as np
7 | import warnings
8 |
9 | # ========= STATIC FUNCTIONS =============
10 | def find_max_h(images):
11 | return max([x.size[1] for x in images])
12 | def find_max_w(images):
13 | return max([x.size[0] for x in images])
14 | def find_max_size(images):
15 | return find_max_w(images), find_max_h(images)
16 |
17 |
18 | def stack(images, axis = 0):
19 | return IImage(np.concatenate([x.data for x in images], axis))
20 | def tstack(images):
21 | w,h = find_max_size(images)
22 | images = [x.pad2wh(w,h) for x in images]
23 | return IImage(np.concatenate([x.data for x in images], 0))
24 | def hstack(images):
25 | h = find_max_h(images)
26 | images = [x.pad2wh(h = h) for x in images]
27 | return IImage(np.concatenate([x.data for x in images], 2))
28 | def vstack(images):
29 | w = find_max_w(images)
30 | images = [x.pad2wh(w = w) for x in images]
31 | return IImage(np.concatenate([x.data for x in images], 1))
32 |
33 | def grid(images, nrows = None, ncols = None):
34 | combined = stack(images)
35 | if nrows is not None:
36 | ncols = math.ceil(combined.data.shape[0] / nrows)
37 | elif ncols is not None:
38 | nrows = math.ceil(combined.data.shape[0] / ncols)
39 | else:
40 | warnings.warn("No dimensions specified, creating a grid with 5 columns (default)")
41 | ncols = 5
42 | nrows = math.ceil(combined.data.shape[0] / ncols)
43 |
44 | pad = nrows * ncols - combined.data.shape[0]
45 | data = np.pad(combined.data, ((0,pad),(0,0),(0,0),(0,0)))
46 | rows = [np.concatenate(x,1,dtype=np.uint8) for x in np.array_split(data, nrows)]
47 | return IImage(np.concatenate(rows, 0, dtype = np.uint8)[None])
--------------------------------------------------------------------------------
/src/smplfusion/libimage/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/smplfusion/libimage/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/libimage/__pycache__/iimage.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/iimage.cpython-310.pyc
--------------------------------------------------------------------------------
/src/smplfusion/libimage/__pycache__/iimage.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/iimage.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/libimage/__pycache__/iimage_gallery.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/iimage_gallery.cpython-310.pyc
--------------------------------------------------------------------------------
/src/smplfusion/libimage/__pycache__/iimage_gallery.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/iimage_gallery.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/libimage/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/src/smplfusion/libimage/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/libimage/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/libimage/iimage_gallery.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .utils import bytes2html
3 |
4 | class ImageGallery:
5 | def __init__(self, images, captions = None, size='auto', max_rows = -1, root_path = '/', caption_font_size = '2em'):
6 | self.size = size
7 | self.images = images
8 | self.display_str = None
9 | self.captions = captions if captions is not None else [''] * len(self.images)
10 | self.max_rows = max_rows
11 | self.root_path = root_path
12 | self.caption_font_size = caption_font_size
13 |
14 | def generate_display(self):
15 | """Shows a set of images in a gallery that flexes with the width of the notebook.
16 |
17 | Parameters
18 | ----------
19 | images: list of str or bytes
20 | URLs or bytes of images to display
21 |
22 | row_height: str
23 | CSS height value to assign to all images. Set to 'auto' by default to show images
24 | with their native dimensions. Set to a value e.g. '250px' to make all rows
25 | in the gallery equal height.
26 | """
27 | figures = []
28 | row_figures = 0
29 | for image, caption in zip(self.images, self.captions):
30 | if isinstance(image, str):
31 | with open(image,'rb') as f:
32 | link = os.path.relpath(image, self.root_path)
33 | # src = bytes2html(f.read(), width = self.size)
34 | src = bytes2html(f.read())
35 | src = f'{src} '
36 | else:
37 | if image.display_str is None: image.generate_display()
38 | if image.is_video():
39 | src = image.display_str
40 | else:
41 | # src = _src_from_data(image.display_str)
42 | src = image.to_html(width = "100%", root_path = self.root_path)
43 | if caption != '':
44 | caption = f'{caption} '
45 | figures.append(f'''
46 |
47 | {caption}
48 | {src}
49 |
50 | ''')
51 | row_figures += 1
52 | if row_figures == self.max_rows:
53 | row_figures = 0
54 | figures.append('
')
55 | self.display_str = f'''
56 |
57 | {''.join(figures)}
58 |
59 | '''
60 |
61 | def _repr_html_(self):
62 | if self.display_str is None: self.generate_display()
63 | return self.display_str
64 |
65 | def save(self, path):
66 | if self.display_str is None: self.generate_display()
67 | with open(path, 'w') as f:
68 | f.write(self.display_str)
--------------------------------------------------------------------------------
/src/smplfusion/libimage/utils.py:
--------------------------------------------------------------------------------
1 | from IPython.display import Image as IpyImage
2 |
3 | def bytes2html(data, width='auto'):
4 | img_obj = IpyImage(data=data, format='JPG')
5 | for bundle in img_obj._repr_mimebundle_():
6 | for mimetype, b64value in bundle.items():
7 | if mimetype.startswith('image/'):
8 | return f' '
--------------------------------------------------------------------------------
/src/smplfusion/libpath.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import *
3 | import sys
4 |
5 | import pickle
6 | import inspect
7 | import json
8 | import importlib.util
9 | import importlib
10 | import random
11 | import csv
12 |
13 | # Requirements
14 | import yaml
15 | from omegaconf import OmegaConf
16 | import torch
17 | import numpy as np
18 | from PIL import Image
19 | from IPython.display import Markdown # TODO: Remove this requirement
20 |
21 | from .libimage import IImage
22 | # from .experiment import SmartFolder
23 |
24 | def instantiate_from_config(config):
25 | if not "target" in config:
26 | raise KeyError("Expected key `target` to instantiate.")
27 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
28 |
29 | def load_obj(objyaml):
30 | if "__init__" in objyaml:
31 | return get_obj_from_str(objyaml['__class__'])(**objyaml["__init__"])
32 | else:
33 | return get_obj_from_str(objyaml['__class__'])()
34 |
35 | def get_obj_from_str(string):
36 | module, cls = string.rsplit(".", 1)
37 | try:
38 | return getattr(importlib.import_module(module, package=None), cls)
39 | except:
40 | return getattr(importlib.import_module('lib.' + module, package=None), cls)
41 |
42 |
43 | def path2var(path):
44 | return path.replace(" ", "_").replace(".", "_").replace("-", "_").replace(":", "_")
45 |
46 |
47 | def var2path(path):
48 | return path.replace("ß", " ").replace("ä", ".").replace("ö", "-")
49 |
50 |
51 | def save(obj, path):
52 | pass
53 |
54 | class FileType:
55 | pass
56 |
57 | def cvt(x: str):
58 | try: return int(x)
59 | except:
60 | try: return float(x)
61 | except: return x
62 |
63 | class Path:
64 | FOLDER = FileType()
65 |
66 | def cwd(path):
67 | return Path(os.path.dirname(os.path.abspath(path)))
68 | def __init__(self, path):
69 | self.path = path
70 | if os.path.isdir(self.path):
71 | self.reload()
72 | else:
73 | self.extension = self.path.split(".")[-1]
74 | self.subpaths = {}
75 | self.filename = os.path.basename(self.path)
76 |
77 | def reload(self):
78 | if os.path.isdir(self.path):
79 | self.subpaths = {path2var(x): x for x in os.listdir(self.path)}
80 | self.keyorder = sorted(list(self.subpaths.keys()))
81 | self.extension = None
82 |
83 | def read(self, is_preview=False):
84 | if self.extension is None or is_preview:
85 | # if self.isdir and os.path.exists(f'{self}/__smart__'):
86 | # return SmartFolder.open(str(self))
87 | return self
88 | elif self.extension == "yaml":
89 | yaml = OmegaConf.load(self.path)
90 | if "__class__" in yaml:
91 | return load_obj(yaml)
92 | if "model" in yaml:
93 | return instantiate_from_config(yaml.model)
94 | return yaml
95 | elif self.extension == "yamlobj":
96 | return load_obj(OmegaConf.load(self.path))
97 | elif self.extension in ["jpg", "jpeg", "png"]:
98 | return IImage.open(self.path)
99 | elif self.extension in ["ckpt", "pt"]:
100 | return torch.load(self.path)
101 | elif self.extension in ["pkl", "bin"]:
102 | with open(self.path, "rb") as f:
103 | return pickle.load(f)
104 | elif self.extension in ["txt"]:
105 | with open(self.path) as f:
106 | return f.read()
107 | elif self.extension in ["lst", "list"]:
108 | with open(self.path) as f:
109 | return [cvt(x) for x in f.read().split("\n")]
110 | elif self.extension in ["md"]:
111 | with open(self.path) as f:
112 | return Markdown(f.read())
113 | elif self.extension in ["json"]:
114 | with open(self.path) as f:
115 | return json.load(f)
116 | elif self.extension in ["npy", "npz"]:
117 | return np.load(self.path)
118 | elif self.extension in ['csv', 'table']:
119 | with open(self.path) as f:
120 | return [[cvt(x) for x in row] for row in csv.reader(f, delimiter=' ')]
121 | elif self.extension in ["py"]:
122 | spec = importlib.util.spec_from_file_location(f"autoload.module", self.path)
123 | module = importlib.util.module_from_spec(spec)
124 | spec.loader.exec_module(module)
125 | return module
126 | else:
127 | return self
128 |
129 | def exists(self):
130 | return os.path.exists(str(self))
131 |
132 | def sample(self):
133 | return getattr(self, random.choice(list(self.subpaths.keys())))
134 |
135 | def __dir__(self):
136 | self.reload()
137 | return list(self.subpaths.keys())
138 |
139 | def __len__(self):
140 | return len(self.subpaths)
141 |
142 | def ls(self):
143 | self.reload()
144 | return list(self.subpaths.values())
145 | # return [str(self + x) for x in self.subpaths.values()]
146 | # return dir(self)
147 |
148 | @property
149 | def parent(self): return Path(os.path.dirname(self.path))
150 | @property
151 | def isdir(self): return os.path.isdir(os.path.abspath(self.path))
152 |
153 | def __getattr__(self, __name: str):
154 | # print ("Func", inspect.stack()[1].function, "name", __name)
155 | is_preview = inspect.stack()[1].function == "getattr_paths"
156 |
157 | self.reload()
158 | if __name in self.subpaths and self.subpaths[__name] in os.listdir(self.path):
159 | return Path(join(self.path, self.subpaths[__name])).read(is_preview)
160 | elif __name in ['__wrapped__']:
161 | raise AttributeError()
162 | else:
163 | return Path(join(self.path, __name))
164 |
165 | def __setattr__(self, __name: str, __value: any):
166 | if __value == Path.FOLDER:
167 | os.makedirs(f'{self}/{__name}', exist_ok=True)
168 | self.reload()
169 | else:
170 | return super().__setattr__(__name, __value)
171 |
172 | def __hasattr__(self, __name: str):
173 | self.reload()
174 | return self.subpaths is not None and __name in self.subpaths
175 |
176 | # def __getitem__(self, __name):
177 | # if __name in self.subpaths and self.subpaths[__name] in os.listdir(self.path):
178 | # return Path(join(self.path, self.subpaths[__name])).read()
179 |
180 | def __add__(self, other):
181 | assert other is not str
182 | return Path(join(self.path, other))
183 |
184 | def __call__(self, idx=None):
185 | if idx is None:
186 | return str(self)
187 | if isinstance(idx, str):
188 | return Path(join(self.path, idx))
189 | else:
190 | return Path(join(self.path, self.subpaths[self.keyorder[idx]]))
191 |
192 | def __getitem__(self, idx):
193 | if isinstance(idx, str):
194 | return Path(join(self.path, idx)).read()
195 | else:
196 | return Path(join(self.path, self.subpaths[self.keyorder[idx]])).read()
197 |
198 | def __str__(self):
199 | return os.path.abspath(self.path)
200 |
201 | def __repr__(self):
202 | return f"Path reference to: {os.path.abspath(self.path)}"
203 |
204 | def new_child(self, ext = ''):
205 | idx = 0
206 | while os.path.exists(join(str(self), f"file_{idx}{ext}")):
207 | idx += 1
208 | return join(str(self), f"file_{idx}{ext}")
--------------------------------------------------------------------------------
/src/smplfusion/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/models/__init__.py
--------------------------------------------------------------------------------
/src/smplfusion/models/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/models/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/models/__pycache__/unet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/models/__pycache__/unet.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/models/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/models/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/models/__pycache__/vae.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/models/__pycache__/vae.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/models/encoders/__pycache__/clip_embedder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/models/encoders/__pycache__/clip_embedder.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/models/encoders/clip_embedder.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from transformers import CLIPTokenizer, CLIPTextModel
3 |
4 | class FrozenCLIPEmbedder(nn.Module):
5 | """Uses the CLIP transformer encoder for text (from huggingface)"""
6 | LAYERS = [
7 | "last",
8 | "pooled",
9 | "hidden"
10 | ]
11 |
12 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
13 | freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
14 | super().__init__()
15 | assert layer in self.LAYERS
16 | self.tokenizer = CLIPTokenizer.from_pretrained(version)
17 | self.transformer = CLIPTextModel.from_pretrained(version)
18 | self.device = device
19 | self.max_length = max_length
20 | if freeze:
21 | self.freeze()
22 | self.layer = layer
23 | self.layer_idx = layer_idx
24 | if layer == "hidden":
25 | assert layer_idx is not None
26 | assert 0 <= abs(layer_idx) <= 12
27 |
28 | def freeze(self):
29 | self.transformer = self.transformer.eval()
30 | # self.train = disabled_train
31 | for param in self.parameters():
32 | param.requires_grad = False
33 |
34 | def forward(self, text):
35 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
36 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
37 | tokens = batch_encoding["input_ids"].to(self.device)
38 | outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
39 | if self.layer == "last":
40 | z = outputs.last_hidden_state
41 | elif self.layer == "pooled":
42 | z = outputs.pooler_output[:, None, :]
43 | else:
44 | z = outputs.hidden_states[self.layer_idx]
45 | return z
46 |
47 | def encode(self, text):
48 | return self(text)
--------------------------------------------------------------------------------
/src/smplfusion/models/encoders/clip_image_embedder.py:
--------------------------------------------------------------------------------
1 | class ClipImageEmbedder(nn.Module):
2 | def __init__(
3 | self,
4 | model,
5 | jit=False,
6 | device='cuda' if torch.cuda.is_available() else 'cpu',
7 | antialias=True,
8 | ucg_rate=0.
9 | ):
10 | super().__init__()
11 | from clip import load as load_clip
12 | self.model, _ = load_clip(name=model, device=device, jit=jit)
13 |
14 | self.antialias = antialias
15 |
16 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
17 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
18 | self.ucg_rate = ucg_rate
19 |
20 | def preprocess(self, x):
21 | # normalize to [0,1]
22 | x = kornia.geometry.resize(x, (224, 224),
23 | interpolation='bicubic', align_corners=True,
24 | antialias=self.antialias)
25 | x = (x + 1.) / 2.
26 | # re-normalize according to clip
27 | x = kornia.enhance.normalize(x, self.mean, self.std)
28 | return x
29 |
30 | def forward(self, x, no_dropout=False):
31 | # x is assumed to be in range [-1,1]
32 | out = self.model.encode_image(self.preprocess(x))
33 | out = out.to(x.dtype)
34 | if self.ucg_rate > 0. and not no_dropout:
35 | out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
36 | return out
--------------------------------------------------------------------------------
/src/smplfusion/models/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import kornia
4 | from torch.utils.checkpoint import checkpoint
5 |
6 | from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
7 |
8 | import open_clip
9 | from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
10 | from ldm.modules.diffusionmodules.openaimodel import Timestep
11 | from ...util import default, count_params, autocast
12 |
13 |
14 | class ClassEmbedder(nn.Module):
15 | def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
16 | super().__init__()
17 | self.key = key
18 | self.embedding = nn.Embedding(n_classes, embed_dim)
19 | self.n_classes = n_classes
20 | self.ucg_rate = ucg_rate
21 |
22 | def forward(self, batch, key=None, disable_dropout=False):
23 | if key is None:
24 | key = self.key
25 | # this is for use in crossattn
26 | c = batch[key][:, None]
27 | if self.ucg_rate > 0. and not disable_dropout:
28 | mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
29 | c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
30 | c = c.long()
31 | c = self.embedding(c)
32 | return c
33 |
34 | def get_unconditional_conditioning(self, bs, device="cuda"):
35 | uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
36 | uc = torch.ones((bs,), device=device) * uc_class
37 | uc = {self.key: uc}
38 | return uc
39 |
40 |
41 | def disabled_train(self, mode=True):
42 | """Overwrite model.train with this function to make sure train/eval mode
43 | does not change anymore."""
44 | return self
45 |
46 |
47 | class FrozenCLIPT5Encoder(AbstractEncoder):
48 | def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
49 | clip_max_length=77, t5_max_length=77):
50 | super().__init__()
51 | self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
52 | self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
53 | print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
54 | f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
55 |
56 | def encode(self, text):
57 | return self(text)
58 |
59 | def forward(self, text):
60 | clip_z = self.clip_encoder.encode(text)
61 | t5_z = self.t5_encoder.encode(text)
62 | return [clip_z, t5_z]
63 |
64 | class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
65 | def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
66 | super().__init__(*args, **kwargs)
67 | if clip_stats_path is None:
68 | clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
69 | else:
70 | clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
71 | self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
72 | self.register_buffer("data_std", clip_std[None, :], persistent=False)
73 | self.time_embed = Timestep(timestep_dim)
74 |
75 | def scale(self, x):
76 | # re-normalize to centered mean and unit variance
77 | x = (x - self.data_mean) * 1. / self.data_std
78 | return x
79 |
80 | def unscale(self, x):
81 | # back to original data stats
82 | x = (x * self.data_std) + self.data_mean
83 | return x
84 |
85 | def forward(self, x, noise_level=None):
86 | if noise_level is None:
87 | noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
88 | else:
89 | assert isinstance(noise_level, torch.Tensor)
90 | x = self.scale(x)
91 | z = self.q_sample(x, noise_level)
92 | z = self.unscale(z)
93 | noise_level = self.time_embed(noise_level)
94 | return z, noise_level
95 |
--------------------------------------------------------------------------------
/src/smplfusion/models/encoders/open_clip_embedder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.utils.checkpoint import checkpoint
4 |
5 | import open_clip
6 |
7 | class FrozenOpenCLIPEmbedder(nn.Module):
8 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
9 | freeze=True, layer="last"):
10 | super().__init__()
11 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
12 | del model.visual
13 | self.model = model
14 |
15 | self.device = device
16 | self.max_length = max_length
17 | if freeze: self.freeze()
18 | self.layer = layer
19 | if self.layer == "last":
20 | self.layer_idx = 0
21 | elif self.layer == "penultimate":
22 | self.layer_idx = 1
23 | else:
24 | raise NotImplementedError()
25 |
26 | def freeze(self):
27 | self.model = self.model.eval()
28 | for param in self.parameters():
29 | param.requires_grad = False
30 |
31 | def forward(self, text):
32 | tokens = open_clip.tokenize(text)
33 | z = self.encode_with_transformer(tokens.to(self.device))
34 | return z
35 |
36 | def encode_with_transformer(self, text):
37 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
38 | x = x + self.model.positional_embedding
39 | x = x.permute(1, 0, 2) # NLD -> LND
40 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
41 | x = x.permute(1, 0, 2) # LND -> NLD
42 | x = self.model.ln_final(x)
43 | return x
44 |
45 | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
46 | for i, r in enumerate(self.model.transformer.resblocks):
47 | if i == len(self.model.transformer.resblocks) - self.layer_idx:
48 | break
49 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
50 | x = checkpoint(r, x, attn_mask)
51 | else:
52 | x = r(x, attn_mask=attn_mask)
53 | return x
54 |
55 | def encode(self, text):
56 | return self(text)
--------------------------------------------------------------------------------
/src/smplfusion/models/encoders/open_clip_image_embedder.py:
--------------------------------------------------------------------------------
1 | class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
2 | """
3 | Uses the OpenCLIP vision transformer encoder for images
4 | """
5 |
6 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
7 | freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
8 | super().__init__()
9 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
10 | pretrained=version, )
11 | del model.transformer
12 | self.model = model
13 |
14 | self.device = device
15 | self.max_length = max_length
16 | if freeze:
17 | self.freeze()
18 | self.layer = layer
19 | if self.layer == "penultimate":
20 | raise NotImplementedError()
21 | self.layer_idx = 1
22 |
23 | self.antialias = antialias
24 |
25 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
26 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
27 | self.ucg_rate = ucg_rate
28 |
29 | def preprocess(self, x):
30 | x = kornia.geometry.resize(x, (224, 224),
31 | interpolation='bicubic', align_corners=True,
32 | antialias=self.antialias)
33 | x = (x + 1.) / 2.
34 | x = kornia.enhance.normalize(x, self.mean, self.std)
35 | return x
36 |
37 | def freeze(self):
38 | self.model = self.model.eval()
39 | for param in self.parameters():
40 | param.requires_grad = False
41 |
42 | @autocast
43 | def forward(self, image, no_dropout=False):
44 | z = self.encode_with_vision_transformer(image)
45 | if self.ucg_rate > 0. and not no_dropout:
46 | z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
47 | return z
48 |
49 | def encode_with_vision_transformer(self, img):
50 | img = self.preprocess(img)
51 | x = self.model.visual(img)
52 | return x
53 |
54 | def encode(self, text):
55 | return self(text)
--------------------------------------------------------------------------------
/src/smplfusion/models/encoders/t5_mebedder.py:
--------------------------------------------------------------------------------
1 | class FrozenT5Embedder(AbstractEncoder):
2 | """Uses the T5 transformer encoder for text"""
3 |
4 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
5 | freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
6 | super().__init__()
7 | self.tokenizer = T5Tokenizer.from_pretrained(version)
8 | self.transformer = T5EncoderModel.from_pretrained(version)
9 | self.device = device
10 | self.max_length = max_length # TODO: typical value?
11 | if freeze:
12 | self.freeze()
13 |
14 | def freeze(self):
15 | self.transformer = self.transformer.eval()
16 | # self.train = disabled_train
17 | for param in self.parameters():
18 | param.requires_grad = False
19 |
20 | def forward(self, text):
21 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
22 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
23 | tokens = batch_encoding["input_ids"].to(self.device)
24 | outputs = self.transformer(input_ids=tokens)
25 |
26 | z = outputs.last_hidden_state
27 | return z
28 |
29 | def encode(self, text):
30 | return self(text)
--------------------------------------------------------------------------------
/src/smplfusion/models/vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import pytorch_lightning as pl
4 | from contextlib import contextmanager
5 |
6 | from ..modules.autoencoder import Encoder, Decoder
7 | from ..modules.distributions import DiagonalGaussianDistribution
8 |
9 | from ..util import instantiate_from_config
10 | from ..modules.ema import LitEma
11 |
12 | class AutoencoderKL(pl.LightningModule):
13 | def __init__(self,
14 | ddconfig,
15 | lossconfig,
16 | embed_dim,
17 | ckpt_path=None,
18 | ignore_keys=[],
19 | image_key="image",
20 | colorize_nlabels=None,
21 | monitor=None,
22 | ema_decay=None,
23 | learn_logvar=False
24 | ):
25 | super().__init__()
26 | self.learn_logvar = learn_logvar
27 | self.image_key = image_key
28 | self.encoder = Encoder(**ddconfig)
29 | self.decoder = Decoder(**ddconfig)
30 | self.loss = instantiate_from_config(lossconfig)
31 | assert ddconfig["double_z"]
32 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
33 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
34 | self.embed_dim = embed_dim
35 | if colorize_nlabels is not None:
36 | assert type(colorize_nlabels)==int
37 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
38 | if monitor is not None:
39 | self.monitor = monitor
40 |
41 | self.use_ema = ema_decay is not None
42 | if self.use_ema:
43 | self.ema_decay = ema_decay
44 | assert 0. < ema_decay < 1.
45 | self.model_ema = LitEma(self, decay=ema_decay)
46 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
47 |
48 | if ckpt_path is not None:
49 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
50 |
51 | def init_from_ckpt(self, path, ignore_keys=list()):
52 | sd = torch.load(path, map_location="cpu")["state_dict"]
53 | keys = list(sd.keys())
54 | for k in keys:
55 | for ik in ignore_keys:
56 | if k.startswith(ik):
57 | print("Deleting key {} from state_dict.".format(k))
58 | del sd[k]
59 | self.load_state_dict(sd, strict=False)
60 | print(f"Restored from {path}")
61 |
62 | @contextmanager
63 | def ema_scope(self, context=None):
64 | if self.use_ema:
65 | self.model_ema.store(self.parameters())
66 | self.model_ema.copy_to(self)
67 | if context is not None:
68 | print(f"{context}: Switched to EMA weights")
69 | try:
70 | yield None
71 | finally:
72 | if self.use_ema:
73 | self.model_ema.restore(self.parameters())
74 | if context is not None:
75 | print(f"{context}: Restored training weights")
76 |
77 | def on_train_batch_end(self, *args, **kwargs):
78 | if self.use_ema:
79 | self.model_ema(self)
80 |
81 | def encode(self, x):
82 | h = self.encoder(x)
83 | moments = self.quant_conv(h)
84 | posterior = DiagonalGaussianDistribution(moments)
85 | return posterior
86 |
87 | def decode(self, z):
88 | z = self.post_quant_conv(z)
89 | dec = self.decoder(z)
90 | return dec
91 |
92 | def forward(self, input, sample_posterior=True):
93 | posterior = self.encode(input)
94 | if sample_posterior:
95 | z = posterior.sample()
96 | else:
97 | z = posterior.mode()
98 | dec = self.decode(z)
99 | return dec, posterior
100 |
101 | def get_input(self, batch, k):
102 | x = batch[k]
103 | if len(x.shape) == 3:
104 | x = x[..., None]
105 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
106 | return x
107 |
108 | def training_step(self, batch, batch_idx, optimizer_idx):
109 | inputs = self.get_input(batch, self.image_key)
110 | reconstructions, posterior = self(inputs)
111 |
112 | if optimizer_idx == 0:
113 | # train encoder+decoder+logvar
114 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
115 | last_layer=self.get_last_layer(), split="train")
116 | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
117 | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
118 | return aeloss
119 |
120 | if optimizer_idx == 1:
121 | # train the discriminator
122 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
123 | last_layer=self.get_last_layer(), split="train")
124 |
125 | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
126 | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
127 | return discloss
128 |
129 | def validation_step(self, batch, batch_idx):
130 | log_dict = self._validation_step(batch, batch_idx)
131 | with self.ema_scope():
132 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
133 | return log_dict
134 |
135 | def _validation_step(self, batch, batch_idx, postfix=""):
136 | inputs = self.get_input(batch, self.image_key)
137 | reconstructions, posterior = self(inputs)
138 | aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
139 | last_layer=self.get_last_layer(), split="val"+postfix)
140 |
141 | discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
142 | last_layer=self.get_last_layer(), split="val"+postfix)
143 |
144 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
145 | self.log_dict(log_dict_ae)
146 | self.log_dict(log_dict_disc)
147 | return self.log_dict
148 |
149 | def configure_optimizers(self):
150 | lr = self.learning_rate
151 | ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
152 | self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
153 | if self.learn_logvar:
154 | print(f"{self.__class__.__name__}: Learning logvar")
155 | ae_params_list.append(self.loss.logvar)
156 | opt_ae = torch.optim.Adam(ae_params_list,
157 | lr=lr, betas=(0.5, 0.9))
158 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
159 | lr=lr, betas=(0.5, 0.9))
160 | return [opt_ae, opt_disc], []
161 |
162 | def get_last_layer(self):
163 | return self.decoder.conv_out.weight
164 |
165 | @torch.no_grad()
166 | def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
167 | log = dict()
168 | x = self.get_input(batch, self.image_key)
169 | x = x.to(self.device)
170 | if not only_inputs:
171 | xrec, posterior = self(x)
172 | if x.shape[1] > 3:
173 | # colorize with random projection
174 | assert xrec.shape[1] > 3
175 | x = self.to_rgb(x)
176 | xrec = self.to_rgb(xrec)
177 | log["samples"] = self.decode(torch.randn_like(posterior.sample()))
178 | log["reconstructions"] = xrec
179 | if log_ema or self.use_ema:
180 | with self.ema_scope():
181 | xrec_ema, posterior_ema = self(x)
182 | if x.shape[1] > 3:
183 | # colorize with random projection
184 | assert xrec_ema.shape[1] > 3
185 | xrec_ema = self.to_rgb(xrec_ema)
186 | log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
187 | log["reconstructions_ema"] = xrec_ema
188 | log["inputs"] = x
189 | return log
190 |
191 | def to_rgb(self, x):
192 | assert self.image_key == "segmentation"
193 | if not hasattr(self, "colorize"):
194 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
195 | x = F.conv2d(x, weight=self.colorize)
196 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
197 | return x
--------------------------------------------------------------------------------
/src/smplfusion/modelzoo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modelzoo/__init__.py
--------------------------------------------------------------------------------
/src/smplfusion/modelzoo/dreamshaper8.py:
--------------------------------------------------------------------------------
1 | print ("Loading model: Dreamshaper V8")
2 |
3 | from os.path import dirname
4 | import importlib
5 | from omegaconf import OmegaConf
6 | import torch
7 | import safetensors
8 | import safetensors.torch
9 | import open_clip
10 |
11 | PROJECT_DIR = dirname(dirname(dirname(dirname(__file__))))
12 | LIB_DIR = dirname(dirname(__file__))
13 | print (PROJECT_DIR)
14 |
15 | CONFIG_FOLDER = f'{LIB_DIR}/config/'
16 | ASSETS_FOLDER = f'{PROJECT_DIR}/assets/'
17 | MODEL_FOLDER = f'{ASSETS_FOLDER}/models/'
18 |
19 | def get_obj_from_str(string):
20 | module, cls = string.rsplit(".", 1)
21 | try:
22 | return getattr(importlib.import_module(module, package=None), cls)
23 | except:
24 | return getattr(importlib.import_module('lib.' + module, package=None), cls)
25 | def load_obj(path):
26 | objyaml = OmegaConf.load(path)
27 | return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
28 |
29 | state_dict = safetensors.torch.load_file(f'{MODEL_FOLDER}/dreamshaper/dreamshaper_8.safetensors')
30 |
31 | config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
32 | unet = load_obj(f'{CONFIG_FOLDER}/unet/v1.yaml').eval().cuda()
33 | vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
34 | encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
35 |
36 | extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
37 | unet_state = extract(state_dict, 'model.diffusion_model')
38 | encoder_state = extract(state_dict, 'cond_stage_model')
39 | vae_state = extract(state_dict, 'first_stage_model')
40 |
41 | unet.load_state_dict(unet_state, strict=False);
42 | encoder.load_state_dict(encoder_state, strict=False);
43 | vae.load_state_dict(vae_state, strict=False);
44 |
45 | unet = unet.requires_grad_(False)
46 | encoder = encoder.requires_grad_(False)
47 | vae = vae.requires_grad_(False)
--------------------------------------------------------------------------------
/src/smplfusion/modelzoo/dreamshaper8_inpainting.py:
--------------------------------------------------------------------------------
1 | print ("Loading model: Dreamshaper Inpainting V8")
2 |
3 | from os.path import dirname
4 | import importlib
5 | from omegaconf import OmegaConf
6 | import torch
7 | import safetensors
8 | import safetensors.torch
9 | import open_clip
10 |
11 | PROJECT_DIR = dirname(dirname(dirname(dirname(__file__))))
12 | LIB_DIR = dirname(dirname(__file__))
13 | print (PROJECT_DIR)
14 |
15 | CONFIG_FOLDER = f'{LIB_DIR}/config/'
16 | ASSETS_FOLDER = f'{PROJECT_DIR}/assets/'
17 | MODEL_FOLDER = f'{ASSETS_FOLDER}/models/'
18 |
19 | def get_obj_from_str(string):
20 | module, cls = string.rsplit(".", 1)
21 | try:
22 | return getattr(importlib.import_module(module, package=None), cls)
23 | except:
24 | return getattr(importlib.import_module('lib.' + module, package=None), cls)
25 | def load_obj(path):
26 | objyaml = OmegaConf.load(path)
27 | return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
28 |
29 | state_dict = safetensors.torch.load_file(f'{MODEL_FOLDER}/dreamshaper/dreamshaper_8Inpainting.safetensors')
30 |
31 | config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
32 | unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda()
33 | vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
34 | encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
35 |
36 | extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
37 | unet_state = extract(state_dict, 'model.diffusion_model')
38 | encoder_state = extract(state_dict, 'cond_stage_model')
39 | vae_state = extract(state_dict, 'first_stage_model')
40 |
41 | unet.load_state_dict(unet_state);
42 | encoder.load_state_dict(encoder_state);
43 | vae.load_state_dict(vae_state);
44 |
45 | unet = unet.requires_grad_(False)
46 | encoder = encoder.requires_grad_(False)
47 | vae = vae.requires_grad_(False)
--------------------------------------------------------------------------------
/src/smplfusion/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/__init__.py
--------------------------------------------------------------------------------
/src/smplfusion/modules/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/modules/__pycache__/autoencoder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/__pycache__/autoencoder.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/modules/__pycache__/distributions.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/__pycache__/distributions.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/modules/__pycache__/ema.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/__pycache__/ema.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/modules/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/modules/attention/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__init__.py
--------------------------------------------------------------------------------
/src/smplfusion/modules/attention/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/modules/attention/__pycache__/basic_transformer_block.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__pycache__/basic_transformer_block.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/modules/attention/__pycache__/cross_attention.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__pycache__/cross_attention.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/modules/attention/__pycache__/feed_forward.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__pycache__/feed_forward.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/modules/attention/__pycache__/memory_efficient_cross_attention.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__pycache__/memory_efficient_cross_attention.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/modules/attention/__pycache__/spatial_transformer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/modules/attention/__pycache__/spatial_transformer.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/modules/attention/basic_transformer_block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .feed_forward import FeedForward
4 |
5 | try:
6 | from .cross_attention import PatchedCrossAttention as CrossAttention
7 | except:
8 | try:
9 | from .memory_efficient_cross_attention import MemoryEfficientCrossAttention as CrossAttention
10 | except:
11 | from .cross_attention import CrossAttention
12 | from ..util import checkpoint
13 | from ...patches import router
14 |
15 | class BasicTransformerBlock(nn.Module):
16 | def __init__(
17 | self,dim,n_heads,d_head,dropout=0.0,context_dim=None,
18 | gated_ff=True,checkpoint=True,disable_self_attn=False,
19 | ):
20 | super().__init__()
21 | self.disable_self_attn = disable_self_attn
22 | # is a self-attention if not self.disable_self_attn
23 | self.attn1 = CrossAttention(query_dim=dim,heads=n_heads,dim_head=d_head,dropout=dropout,context_dim=context_dim if self.disable_self_attn else None)
24 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
25 | # is self-attn if context is none
26 | self.attn2 = CrossAttention(query_dim=dim,context_dim=context_dim,heads=n_heads,dim_head=d_head,dropout=dropout)
27 | self.norm1 = nn.LayerNorm(dim)
28 | self.norm2 = nn.LayerNorm(dim)
29 | self.norm3 = nn.LayerNorm(dim)
30 | self.checkpoint = checkpoint
31 |
32 | def forward(self, x, context=None):
33 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
34 |
35 | def _forward(self, x, context=None):
36 | x = x + self.attn1(self.norm1(x), context=context if self.disable_self_attn else None)
37 | x = x + self.attn2(self.norm2(x), context=context)
38 | x = x + self.ff(self.norm3(x))
39 | return x
40 |
41 | class PatchedBasicTransformerBlock(nn.Module):
42 | def __init__(
43 | self,dim,n_heads,d_head,dropout=0.0,context_dim=None,
44 | gated_ff=True,checkpoint=True,disable_self_attn=False,
45 | ):
46 | super().__init__()
47 | self.disable_self_attn = disable_self_attn
48 | # is a self-attention if not self.disable_self_attn
49 | self.attn1 = CrossAttention(query_dim=dim,heads=n_heads,dim_head=d_head,dropout=dropout,context_dim=context_dim if self.disable_self_attn else None)
50 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
51 | # is self-attn if context is none
52 | self.attn2 = CrossAttention(query_dim=dim,context_dim=context_dim,heads=n_heads,dim_head=d_head,dropout=dropout)
53 | self.norm1 = nn.LayerNorm(dim)
54 | self.norm2 = nn.LayerNorm(dim)
55 | self.norm3 = nn.LayerNorm(dim)
56 | self.checkpoint = checkpoint
57 |
58 | def forward(self, x, context=None):
59 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
60 |
61 | def _forward(self, x, context=None):
62 | return router.basic_transformer_forward(self, x, context)
63 |
--------------------------------------------------------------------------------
/src/smplfusion/modules/attention/cross_attention.py:
--------------------------------------------------------------------------------
1 | # CrossAttn precision handling
2 | import os
3 |
4 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
5 |
6 | import torch
7 | from torch import nn
8 |
9 | from torch import einsum
10 | from einops import rearrange, repeat
11 | import torch
12 | from torch import nn
13 | from typing import Optional, Any
14 | from ...patches import router
15 |
16 | class CrossAttention(nn.Module):
17 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
18 | super().__init__()
19 | inner_dim = dim_head * heads
20 | context_dim = context_dim or query_dim
21 |
22 | self.scale = dim_head**-0.5
23 | self.heads = heads
24 |
25 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
26 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
27 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
28 |
29 | self.to_out = nn.Sequential(
30 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
31 | )
32 |
33 | def forward(self, x, context=None, mask=None):
34 | h = self.heads
35 |
36 | q = self.to_q(x)
37 | context = x if context is None else context
38 | k = self.to_k(context)
39 | v = self.to_v(context)
40 |
41 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
42 |
43 | # force cast to fp32 to avoid overflowing
44 | if _ATTN_PRECISION == "fp32":
45 | with torch.autocast(enabled=False, device_type="cuda"):
46 | q, k = q.float(), k.float()
47 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
48 | else:
49 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
50 |
51 | del q, k
52 |
53 | if mask is not None:
54 | mask = rearrange(mask, "b ... -> b (...)")
55 | max_neg_value = -torch.finfo(sim.dtype).max
56 | mask = repeat(mask, "b j -> (b h) () j", h=h)
57 | sim.masked_fill_(~mask, max_neg_value)
58 |
59 | # attention, what we cannot get enough of
60 | sim = sim.softmax(dim=-1)
61 |
62 | out = einsum("b i j, b j d -> b i d", sim, v)
63 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
64 | return self.to_out(out)
65 |
66 | class PatchedCrossAttention(nn.Module):
67 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
68 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
69 | super().__init__()
70 | inner_dim = dim_head * heads
71 | context_dim = context_dim or query_dim
72 |
73 | self.heads = heads
74 | self.dim_head = dim_head
75 | self.scale = dim_head**-0.5
76 |
77 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
78 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
79 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
80 |
81 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
82 | self.attention_op: Optional[Any] = None
83 |
84 | def forward(self, x, context=None, mask=None):
85 | return router.attention_forward(self, x, context, mask)
--------------------------------------------------------------------------------
/src/smplfusion/modules/attention/feed_forward.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class GEGLU(nn.Module):
7 | def __init__(self, dim_in, dim_out):
8 | super().__init__()
9 | self.proj = nn.Linear(dim_in, dim_out * 2)
10 |
11 | def forward(self, x):
12 | x, gate = self.proj(x).chunk(2, dim=-1)
13 | return x * F.gelu(gate)
14 |
15 |
16 | class FeedForward(nn.Module):
17 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
18 | super().__init__()
19 | inner_dim = int(dim * mult)
20 | dim_out = dim_out or dim
21 | project_in = nn.Sequential(
22 | nn.Linear(dim, inner_dim),
23 | nn.GELU()
24 | ) if not glu else GEGLU(dim, inner_dim)
25 |
26 | self.net = nn.Sequential(
27 | project_in,
28 | nn.Dropout(dropout),
29 | nn.Linear(inner_dim, dim_out)
30 | )
31 |
32 | def forward(self, x):
33 | return self.net(x)
--------------------------------------------------------------------------------
/src/smplfusion/modules/attention/memory_efficient_cross_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from typing import Optional, Any
4 |
5 | try:
6 | import xformers
7 | import xformers.ops
8 | XFORMERS_IS_AVAILBLE = True
9 | except:
10 | XFORMERS_IS_AVAILBLE = False
11 |
12 | class MemoryEfficientCrossAttention(nn.Module):
13 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
14 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
15 | super().__init__()
16 | # print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using {heads} heads.")
17 | inner_dim = dim_head * heads
18 | context_dim = context_dim or query_dim
19 |
20 | self.heads = heads
21 | self.dim_head = dim_head
22 |
23 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
24 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
25 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
26 |
27 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
28 | self.attention_op: Optional[Any] = None
29 |
30 | def forward(self, x, context=None, mask=None):
31 | q = self.to_q(x)
32 | context = x if context is None else context
33 | k = self.to_k(context)
34 | v = self.to_v(context)
35 |
36 | b, _, _ = q.shape
37 | q, k, v = map(
38 | lambda t: t.unsqueeze(3)
39 | .reshape(b, t.shape[1], self.heads, self.dim_head)
40 | .permute(0, 2, 1, 3)
41 | .reshape(b * self.heads, t.shape[1], self.dim_head)
42 | .contiguous(),
43 | (q, k, v),
44 | )
45 |
46 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
47 |
48 | if mask is not None:
49 | raise NotImplementedError
50 | out = (
51 | out.unsqueeze(0)
52 | .reshape(b, self.heads, out.shape[1], self.dim_head)
53 | .permute(0, 2, 1, 3)
54 | .reshape(b, out.shape[1], self.heads * self.dim_head)
55 | )
56 | return self.to_out(out)
--------------------------------------------------------------------------------
/src/smplfusion/modules/attention/spatial_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import math
4 |
5 | from torch import einsum
6 | from einops import rearrange, repeat
7 | from .basic_transformer_block import PatchedBasicTransformerBlock as BasicTransformerBlock
8 |
9 | def Normalize(in_channels):
10 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
11 |
12 | def init_(tensor):
13 | dim = tensor.shape[-1]
14 | std = 1 / math.sqrt(dim)
15 | tensor.uniform_(-std, std)
16 | return tensor
17 |
18 |
19 | def zero_module(module):
20 | for p in module.parameters():
21 | p.detach().zero_()
22 | return module
23 |
24 |
25 | class SpatialTransformer(nn.Module):
26 | """
27 | Transformer block for image-like data.
28 | First, project the input (aka embedding)
29 | and reshape to b, t, d.
30 | Then apply standard transformer action.
31 | Finally, reshape to image
32 | NEW: use_linear for more efficiency instead of the 1x1 convs
33 | """
34 | def __init__(self, in_channels, n_heads, d_head,
35 | depth=1, dropout=0., context_dim=None,
36 | disable_self_attn=False, use_linear=False,
37 | use_checkpoint=True):
38 | super().__init__()
39 | if context_dim is not None and not isinstance(context_dim, list):
40 | context_dim = [context_dim]
41 | self.in_channels = in_channels
42 | inner_dim = n_heads * d_head
43 | self.norm = Normalize(in_channels)
44 | if not use_linear:
45 | self.proj_in = nn.Conv2d(in_channels,inner_dim,kernel_size=1,stride=1,padding=0)
46 | else:
47 | self.proj_in = nn.Linear(in_channels, inner_dim)
48 |
49 | self.transformer_blocks = nn.ModuleList(
50 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
51 | disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
52 | for d in range(depth)]
53 | )
54 | if not use_linear:
55 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
56 | in_channels,
57 | kernel_size=1,
58 | stride=1,
59 | padding=0))
60 | else:
61 | self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
62 | self.use_linear = use_linear
63 |
64 | def forward(self, x, context=None):
65 | # note: if no context is given, cross-attention defaults to self-attention
66 | if not isinstance(context, list):
67 | context = [context]
68 | b, c, h, w = x.shape
69 | x_in = x
70 | x = self.norm(x)
71 | if not self.use_linear:
72 | x = self.proj_in(x)
73 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
74 | if self.use_linear:
75 | x = self.proj_in(x)
76 | for i, block in enumerate(self.transformer_blocks):
77 | x = block(x, context=context[i])
78 | if self.use_linear:
79 | x = self.proj_out(x)
80 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
81 | if not self.use_linear:
82 | x = self.proj_out(x)
83 | return x + x_in
84 |
85 |
--------------------------------------------------------------------------------
/src/smplfusion/modules/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class DiagonalGaussianDistribution(object):
6 | def __init__(self, parameters, deterministic=False):
7 | self.parameters = parameters
8 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
9 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
10 | self.deterministic = deterministic
11 | self.std = torch.exp(0.5 * self.logvar)
12 | self.var = torch.exp(self.logvar)
13 | if self.deterministic:
14 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
15 |
16 | def sample(self):
17 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
18 | return x
19 |
20 | def kl(self, other=None):
21 | if self.deterministic:
22 | return torch.Tensor([0.])
23 | else:
24 | if other is None:
25 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
26 | + self.var - 1.0 - self.logvar,
27 | dim=[1, 2, 3])
28 | else:
29 | return 0.5 * torch.sum(
30 | torch.pow(self.mean - other.mean, 2) / other.var
31 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
32 | dim=[1, 2, 3])
33 |
34 | def nll(self, sample, dims=[1,2,3]):
35 | if self.deterministic:
36 | return torch.Tensor([0.])
37 | logtwopi = np.log(2.0 * np.pi)
38 | return 0.5 * torch.sum(
39 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
40 | dim=dims)
41 |
42 | def mode(self):
43 | return self.mean
44 |
45 |
46 | def normal_kl(mean1, logvar1, mean2, logvar2):
47 | """
48 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
49 | Compute the KL divergence between two gaussians.
50 | Shapes are automatically broadcasted, so batches can be compared to
51 | scalars, among other use cases.
52 | """
53 | tensor = None
54 | for obj in (mean1, logvar1, mean2, logvar2):
55 | if isinstance(obj, torch.Tensor):
56 | tensor = obj
57 | break
58 | assert tensor is not None, "at least one argument must be a Tensor"
59 |
60 | # Force variances to be Tensors. Broadcasting helps convert scalars to
61 | # Tensors, but it does not work for torch.exp().
62 | logvar1, logvar2 = [
63 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
64 | for x in (logvar1, logvar2)
65 | ]
66 |
67 | return 0.5 * (
68 | -1.0
69 | + logvar2
70 | - logvar1
71 | + torch.exp(logvar1 - logvar2)
72 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
73 | )
74 |
--------------------------------------------------------------------------------
/src/smplfusion/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1, dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | # remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.', '')
20 | self.m_name2s_name.update({name: s_name})
21 | self.register_buffer(s_name, p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def reset_num_updates(self):
26 | del self.num_updates
27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28 |
29 | def forward(self, model):
30 | decay = self.decay
31 |
32 | if self.num_updates >= 0:
33 | self.num_updates += 1
34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35 |
36 | one_minus_decay = 1.0 - decay
37 |
38 | with torch.no_grad():
39 | m_param = dict(model.named_parameters())
40 | shadow_params = dict(self.named_buffers())
41 |
42 | for key in m_param:
43 | if m_param[key].requires_grad:
44 | sname = self.m_name2s_name[key]
45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47 | else:
48 | assert not key in self.m_name2s_name
49 |
50 | def copy_to(self, model):
51 | m_param = dict(model.named_parameters())
52 | shadow_params = dict(self.named_buffers())
53 | for key in m_param:
54 | if m_param[key].requires_grad:
55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56 | else:
57 | assert not key in self.m_name2s_name
58 |
59 | def store(self, parameters):
60 | """
61 | Save the current parameters for restoring later.
62 | Args:
63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64 | temporarily stored.
65 | """
66 | self.collected_params = [param.clone() for param in parameters]
67 |
68 | def restore(self, parameters):
69 | """
70 | Restore the parameters stored with the `store` method.
71 | Useful to validate the model with EMA parameters without affecting the
72 | original optimization process. Store the parameters before the
73 | `copy_to` method. After validation (or model saving), use this to
74 | restore the former parameters.
75 | Args:
76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77 | updated with the stored parameters.
78 | """
79 | for c_param, param in zip(self.collected_params, parameters):
80 | param.data.copy_(c_param.data)
81 |
--------------------------------------------------------------------------------
/src/smplfusion/modules/partial_conv2d.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # BSD 3-Clause License
3 | #
4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5 | #
6 | # Author & Contact: Guilin Liu (guilinl@nvidia.com)
7 | ###############################################################################
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import nn, cuda
12 |
13 | from .. import share
14 |
15 | partial_res = [8, 16, 32, 64]
16 |
17 | # class PartialConv2d(nn.Conv2d):
18 | # def __init__(self, *args, **kwargs):
19 |
20 | # # whether the mask is multi-channel or not
21 | # if 'multi_channel' in kwargs:
22 | # self.multi_channel = kwargs['multi_channel']
23 | # kwargs.pop('multi_channel')
24 | # else:
25 | # self.multi_channel = False
26 |
27 | # if 'return_mask' in kwargs:
28 | # self.return_mask = kwargs['return_mask']
29 | # kwargs.pop('return_mask')
30 | # else:
31 | # self.return_mask = False
32 |
33 | # super(PartialConv2d, self).__init__(*args, **kwargs)
34 |
35 | # if self.multi_channel:
36 | # self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])
37 | # else:
38 | # self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
39 |
40 | # self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]
41 |
42 | # self.last_size = (None, None, None, None)
43 | # self.update_mask = None
44 | # self.mask_ratio = None
45 |
46 | # def forward(self, input, mask_in=None):
47 | # assert len(input.shape) == 4
48 | # if mask_in is not None or self.last_size != tuple(input.shape):
49 | # self.last_size = tuple(input.shape)
50 |
51 | # with torch.no_grad():
52 | # if self.weight_maskUpdater.type() != input.type():
53 | # self.weight_maskUpdater = self.weight_maskUpdater.to(input)
54 |
55 | # if mask_in is None:
56 | # # if mask is not provided, create a mask
57 | # if self.multi_channel:
58 | # mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input)
59 | # else:
60 | # mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
61 | # else:
62 | # mask = mask_in
63 |
64 | # self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)
65 |
66 | # # for mixed precision training, change 1e-8 to 1e-6
67 | # self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
68 | # # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
69 | # self.update_mask = torch.clamp(self.update_mask, 0, 1)
70 | # self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
71 |
72 |
73 | # raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)
74 |
75 | # if self.bias is not None:
76 | # bias_view = self.bias.view(1, self.out_channels, 1, 1)
77 | # output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
78 | # output = torch.mul(output, self.update_mask)
79 | # else:
80 | # output = torch.mul(raw_out, self.mask_ratio)
81 |
82 |
83 | # if self.return_mask:
84 | # return output, self.update_mask
85 | # else:
86 | # return output
87 |
88 |
89 | class PartialConv2d(nn.Conv2d):
90 | """
91 | NOTE: You need to use share.set_mask(original_mask) before inferencing with PartialConv2d.
92 | """
93 |
94 | def __init__(self, *args, **kwargs):
95 | super(PartialConv2d, self).__init__(*args, **kwargs)
96 |
97 | self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
98 | self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]
99 |
100 | def forward(self, input):
101 | if share.input_mask is None:
102 | raise Exception('Please set share.set_mask(original_mask) before inferencing with PartialConv2d.')
103 |
104 | bias_view = torch.zeros((1, self.out_channels, 1, 1), dtype=input.dtype, device=input.device)
105 | if self.bias is not None:
106 | bias_view = self.bias.view(1, self.out_channels, 1, 1)
107 |
108 | # Get the resized mask for the current resolution
109 | res = max(input.shape[2:])
110 |
111 | if res == max(share.input_mask.shape64):
112 | mask = share.input_mask.val64
113 | mask_down = share.input_mask.val32
114 | elif res == max(share.input_mask.shape32):
115 | mask = share.input_mask.val32
116 | mask_down = share.input_mask.val16
117 | elif res == max(share.input_mask.shape16):
118 | mask = share.input_mask.val16
119 | mask_down = share.input_mask.val8
120 | elif res == max(share.input_mask.shape8):
121 | mask = share.input_mask.val8
122 |
123 | mask = mask.to(input.device)
124 |
125 | # Separately perform the convolution operation on masked and known regions
126 | masked_input = torch.mul(input, mask)
127 | known_input = torch.mul(input, 1-mask)
128 |
129 | input_out =super(PartialConv2d, self).forward(input)
130 | masked_out = super(PartialConv2d, self).forward(masked_input)
131 | known_out = super(PartialConv2d, self).forward(known_input)
132 |
133 | # Calculate the rescaling weights for known and unknown regions
134 |
135 | # ############ Weighting strategy No 1 #################
136 |
137 | # pixel_counts = F.conv2d(
138 | # F.pad(mask, [*self.padding, *self.padding], mode='reflect'),
139 | # self.weight_maskUpdater.to(mask.device),
140 | # bias=None,
141 | # stride=self.stride,
142 | # padding=(0, 0),
143 | # dilation=self.dilation,
144 | # groups=1
145 | # )
146 |
147 | # mask_ratio_unknown = self.slide_winsize/ (pixel_counts)
148 | # mask_ratio_known = self.slide_winsize / (self.slide_winsize - pixel_counts)
149 |
150 | # ################## End of No 1 #########################
151 |
152 | # ################ Weighting strategy No 2 ###############
153 |
154 | # ones_input = torch.ones_like(input)
155 | # masks_input = mask.repeat(1, ones_input.shape[1], 1, 1)
156 |
157 | # sum_overall = F.conv2d(
158 | # F.pad(ones_input, [*self.padding, *self.padding], mode='reflect'),
159 | # torch.abs(self.weight),
160 | # bias=None,
161 | # stride=self.stride,
162 | # padding=(0, 0),
163 | # dilation=self.dilation,
164 | # groups=1
165 | # )
166 |
167 | # sum_masked = F.conv2d(
168 | # F.pad(masks_input, [*self.padding, *self.padding], mode='reflect'),
169 | # torch.abs(self.weight),
170 | # bias=None,
171 | # stride=self.stride,
172 | # padding=(0, 0),
173 | # dilation=self.dilation,
174 | # groups=1
175 | # )
176 |
177 | # mask_ratio_unknown = sum_overall / (sum_masked)
178 | # mask_ratio_known = sum_overall / (sum_overall - sum_masked)
179 |
180 | # ################## End of No 2 #########################
181 |
182 | ################ Weighting strategy No 3 ###############
183 | if res not in partial_res:
184 | return input_out
185 |
186 | input_norm = torch.norm(input_out - bias_view, dim=1, keepdim=True)
187 | known_norm = torch.norm(known_out - bias_view, dim=1, keepdim=True)
188 | masked_norm = torch.norm(masked_out - bias_view, dim=1, keepdim=True)
189 |
190 | mask_ratio_unknown = input_norm / masked_norm
191 | mask_ratio_known = input_norm / known_norm
192 |
193 | ################## End of No 3 #########################
194 |
195 | # Replace nan and inf with 0.0
196 | mask_ratio_unknown = torch.nan_to_num(mask_ratio_unknown, nan=0.0, posinf=0.0, neginf=0.0)
197 | mask_ratio_known = torch.nan_to_num(mask_ratio_known, nan=0.0, posinf=0.0, neginf=0.0)
198 |
199 | ###################### DEBUG ############################
200 | # if res == 8:
201 | # print(mask_ratio_unknown[0][0], mask_ratio_unknown.shape)
202 | # print(mask_ratio_known[0][0], mask_ratio_known.shape)
203 |
204 | ################### END OF DEBUG ########################
205 |
206 | # If set to true, doesn't rescale the convolution outputs
207 | ignore_mask_ratio = False
208 | if ignore_mask_ratio:
209 | mask_ratio_known = 1.0
210 | mask_ratio_unknown = 1.0
211 |
212 | known_out = known_out - bias_view
213 | masked_out = masked_out - bias_view
214 |
215 | if max(self.stride) > 1:
216 | mask_down = mask_down.to(input.device)
217 | out = masked_out * mask_down * mask_ratio_unknown + known_out * (1-mask_down) * mask_ratio_known
218 | else:
219 | out = masked_out * mask * mask_ratio_unknown + known_out * (1-mask) * mask_ratio_known
220 |
221 | out = out + bias_view
222 |
223 | return out
--------------------------------------------------------------------------------
/src/smplfusion/patches/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/__init__.py
--------------------------------------------------------------------------------
/src/smplfusion/patches/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/__pycache__/router.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/__pycache__/router.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/__init__.py:
--------------------------------------------------------------------------------
1 | from . import zeropaint
2 | from . import default
3 | from . import attstore
4 | from . import inpaint
5 | from . import shuffled
6 | from . import introvert
7 | from . import boosted_maskscore
8 |
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/__pycache__/attstore.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/attstore.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/__pycache__/boosted_maskscore.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/boosted_maskscore.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/__pycache__/default.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/default.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/__pycache__/ediff.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/ediff.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/__pycache__/inpaint.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/inpaint.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/__pycache__/introvert.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/introvert.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/__pycache__/shuffled.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/shuffled.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/__pycache__/zeropaint.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/attentionpatch/__pycache__/zeropaint.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/attstore.py:
--------------------------------------------------------------------------------
1 | # CrossAttn precision handling
2 | import os
3 | import torch
4 | from torch import nn
5 |
6 | from torch import einsum
7 | from einops import rearrange, repeat
8 | from ... import share
9 |
10 | att_res = [16 * 16]
11 | force_idx = 1
12 |
13 | def forward(self, x, context=None, mask=None):
14 | h = self.heads
15 |
16 | q = self.to_q(x)
17 | att_type = "self" if context is None else "cross"
18 | context = x if context is None else context
19 | k = self.to_k(context)
20 | v = self.to_v(context)
21 |
22 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
23 |
24 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
25 |
26 | if att_type == "cross" and q.shape[1] in att_res:
27 | share.sim.append(sim)
28 |
29 | # attention, what we cannot get enough of
30 | sim = sim.softmax(dim=-1)
31 |
32 | out = einsum("b i j, b j d -> b i d", sim, v)
33 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
34 | return self.to_out(out)
35 |
36 | def forward_force(self, x, context=None, mask=None):
37 | h = self.heads
38 |
39 | q = self.to_q(x)
40 | att_type = "self" if context is None else "cross"
41 | context = x if context is None else context
42 | k = self.to_k(context)
43 | v = self.to_v(context)
44 |
45 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
46 |
47 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
48 |
49 | if att_type == "cross":
50 | context_dim = context.shape[1] # Number of tokens, might not be 77
51 |
52 | # if q.shape[1] == share.input_mask.res64:
53 | # _sim = 100 * torch.eye(context_dim)[force_idx].half().cuda()[None].repeat(sim.shape[0], sim.shape[1], 1)
54 | # _sim += 100 * torch.eye(context_dim)[0].half().cuda()[None].repeat(sim.shape[0], sim.shape[1], 1)
55 | # sim[:,share.input_mask.val64.reshape(-1) > 0,:] = _sim[:,share.input_mask.val64.reshape(-1) > 0,:]
56 | # if q.shape[1] == share.input_mask.res32:
57 | # _sim = 100 * torch.eye(context_dim)[force_idx].half().cuda()[None].repeat(sim.shape[0], sim.shape[1], 1)
58 | # _sim += 100 * torch.eye(context_dim)[0].half().cuda()[None].repeat(sim.shape[0], sim.shape[1], 1)
59 | # sim[:,share.input_mask.val32.reshape(-1) > 0,:] = _sim[:,share.input_mask.val32.reshape(-1) > 0,:]
60 | if q.shape[1] == share.input_mask.res16:
61 | # print (sim.shape)
62 | _sim = 100 * torch.eye(context_dim)[force_idx].half().cuda()[None].repeat(sim.shape[0]//2, sim.shape[1], 1)
63 | _sim += 100 * torch.eye(context_dim)[0].half().cuda()[None].repeat(sim.shape[0]//2, sim.shape[1], 1)
64 | # sim[sim.shape[0]//2:,share.input_mask.val16.reshape(-1) > 0,:] = _sim[:,share.input_mask.val16.reshape(-1) > 0,:]
65 | # sim[sim.shape[0]//2:, share.input_mask.val16.reshape(-1) > 0, 0] *= 0.8
66 | share.sim.append(sim[sim.shape[0]//2:])
67 |
68 | # attention, what we cannot get enough of
69 | sim = sim.softmax(dim=-1)
70 |
71 | out = einsum("b i j, b j d -> b i d", sim, v)
72 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
73 | if q.shape[1] == share.input_mask.res16:
74 | if att_type == "cross":
75 | share.cross_out.append(out.detach())
76 | if att_type == "self":
77 | share.self_out.append(out.detach())
78 |
79 | return self.to_out(out)
80 |
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/boosted_maskscore copy.py:
--------------------------------------------------------------------------------
1 | from ... import share
2 | import torch
3 |
4 | # import xformers
5 | # import xformers.ops
6 |
7 | import torchvision.transforms.functional as TF
8 |
9 | from torch import nn, einsum
10 | from inspect import isfunction
11 | from einops import rearrange, repeat
12 |
13 | qkv_reduce_dims = [-1]
14 | increase_indices = [1,2]
15 |
16 | w8 = 0.
17 | w16 = 0.
18 | w32 = 0.
19 | w64 = 0.
20 |
21 | def forward_and_save(self, x, context=None, mask=None):
22 | att_type = "self" if context is None else "cross"
23 |
24 | h = self.heads
25 | q = self.to_q(x)
26 | is_cross = context is not None
27 | context = x if context is None else context
28 | k = self.to_k(context)
29 | v = self.to_v(context)
30 |
31 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
32 |
33 | scale = self.dim_head**-0.5
34 | sim = einsum("b i d, b j d -> b i j", q, k) * scale
35 |
36 | if is_cross:
37 | N = sim.shape[0] // 2
38 |
39 | delta_uncond = torch.zeros_like(sim[:N])
40 |
41 | if q.shape[1] == share.input_shape.res8:
42 | W = w8 + torch.zeros_like(sim[:N])
43 | if q.shape[1] == share.input_shape.res16:
44 | W = w16 + torch.zeros_like(sim[:N])
45 | if q.shape[1] == share.input_shape.res32:
46 | W = w32 + torch.zeros_like(sim[:N])
47 | if q.shape[1] == share.input_shape.res64:
48 | W = w64 + torch.zeros_like(sim[:N])
49 |
50 | tokens = torch.eye(77)[increase_indices].sum(0).cuda()
51 |
52 | sim[N:].argmax()
53 |
54 | max_sim = sim[N:].detach().amax(dim=qkv_reduce_dims, keepdim = True) # (16x1x1)
55 | sigma = share.schedule.sqrt_noise_signal_ratio[share.timestep]
56 | mask = share.input_mask.get_res(q, 'cuda').reshape(1,-1,1)
57 |
58 | delta_cond = W * max_sim * mask * tokens
59 | sim += torch.cat([delta_uncond, delta_cond])
60 |
61 | if hasattr(share, '_crossattn_similarity') and x.shape[1] == share.input_shape.res16 and att_type == 'cross':
62 | share._crossattn_similarity.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
63 | if hasattr(share, '_crossattn_similarity_res8') and x.shape[1] == share.input_shape.res8 and att_type == 'cross':
64 | share._crossattn_similarity_res8.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
65 | if hasattr(share, '_crossattn_similarity_res16') and x.shape[1] == share.input_shape.res16 and att_type == 'cross':
66 | share._crossattn_similarity_res16.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
67 | if hasattr(share, '_crossattn_similarity_res32') and x.shape[1] == share.input_shape.res32 and att_type == 'cross':
68 | share._crossattn_similarity_res32.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
69 | if hasattr(share, '_crossattn_similarity_res64') and x.shape[1] == share.input_shape.res64 and att_type == 'cross':
70 | share._crossattn_similarity_res64.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
71 |
72 | sim = sim.softmax(dim=-1)
73 | out = einsum("b i j, b j d -> b i d", sim, v)
74 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
75 | return self.to_out(out)
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/boosted_maskscore.py:
--------------------------------------------------------------------------------
1 | from ... import share
2 | import torch
3 |
4 | # import xformers
5 | # import xformers.ops
6 |
7 | import torchvision.transforms.functional as TF
8 |
9 | from torch import nn, einsum
10 | from inspect import isfunction
11 | from einops import rearrange, repeat
12 | import numpy as np
13 |
14 | qkv_reduce_dims = [-1]
15 | increase_indices = [1,2]
16 | topk_heads = 0.5
17 |
18 | w8 = 0.
19 | w16 = 0.
20 | w32 = 0.
21 | w64 = 0.
22 |
23 | def forward(self, x, context=None, mask=None):
24 | att_type = "self" if context is None else "cross"
25 |
26 | h = self.heads
27 | q = self.to_q(x)
28 | is_cross = context is not None
29 | context = x if context is None else context
30 | k = self.to_k(context)
31 | v = self.to_v(context)
32 |
33 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
34 |
35 | scale = self.dim_head**-0.5
36 | sim = einsum("b i d, b j d -> b i j", q, k) * scale
37 |
38 | if is_cross:
39 | if share.timestep < 0:
40 | # Choose the weight
41 | if q.shape[1] == share.input_shape.res8: w = w8
42 | if q.shape[1] == share.input_shape.res16: w = w16
43 | if q.shape[1] == share.input_shape.res32: w = w32
44 | if q.shape[1] == share.input_shape.res64: w = w64
45 | mask = share.input_mask.get_res(q, 'cuda').reshape(-1)
46 |
47 | sim = rearrange(sim, "(b h) n d -> b h n d", h=h) # (2x10x4096x77)
48 | sim[1,...,0] -= 0.2 * sim[1, ..., 0] * mask
49 |
50 | # # Get coefficients
51 | # sigma = (1 - share.schedule.alphas[share.timestep])
52 | # sim_max = sim.detach().amax(dim=qkv_reduce_dims, keepdim = True) # (2x10x1x1)
53 |
54 | # # Use to modify only the conditional part
55 | # batch_one_hot = torch.tensor([0,1.])[:, None, None, None].cuda() # (2x1x1x1)
56 |
57 | # for token_idx in increase_indices:
58 | # # Get the index of the head to be modified
59 | # n_topk = int(sim.shape[1] * topk_heads)
60 | # head_indices = (sim[1,...,1]).amax(dim = 1).topk(n_topk).indices.cpu()
61 | # head_one_hot = torch.eye(h)[head_indices]
62 | # # head_one_hot[n_topk // 2:] *= 0.5
63 | # head_one_hot = head_one_hot.sum(0).cuda()[None,:,None,None] # (1x10x1x1)
64 |
65 | # # head_one_hot = (1 / (1 + torch.arange(sim.shape[1]))) ** 0.5
66 | # # head_one_hot = head_one_hot.cuda()[None,:,None,None]
67 |
68 | # # Get the one hot token index
69 | # token_one_hot = torch.eye(77)[token_idx].cuda()[None,None,None,:] # (1x1x1x77)
70 |
71 | # sim += w * sim_max * mask * token_one_hot * head_one_hot * batch_one_hot
72 |
73 | sim = rearrange(sim, "b h n d -> (b h) n d", h=h) # (2x10x4096x77)
74 | if not is_cross and q.shape[1] == share.input_shape.res16:
75 | # shape of sim: 20 x 4096 x 4096
76 | if share.timestep < 0:
77 | sim_max = sim.detach().amax(dim=-1, keepdim = True) # (20x1x1)
78 |
79 | mask = share.input_mask.get_res(q, 'cuda').reshape(-1)
80 | delta = (mask[:,None] @ mask[None,:])[None]
81 | gamma = ((1 - mask[:,None]) @ (mask[None,:]))[None]
82 |
83 | sim = sim * (1 + 1.0 * share.schedule.sqrt_one_minus_alphas[share.timestep] * delta)
84 | sim = sim * (1 - 1.0 * share.schedule.sqrt_one_minus_alphas[share.timestep] * gamma)
85 | # sim = sim + 1.0 * share.schedule.sqrt_noise_signal_ratio[share.timestep] * sim_max * delta
86 |
87 | # Chunk into 2 parts to differentiate the unconditional and conditional parts
88 | if hasattr(share, '_crossattn_similarity') and x.shape[1] == share.input_shape.res16 and att_type == 'cross':
89 | share._crossattn_similarity.append(torch.stack(share.reshape(sim).chunk(2)))
90 | if hasattr(share, '_crossattn_similarity_res8') and x.shape[1] == share.input_shape.res8 and att_type == 'cross':
91 | share._crossattn_similarity_res8.append(torch.stack(share.reshape(sim).chunk(2)))
92 | if hasattr(share, '_crossattn_similarity_res16') and x.shape[1] == share.input_shape.res16 and att_type == 'cross':
93 | share._crossattn_similarity_res16.append(torch.stack(share.reshape(sim).chunk(2)))
94 | if hasattr(share, '_crossattn_similarity_res32') and x.shape[1] == share.input_shape.res32 and att_type == 'cross':
95 | share._crossattn_similarity_res32.append(torch.stack(share.reshape(sim).chunk(2)))
96 | if hasattr(share, '_crossattn_similarity_res64') and x.shape[1] == share.input_shape.res64 and att_type == 'cross':
97 | share._crossattn_similarity_res64.append(torch.stack(share.reshape(sim).chunk(2)))
98 |
99 | sim = sim.softmax(dim=-1)
100 | out = einsum("b i j, b j d -> b i d", sim, v)
101 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
102 | return self.to_out(out)
103 |
104 | forward_and_save = forward
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/default.py:
--------------------------------------------------------------------------------
1 | from ... import share
2 |
3 | try:
4 | import xformers
5 | import xformers.ops
6 | XFORMERS_IS_AVAILBLE = True
7 | except:
8 | XFORMERS_IS_AVAILBLE = False
9 | print("No module 'xformers'. Proceeding without it.")
10 |
11 |
12 | import torch
13 | from torch import nn, einsum
14 | import torchvision.transforms.functional as TF
15 | from einops import rearrange, repeat
16 |
17 | _ATTN_PRECISION = None
18 |
19 | def forward_sd2(self, x, context=None, mask=None):
20 | h = self.heads
21 | q = self.to_q(x)
22 | context = x if context is None else context
23 | k = self.to_k(context)
24 | v = self.to_v(context)
25 |
26 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
27 |
28 | if _ATTN_PRECISION =="fp32": # force cast to fp32 to avoid overflowing
29 | with torch.autocast(enabled=False, device_type = 'cuda'):
30 | q, k = q.float(), k.float()
31 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
32 | else:
33 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
34 | del q, k
35 |
36 | if mask is not None:
37 | mask = rearrange(mask, 'b ... -> b (...)')
38 | max_neg_value = -torch.finfo(sim.dtype).max
39 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
40 | sim.masked_fill_(~mask, max_neg_value)
41 |
42 | # attention, what we cannot get enough of
43 | sim = sim.softmax(dim=-1)
44 |
45 | out = einsum('b i j, b j d -> b i d', sim, v)
46 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
47 | return self.to_out(out)
48 |
49 | def forward_xformers(self, x, context=None, mask=None):
50 | q = self.to_q(x)
51 | context = x if context is None else context
52 | k = self.to_k(context)
53 | v = self.to_v(context)
54 |
55 | b, _, _ = q.shape
56 | q, k, v = map(
57 | lambda t: t.unsqueeze(3)
58 | .reshape(b, t.shape[1], self.heads, self.dim_head)
59 | .permute(0, 2, 1, 3)
60 | .reshape(b * self.heads, t.shape[1], self.dim_head)
61 | .contiguous(),
62 | (q, k, v),
63 | )
64 |
65 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
66 |
67 | if mask is not None:
68 | raise NotImplementedError
69 | out = (
70 | out.unsqueeze(0)
71 | .reshape(b, self.heads, out.shape[1], self.dim_head)
72 | .permute(0, 2, 1, 3)
73 | .reshape(b, out.shape[1], self.heads * self.dim_head)
74 | )
75 | return self.to_out(out)
76 |
77 | if XFORMERS_IS_AVAILBLE:
78 | forward = forward_xformers
79 | else:
80 | forward = forward_sd2
81 |
82 | def forward_and_save(self, x, context=None, mask=None):
83 | att_type = "self" if context is None else "cross"
84 |
85 | h = self.heads
86 | q = self.to_q(x)
87 | context = x if context is None else context
88 | k = self.to_k(context)
89 | v = self.to_v(context)
90 |
91 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
92 |
93 | # with torch.autocast(enabled=False, device_type = 'cuda'):
94 | # q, k = q.float(), k.float()
95 | # sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
96 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
97 | # del q,k
98 |
99 | if hasattr(share, '_crossattn_similarity_res8') and x.shape[1] == share.input_shape.res8 and att_type == 'cross':
100 | share._crossattn_similarity_res8.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
101 | if hasattr(share, '_crossattn_similarity_res16') and x.shape[1] == share.input_shape.res16 and att_type == 'cross':
102 | share._crossattn_similarity_res16.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
103 | if hasattr(share, '_crossattn_similarity_res32') and x.shape[1] == share.input_shape.res32 and att_type == 'cross':
104 | share._crossattn_similarity_res32.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
105 | if hasattr(share, '_crossattn_similarity_res64') and x.shape[1] == share.input_shape.res64 and att_type == 'cross':
106 | share._crossattn_similarity_res64.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
107 |
108 | # attention, what we cannot get enough of
109 | sim = sim.softmax(dim=-1)
110 | out = einsum("b i j, b j d -> b i d", sim, v)
111 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
112 | return self.to_out(out)
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/inpaint.py:
--------------------------------------------------------------------------------
1 | from ... import share
2 | import torch
3 |
4 | # import xformers
5 | # import xformers.ops
6 |
7 | import torchvision.transforms.functional as TF
8 |
9 | from torch import nn, einsum
10 | from inspect import isfunction
11 | from einops import rearrange, repeat
12 |
13 | qkv_reduce_dims = [-1]
14 | increase_indices = [1,2]
15 |
16 | def forward(self, x, context=None, mask=None):
17 | h = self.heads
18 | q = self.to_q(x)
19 | is_cross = context is not None
20 | context = x if context is None else context
21 | k = self.to_k(context)
22 | v = self.to_v(context)
23 |
24 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
25 |
26 | scale = self.dim_head**-0.5
27 | sim = einsum("b i d, b j d -> b i j", q, k) * scale
28 |
29 | if is_cross:
30 | if q.shape[1] in [share.input_shape.res16]:
31 | # For simplicity, assume token 1 is target, token 2 is (1 word label)
32 | # sim: (16x4096x77); mask: (64x64) -> mask.reshape(-1): (4096)
33 | N = sim.shape[0]//2
34 | sim[N:] += (
35 | share.w
36 | * share.noise_signal_ratio[share.timestep]
37 | * share.input_mask.get_res(q, 'cuda').reshape(1,-1,1)# (1,4096,1)
38 | * sim[N:].amax(dim=qkv_reduce_dims, keepdim = True) # (16x1x1)
39 | ) * (torch.eye(77)[increase_indices].sum(0).cuda()) # (1,1,77)
40 |
41 | del q, k
42 | sim = sim.softmax(dim=-1)
43 |
44 | out = einsum("b i j, b j d -> b i d", sim, v)
45 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
46 | return self.to_out(out)
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/introvert.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numbers
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | import torchvision
8 | from torch import nn, einsum
9 | from einops import rearrange, repeat
10 |
11 | from ... import share
12 | from ...libimage import IImage
13 |
14 |
15 | # Default for version 4
16 | introvert_res = [16, 32, 64]
17 | introvert_on = True
18 | token_idx = [1,2]
19 |
20 | # Visualization purpose
21 | viz_image = None
22 | viz_mask = None
23 |
24 | video_frames_selfattn = []
25 | video_frames_crossattn = []
26 | visualize_resolution = 16
27 | visualize_selfattn = False
28 | visualize_crossattn = False
29 |
30 | class GaussianSmoothing(nn.Module):
31 | """
32 | Apply gaussian smoothing on a
33 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
34 | in the input using a depthwise convolution.
35 | Arguments:
36 | channels (int, sequence): Number of channels of the input tensors. Output will
37 | have this number of channels as well.
38 | kernel_size (int, sequence): Size of the gaussian kernel.
39 | sigma (float, sequence): Standard deviation of the gaussian kernel.
40 | dim (int, optional): The number of dimensions of the data.
41 | Default value is 2 (spatial).
42 | """
43 | def __init__(self, channels, kernel_size, sigma, dim=2):
44 | super(GaussianSmoothing, self).__init__()
45 | if isinstance(kernel_size, numbers.Number):
46 | kernel_size = [kernel_size] * dim
47 | if isinstance(sigma, numbers.Number):
48 | sigma = [sigma] * dim
49 |
50 | # The gaussian kernel is the product of the
51 | # gaussian function of each dimension.
52 | kernel = 1
53 | meshgrids = torch.meshgrid(
54 | [
55 | torch.arange(size, dtype=torch.float32)
56 | for size in kernel_size
57 | ]
58 | )
59 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
60 | mean = (size - 1) / 2
61 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
62 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
63 |
64 | # Make sure sum of values in gaussian kernel equals 1.
65 | kernel = kernel / torch.sum(kernel)
66 |
67 | # Reshape to depthwise convolutional weight
68 | kernel = kernel.view(1, 1, *kernel.size())
69 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
70 |
71 | self.register_buffer('weight', kernel)
72 | self.groups = channels
73 |
74 | if dim == 1:
75 | self.conv = F.conv1d
76 | elif dim == 2:
77 | self.conv = F.conv2d
78 | elif dim == 3:
79 | self.conv = F.conv3d
80 | else:
81 | raise RuntimeError(
82 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
83 | )
84 |
85 | def forward(self, input):
86 | """
87 | Apply gaussian filter to input.
88 | Arguments:
89 | input (torch.Tensor): Input to apply gaussian filter on.
90 | Returns:
91 | filtered (torch.Tensor): Filtered output.
92 | """
93 | return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups, padding='same')
94 |
95 |
96 | def forward(self, x, context=None, mask=None):
97 | is_cross = context is not None
98 | att_type = "self" if context is None else "cross"
99 |
100 | h = self.heads
101 |
102 | q = self.to_q(x)
103 | context = x if context is None else context
104 | k = self.to_k(context)
105 | v = self.to_v(context)
106 |
107 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
108 |
109 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
110 | sim_before = sim
111 | del q, k
112 |
113 | if mask is not None:
114 | mask = rearrange(mask, 'b ... -> b (...)')
115 | max_neg_value = -torch.finfo(sim.dtype).max
116 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
117 | sim.masked_fill_(~mask, max_neg_value)
118 |
119 | if hasattr(share, '_crossattn_similarity_res8') and x.shape[1] == share.input_shape.res8 and att_type == 'cross':
120 | share._crossattn_similarity_res8.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
121 | if hasattr(share, '_crossattn_similarity_res16') and x.shape[1] == share.input_shape.res16 and att_type == 'cross':
122 | share._crossattn_similarity_res16.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
123 | if hasattr(share, '_crossattn_similarity_res32') and x.shape[1] == share.input_shape.res32 and att_type == 'cross':
124 | share._crossattn_similarity_res32.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
125 | if hasattr(share, '_crossattn_similarity_res64') and x.shape[1] == share.input_shape.res64 and att_type == 'cross':
126 | share._crossattn_similarity_res64.append(torch.stack(share.reshape(sim).chunk(2))) # Chunk into 2 parts to differentiate the unconditional and conditional parts
127 |
128 | sim = sim.softmax(dim=-1)
129 | out = einsum('b i j, b j d -> b i d', sim, v)
130 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
131 |
132 | if is_cross:
133 | return self.to_out(out)
134 |
135 | return self.to_out(out), v, sim_before
136 |
137 |
138 | def introvert_rescale(y, self_v, self_sim, cross_sim, self_h, to_out):
139 | mask = share.introvert_mask.get_res(self_v)
140 | shape = share.introvert_mask.get_shape(self_v)
141 | res = share.introvert_mask.get_res_val(self_v)
142 | # print (res, shape)
143 |
144 | ################# Introvert Attention V4 ################
145 | # Use this with 50% of DDIM steps
146 |
147 | # TODO: LOOK INTO THIS. WHY WITHOUT BINARY WORKS BETTER????
148 | # mask = (mask > 0.5).to(torch.float32)
149 | m = mask.to(self_v.device)
150 | # mask_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).cuda()
151 | # m = mask_smoothing(m) # Smoothing on binary mask also works
152 | m = rearrange(m, 'b c h w -> b (h w) c').contiguous()
153 | mo = m
154 | m = torch.matmul(m, m.permute(0, 2, 1)) + (1-m)
155 |
156 | cross_sim = cross_sim[:, token_idx].sum(dim=1)
157 | cross_sim = cross_sim.reshape(shape)
158 | # TODO: comment this out if it is not neccessary
159 | heatmap_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).cuda()
160 | cross_sim = heatmap_smoothing(cross_sim.unsqueeze(0))[0]
161 | cross_sim = cross_sim.reshape(-1)
162 | cross_sim = ((cross_sim - torch.median(cross_sim.ravel())) / torch.max(cross_sim.ravel())).clip(0, 1)
163 |
164 | # If introvert attention is off, return original SA result
165 | if introvert_on and res in introvert_res:
166 | w = (1 - m) * cross_sim.reshape(1, 1, -1) + m
167 | # On 64 resolution make scaling with constant, as cross_sim doesn't contain semantic meaning
168 | if res == 64: w = m
169 | self_sim = self_sim * w
170 | self_sim_viz = self_sim # Keep for viz purposes
171 | self_sim = self_sim.softmax(dim=-1)
172 | out = einsum('b i j, b j d -> b i d', self_sim, self_v)
173 | out = rearrange(out, '(b h) n d -> b n (h d)', h=self_h)
174 | out = to_out(out)
175 | else:
176 | self_sim_viz = self_sim # Keep for viz purposes
177 | out = y
178 | ################## END OF Introvert Attention V4 ###########################
179 |
180 |
181 | ################# VISUALIZE CROSS ATTENTION ###############################
182 | if visualize_crossattn and res == visualize_resolution:
183 | cross_vis = cross_sim.reshape(shape)
184 | up = (64 // res) * 8
185 | if viz_image is not None:
186 | heatmap = IImage(cross_vis, vmin=0).heatmap((cross_vis.shape[0]*up, cross_vis.shape[1]*up))
187 | video_frames_crossattn.append(viz_image + heatmap)
188 | else:
189 | heatmap = IImage(cross_vis, vmin=0).heatmap((cross_vis.shape[0]*up, cross_vis.shape[1]*up))
190 | video_frames_crossattn.append(heatmap)
191 | ###############################
192 |
193 | ################# VISUALIZE SELF ATTENTION ###############################
194 | if visualize_selfattn and res == visualize_resolution:
195 | selected = []
196 | up = (64 // res) * 8
197 | for i in range(mo.shape[1]):
198 | if mo[0, i, 0]:
199 | selected.append(self_sim_viz[:, i, :])
200 | selected = torch.stack(selected, dim=1)
201 | selected_vis = selected.mean(0).mean(0)
202 |
203 | selected_vis = selected_vis.reshape(shape)
204 |
205 | if viz_image is not None:
206 | heatmap = IImage(selected_vis, vmin=0, vmax=1).heatmap((selected_vis.shape[0]*up, selected_vis.shape[1]*up))
207 | video_frames_selfattn.append(viz_image+heatmap)
208 | else:
209 | heatmap = IImage(selected_vis, vmin=0).heatmap((selected_vis.shape[0]*up, selected_vis.shape[1]*up))
210 | video_frames_selfattn.append(heatmap)
211 | ###############################
212 |
213 | return out
214 |
215 |
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/maskscore.py:
--------------------------------------------------------------------------------
1 | from ... import share
2 | import torch
3 |
4 | # import xformers
5 | # import xformers.ops
6 |
7 | import torchvision.transforms.functional as TF
8 |
9 | from torch import nn, einsum
10 | from inspect import isfunction
11 | from einops import rearrange, repeat
12 |
13 | qkv_reduce_dims = [-1]
14 | increase_indices = [1,2]
15 |
16 | def forward(self, x, context=None, mask=None):
17 | h = self.heads
18 | q = self.to_q(x)
19 | is_cross = context is not None
20 | context = x if context is None else context
21 | k = self.to_k(context)
22 | v = self.to_v(context)
23 |
24 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
25 |
26 | scale = self.dim_head**-0.5
27 | sim = einsum("b i d, b j d -> b i j", q, k) * scale
28 |
29 | if is_cross:
30 | if q.shape[1] in [share.input_shape.res16]:
31 | # For simplicity, assume token 1 is target, token 2 is (1 word label)
32 | # sim: (16x4096x77); mask: (64x64) -> mask.reshape(-1): (4096)
33 | N = sim.shape[0]//2
34 | sim[N:] += (
35 | share.w
36 | * share.noise_signal_ratio[share.timestep]
37 | * share.input_mask.get_res(q, 'cuda').reshape(1,-1,1)# (1,4096,1)
38 | * sim[N:].amax(dim=qkv_reduce_dims, keepdim = True) # (16x1x1)
39 | ) * (torch.eye(77)[increase_indices].sum(0).cuda()) # (1,1,77)
40 |
41 | del q, k
42 | sim = sim.softmax(dim=-1)
43 |
44 | out = einsum("b i j, b j d -> b i d", sim, v)
45 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
46 | return self.to_out(out)
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/other.py:
--------------------------------------------------------------------------------
1 | # import xformers
2 | # import xformers.ops
3 |
4 | def forward(self, x, context=None, mask=None):
5 | q = self.to_q(x)
6 | context = x if context is None else context
7 | k = self.to_k(context)
8 | v = self.to_v(context)
9 |
10 | b, _, _ = q.shape
11 | q, k, v = map(
12 | lambda t: t.unsqueeze(3)
13 | .reshape(b, t.shape[1], self.heads, self.dim_head)
14 | .permute(0, 2, 1, 3)
15 | .reshape(b * self.heads, t.shape[1], self.dim_head)
16 | .contiguous(),
17 | (q, k, v),
18 | )
19 |
20 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
21 |
22 | if mask is not None:
23 | raise NotImplementedError
24 | out = (
25 | out.unsqueeze(0)
26 | .reshape(b, self.heads, out.shape[1], self.dim_head)
27 | .permute(0, 2, 1, 3)
28 | .reshape(b, out.shape[1], self.heads * self.dim_head)
29 | )
30 | return self.to_out(1 - out)
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/shuffled.py:
--------------------------------------------------------------------------------
1 | from ... import share
2 |
3 | # import xformers
4 | # import xformers.ops
5 |
6 | import torch
7 | from torch import nn, einsum
8 | import torchvision.transforms.functional as TF
9 | from einops import rearrange, repeat
10 |
11 | layer_mask = share.LayerMask()
12 |
13 | def forward(self, x, context=None, mask=None):
14 | att_type = "self" if context is None else "cross"
15 |
16 | h = self.heads
17 | q = self.to_q(x)
18 | context = x if context is None else context
19 | k = self.to_k(context)
20 | v = self.to_v(context)
21 |
22 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
23 |
24 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
25 |
26 | # attention, what we cannot get enough of
27 | sim = sim.softmax(dim=-1)
28 | out = einsum("b i j, b j d -> b i d", sim, v)
29 |
30 | if att_type == 'self' and q.shape[1] in [share.input_shape.res16]:
31 | out = out[:,torch.randperm(out.shape[1]),:]
32 |
33 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
34 | return self.to_out(out)
--------------------------------------------------------------------------------
/src/smplfusion/patches/attentionpatch/zeropaint.py:
--------------------------------------------------------------------------------
1 | from ... import share
2 | import torch
3 |
4 | # import xformers
5 | # import xformers.ops
6 |
7 | import torchvision.transforms.functional as TF
8 |
9 | from torch import nn, einsum
10 | from inspect import isfunction
11 | from einops import rearrange, repeat
12 |
13 | qkv_reduce_dims = [-1, -2]
14 |
15 | def forward(self, x, context=None, mask=None):
16 | h = self.heads
17 | q = self.to_q(x)
18 | is_cross = context is not None
19 | context = x if context is None else context
20 | k = self.to_k(context)
21 | v = self.to_v(context)
22 |
23 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
24 | scale = self.dim_head**-0.5
25 | sim = einsum("b i d, b j d -> b i j", q, k)
26 | steps_40 = [999, 974, 949, 924, 899, 874, 849, 824, 799, 774, 749, 724, 699, 674, 649, 624, 599, 574, 549, 524, 499, 474, 449, 424, 399, 374, 349, 324, 299, 274, 249, 224, 199, 174, 149, 124, 99, 74, 49, 24]
27 | test_ind = steps_40.index(share.timestep)
28 | if hasattr(share, 'list_of_masks') and is_cross:
29 | if q.shape[1] in [share.input_shape.res16, share.input_shape.res32] or True:
30 | for masked_object in share.list_of_masks:
31 | # sim: (16x4096x77); mask: (64x64) -> mask.reshape(-1): (4096)
32 | zp_condition = (
33 | masked_object['w']
34 | # * share.zp_sigmas[share.timestep]
35 | * share.zp_sigmas[test_ind]
36 | * masked_object['mask'].get_res(q, 'cuda').reshape(1,-1,1)# (1,4096,1)
37 | * sim.amax(dim=qkv_reduce_dims, keepdim = True) # (16x1x1)
38 | ) * (torch.eye(77)[masked_object['token_idx']].sum(0).cuda()) # (1,1,77)
39 |
40 | zp_zeros = torch.zeros_like(zp_condition).cuda()
41 | final_condition = torch.concat([zp_zeros[:8,:,:],zp_condition[8:,:,:]])
42 |
43 | sim = sim+final_condition
44 | # print(sim.shape,"wwwww")
45 | del q, k
46 |
47 | sim = sim*scale
48 | sim = sim.softmax(dim=-1)
49 |
50 | out = einsum("b i j, b j d -> b i d", sim, v)
51 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
52 | return self.to_out(out)
--------------------------------------------------------------------------------
/src/smplfusion/patches/router.py:
--------------------------------------------------------------------------------
1 | from . import attentionpatch
2 | from . import transformerpatch
3 |
4 | VERBOSE = False
5 | attention_forward = attentionpatch.default.forward
6 | basic_transformer_forward = transformerpatch.default.forward
7 |
8 | def reset():
9 | global attention_forward, basic_transformer_forward
10 | attention_forward = attentionpatch.default.forward
11 | basic_transformer_forward = transformerpatch.default.forward
12 | if VERBOSE: print ("Resetting Diffusion Model")
13 |
14 | print ("RELOADING ROUTER")
15 | print (attention_forward)
--------------------------------------------------------------------------------
/src/smplfusion/patches/transformerpatch/__init__.py:
--------------------------------------------------------------------------------
1 | from . import default
2 | from . import guided
3 | from . import weighting_versions
4 | from . import introvert
5 |
--------------------------------------------------------------------------------
/src/smplfusion/patches/transformerpatch/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/transformerpatch/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/transformerpatch/__pycache__/default.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/transformerpatch/__pycache__/default.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/transformerpatch/__pycache__/guided.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/transformerpatch/__pycache__/guided.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/transformerpatch/__pycache__/introvert.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/transformerpatch/__pycache__/introvert.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/transformerpatch/__pycache__/weighting_versions.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/patches/transformerpatch/__pycache__/weighting_versions.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/patches/transformerpatch/default.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ... import share
3 |
4 | def set_wsa(value):
5 | global wsa
6 | wsa = torch.tensor([1., value])[:,None,None].cuda()
7 | def set_wca(value):
8 | global wca
9 | wca = torch.tensor([1., value])[:,None,None].cuda()
10 | def set_wff(value):
11 | global wff
12 | wff = torch.tensor([1., value])[:,None,None].cuda()
13 |
14 | # set_wca(1.)
15 | # set_wsa(1.)
16 | # set_wff(1.)
17 |
18 | wca = 1.0
19 | wsa = 1.0
20 | wff = 1.0
21 |
22 | def forward(self, x, context=None):
23 | x = x + wsa * self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) # Self Attn.
24 | x = x + wca * self.attn2(self.norm2(x), context=context) # Cross Attn.
25 | x = x + wff * self.ff(self.norm3(x))
26 | return x
27 |
28 | def forward_and_save(self, x, context=None):
29 | val = [x]
30 | val.append(wsa * self.attn1(self.norm1(x), context=context if self.disable_self_attn else None)) # Self Attn.
31 | x = x + val[-1]
32 | val.append(wca * self.attn2(self.norm2(x), context=context)) # Cross Attn.
33 | x = x + val[-1]
34 | val.append(wff * self.ff(self.norm3(x)))
35 | x = x + val[-1]
36 |
37 | # Save outputs
38 | if hasattr(share, '_basic_transformer_input') and x.shape[1] == share.input_shape.res16:
39 | share._basic_transformer_input.append(share.reshape(val[0]))
40 | if hasattr(share, '_basic_transformer_selfattn') and x.shape[1] == share.input_shape.res16:
41 | share._basic_transformer_selfattn.append(share.reshape(val[1]))
42 | if hasattr(share, '_basic_transformer_crossattn') and x.shape[1] == share.input_shape.res16:
43 | share._basic_transformer_crossattn.append(share.reshape(val[2]))
44 | if hasattr(share, '_basic_transformer_ff') and x.shape[1] == share.input_shape.res16:
45 | share._basic_transformer_ff.append(share.reshape(val[3]))
46 | return x
47 |
--------------------------------------------------------------------------------
/src/smplfusion/patches/transformerpatch/guided.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ... import share
3 |
4 | def set_wsa(value):
5 | global wsa
6 | wsa = torch.tensor([1., value])[:,None,None].cuda()
7 | def set_wca(value):
8 | global wca
9 | wca = torch.tensor([1., value])[:,None,None].cuda()
10 | def set_wff(value):
11 | global wff
12 | wff = torch.tensor([1., value])[:,None,None].cuda()
13 |
14 | set_wca(1.)
15 | set_wsa(1.)
16 | set_wff(1.)
17 |
18 | guidance_scale = 7.5
19 |
20 | def forward(self, x, context=None):
21 | # print (x.shape)
22 | mask = share.input_mask.get_res(x).reshape(-1,1).cuda()
23 |
24 | _out_sa = wsa * self.attn1(self.norm1(x), None) # Self Attn.
25 |
26 | # _out_ca_uncond = wca * self.attn2(self.norm2(x), context=context) # Cross Attn. "Unconditional"
27 | _out_ca = wca * self.attn2(self.norm2(x + _out_sa), context=context) # Cross Attn.
28 |
29 | if share.timestep < 0:
30 | # x = x + (1 - mask) * _out_sa + mask * (guidance_scale * _out_ca_uncond + (1 - guidance_scale) * (_out_sa + _out_ca))
31 | x = x + (1 - mask) * (_out_sa + _out_ca) + mask * (guidance_scale * _out_ca_uncond + (1 - guidance_scale) * (_out_sa + _out_ca))
32 | else:
33 | x = x + _out_sa + _out_ca
34 |
35 | _out_ff = wff * self.ff(self.norm3(x))
36 | x = x + _out_ff
37 |
38 | if False:
39 | save([_out_ca, _out_sa, _out_ff])
40 |
41 | return x
42 |
43 | forward_and_save = forward
44 |
45 | def save(val):# Save outputs
46 | if hasattr(share, '_basic_transformer_selfattn'):
47 | share._basic_transformer_selfattn.append(share.reshape(val[0]))
48 | if hasattr(share, '_basic_transformer_crossattn'):
49 | share._basic_transformer_crossattn.append(share.reshape(val[1]))
50 | if hasattr(share, '_basic_transformer_ff'):
51 | share._basic_transformer_ff.append(share.reshape(val[1]))
--------------------------------------------------------------------------------
/src/smplfusion/patches/transformerpatch/introvert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from einops import rearrange, repeat
4 | from ... import share
5 | from ..attentionpatch import introvert
6 |
7 | w_ff = 1.
8 | w_sa = 1.
9 | w_ca = 1.
10 |
11 | use_grad = True
12 |
13 |
14 | def forward(self, x, context=None):
15 | # with torch.no_grad():
16 | if use_grad:
17 | y, self_v, self_sim = self.attn1(self.norm1(x), None) # Self Attn.
18 |
19 | x_uncond, x_cond = x.chunk(2)
20 | context_uncond, context_cond = context.chunk(2)
21 |
22 | y_uncond, y_cond = y.chunk(2)
23 | self_sim_uncond, self_sim_cond = self_sim.chunk(2)
24 | self_v_uncond, self_v_cond = self_v.chunk(2)
25 |
26 | # Calculate CA similarities with conditional context
27 | cross_h = self.attn2.heads
28 | cross_q = self.attn2.to_q(self.norm2(x_cond+y_cond))
29 | cross_k = self.attn2.to_k(context_cond)
30 | cross_v = self.attn2.to_v(context_cond)
31 |
32 | cross_q, cross_k, cross_v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=cross_h), (cross_q, cross_k, cross_v))
33 |
34 | with torch.autocast(enabled=False, device_type = 'cuda'):
35 | cross_q, cross_k = cross_q.float(), cross_k.float()
36 | cross_sim = einsum('b i d, b j d -> b i j', cross_q, cross_k) * self.attn2.scale
37 |
38 | del cross_q, cross_k
39 | cross_sim = cross_sim.softmax(dim=-1) # Up to this point cross_sim is regular cross_sim in CA layer
40 |
41 | cross_sim = cross_sim.mean(dim=0) # Calculate mean across heads
42 |
43 | # Introvert Attention rescale heppens here
44 | y_cond = introvert.introvert_rescale(
45 | y_cond, self_v_cond, self_sim_cond, cross_sim, self.attn1.heads, self.attn1.to_out) # Rescale cond
46 | y_uncond = introvert.introvert_rescale(
47 | y_uncond, self_v_uncond, self_sim_uncond, cross_sim, self.attn1.heads, self.attn1.to_out) # Rescale uncond
48 |
49 | y = torch.cat([y_uncond, y_cond], dim=0)
50 |
51 | x = x + w_sa * y
52 | x = x + w_ca * self.attn2(self.norm2(x), context=context) # Cross Attn.
53 | x = x + w_ff * self.ff(self.norm3(x))
54 | return x
--------------------------------------------------------------------------------
/src/smplfusion/patches/transformerpatch/weighting_versions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ... import share
3 |
4 | def set_wsa(value):
5 | global wsa
6 | wsa = torch.tensor([1., value])[:,None,None].cuda()
7 | def set_wca(value):
8 | global wca
9 | wca = torch.tensor([1., value])[:,None,None].cuda()
10 | def set_wff(value):
11 | global wff
12 | wff = torch.tensor([1., value])[:,None,None].cuda()
13 |
14 | set_wca(1.)
15 | set_wsa(1.)
16 | set_wff(1.)
17 |
18 | def forward(self, x, context=None):
19 | x = x + wca * self.attn2(self.norm2(x), context=context) # Cross Attn.
20 | x = x + wsa * self.attn1(self.norm1(x), None) # Self Attn.
21 | x = x + wff * self.ff(self.norm3(x))
22 | return x
23 |
24 | def forward_and_save(self, x, context=None):
25 | val = [x]
26 | val.append(wsa * self.attn1(self.norm1(x), None)) # Self Attn.
27 | x = x + val[-1]
28 | val.append(wca * self.attn2(self.norm2(x), context=context)) # Cross Attn.
29 | x = x + val[-1]
30 | val.append(wff * self.ff(self.norm3(x)))
31 | x = x + val[-1]
32 |
33 | # Save outputs
34 | if hasattr(share, '_basic_transformer_input') and x.shape[1] == share.input_shape.res16:
35 | share._basic_transformer_input.append(share.reshape(val[0]))
36 | if hasattr(share, '_basic_transformer_selfattn') and x.shape[1] == share.input_shape.res16:
37 | share._basic_transformer_selfattn.append(share.reshape(val[1]))
38 | if hasattr(share, '_basic_transformer_crossattn') and x.shape[1] == share.input_shape.res16:
39 | share._basic_transformer_crossattn.append(share.reshape(val[2]))
40 | if hasattr(share, '_basic_transformer_ff') and x.shape[1] == share.input_shape.res16:
41 | share._basic_transformer_ff.append(share.reshape(val[3]))
42 | return x
43 |
44 |
45 | # ========= MODIFICATIONS ============= #
46 |
47 |
48 | def forward_and_save3(self, x, context=None):
49 | val = []
50 | val.append(self.attn1(self.norm1(x), None)) # Self Attn.
51 | modify_res = [share.input_mask.res16, share.input_mask.res32, share.input_mask.res64]
52 |
53 | if x.shape[1] in modify_res:
54 | val[-1] = val[-1] + (wsa - 1) * share.input_mask.get_res(x).reshape(-1)[:,None].cuda() * val[-1]
55 | x = x + val[-1]
56 |
57 | val.append(self.attn2(self.norm2(x), context=context)) # Cross Attn.
58 |
59 | if x.shape[1] in modify_res:
60 | val[-1] = val[-1] + (wca - 1) * share.input_mask.get_res(x).reshape(-1)[:,None].cuda() * val[-1]
61 | x = x + val[-1]
62 |
63 | val.append(self.ff(self.norm3(x)))
64 |
65 | if x.shape[1] in modify_res:
66 | val[-1] = val[-1] + (wff - 1) * share.input_mask.get_res(x).reshape(-1)[:,None].cuda() * val[-1]
67 | x = x + val[-1]
68 | if hasattr(share, 'out_basic_transformer_block') and x.shape[1] == share.input_mask.res16:
69 | share.out_basic_transformer_block.append(val)
70 | return x
71 |
72 | def forward_and_save2(self, x, context=None):
73 | val = []
74 | val.append(self.attn1(self.norm1(x), None)) # Self Attn.
75 | x = x + wsa * val[-1]
76 | val.append(self.attn2(self.norm2(x), context=context)) # Cross Attn.
77 | print (val[-1].shape, share.input_mask.val16.shape)
78 | x = x + val[-1] + wca * share.input_mask.val16 * val[-1]
79 | val.append(self.ff(self.norm3(x)))
80 | x = x + wff * val[-1]
81 | if hasattr(share, 'out_basic_transformer_block') and x.shape[1] == share.input_mask.res16:
82 | share.out_basic_transformer_block.append(val)
83 | return x
84 |
85 | def forward_and_reweight(self, x, context=None):
86 | modify_res = [share.input_mask.res16, share.input_mask.res32, share.input_mask.res64]
87 |
88 | _attn1 = self.attn1(self.norm1(x), None) # Self Attn.
89 | _attn2 = self.attn2(self.norm2(x + _attn1), context=context) # Cross Attn.
90 | _ff = self.ff(self.norm3(x + _attn1 + _attn2))
91 |
92 | if x.shape[1] in modify_res:
93 | lm1,lm2,lm3 = wsa - 1, wca - 1, wff - 1
94 | _attn1 *= (1 + lm1 * share.input_mask.get_res(x).reshape(-1)[:,None].cuda())
95 | _attn2 *= (1 + lm2 * share.input_mask.get_res(x).reshape(-1)[:,None].cuda())
96 | _ff *= (1 + lm3 * share.input_mask.get_res(x).reshape(-1)[:,None].cuda())
97 |
98 | if hasattr(share, 'out_basic_transformer_block') and x.shape[1] == share.input_mask.res16:
99 | share.out_basic_transformer_block.append([_attn1, _attn2, _ff])
100 |
101 | return x + _attn1 + _attn2 + _ff
102 | # return x + (w_sa * _attn1 + w_ca * _attn2 + w_ff * _ff)
103 | # return x + (w_sa * _attn1 + w_ca * _attn2 + w_ff * _ff) / ((w_sa + w_ca + w_ff) / 3)
104 |
105 | def forward_and_reweight2(self, x, context=None):
106 | _attn1 = self.attn1(self.norm1(x), None) # Self Attn.
107 | _attn2 = self.attn2(self.norm2(x + _attn1), context=context) # Cross Attn.
108 | _ff = self.ff(self.norm3(x + _attn1 + _attn2))
109 |
110 | _attn1 += (wsa - 1) / ((wsa + wca + wff) / 3) * share.input_mask.get_res(x).reshape(-1)[:,None].cuda() * _attn1
111 | _attn2 += (wca - 1) / ((wsa + wca + wff) / 3) * share.input_mask.get_res(x).reshape(-1)[:,None].cuda() * _attn2
112 | _ff += (wff - 1) / ((wsa + wca + wff) / 3) * share.input_mask.get_res(x).reshape(-1)[:,None].cuda() * _ff
113 |
114 |
115 | if hasattr(share, 'out_basic_transformer_block') and x.shape[1] == share.input_mask.res16:
116 | share.out_basic_transformer_block.append([_attn1, _attn2, _ff])
117 | # share.out_basic_transformer_block.append([
118 | # w_sa / ((w_sa + w_ca + w_ff) / 3) * _attn1,
119 | # w_ca / ((w_sa + w_ca + w_ff) / 3) * _attn2,
120 | # w_ff / ((w_sa + w_ca + w_ff) / 3) * _ff
121 | # ])
122 |
123 | return x + _attn1 + _attn2 + _ff
124 | # return x + (w_sa * _attn1 + w_ca * _attn2 + w_ff * _ff)
125 | # return x + (w_sa * _attn1 + w_ca * _attn2 + w_ff * _ff) / ((w_sa + w_ca + w_ff) / 3)
--------------------------------------------------------------------------------
/src/smplfusion/scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def linear(n_timestep = 1000, start = 1e-4, end = 2e-2):
4 | return Schedule(torch.linspace(start ** 0.5, end ** 0.5, n_timestep, dtype = torch.float64) ** 2)
5 |
6 | class Schedule:
7 | def __init__(self, betas):
8 | self.betas = betas
9 | self._alphas = 1 - betas
10 | self.alphas = torch.cumprod(self._alphas, 0)
11 | self.one_minus_alphas = 1 - self.alphas
12 | self.sqrt_alphas = torch.sqrt(self.alphas)
13 | self.sqrt_one_minus_alphas = torch.sqrt(1 - self.alphas)
14 | self.sqrt_noise_signal_ratio = self.sqrt_one_minus_alphas / self.sqrt_alphas
15 | self.noise_signal_ratio = (1 - self.alphas) / self.alphas
16 |
17 | def range(self, dt):
18 | return range(len(self.betas)-1, 0, -dt)
19 |
20 | def sigma(self, t, dt):
21 | return torch.sqrt((1 - self._alphas[t - dt]) / (1 - self._alphas[t]) * (1 - self._alphas[t] / self._alphas[t - dt])) # Like I did initially
22 | # return torch.sqrt((1 - self.alphas[t - dt]) / (1 - self.alphas[t]) * (1 - self.alphas[t] / self.alphas[t - dt])) # Like in diffusers
23 |
24 | def sigma_(self, t, dt):
25 | return torch.sqrt((1 - self.alphas[t - dt]) / (1 - self.alphas[t]) * (1 - self.alphas[t] / self.alphas[t - dt])) # Like in diffusers
--------------------------------------------------------------------------------
/src/smplfusion/share.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms.functional as TF
2 | import torch
3 | import sys
4 | from .utils import *
5 |
6 | input_mask = None
7 | input_shape = None
8 | timestep = None
9 | timestep_index = None
10 |
11 | class Seed:
12 | def __getitem__(self, idx):
13 | if isinstance(idx, slice):
14 | idx = list(range(*idx.indices(idx.stop)))
15 | if isinstance(idx, list) or isinstance(idx, tuple):
16 | return [self[_idx] for _idx in idx]
17 | return 12345 ** idx % 54321
18 | seed = Seed()
19 |
20 | lock = {}
21 | def get_lock(value):
22 | global lock
23 | if value not in lock:
24 | lock[value] = True
25 | if lock[value]:
26 | lock[value] = False
27 | return True
28 | return False
29 |
30 |
31 | class DDIMIterator:
32 | def __init__(self, iterator):
33 | self.iterator = iterator
34 | def __iter__(self):
35 | self.iterator = iter(self.iterator)
36 | global timestep_index
37 | timestep_index = 0
38 | return self
39 | def __next__(self):
40 | global timestep, timestep_index, lock
41 | for x in lock: lock[x] = True
42 | timestep = next(self.iterator)
43 | timestep_index += 1
44 | return timestep
45 |
46 | def reshape(x):
47 | return input_shape.reshape(x)
48 |
49 | def set_shape(image_or_shape):
50 | global input_shape
51 | if hasattr(image_or_shape, 'size'):
52 | input_shape = InputShape(image_or_shape.size)
53 | if isinstance(image_or_shape, torch.Tensor):
54 | input_shape = InputShape(image_or_shape.shape[-2:][::-1])
55 | elif isinstance(image_or_shape, list) or isinstance(image_or_shape, tuple):
56 | input_shape = InputShape(image_or_shape)
57 |
58 | self = sys.modules[__name__]
59 |
60 | def set_mask(mask):
61 | global input_mask, mask64, mask32, mask16, mask8, introvert_mask
62 | input_mask = InputMask(mask)
63 | introvert_mask = InputMask(mask)
64 |
65 | mask64 = input_mask.val64[0,0]
66 | mask32 = input_mask.val32[0,0]
67 | mask16 = input_mask.val16[0,0]
68 | mask8 = input_mask.val8[0,0]
69 | set_shape(mask)
70 |
71 | def exists(name):
72 | return hasattr(self, name) and getattr(self, name) is not None
--------------------------------------------------------------------------------
/src/smplfusion/util.py:
--------------------------------------------------------------------------------
1 | # TODO: remove everything below as long as it doesn't break anything!
2 |
3 | import importlib
4 |
5 | import torch
6 | from torch import optim
7 | import numpy as np
8 |
9 | from inspect import isfunction
10 | from PIL import Image, ImageDraw, ImageFont
11 |
12 |
13 | def autocast(f):
14 | def do_autocast(*args, **kwargs):
15 | with torch.cuda.amp.autocast(enabled=True,
16 | dtype=torch.get_autocast_gpu_dtype(),
17 | cache_enabled=torch.is_autocast_cache_enabled()):
18 | return f(*args, **kwargs)
19 |
20 | return do_autocast
21 |
22 |
23 | def log_txt_as_img(wh, xc, size=10):
24 | # wh a tuple of (width, height)
25 | # xc a list of captions to plot
26 | b = len(xc)
27 | txts = list()
28 | for bi in range(b):
29 | txt = Image.new("RGB", wh, color="white")
30 | draw = ImageDraw.Draw(txt)
31 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
32 | nc = int(40 * (wh[0] / 256))
33 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
34 |
35 | try:
36 | draw.text((0, 0), lines, fill="black", font=font)
37 | except UnicodeEncodeError:
38 | print("Cant encode string for logging. Skipping.")
39 |
40 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
41 | txts.append(txt)
42 | txts = np.stack(txts)
43 | txts = torch.tensor(txts)
44 | return txts
45 |
46 |
47 | def ismap(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] > 3)
51 |
52 |
53 | def isimage(x):
54 | if not isinstance(x,torch.Tensor):
55 | return False
56 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
57 |
58 |
59 | def exists(x):
60 | return x is not None
61 |
62 |
63 | def default(val, d):
64 | if exists(val):
65 | return val
66 | return d() if isfunction(d) else d
67 |
68 |
69 | def mean_flat(tensor):
70 | """
71 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
72 | Take the mean over all non-batch dimensions.
73 | """
74 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
75 |
76 |
77 | def count_params(model, verbose=False):
78 | total_params = sum(p.numel() for p in model.parameters())
79 | if verbose:
80 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
81 | return total_params
82 |
83 |
84 | def instantiate_from_config(config):
85 | if not "target" in config:
86 | if config == '__is_first_stage__':
87 | return None
88 | elif config == "__is_unconditional__":
89 | return None
90 | raise KeyError("Expected key `target` to instantiate.")
91 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
92 |
93 |
94 | def get_obj_from_str(string, reload=False):
95 | module, cls = string.rsplit(".", 1)
96 | if reload:
97 | module_imp = importlib.import_module(module)
98 | importlib.reload(module_imp)
99 | return getattr(importlib.import_module(module, package=None), cls)
100 |
101 |
102 | class AdamWwithEMAandWings(optim.Optimizer):
103 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
104 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
105 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
106 | ema_power=1., param_names=()):
107 | """AdamW that saves EMA versions of the parameters."""
108 | if not 0.0 <= lr:
109 | raise ValueError("Invalid learning rate: {}".format(lr))
110 | if not 0.0 <= eps:
111 | raise ValueError("Invalid epsilon value: {}".format(eps))
112 | if not 0.0 <= betas[0] < 1.0:
113 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
114 | if not 0.0 <= betas[1] < 1.0:
115 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
116 | if not 0.0 <= weight_decay:
117 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
118 | if not 0.0 <= ema_decay <= 1.0:
119 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
120 | defaults = dict(lr=lr, betas=betas, eps=eps,
121 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
122 | ema_power=ema_power, param_names=param_names)
123 | super().__init__(params, defaults)
124 |
125 | def __setstate__(self, state):
126 | super().__setstate__(state)
127 | for group in self.param_groups:
128 | group.setdefault('amsgrad', False)
129 |
130 | @torch.no_grad()
131 | def step(self, closure=None):
132 | """Performs a single optimization step.
133 | Args:
134 | closure (callable, optional): A closure that reevaluates the model
135 | and returns the loss.
136 | """
137 | loss = None
138 | if closure is not None:
139 | with torch.enable_grad():
140 | loss = closure()
141 |
142 | for group in self.param_groups:
143 | params_with_grad = []
144 | grads = []
145 | exp_avgs = []
146 | exp_avg_sqs = []
147 | ema_params_with_grad = []
148 | state_sums = []
149 | max_exp_avg_sqs = []
150 | state_steps = []
151 | amsgrad = group['amsgrad']
152 | beta1, beta2 = group['betas']
153 | ema_decay = group['ema_decay']
154 | ema_power = group['ema_power']
155 |
156 | for p in group['params']:
157 | if p.grad is None:
158 | continue
159 | params_with_grad.append(p)
160 | if p.grad.is_sparse:
161 | raise RuntimeError('AdamW does not support sparse gradients')
162 | grads.append(p.grad)
163 |
164 | state = self.state[p]
165 |
166 | # State initialization
167 | if len(state) == 0:
168 | state['step'] = 0
169 | # Exponential moving average of gradient values
170 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
171 | # Exponential moving average of squared gradient values
172 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
173 | if amsgrad:
174 | # Maintains max of all exp. moving avg. of sq. grad. values
175 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
176 | # Exponential moving average of parameter values
177 | state['param_exp_avg'] = p.detach().float().clone()
178 |
179 | exp_avgs.append(state['exp_avg'])
180 | exp_avg_sqs.append(state['exp_avg_sq'])
181 | ema_params_with_grad.append(state['param_exp_avg'])
182 |
183 | if amsgrad:
184 | max_exp_avg_sqs.append(state['max_exp_avg_sq'])
185 |
186 | # update the steps for each param group update
187 | state['step'] += 1
188 | # record the step after step update
189 | state_steps.append(state['step'])
190 |
191 | optim._functional.adamw(params_with_grad,
192 | grads,
193 | exp_avgs,
194 | exp_avg_sqs,
195 | max_exp_avg_sqs,
196 | state_steps,
197 | amsgrad=amsgrad,
198 | beta1=beta1,
199 | beta2=beta2,
200 | lr=group['lr'],
201 | weight_decay=group['weight_decay'],
202 | eps=group['eps'],
203 | maximize=False)
204 |
205 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
206 | for param, ema_param in zip(params_with_grad, ema_params_with_grad):
207 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
208 |
209 | return loss
--------------------------------------------------------------------------------
/src/smplfusion/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .input_image import InputImage
2 | from .input_mask import InputMask
3 | from .input_shape import InputShape
4 | from .layer_mask import LayerMask
--------------------------------------------------------------------------------
/src/smplfusion/utils/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/utils/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/utils/__pycache__/input_image.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/utils/__pycache__/input_image.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/utils/__pycache__/input_mask.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/utils/__pycache__/input_mask.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/utils/__pycache__/input_shape.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/utils/__pycache__/input_shape.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/utils/__pycache__/layer_mask.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/smplfusion/utils/__pycache__/layer_mask.cpython-39.pyc
--------------------------------------------------------------------------------
/src/smplfusion/utils/input_image.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ..libimage import IImage
3 |
4 | class InputImage:
5 | def to(self, device): return InputImage(self.image, device = device)
6 | def cuda(self): return InputImage(self.image, device = 'cuda')
7 | def cpu(self): return InputImage(self.image, device = 'cpu')
8 |
9 | def __init__(self, input_image):
10 | '''
11 | args:
12 | input_image: (b,c,h,w) tensor
13 | '''
14 | if hasattr(input_image, 'is_iimage'):
15 | self.image = input_image
16 | self.val512 = self.full = input_image.torch(0)
17 | elif isinstance(input_image, torch.Tensor):
18 | self.val512 = self.full = input_image.clone()
19 | self.image = IImage(input_image,0)
20 |
21 | self.h,self.w = h,w = self.val512.shape[-2:]
22 | self.shape = [self.h, self.w]
23 | self.shape64 = [self.h // 8, self.w // 8]
24 | self.shape32 = [self.h // 16, self.w // 16]
25 | self.shape16 = [self.h // 32, self.w // 32]
26 | self.shape8 = [self.h // 64, self.w // 64]
27 |
28 | self.res = self.h * self.w
29 | self.res64 = self.res // 64
30 | self.res32 = self.res // 64 // 4
31 | self.res16 = self.res // 64 // 16
32 | self.res8 = self.res // 64 // 64
33 |
34 | self.img = self.image
35 | self.img512 = self.image
36 | self.img64 = self.image.resize((h//8,w//8))
37 | self.img32 = self.image.resize((h//16,w//16))
38 | self.img16 = self.image.resize((h//32,w//32))
39 | self.img8 = self.image.resize((h//64,w//64))
40 |
41 | self.val64 = self.img64.torch()
42 | self.val32 = self.img32.torch()
43 | self.val16 = self.img16.torch()
44 | self.val8 = self.img8.torch()
45 |
46 | def get_res(self, q, device = 'cpu'):
47 | if q.shape[1] == self.res64: return self.val64.to(device)
48 | if q.shape[1] == self.res32: return self.val32.to(device)
49 | if q.shape[1] == self.res16: return self.val16.to(device)
50 | if q.shape[1] == self.res8: return self.val8.to(device)
51 |
52 | def get_shape(self, q, device = 'cpu'):
53 | if q.shape[1] == self.res64: return self.shape64
54 | if q.shape[1] == self.res32: return self.shape32
55 | if q.shape[1] == self.res16: return self.shape16
56 | if q.shape[1] == self.res8: return self.shape8
57 |
58 | def get_res_val(self, q, device = 'cpu'):
59 | if q.shape[1] == self.res64: return 64
60 | if q.shape[1] == self.res32: return 32
61 | if q.shape[1] == self.res16: return 16
62 | if q.shape[1] == self.res8: return 8
63 |
64 | def _repr_png_(self):
65 | return self.img16._repr_png_()
--------------------------------------------------------------------------------
/src/smplfusion/utils/input_mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ..libimage import IImage
3 |
4 | class InputMask:
5 | def to(self, device): return InputMask(self.image, device = device)
6 | def cuda(self): return InputMask(self.image, device = 'cuda')
7 | def cpu(self): return InputMask(self.image, device = 'cpu')
8 |
9 | def __init__(self, input_image, device = 'cpu'):
10 | '''
11 | args:
12 | input_image: (b,c,h,w) tensor
13 | '''
14 | if hasattr(input_image, 'is_iimage'):
15 | self.image = input_image
16 | self.val512 = self.full = (input_image.torch(0) > 0.5).float()
17 | elif isinstance(input_image, torch.Tensor):
18 | self.val512 = self.full = input_image.clone()
19 | self.image = IImage(input_image,0)
20 |
21 | self.h,self.w = h,w = self.val512.shape[-2:]
22 | self.shape = [self.h, self.w]
23 | self.shape64 = [self.h // 8, self.w // 8]
24 | self.shape32 = [self.h // 16, self.w // 16]
25 | self.shape16 = [self.h // 32, self.w // 32]
26 | self.shape8 = [self.h // 64, self.w // 64]
27 |
28 | self.res = self.h * self.w
29 | self.res64 = self.res // 64
30 | self.res32 = self.res // 64 // 4
31 | self.res16 = self.res // 64 // 16
32 | self.res8 = self.res // 64 // 64
33 |
34 | self.img = self.image
35 | self.img512 = self.image
36 | self.img64 = self.image.resize((h//8,w//8))
37 | self.img32 = self.image.resize((h//16,w//16))
38 | self.img16 = self.image.resize((h//32,w//32))
39 | self.img8 = self.image.resize((h//64,w//64))
40 |
41 | self.val64 = (self.img64.torch(0) > 0.5).float()
42 | self.val32 = (self.img32.torch(0) > 0.5).float()
43 | self.val16 = (self.img16.torch(0) > 0.5).float()
44 | self.val8 = ( self.img8.torch(0) > 0.5).float()
45 |
46 |
47 | def get_res(self, q, device = 'cpu'):
48 | if q.shape[1] == self.res64: return self.val64.to(device)
49 | if q.shape[1] == self.res32: return self.val32.to(device)
50 | if q.shape[1] == self.res16: return self.val16.to(device)
51 | if q.shape[1] == self.res8: return self.val8.to(device)
52 |
53 | def _repr_png_(self):
54 | return self.img16._repr_png_()
55 |
56 | def get_res(self, q, device = 'cpu'):
57 | if q.shape[1] == self.res64: return self.val64.to(device)
58 | if q.shape[1] == self.res32: return self.val32.to(device)
59 | if q.shape[1] == self.res16: return self.val16.to(device)
60 | if q.shape[1] == self.res8: return self.val8.to(device)
61 |
62 | def get_shape(self, q, device = 'cpu'):
63 | if q.shape[1] == self.res64: return self.shape64
64 | if q.shape[1] == self.res32: return self.shape32
65 | if q.shape[1] == self.res16: return self.shape16
66 | if q.shape[1] == self.res8: return self.shape8
67 |
68 | def get_res_val(self, q, device = 'cpu'):
69 | if q.shape[1] == self.res64: return 64
70 | if q.shape[1] == self.res32: return 32
71 | if q.shape[1] == self.res16: return 16
72 | if q.shape[1] == self.res8: return 8
73 |
74 |
75 | # class InputMask2:
76 | # def to(self, device): return InputMask2(self.image, device = device)
77 | # def cuda(self): return InputMask2(self.image, device = 'cuda')
78 | # def cpu(self): return InputMask2(self.image, device = 'cpu')
79 |
80 | # def __init__(self, input_image, device = 'cpu'):
81 | # '''
82 | # args:
83 | # input_image: (b,c,h,w) tensor
84 | # '''
85 | # if hasattr(input_image, 'is_iimage'):
86 | # self.image = input_image
87 | # self.val512 = self.full = input_image.torch(0).bool().float()
88 | # elif isinstance(input_image, torch.Tensor):
89 | # self.val512 = self.full = input_image.clone()
90 | # self.image = IImage(input_image,0)
91 |
92 | # self.h,self.w = h,w = self.val512.shape[-2:]
93 | # self.shape = [self.h, self.w]
94 | # self.shape64 = [self.h // 8, self.w // 8]
95 | # self.shape32 = [self.h // 16, self.w // 16]
96 | # self.shape16 = [self.h // 32, self.w // 32]
97 | # self.shape8 = [self.h // 64, self.w // 64]
98 |
99 | # self.res = self.h * self.w
100 | # self.res64 = self.res // 64
101 | # self.res32 = self.res // 64 // 4
102 | # self.res16 = self.res // 64 // 16
103 | # self.res8 = self.res // 64 // 64
104 |
105 | # self.img = self.image
106 | # self.img512 = self.image
107 | # self.img64 = self.image.resize((h//8,w//8))
108 | # self.img32 = self.image.resize((h//16,w//16))
109 | # self.img16 = self.image.resize((h//32,w//32)).dilate(1)
110 | # self.img8 = self.image.resize((h//64,w//64)).dilate(1)
111 |
112 | # self.val64 = self.img64.torch(0).bool().float()
113 | # self.val32 = self.img32.torch(0).bool().float()
114 | # self.val16 = self.img16.torch(0).bool().float()
115 | # self.val8 = self.img8.torch(0).bool().float()
116 |
117 |
118 | # def get_res(self, q, device = 'cpu'):
119 | # if q.shape[1] == self.res64: return self.val64.to(device)
120 | # if q.shape[1] == self.res32: return self.val32.to(device)
121 | # if q.shape[1] == self.res16: return self.val16.to(device)
122 | # if q.shape[1] == self.res8: return self.val8.to(device)
123 |
124 | # def _repr_png_(self):
125 | # return self.img16._repr_png_()
126 |
127 | # def get_res(self, q, device = 'cpu'):
128 | # if q.shape[1] == self.res64: return self.val64.to(device)
129 | # if q.shape[1] == self.res32: return self.val32.to(device)
130 | # if q.shape[1] == self.res16: return self.val16.to(device)
131 | # if q.shape[1] == self.res8: return self.val8.to(device)
132 |
133 | # def get_shape(self, q, device = 'cpu'):
134 | # if q.shape[1] == self.res64: return self.shape64
135 | # if q.shape[1] == self.res32: return self.shape32
136 | # if q.shape[1] == self.res16: return self.shape16
137 | # if q.shape[1] == self.res8: return self.shape8
138 |
139 | # def get_res_val(self, q, device = 'cpu'):
140 | # if q.shape[1] == self.res64: return 64
141 | # if q.shape[1] == self.res32: return 32
142 | # if q.shape[1] == self.res16: return 16
143 | # if q.shape[1] == self.res8: return 8
--------------------------------------------------------------------------------
/src/smplfusion/utils/input_shape.py:
--------------------------------------------------------------------------------
1 | class InputShape:
2 | def __init__(self, image_size):
3 | self.h,self.w = image_size[::-1]
4 | self.res = self.h * self.w
5 | self.res64 = self.res // 64
6 | self.res32 = self.res // 64 // 4
7 | self.res16 = self.res // 64 // 16
8 | self.res8 = self.res // 64 // 64
9 | self.shape = [self.h, self.w]
10 | self.shape64 = [self.h // 8, self.w // 8]
11 | self.shape32 = [self.h // 16, self.w // 16]
12 | self.shape16 = [self.h // 32, self.w // 32]
13 | self.shape8 = [self.h // 64, self.w // 64]
14 |
15 | def reshape(self, x):
16 | assert len(x.shape) == 3
17 | if x.shape[1] == self.res64: return x.reshape([x.shape[0]] + self.shape64 + [x.shape[-1]])
18 | if x.shape[1] == self.res32: return x.reshape([x.shape[0]] + self.shape32 + [x.shape[-1]])
19 | if x.shape[1] == self.res16: return x.reshape([x.shape[0]] + self.shape16 + [x.shape[-1]])
20 | if x.shape[1] == self.res8: return x.reshape([x.shape[0]] + self.shape8 + [x.shape[-1]])
21 | raise Exception("Unknown shape")
22 |
23 | def get_res(self, q, device = 'cpu'):
24 | if q.shape[1] == self.res64: return 64
25 | if q.shape[1] == self.res32: return 32
26 | if q.shape[1] == self.res16: return 16
27 | if q.shape[1] == self.res8: return 8
--------------------------------------------------------------------------------
/src/smplfusion/utils/layer_mask.py:
--------------------------------------------------------------------------------
1 | class LayerMask:
2 | def __init__(self, self64 = False, self32 = False, self16 = False, self8 = False, cross64 = False, cross32 = False, cross16 = False, cross8 = False):
3 | self.self64 = self64
4 | self.self32 = self32
5 | self.self16 = self16
6 | self.self8 = self8
7 | self.cross64 = cross64
8 | self.cross32 = cross32
9 | self.cross16 = cross16
10 | self.cross8 = cross8
11 |
--------------------------------------------------------------------------------
/src/zeropainter/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Picsart-AI-Research/Zero-Painter/9d391845e633ef69d85ed41ab07ff026d2ce722a/src/zeropainter/__init__.py
--------------------------------------------------------------------------------
/src/zeropainter/convert_diffusers.py:
--------------------------------------------------------------------------------
1 | # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2 | # *Only* converts the UNet, VAE, and Text Encoder.
3 | # Does not convert optimizer state or any other thing.
4 | # Written by jachiam https://github.com/jachiam
5 |
6 | import argparse
7 | import os.path as osp
8 |
9 | import torch
10 |
11 |
12 | # =================#
13 | # UNet Conversion #
14 | # =================#
15 |
16 | unet_conversion_map = [
17 | # (stable-diffusion, HF Diffusers)
18 | ("time_embed.0.weight", "time_embedding.linear_1.weight"),
19 | ("time_embed.0.bias", "time_embedding.linear_1.bias"),
20 | ("time_embed.2.weight", "time_embedding.linear_2.weight"),
21 | ("time_embed.2.bias", "time_embedding.linear_2.bias"),
22 | ("input_blocks.0.0.weight", "conv_in.weight"),
23 | ("input_blocks.0.0.bias", "conv_in.bias"),
24 | ("out.0.weight", "conv_norm_out.weight"),
25 | ("out.0.bias", "conv_norm_out.bias"),
26 | ("out.2.weight", "conv_out.weight"),
27 | ("out.2.bias", "conv_out.bias"),
28 | ]
29 |
30 | unet_conversion_map_resnet = [
31 | # (stable-diffusion, HF Diffusers)
32 | ("in_layers.0", "norm1"),
33 | ("in_layers.2", "conv1"),
34 | ("out_layers.0", "norm2"),
35 | ("out_layers.3", "conv2"),
36 | ("emb_layers.1", "time_emb_proj"),
37 | ("skip_connection", "conv_shortcut"),
38 | ]
39 |
40 | unet_conversion_map_layer = []
41 | # hardcoded number of downblocks and resnets/attentions...
42 | # would need smarter logic for other networks.
43 | for i in range(4):
44 | # loop over downblocks/upblocks
45 |
46 | for j in range(2):
47 | # loop over resnets/attentions for downblocks
48 | hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
49 | sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
50 | unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
51 |
52 | if i < 3:
53 | # no attention layers in down_blocks.3
54 | hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
55 | sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
56 | unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
57 |
58 | for j in range(3):
59 | # loop over resnets/attentions for upblocks
60 | hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
61 | sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
62 | unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
63 |
64 | if i > 0:
65 | # no attention layers in up_blocks.0
66 | hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
67 | sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
68 | unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
69 |
70 | if i < 3:
71 | # no downsample in down_blocks.3
72 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
73 | sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
74 | unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
75 |
76 | # no upsample in up_blocks.3
77 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
78 | sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
79 | unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
80 |
81 | hf_mid_atn_prefix = "mid_block.attentions.0."
82 | sd_mid_atn_prefix = "middle_block.1."
83 | unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
84 |
85 | for j in range(2):
86 | hf_mid_res_prefix = f"mid_block.resnets.{j}."
87 | sd_mid_res_prefix = f"middle_block.{2*j}."
88 | unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
89 |
90 |
91 | def convert_unet_state_dict(unet_state_dict):
92 | # buyer beware: this is a *brittle* function,
93 | # and correct output requires that all of these pieces interact in
94 | # the exact order in which I have arranged them.
95 | mapping = {k: k for k in unet_state_dict.keys()}
96 | for sd_name, hf_name in unet_conversion_map:
97 | mapping[hf_name] = sd_name
98 | for k, v in mapping.items():
99 | if "resnets" in k:
100 | for sd_part, hf_part in unet_conversion_map_resnet:
101 | v = v.replace(hf_part, sd_part)
102 | mapping[k] = v
103 | for k, v in mapping.items():
104 | for sd_part, hf_part in unet_conversion_map_layer:
105 | v = v.replace(hf_part, sd_part)
106 | mapping[k] = v
107 | new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
108 | return new_state_dict
109 |
110 |
111 | # ================#
112 | # VAE Conversion #
113 | # ================#
114 |
115 | vae_conversion_map = [
116 | # (stable-diffusion, HF Diffusers)
117 | ("nin_shortcut", "conv_shortcut"),
118 | ("norm_out", "conv_norm_out"),
119 | ("mid.attn_1.", "mid_block.attentions.0."),
120 | ]
121 |
122 | for i in range(4):
123 | # down_blocks have two resnets
124 | for j in range(2):
125 | hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
126 | sd_down_prefix = f"encoder.down.{i}.block.{j}."
127 | vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
128 |
129 | if i < 3:
130 | hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
131 | sd_downsample_prefix = f"down.{i}.downsample."
132 | vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
133 |
134 | hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
135 | sd_upsample_prefix = f"up.{3-i}.upsample."
136 | vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
137 |
138 | # up_blocks have three resnets
139 | # also, up blocks in hf are numbered in reverse from sd
140 | for j in range(3):
141 | hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
142 | sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
143 | vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
144 |
145 | # this part accounts for mid blocks in both the encoder and the decoder
146 | for i in range(2):
147 | hf_mid_res_prefix = f"mid_block.resnets.{i}."
148 | sd_mid_res_prefix = f"mid.block_{i+1}."
149 | vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
150 |
151 |
152 | vae_conversion_map_attn = [
153 | # (stable-diffusion, HF Diffusers)
154 | ("norm.", "group_norm."),
155 | ("q.", "query."),
156 | ("k.", "key."),
157 | ("v.", "value."),
158 | ("proj_out.", "proj_attn."),
159 | ]
160 |
161 |
162 | def reshape_weight_for_sd(w):
163 | # convert HF linear weights to SD conv2d weights
164 | return w.reshape(*w.shape, 1, 1)
165 |
166 |
167 | def convert_vae_state_dict(vae_state_dict):
168 | mapping = {k: k for k in vae_state_dict.keys()}
169 | for k, v in mapping.items():
170 | for sd_part, hf_part in vae_conversion_map:
171 | v = v.replace(hf_part, sd_part)
172 | mapping[k] = v
173 | for k, v in mapping.items():
174 | if "attentions" in k:
175 | for sd_part, hf_part in vae_conversion_map_attn:
176 | v = v.replace(hf_part, sd_part)
177 | mapping[k] = v
178 | new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
179 | weights_to_convert = ["q", "k", "v", "proj_out"]
180 | for k, v in new_state_dict.items():
181 | for weight_name in weights_to_convert:
182 | if f"mid.attn_1.{weight_name}.weight" in k:
183 | print(f"Reshaping {k} for SD format")
184 | new_state_dict[k] = reshape_weight_for_sd(v)
185 | return new_state_dict
186 |
187 |
188 | # =========================#
189 | # Text Encoder Conversion #
190 | # =========================#
191 | # pretty much a no-op
192 |
193 |
194 | def convert_text_enc_state_dict(text_enc_dict):
195 | return text_enc_dict
196 |
197 |
198 | if __name__ == "__main__":
199 | parser = argparse.ArgumentParser()
200 |
201 | parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
202 | parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
203 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
204 |
205 | args = parser.parse_args()
206 |
207 | assert args.model_path is not None, "Must provide a model path!"
208 |
209 | assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
210 |
211 | unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
212 | vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
213 | text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
214 |
215 | # Convert the UNet model
216 | unet_state_dict = torch.load(unet_path, map_location='cpu')
217 | unet_state_dict = convert_unet_state_dict(unet_state_dict)
218 | unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
219 |
220 | # Convert the VAE model
221 | vae_state_dict = torch.load(vae_path, map_location='cpu')
222 | vae_state_dict = convert_vae_state_dict(vae_state_dict)
223 | vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
224 |
225 | # Convert the text encoder model
226 | text_enc_dict = torch.load(text_enc_path, map_location='cpu')
227 | text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
228 | text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
229 |
230 | # Put together new checkpoint
231 | state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
232 | if args.half:
233 | state_dict = {k:v.half() for k,v in state_dict.items()}
234 | state_dict = {"state_dict": state_dict}
235 | torch.save(state_dict, args.checkpoint_path)
--------------------------------------------------------------------------------
/src/zeropainter/dreamshaper.py:
--------------------------------------------------------------------------------
1 | import src.smplfusion
2 | from src.smplfusion import scheduler
3 | from src.smplfusion.common import *
4 | from types import SimpleNamespace
5 |
6 | import importlib
7 | from omegaconf import OmegaConf
8 | import torch
9 | import safetensors
10 | import safetensors.torch
11 |
12 | from os.path import dirname
13 |
14 | print("Loading model: Dreamshaper Inpainting V8")
15 | PROJECT_DIR = dirname(dirname(dirname(dirname(__file__))))
16 | print(PROJECT_DIR)
17 | CONFIG_FOLDER = f"{PROJECT_DIR}/lib/smplfusion/config/"
18 | MODEL_FOLDER = f"{PROJECT_DIR}/assets/models/"
19 | ASSETS_FOLDER = f"{PROJECT_DIR}/assets/"
20 |
21 |
22 | def get_inpainting_condition(model, image, mask):
23 | latent_size = [x // 8 for x in image.size]
24 | condition_x0 = (
25 | model.vae.encode(image.torch().cuda() * ~mask.torch(0).bool().cuda()).mean
26 | * model.config.scale_factor
27 | )
28 | condition_mask = mask.resize(latent_size[::-1]).cuda().torch(0).bool().float()
29 | return torch.cat([condition_mask, condition_x0], 1)
30 |
31 |
32 | def get_obj_from_str(string):
33 | module, cls = string.rsplit(".", 1)
34 | return getattr(importlib.import_module(module, package=None), cls)
35 |
36 | def load_obj(path):
37 | objyaml = OmegaConf.load(path)
38 | return get_obj_from_str(objyaml["__class__"])(**objyaml.get("__init__", {}))
39 |
40 |
41 | def get_t2i_model():
42 | model_t2i = SimpleNamespace()
43 | state_dict = safetensors.torch.load_file(
44 | f"{MODEL_FOLDER}/dreamshaper/dreamshaper_8.safetensors", device="cuda"
45 | )
46 |
47 | model_t2i.config = OmegaConf.load(f"{CONFIG_FOLDER}/ddpm/v1.yaml")
48 | model_t2i.unet = load_obj(f"{CONFIG_FOLDER}/unet/v1.yaml").eval().cuda()
49 | model_t2i.vae = load_obj(f"{CONFIG_FOLDER}/vae.yaml").eval().cuda()
50 | model_t2i.encoder = load_obj(f"{CONFIG_FOLDER}/encoders/clip.yaml").eval().cuda()
51 |
52 | extract = lambda state_dict, model: {
53 | x[len(model) + 1 :]: y for x, y in state_dict.items() if model in x
54 | }
55 | unet_state = extract(state_dict, "model.diffusion_model")
56 | encoder_state = extract(state_dict, "cond_stage_model")
57 | vae_state = extract(state_dict, "first_stage_model")
58 |
59 | model_t2i.unet.load_state_dict(unet_state, strict=False)
60 | model_t2i.encoder.load_state_dict(encoder_state, strict=False)
61 | model_t2i.vae.load_state_dict(vae_state, strict=False)
62 | model_t2i.unet = model_t2i.unet.requires_grad_(False)
63 | model_t2i.encoder = model_t2i.encoder.requires_grad_(False)
64 | model_t2i.vae = model_t2i.vae.requires_grad_(False)
65 |
66 | model_t2i.schedule = scheduler.linear(
67 | model_t2i.config.timesteps,
68 | model_t2i.config.linear_start,
69 | model_t2i.config.linear_end,
70 | )
71 | return model_t2i, unet_state
72 |
73 |
74 | def get_inpainting_model():
75 | model_inp = SimpleNamespace()
76 | state_dict = safetensors.torch.load_file(
77 | f"{MODEL_FOLDER}/dreamshaper/dreamshaper_8Inpainting.safetensors", device="cuda"
78 | )
79 |
80 | model_inp.config = OmegaConf.load(f"{CONFIG_FOLDER}/ddpm/v1.yaml")
81 | model_inp.unet = load_obj(f"{CONFIG_FOLDER}/unet/inpainting/v1.yaml").eval().cuda()
82 | model_inp.vae = load_obj(f"{CONFIG_FOLDER}/vae.yaml").eval().cuda()
83 | model_inp.encoder = load_obj(f"{CONFIG_FOLDER}/encoders/clip.yaml").eval().cuda()
84 |
85 | extract = lambda state_dict, model: {
86 | x[len(model) + 1 :]: y for x, y in state_dict.items() if model in x
87 | }
88 | unet_state = extract(state_dict, "model.diffusion_model")
89 | encoder_state = extract(state_dict, "cond_stage_model")
90 | vae_state = extract(state_dict, "first_stage_model")
91 |
92 | model_inp.unet.load_state_dict(unet_state, strict=False)
93 | model_inp.encoder.load_state_dict(encoder_state, strict=False)
94 | model_inp.vae.load_state_dict(vae_state, strict=False)
95 | model_inp.unet = model_inp.unet.requires_grad_(False)
96 | model_inp.encoder = model_inp.encoder.requires_grad_(False)
97 | model_inp.vae = model_inp.vae.requires_grad_(False)
98 |
99 | model_inp.schedule = scheduler.linear(
100 | model_inp.config.timesteps,
101 | model_inp.config.linear_start,
102 | model_inp.config.linear_end,
103 | )
104 | return model_inp, unet_state
105 |
--------------------------------------------------------------------------------
/src/zeropainter/inpainting.py:
--------------------------------------------------------------------------------
1 | from src.smplfusion.common import *
2 | import torch
3 | from src.smplfusion import share
4 | from src.smplfusion.patches import attentionpatch
5 | from tqdm import tqdm
6 | from pytorch_lightning import seed_everything
7 | from src.smplfusion.patches import router
8 | from src.smplfusion import IImage
9 |
10 | # negative_prompt = "worst quality, ugly, gross, disfigured, deformed, dehydrated, extra limbs, fused body parts, mutilated, malformed, mutated, bad anatomy, bad proportions, low quality, cropped, low resolution, out of frame, poorly drawn, text, watermark, letters, jpeg artifacts"
11 | # positive_prompt = ", realistic, HD, Full HD, 4K, high quality, high resolution, masterpiece, trending on artstation, realistic lighting"
12 | negative_prompt = ''
13 | positive_prompt = ''
14 | VERBOSE = True
15 |
16 | class AttnForward:
17 | def __init__(self, masks, object_context, object_uc_context):
18 | self.masks = masks
19 | self.object_context = object_context
20 | self.object_uc_context = object_uc_context
21 |
22 | def __call__(data, self, x, context=None, mask=None):
23 | att_type = "self" if context is None else "cross"
24 | batch_size = x.shape[0]
25 |
26 | if att_type == 'cross' and x.shape[1] in [share.input_shape.res16]: # For cross attention
27 | out = torch.zeros_like(x)
28 | for i in range(len(data.masks)):
29 | if data.masks[i].sum()>0:
30 | if batch_size == 1:
31 | out[:,data.masks[i]] = attentionpatch.default.forward(
32 | self,
33 | x[:,data.masks[i] > 0],
34 | data.object_context[i][None]
35 | ).float()
36 | elif batch_size == 2:
37 | out[:,data.masks[i]] = attentionpatch.default.forward(
38 | self,
39 | x[:,data.masks[i] > 0],
40 | torch.stack([data.object_uc_context[i],data.object_context[i]])
41 | ).float()
42 | else:
43 | raise NotImplementedError("Batch Size > 1 not yet supported!")
44 | return out
45 | else:
46 | return attentionpatch.default.forward(self, x, context, mask)
47 |
48 | def gen_filled_image(model, prompt, image, mask, zp_masks, seed, T = 899, dt = 20, model_t2i = None, guidance_scale = 7.5, use_lcm_multistep = False):
49 | masks = [x.modified_mask.val16.flatten()>0 for x in zp_masks]
50 | masks.append(torch.stack(masks).sum(0) == 0)
51 |
52 | context = model.encoder.encode(['', f"realistic photo of a {prompt}" + positive_prompt])
53 | object_context = model.encoder.encode([x.local_prompt for x in zp_masks] + [f"realistic photo of a {prompt}" + positive_prompt])
54 | object_uc_context = model.encoder.encode([negative_prompt] * len(zp_masks) + [', '.join([x.local_prompt for x in zp_masks])])
55 |
56 | seed_everything(seed)
57 |
58 | eps = torch.randn((1,4,64,64)).cuda()
59 | # zT = schedule.sqrt_alphas[799] * condition_mask + schedule.sqrt_one_minus_alphas[799] * eps
60 | # zT = (condition_mask < 0) * zT + (condition_mask > 0) * eps
61 | # zT = schedule.sqrt_one_minus_alphas[899] * eps
62 | # zT = schedule.sqrt_alphas[899] * IImage(255*orig_masks_all).resize(64).alpha().torch().cuda() + schedule.sqrt_one_minus_alphas[899] * eps
63 | condition_x0 = model.vae.encode(image.torch().cuda() * ~mask.torch(0).bool().cuda()).mean * model.config.scale_factor
64 | condition_xT = model.schedule.sqrt_alphas[T] * condition_x0 + model.schedule.sqrt_one_minus_alphas[T] * eps
65 | condition_mask = mask.resize(64).cuda().torch(0).bool().float()
66 | zT = (condition_mask == 0) * condition_xT + (condition_mask > 0) * eps
67 | # zT = eps
68 |
69 | router.attention_forward = AttnForward(masks, object_context, object_uc_context)
70 | # router.attention_forward = attentionpatch.default.forward
71 |
72 | with torch.autocast('cuda'), torch.no_grad():
73 | zt = zT
74 | timesteps = list(range(T, 0, -dt))
75 | pbar = tqdm(timesteps) if VERBOSE else timesteps
76 | for index,t in enumerate(pbar):
77 | if index == 0:
78 | current_mask = ~(~mask).dilate(2)
79 | condition_x0 = model.vae.encode(image.torch().cuda() * ~mask.cuda().torch(0).bool().cuda()).mean * model.config.scale_factor
80 | condition_mask = current_mask.resize(64).cuda().torch(0).bool().float()
81 | condition = torch.cat([condition_mask, condition_x0], 1)
82 | if index == 5:
83 | current_mask = ~(~mask).dilate(0)
84 | condition_x0 = model.vae.encode(image.torch().cuda() * ~mask.cuda().torch(0).bool().cuda()).mean * model.config.scale_factor
85 | condition_mask = current_mask.resize(64).cuda().torch(0).bool().float()
86 | condition = torch.cat([condition_mask, condition_x0], 1)
87 | if index == len(timesteps) - 5:
88 | if model_t2i is not None:
89 | model = model_t2i
90 | condition = None
91 | else:
92 | current_mask = mask.dilate(512)
93 | condition_x0 = model.vae.encode(image.torch().cuda() * ~current_mask.cuda().torch(0).bool().cuda()).mean * model.config.scale_factor
94 | condition_mask = current_mask.resize(64).cuda().torch(0).bool().float()
95 | condition = torch.cat([condition_mask, condition_x0], 1)
96 |
97 | _zt = zt if condition is None else torch.cat([zt, condition], 1)
98 |
99 | if use_lcm_multistep:
100 | eps = model.unet(
101 | _zt,
102 | timesteps = torch.tensor([t]).cuda(),
103 | context = context[1][None]
104 | )
105 | else:
106 | eps_uncond, eps = model.unet(
107 | torch.cat([_zt, _zt]),
108 | timesteps = torch.tensor([t, t]).cuda(),
109 | context = context
110 | ).chunk(2)
111 | eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
112 |
113 | z0 = (zt - model.schedule.sqrt_one_minus_alphas[t] * eps) / model.schedule.sqrt_alphas[t]
114 | zt = model.schedule.sqrt_alphas[t - dt] * z0 + model.schedule.sqrt_one_minus_alphas[t - dt] * eps
115 |
116 | out = IImage(model.vae.decode(z0 / model.config.scale_factor))
117 | return out
118 | # # out.save('../../assets/paper_data/visuals_for_paper/'+str(seed_num)+'_'+name)
119 |
120 | # np_out = np.array(out.data[0])
121 | # print(np_out.max(), orig_masks_all.max())
122 | # blending_result_helper = 0.4*(np_out/255)+0.6*orig_masks_all
123 | # res = np.hstack([np_out,blending_result_helper*255])
124 | # n = name.split('.')[0]
125 | # cv2.imwrite('../../assets/paper_data/visuals_for_paper/'+n+'_'+str(my_seed_gen)+'__'+str(seed_num)+'_.png',res[:,:,::-1])
126 | # (IImage(255 * blending_result_helper) | IImage(np_out))
--------------------------------------------------------------------------------
/src/zeropainter/models.py:
--------------------------------------------------------------------------------
1 | from types import SimpleNamespace
2 | from . import convert_diffusers
3 | import safetensors.torch
4 |
5 | import torch
6 | from src.smplfusion import scheduler
7 | from os.path import dirname
8 |
9 | from omegaconf import OmegaConf
10 | import importlib
11 |
12 | def load_obj(objyaml):
13 | if "__init__" in objyaml:
14 | return get_obj_from_str(objyaml["__class__"])(**objyaml["__init__"])
15 | else:
16 | return get_obj_from_str(objyaml["__class__"])()
17 |
18 |
19 | def get_obj_from_str(string):
20 | module, cls = string.rsplit(".", 1)
21 | try:
22 | return getattr(importlib.import_module(module, package=None), cls)
23 | except Exception as e:
24 | return getattr(importlib.import_module("src." + module, package=None), cls)
25 |
26 |
27 | def get_inpainting_condition(model, image, mask):
28 | latent_size = [x // 8 for x in image.size]
29 | condition_x0 = (
30 | model.vae.encode(image.torch().cuda() * ~mask.torch(0).bool().cuda()).mean
31 | * model.config.scale_factor
32 | )
33 | condition_mask = mask.resize(latent_size[::-1]).cuda().torch(0).bool().float()
34 | return torch.cat([condition_mask, condition_x0], 1)
35 |
36 |
37 | def get_t2i_model(config_folder, model_folder):
38 | model_t2i = SimpleNamespace()
39 |
40 | model_t2i.config = OmegaConf.load(
41 | f"{config_folder}/ddpm/v1.yaml"
42 | ) # smplfusion.options.ddpm.v1_yaml
43 | model_t2i.schedule = scheduler.linear(
44 | model_t2i.config.timesteps,
45 | model_t2i.config.linear_start,
46 | model_t2i.config.linear_end,
47 | )
48 |
49 | model_t2i.unet = load_obj(
50 | OmegaConf.load(f"{config_folder}/unet/v1.yaml")
51 | ).cuda() # smplfusion.options.unet.v1_yaml.cuda()
52 | model_t2i.vae = load_obj(
53 | OmegaConf.load(f"{config_folder}/vae.yaml")
54 | ).cuda() # smplfusion.options.vae_yaml.cuda()
55 | model_t2i.encoder = load_obj(
56 | OmegaConf.load(f"{config_folder}/encoders/clip.yaml")
57 | ).cuda() # smplfusion.options.encoders.clip_yaml.cuda()
58 |
59 | unet_state_dict = torch.load(
60 | f"{model_folder}/unet.ckpt"
61 | ) # assets.models.sd_1_5_inpainting.unet_ckpt
62 | vae_state_dict = torch.load(
63 | f"{model_folder}/vae.ckpt"
64 | ) # assets.models.sd_1_5_inpainting.vae_ckpt
65 | encoder_state_dict = torch.load(
66 | f"{model_folder}/encoder.ckpt"
67 | ) # assets.models.sd_1_5_inpainting.encoder_ckpt
68 |
69 | model_t2i.unet.load_state_dict(unet_state_dict)
70 | model_t2i.vae.load_state_dict(vae_state_dict)
71 | model_t2i.encoder.load_state_dict(encoder_state_dict, strict=False)
72 |
73 | model_t2i.unet = model_t2i.unet.requires_grad_(False).eval()
74 | model_t2i.vae = model_t2i.vae.requires_grad_(False).eval()
75 | model_t2i.encoder = model_t2i.encoder.requires_grad_(False).eval()
76 | return model_t2i, unet_state_dict
77 |
78 |
79 | def get_inpainting_model(config_folder, model_folder):
80 | model_inp = SimpleNamespace()
81 | model_inp.config = OmegaConf.load(
82 | f"{config_folder}/ddpm/v1.yaml"
83 | ) # smplfusion.options.ddpm.v1_yaml
84 | model_inp.schedule = scheduler.linear(
85 | model_inp.config.timesteps,
86 | model_inp.config.linear_start,
87 | model_inp.config.linear_end,
88 | )
89 |
90 | model_inp.unet = load_obj(
91 | OmegaConf.load(f"{config_folder}/unet/inpainting/v1.yaml")
92 | ).cuda() # smplfusion.options.unet.inpainting.v1_yaml.cuda()
93 | model_inp.vae = load_obj(
94 | OmegaConf.load(f"{config_folder}/vae.yaml")
95 | ).cuda() # smplfusion.options.vae_yaml.cuda()
96 | model_inp.encoder = load_obj(
97 | OmegaConf.load(f"{config_folder}/encoders/clip.yaml")
98 | ).cuda() # smplfusion.options.encoders.clip_yaml.cuda()
99 |
100 | unet_state_dict = torch.load(
101 | f"{model_folder}/unet.ckpt"
102 | ) # assets.models.sd_1_5_inpainting.unet_ckpt
103 | vae_state_dict = torch.load(
104 | f"{model_folder}/vae.ckpt"
105 | ) # assets.models.sd_1_5_inpainting.vae_ckpt
106 | encoder_state_dict = torch.load(
107 | f"{model_folder}/encoder.ckpt"
108 | ) # assets.models.sd_1_5_inpainting.encoder_ckpt
109 |
110 | model_inp.unet.load_state_dict(unet_state_dict)
111 | model_inp.vae.load_state_dict(vae_state_dict)
112 | model_inp.encoder.load_state_dict(encoder_state_dict, strict=False)
113 |
114 | model_inp.unet = model_inp.unet.requires_grad_(False).eval()
115 | model_inp.vae = model_inp.vae.requires_grad_(False).eval()
116 | model_inp.encoder = model_inp.encoder.requires_grad_(False).eval()
117 |
118 | return model_inp, unet_state_dict
119 |
120 |
121 | def get_lora(lora_path):
122 | _lora_state_dict = safetensors.torch.load_file(lora_path)
123 |
124 | # Dictionary for converting
125 | unet_conversion_dict = {}
126 | unet_conversion_dict.update(convert_diffusers.unet_conversion_map_layer)
127 | unet_conversion_dict.update(convert_diffusers.unet_conversion_map_resnet)
128 | unet_conversion_dict.update(convert_diffusers.unet_conversion_map)
129 | unet_conversion_dict = {
130 | y.replace(".", "_"): x for x, y in unet_conversion_dict.items()
131 | }
132 | unet_conversion_dict["lora_unet_"] = ""
133 |
134 | lora_state_dict = {}
135 | for key in _lora_state_dict:
136 | key_converted = key
137 | for x, y in unet_conversion_dict.items():
138 | key_converted = key_converted.replace(x, y)
139 | lora_state_dict[key_converted] = _lora_state_dict[key]
140 |
141 | return lora_state_dict
142 |
--------------------------------------------------------------------------------
/src/zeropainter/segmentation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from os.path import dirname
3 |
4 |
5 | from segment_anything import SamPredictor, sam_model_registry
6 |
7 | def get_bounding_box(mask):
8 | all_y,all_x = np.where(mask == 1)
9 | all_y,all_x = sorted(all_y),sorted(all_x)
10 | x1,y1,x2,y2 = all_x[0],all_y[0],all_x[-1],all_y[-1]
11 | x,y,w,h = x1,y1,x2-x1,y2-y1
12 | return x,y,w,h
13 |
14 | def get_segmentation_model(path_to_check):
15 | # path_to_check = "sam_vit_h_4b8939.pth"
16 | sam = sam_model_registry["vit_h"](checkpoint=path_to_check)
17 | predictor = SamPredictor(sam)
18 | predictor.model.cuda();
19 | return predictor
20 |
21 | def get_segmentation(predictor, im, mask):
22 | im_np = im.data[0]
23 |
24 | # SAM prediction
25 | # mask = obj['data_512'][i]['mask']
26 | x,y,w,h = get_bounding_box(mask)
27 | predictor.set_image(im_np)
28 | input_box = np.array([x,y,x+w,y+h])
29 | masks, scores, other = predictor.predict(box=input_box)
30 |
31 | _masks = np.concatenate([masks, ~masks])
32 | _scores = np.concatenate([scores, scores])
33 | _scores = np.array([y for x,y in zip(_masks, _scores) if ((1 - mask) * x).sum() <= (mask * x).sum()])
34 | _masks = np.array([x for x in _masks if ((1 - mask) * x).sum() <= (mask * x).sum()])
35 | if len(_masks) > 0:
36 | masks,scores = _masks,_scores
37 |
38 | pred_seg_mask = masks[scores.argmax()]
39 | pred_seg_mask[pred_seg_mask > 0] = 1
40 | pred_seg_mask = np.stack([pred_seg_mask] * 3, axis=-1) * 1
41 |
42 | return pred_seg_mask * mask[...,None]
43 |
--------------------------------------------------------------------------------
/src/zeropainter/zero_painter_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import cv2
4 | import numpy as np
5 | from src.smplfusion import share
6 |
7 | class ZeroPainterMask:
8 | def __init__(self, color, local_prompt, img_grey,rgb,image_w,image_h):
9 | self.img_grey = img_grey
10 | self.color = eval(color)
11 | self.local_prompt = local_prompt
12 | self.bbox = None
13 | self.bbox_64 = None
14 | self.area = None
15 | self.mask = None
16 | self.mask_64 = None
17 | self.token_idx = None
18 | self.modified_mask = None
19 | self.inverse_mask = None
20 | self.modified_indexses_of_prompt = None
21 | self.sot_index = [0]
22 | self.w_positive = 1
23 | self.w_negative = 1
24 | self.image_w = image_w
25 | self.image_h = image_h
26 | self.rgb = rgb
27 | self.modify_info()
28 |
29 | def get_bounding_box(self, mask):
30 | mask = mask*255.0
31 | all_y, all_x = np.where(mask == 255.0)
32 | all_y, all_x = sorted(all_y), sorted(all_x)
33 | x1, y1, x2, y2 = all_x[0], all_y[0], all_x[-1], all_y[-1]
34 | x, y, w, h = x1, y1, x2 - x1, y2 - y1
35 | return x, y, w, h
36 |
37 |
38 | def modify_info(self):
39 | # Define the color to be searched for
40 | if self.rgb:
41 | color_to_search = np.array([self.color[0],self.color[1],self.color[2]])
42 | mask_1d = np.all(self.img_grey == color_to_search, axis=2)*1.0
43 | else:
44 | mask_1d = (self.img_grey == self.color) * 1.0
45 | mask_1d_64 = cv2.resize(mask_1d, (64, 64), interpolation=cv2.INTER_NEAREST)
46 | self.bbox = self.get_bounding_box(mask_1d)
47 | self.bbox_64 = self.get_bounding_box(mask_1d_64)
48 | self.area = self.bbox[2] * self.bbox[3]
49 | self.mask = mask_1d.copy()
50 | self.mask_64 = mask_1d_64 // 255
51 |
52 | splited_prompt = self.local_prompt.split()
53 | self.token_idx = np.arange(1, len(splited_prompt) + 1, 1, dtype=int)
54 |
55 | #dif part
56 | mask_1d = mask_1d[None,None,:,:]
57 | mask_1d = torch.from_numpy(mask_1d)
58 |
59 | self.modified_indexses_of_prompt = np.arange(5, len(splited_prompt)+1)
60 | self.modified_mask = share.InputMask(mask_1d)
61 | self.inverse_mask = share.InputMask(1-mask_1d)
62 |
63 | return
64 |
65 | class ZeroPainterSample:
66 | #Example of sample
67 | #{'prompt': 'Brown gift box beside red candle.',
68 | # 'color_context_dict': {'1': 'Brown gift box', '2': 'red candle'}}
69 | def __init__(self, item, img_grey, image_h=512,image_w=512):
70 |
71 | self.item = item
72 | self.global_prompt = self.item['prompt']
73 | self.image_w = image_w
74 | self.image_h = image_h
75 | self.img_grey = img_grey
76 | self.masks = self.load_masks()
77 |
78 |
79 | def load_masks(self):
80 | data_samples = []
81 |
82 | self.img_grey = cv2.resize(self.img_grey, (self.image_w, self.image_h),interpolation=cv2.INTER_NEAREST)
83 | for color, local_prompt in self.item['color_context_dict'].items():
84 | data_samples.append(ZeroPainterMask(color,local_prompt,self.img_grey,self.item['rgb'],self.image_w,self.image_h))
85 |
86 | return data_samples
87 |
88 |
89 | class ZeroPainterDataset:
90 | def __init__(self, root_path_img, json_path,rgb=True):
91 | self.root_path_img = root_path_img
92 | self.json_path = json_path
93 | self.rgb = rgb
94 |
95 | if isinstance(json_path, dict) or isinstance(json_path, list):
96 | self.json_data = json_path
97 | else:
98 | with open(self.json_path, 'r') as file:
99 | self.json_data = json.load(file)
100 | def __len__(self):
101 | return len(self.json_data)
102 |
103 | def __getitem__(self, index):
104 | item = self.json_data[index]
105 | item['rgb'] = self.rgb
106 | if isinstance(self.root_path_img, str):
107 | self.img_path = self.root_path_img
108 | if self.rgb:
109 | self.img_grey = cv2.imread(self.img_path)
110 | else:
111 | self.img_grey = cv2.imread(self.img_path,0)
112 | else:
113 | self.img_grey = np.array(self.root_path_img)
114 |
115 | return ZeroPainterSample(item, self.img_grey)
116 |
117 |
118 |
--------------------------------------------------------------------------------
/src/zeropainter/zero_painter_pipline.py:
--------------------------------------------------------------------------------
1 | from . import segmentation, generation, inpainting
2 |
3 | import torch
4 | import numpy as np
5 | from src.smplfusion import libimage
6 | from src.smplfusion import IImage
7 |
8 |
9 | class ZeroPainter:
10 | def __init__(self, model_t2i, model_inp, model_sam):
11 | self.model_t2i = model_t2i
12 | self.model_inp = model_inp
13 | self.model_sam = model_sam
14 |
15 |
16 | def gen_sample(
17 | self,
18 | sample,
19 | object_seed,
20 | image_seed,
21 | num_ddim_steps_t2i,
22 | num_ddim_steps_inp,
23 | cfg_scale_t2i=7.5,
24 | cfg_scale_inp=7.5,
25 | use_lcm_multistep_t2i=False,
26 | use_lcm_multistep_inp=False,
27 | ):
28 |
29 |
30 | gen_obj_list = []
31 | gen_mask_list = []
32 | real_mask_list = []
33 |
34 | if isinstance(object_seed, int):
35 | object_seed = [object_seed] * len(sample.masks)
36 |
37 | for i in range(len(sample.masks)):
38 | eps_list, z0_list, zt_list = generation.gen_single_object(
39 | self.model_t2i,
40 | sample.masks[i],
41 | object_seed[i],
42 | dt=1000 // num_ddim_steps_t2i,
43 | guidance_scale=cfg_scale_t2i,
44 | use_lcm_multistep=use_lcm_multistep_t2i,
45 | )
46 | mask = sample.masks[i].mask
47 |
48 | gen_image = IImage(
49 | self.model_t2i.vae.decode(z0_list[-1] / self.model_t2i.config.scale_factor)
50 | )
51 | gen_mask = segmentation.get_segmentation(self.model_sam, gen_image, mask)
52 | gen_object = gen_image * IImage(255 * gen_mask)
53 |
54 | gen_obj_list.append(gen_object)
55 | gen_mask_list.append(gen_mask)
56 | real_mask_list.append(mask)
57 |
58 | gen_image = IImage(libimage.stack(gen_obj_list).data.sum(0))
59 | gen_mask = IImage(255 * np.sum(gen_mask_list, 0))
60 | real_mask = IImage(255 * np.sum(real_mask_list, 0))
61 |
62 | output = inpainting.gen_filled_image(
63 | self.model_inp,
64 | sample.global_prompt,
65 | gen_image,
66 | (~gen_mask.alpha()).dilate(3),
67 | sample.masks,
68 | image_seed,
69 | dt=1000 // num_ddim_steps_inp,
70 | guidance_scale=cfg_scale_inp,
71 | use_lcm_multistep=use_lcm_multistep_inp,
72 | )
73 |
74 | return output
75 |
76 |
77 |
78 |
--------------------------------------------------------------------------------
/zero_painter.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from src.zeropainter.zero_painter_pipline import ZeroPainter
3 | from src.zeropainter import models, dreamshaper,segmentation
4 | from src.zeropainter import zero_painter_dataset
5 |
6 | def get_args():
7 | parser = argparse.ArgumentParser()
8 |
9 | parser.add_argument('--mask-path', default='data/masks/1_rgb.png',help='Mask path.')
10 | parser.add_argument('--metadata', default='data/metadata/1.json', type=str, help='Text prompt.')
11 | parser.add_argument('--output-dir', default='data/outputs/',help='Output dir.')
12 |
13 | #load models
14 | parser.add_argument('--config-folder-for-models', type=str, default='config', help='Path to configs')
15 | parser.add_argument('--model-folder-inpiting', type=str, default='models/sd-1-5-inpainting', help='Path to load inpainting model')
16 | parser.add_argument('--model-folder-generation', type=str, default='models/sd-1-4', help='Path to load generation model')
17 | parser.add_argument('--segment-anything-model', type=str, default='models/sam_vit_h_4b8939.pth', help='Path to load segmentation model')
18 |
19 | return parser.parse_args()
20 |
21 |
22 | def main():
23 | args = get_args()
24 |
25 | model_inp,_ = models.get_inpainting_model(args.config_folder_for_models,args.model_folder_inpiting)
26 | model_t2i,_ = models.get_t2i_model(args.config_folder_for_models,args.model_folder_generation)
27 | model_sam = segmentation.get_segmentation_model(args.segment_anything_model)
28 | zero_painter_model = ZeroPainter(model_t2i, model_inp, model_sam)
29 |
30 | data = zero_painter_dataset.ZeroPainterDataset(args.mask_path, args.metadata)
31 | name = args.mask_path.split('/')[-1]
32 | result = zero_painter_model.gen_sample(data[0], 42, 42,30,30)
33 | result.save(args.output_dir+name)
34 |
35 |
36 | if __name__ == '__main__':
37 | main()
38 |
--------------------------------------------------------------------------------