├── .gitignore ├── LICENSE ├── README.md ├── img ├── VAE_arch.md └── ui.png └── scripts └── vae_tile.py /.gitignore: -------------------------------------------------------------------------------- 1 | # meta 2 | .vscode/ 3 | __pycache__/ 4 | 5 | # experiments 6 | exp/*/* 7 | !exp/*/*.py 8 | !exp/*/*.cmd 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Armit 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stable-diffusion-webui-vae-tile-infer 2 | 3 | Yet another vae tiling inferer extremely saving your VRAM, extension script for AUTOMATIC1111/stable-diffusion-webui. 4 | 5 | ---- 6 | 7 | ⚠ This repo is for **experiments & code study** use for developers who wanna read our idea. 😀 8 | ⚠ You should use [multidiffusion-upscaler-for-automatic1111](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111)'s implementation **in production**, we put updates there. 9 | 10 | ℹ When processing with large images, please **turn off previews** to really save time and resoureces!! 11 | 12 | ⚠ 我们成立了插件反馈 QQ 群: 616795645 (赤狐屿),欢迎出建议、意见、报告bug等 (w 13 | ⚠ We have a QQ chat group (616795645) now, any suggestions, discussions and bug reports are highly wellllcome!! 14 | 15 |  16 | 17 | 18 | ### Benchmark 19 | 20 | ``` 21 | device = NVIDIA GeForce RTX 3060 (12G VRAM) 22 | dtype = float16 23 | auto_adjust = True 24 | gn_sync = Approx 25 | skip_infer = None 26 | ``` 27 | 28 | ⚪ Encoding is cheap 29 | 30 | | Image Size | original | tile (tile_size=1024) | 31 | | :-: | :-: | :-: | 32 | | 512 x 512 | 0.009s / 2584.194MB | 0.417s / 2653.301MB / 1 tile | 33 | | 768 x 768 | 0.011s / 3227.944MB | 0.530s / 3332.989MB / 1 tile | 34 | | 1024 x 1024 | 0.012s / 4481.913MB | 0.758s / 4271.676MB / 1 tile | 35 | | 1600 x 1600 | 0.031s / 8512.850MB | 1.499s / 4301.680MB / 4 tiles | 36 | | 2048 x 2048 | 0.034s / 10309.194MB | 2.368s / 4319.680MB / 4 tiles | 37 | 38 | ⚪ Decoding is heavy 39 | 40 | - ablation on image size (tile_size=128) 41 | 42 | | Image Size | original | tile | 43 | | :-: | :-: | :-: | 44 | | 512 x 512 | 0.020s / 2616.033MB | 0.202s / 2685.320MB / 1 tile | 45 | | 768 x 768 | 0.030s / 3296.306MB | 0.427s / 3399.634MB / 1 tile | 46 | | 1024 x 768 | 0.024s / 3704.470MB | 0.561s / 3824.823MB / 1 tile | 47 | | 1280 x 720 | 0.023s / 3985.083MB | 1.510s / 4386.115MB / 2 tiles | 48 | | 1024 x 1024 | 0.017s / 4248.689MB | 0.747s / 4386.074MB / 1 tile | 49 | | 1920 x 1080 | 0.031s / 6375.797MB | 2.325s / 4387.078MB / 4 tiles | 50 | | 2048 x 1024 | 0.032s / 6425.564MB | 2.307s / 4387.107MB / 2 tiles | 51 | | 1600 x 1600 | 0.033s / 8373.138MB | 2.649s / 4387.482MB / 4 tiles | 52 | | 2048 x 1536 | 2.252s / 8602.439MB | 3.041s / 4387.971MB / 4 tiles | 53 | | 2560 x 1440 | 3.899s / 9725.989MB | 3.453s / 4389.521MB / 6 tiles | 54 | | 2048 x 2048 | 2.582s / 10265.877MB | 3.814s / 4389.111MB / 4 tiles | 55 | | 2560 x 4096 | OOM | 8.446s / 4397.221MB / 12 tiles | 56 | | 4096 x 4096 | OOM | 12.998s / 4407.095MB / 16 tiles | 57 | | 4096 x 8192 | OOM | 24.900s / 4428.142MB / 32 tiles | 58 | | 8192 x 8192 | OOM | 49.069s / 4469.158MB / 64 tiles | 59 | 60 | - ablation on tile size (image_size=2048) 61 | 62 | ℹ VRAM peak usage is only related to the tile size, say goodbye to all OOMs :) 63 | 64 | | Tile Size | tile | 65 | | :-: | :-: | 66 | | 32 | 3.630s, max VRAM alloc 2247.986 MB / 64 tiles | 67 | | 48 | 3.500s, max VRAM alloc 2433.626 MB / 36 tiles | 68 | | 64 | 3.347s, max VRAM alloc 2689.111 MB / 16 tiles | 69 | | 96 | 3.636s, max VRAM alloc 3402.735 MB / 9 tiles | 70 | | 128 | 3.803s, max VRAM alloc 4389.111 MB / 4 tiles | 71 | | 160 | 4.273s, max VRAM alloc 5646.989 MB / 4 tiles | 72 | | 192 | 5.809s, max VRAM alloc 7930.127 MB / 4 tiles | 73 | 74 | 75 | ### How it works? 76 | 77 | - split RGB image / latent image to overlapped tiles (not always be square) 78 | - normally VAE encode / decode each tile 79 | - concatenate all tiles back 80 | 81 | ⚪ settings tuning 82 | 83 | - `Encoder/Decoder tile size`: image tile as the actual processing unit; **set it as large as possible before gets OOM** :) 84 | - `Encoder/Decoder pad size`: overlapped padding of each tile; larger value making more seamless 85 | - `Auto adjust real tile size`: auto shrink real tile size to match tensor shape, avoding too small tailing tile 86 | - `GroupNorm sync`: how to sync GroupNorm stats 87 | - `Approximated`: using stats from the pre-computed low-resolution image 88 | - `Full sync`: using accurate stats to sync globally 89 | - `No sync`: do not sync 90 | - `Skip infer (experimental)`: skip calculation of certain network blocks, faster but results low quality 91 | 92 | 93 | #### Acknowledgement 94 | 95 | Thanks for the original idea from: 96 | 97 | - multidiffusion-upscaler-for-automatic1111: [https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111) 98 | 99 | ---- 100 | by Armit 101 | 2023/01/20 102 | -------------------------------------------------------------------------------- /img/VAE_arch.md: -------------------------------------------------------------------------------- 1 | ### Encoder 2 | 3 | ⚪ forward 4 | 5 | ```python 6 | x = self.conv_in(x) 7 | x = self.down[0].block[0](x) 8 | x = self.down[0].block[1](x) 9 | x = self.down[1].block[0](x) 10 | x = self.down[1].block[1](x) 11 | x = self.down[2].block[0](x) 12 | x = self.down[2].block[1](x) 13 | x = self.down[3].block[0](x) 14 | x = self.down[3].block[1](x) 15 | x = self.mid.block_1(x) 16 | x = self.mid.attn_1(x) 17 | x = self.mid.block_2(x) 18 | x = self.norm_out(x) 19 | x = self.conv_out(x) 20 | ``` 21 | 22 | ⚪ model 23 | 24 | ``` 25 | Encoder( 26 | (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 27 | (down): ModuleList( 28 | (0): Module( 29 | (block): ModuleList( 30 | (0): ResnetBlock( 31 | (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) 32 | (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 33 | (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) 34 | (dropout): Dropout(p=0.0, inplace=False) 35 | (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 36 | ) 37 | (1): ResnetBlock( 38 | (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) 39 | (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 40 | (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) 41 | (dropout): Dropout(p=0.0, inplace=False) 42 | (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 43 | ) 44 | ) 45 | (attn): ModuleList() 46 | (downsample): Downsample( 47 | (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2)) 48 | ) 49 | ) 50 | (1): Module( 51 | (block): ModuleList( 52 | (0): ResnetBlock( 53 | (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) 54 | (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 55 | (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) 56 | (dropout): Dropout(p=0.0, inplace=False) 57 | (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 58 | (nin_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) 59 | ) 60 | (1): ResnetBlock( 61 | (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) 62 | (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 63 | (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) 64 | (dropout): Dropout(p=0.0, inplace=False) 65 | (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 66 | ) 67 | ) 68 | (attn): ModuleList() 69 | (downsample): Downsample( 70 | (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2)) 71 | ) 72 | ) 73 | (2): Module( 74 | (block): ModuleList( 75 | (0): ResnetBlock( 76 | (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) 77 | (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 78 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 79 | (dropout): Dropout(p=0.0, inplace=False) 80 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 81 | (nin_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) 82 | ) 83 | (1): ResnetBlock( 84 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 85 | (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 86 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 87 | (dropout): Dropout(p=0.0, inplace=False) 88 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 89 | ) 90 | ) 91 | (attn): ModuleList() 92 | (downsample): Downsample( 93 | (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2)) 94 | ) 95 | ) 96 | (3): Module( 97 | (block): ModuleList( 98 | (0): ResnetBlock( 99 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 100 | (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 101 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 102 | (dropout): Dropout(p=0.0, inplace=False) 103 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 104 | ) 105 | (1): ResnetBlock( 106 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 107 | (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 108 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 109 | (dropout): Dropout(p=0.0, inplace=False) 110 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 111 | ) 112 | ) 113 | (attn): ModuleList() 114 | ) 115 | ) 116 | (mid): Module( 117 | (block_1): ResnetBlock( 118 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 119 | (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 120 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 121 | (dropout): Dropout(p=0.0, inplace=False) 122 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 123 | ) 124 | (attn_1): AttnBlock( 125 | (norm): GroupNorm(32, 512, eps=1e-06, affine=True) 126 | (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) 127 | (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) 128 | (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) 129 | (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) 130 | ) 131 | (block_2): ResnetBlock( 132 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 133 | (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 134 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 135 | (dropout): Dropout(p=0.0, inplace=False) 136 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 137 | ) 138 | ) 139 | (norm_out): GroupNorm(32, 512, eps=1e-06, affine=True) 140 | (conv_out): Conv2d(512, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 141 | ) 142 | ``` 143 | 144 | ### Decoder 145 | 146 | ⚪ forward 147 | 148 | ```python 149 | x = self.conv_in(x) 150 | x = self.mid.block_1(x) 151 | x = self.mid.attn_1(x) 152 | x = self.mid.block_2(x) 153 | x = self.up[3].block[0](x) # NOTE: this is reversed 154 | x = self.up[3].block[1](x) 155 | x = self.up[3].block[2](x) 156 | x = self.up[2].block[0](x) 157 | x = self.up[2].block[1](x) 158 | x = self.up[2].block[2](x) 159 | x = self.up[1].block[1](x) 160 | x = self.up[1].block[2](x) 161 | x = self.up[0].block[1](x) 162 | x = self.up[0].block[2](x) 163 | x = self.norm_out(x) 164 | x = self.conv_out(x) 165 | ``` 166 | 167 | ⚪ model 168 | 169 | ``` 170 | Decoder( 171 | (conv_in): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 172 | (mid): Module( 173 | (block_1): ResnetBlock( 174 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 175 | (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 176 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 177 | (dropout): Dropout(p=0.0, inplace=False) 178 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 179 | ) 180 | (attn_1): AttnBlock( 181 | (norm): GroupNorm(32, 512, eps=1e-06, affine=True) 182 | (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) 183 | (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) 184 | (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) 185 | (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1)) 186 | ) 187 | (block_2): ResnetBlock( 188 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 189 | (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 190 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 191 | (dropout): Dropout(p=0.0, inplace=False) 192 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 193 | ) 194 | ) 195 | (up): ModuleList( 196 | (0): Module( 197 | (block): ModuleList( 198 | (0): ResnetBlock( 199 | (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) 200 | (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 201 | (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) 202 | (dropout): Dropout(p=0.0, inplace=False) 203 | (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 204 | (nin_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) 205 | ) 206 | (1): ResnetBlock( 207 | (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) 208 | (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 209 | (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) 210 | (dropout): Dropout(p=0.0, inplace=False) 211 | (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 212 | ) 213 | (2): ResnetBlock( 214 | (norm1): GroupNorm(32, 128, eps=1e-06, affine=True) 215 | (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 216 | (norm2): GroupNorm(32, 128, eps=1e-06, affine=True) 217 | (dropout): Dropout(p=0.0, inplace=False) 218 | (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 219 | ) 220 | ) 221 | (attn): ModuleList() 222 | ) 223 | (1): Module( 224 | (block): ModuleList( 225 | (0): ResnetBlock( 226 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 227 | (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 228 | (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) 229 | (dropout): Dropout(p=0.0, inplace=False) 230 | (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 231 | (nin_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) 232 | ) 233 | (1): ResnetBlock( 234 | (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) 235 | (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 236 | (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) 237 | (dropout): Dropout(p=0.0, inplace=False) 238 | (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 239 | ) 240 | (2): ResnetBlock( 241 | (norm1): GroupNorm(32, 256, eps=1e-06, affine=True) 242 | (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 243 | (norm2): GroupNorm(32, 256, eps=1e-06, affine=True) 244 | (dropout): Dropout(p=0.0, inplace=False) 245 | (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 246 | ) 247 | ) 248 | (attn): ModuleList() 249 | (upsample): Upsample( 250 | (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 251 | ) 252 | ) 253 | (2): Module( 254 | (block): ModuleList( 255 | (0): ResnetBlock( 256 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 257 | (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 258 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 259 | (dropout): Dropout(p=0.0, inplace=False) 260 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 261 | ) 262 | (1): ResnetBlock( 263 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 264 | (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 265 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 266 | (dropout): Dropout(p=0.0, inplace=False) 267 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 268 | ) 269 | (2): ResnetBlock( 270 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 271 | (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 272 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 273 | (dropout): Dropout(p=0.0, inplace=False) 274 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 275 | ) 276 | ) 277 | (attn): ModuleList() 278 | (upsample): Upsample( 279 | (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 280 | ) 281 | ) 282 | (3): Module( 283 | (block): ModuleList( 284 | (0): ResnetBlock( 285 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 286 | (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 287 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 288 | (dropout): Dropout(p=0.0, inplace=False) 289 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 290 | ) 291 | (1): ResnetBlock( 292 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 293 | (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 294 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 295 | (dropout): Dropout(p=0.0, inplace=False) 296 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 297 | ) 298 | (2): ResnetBlock( 299 | (norm1): GroupNorm(32, 512, eps=1e-06, affine=True) 300 | (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 301 | (norm2): GroupNorm(32, 512, eps=1e-06, affine=True) 302 | (dropout): Dropout(p=0.0, inplace=False) 303 | (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 304 | ) 305 | ) 306 | (attn): ModuleList() 307 | (upsample): Upsample( 308 | (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 309 | ) 310 | ) 311 | ) 312 | (norm_out): GroupNorm(32, 128, eps=1e-06, affine=True) 313 | (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 314 | ) 315 | ``` -------------------------------------------------------------------------------- /img/ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/stable-diffusion-webui-vae-tile-infer/a6856b99f38b43eff682e1df36fcc6b99920c665/img/ui.png -------------------------------------------------------------------------------- /scripts/vae_tile.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2023/03/05 4 | 5 | import os 6 | import math 7 | from pathlib import Path 8 | from time import time 9 | from collections import defaultdict 10 | from enum import Enum 11 | from traceback import print_exc 12 | import gc 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from tqdm import tqdm 17 | import gradio as gr 18 | 19 | import modules.devices as devices 20 | from modules.scripts import Script, AlwaysVisible 21 | from modules.shared import state 22 | from modules.processing import opt_f 23 | from modules.sd_vae_approx import cheap_approximation 24 | from modules.ui import gr_show 25 | 26 | from typing import Tuple, List, Dict, Union, Generator 27 | from torch import Tensor 28 | from torch.nn import GroupNorm 29 | from modules.processing import StableDiffusionProcessing 30 | from ldm.models.autoencoder import AutoencoderKL 31 | from ldm.modules.diffusionmodules.model import Encoder, Decoder, ResnetBlock, AttnBlock 32 | 33 | Net = Union[Encoder, Decoder] 34 | Tile = Var = Mean = Tensor 35 | TaskRet = Union[Tuple[GroupNorm, Tile, Tuple[Var, Mean]], Tile] 36 | TaskGen = Generator[TaskRet, None, None] 37 | BBox = Tuple[int, int, int, int] 38 | 39 | 40 | # ↓↓↓ copied from https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111 ↓↓↓ 41 | 42 | def get_default_encoder_tile_size(): 43 | if torch.cuda.is_available(): 44 | total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20 45 | if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072 46 | elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048 47 | elif total_memory > 10*1000: ENCODER_TILE_SIZE = 1536 48 | elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536 49 | elif total_memory > 6*1000: ENCODER_TILE_SIZE = 1024 50 | elif total_memory > 4*1000: ENCODER_TILE_SIZE = 768 51 | else: ENCODER_TILE_SIZE = 512 52 | else: 53 | ENCODER_TILE_SIZE = 512 54 | return ENCODER_TILE_SIZE 55 | 56 | def get_default_decoder_tile_size(): 57 | if torch.cuda.is_available(): 58 | total_memory = torch.cuda.get_device_properties(devices.device).total_memory // 2**20 59 | if total_memory > 30*1000: DECODER_TILE_SIZE = 256 60 | elif total_memory > 16*1000: DECODER_TILE_SIZE = 192 61 | elif total_memory > 12*1000: DECODER_TILE_SIZE = 128 62 | elif total_memory > 10*1000: DECODER_TILE_SIZE = 96 63 | elif total_memory > 8*1000: DECODER_TILE_SIZE = 96 64 | elif total_memory > 6*1000: DECODER_TILE_SIZE = 80 65 | elif total_memory > 4*1000: DECODER_TILE_SIZE = 64 66 | else: DECODER_TILE_SIZE = 64 67 | else: 68 | DECODER_TILE_SIZE = 64 69 | return DECODER_TILE_SIZE 70 | 71 | def get_var_mean(input, num_groups, eps=1e-6): 72 | """ 73 | Get mean and var for group norm 74 | """ 75 | b, c = input.size(0), input.size(1) 76 | channel_in_group = int(c/num_groups) 77 | input_reshaped = input.contiguous().view(1, int(b * num_groups), channel_in_group, *input.size()[2:]) 78 | 79 | var, mean = torch.var_mean(input_reshaped, dim=[0, 2, 3, 4], unbiased=False) 80 | if torch.isinf(var).any(): 81 | var, mean = torch.var_mean(input_reshaped.float(), dim=[0, 2, 3, 4], unbiased=False) 82 | var, mean = var.to(input_reshaped.dtype), mean.to(input_reshaped.dtype) 83 | return var, mean 84 | 85 | def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6): 86 | """ 87 | Custom group norm with fixed mean and var 88 | 89 | @param input: input tensor 90 | @param num_groups: number of groups. by default, num_groups = 32 91 | @param mean: mean, must be pre-calculated by get_var_mean 92 | @param var: var, must be pre-calculated by get_var_mean 93 | @param weight: weight, should be fetched from the original group norm 94 | @param bias: bias, should be fetched from the original group norm 95 | @param eps: epsilon, by default, eps = 1e-6 to match the original group norm 96 | 97 | @return: normalized tensor 98 | """ 99 | b, c = input.size(0), input.size(1) 100 | channel_in_group = int(c/num_groups) 101 | input_reshaped = input.contiguous().view(1, int(b * num_groups), channel_in_group, *input.size()[2:]) 102 | 103 | out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps) 104 | out = out.view(b, c, *input.size()[2:]) 105 | 106 | # post affine transform 107 | if weight is not None: out *= weight.view(1, -1, 1, 1) 108 | if bias is not None: out += bias.view(1, -1, 1, 1) 109 | return out 110 | 111 | # ↑↑↑ copied from https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111 ↑↑↑ 112 | 113 | 114 | class GroupNormSync(Enum): 115 | APPROX = 'Approximated' 116 | SYNC = 'Full sync' 117 | UNSYNC = 'No sync' 118 | 119 | if 'global const': 120 | DEFAULT_OPEN = False 121 | DEFAULT_ENABLED = False 122 | DEFAULT_SMART_IGNORE = True 123 | DEFAULT_ENCODER_PAD_SIZE = 16 124 | DEFAULT_DECODER_PAD_SIZE = 2 125 | DEFAULT_ENCODER_TILE_SIZE = get_default_encoder_tile_size() 126 | DEFAULT_DECODER_TILE_SIZE = get_default_decoder_tile_size() 127 | 128 | DEFAULT_AUTO_SHRINK = True 129 | DEFAULT_ZIGZAG_PROCESS = True 130 | DEFAULT_GN_SYNC = GroupNormSync.APPROX.value 131 | DEFAULT_SKIP_INFER = False 132 | 133 | DEBUG_SHAPE = False 134 | DEBUG_STAGE = False 135 | DEBUG_APPROX = False 136 | 137 | if 'global var': 138 | smart_ignore: bool = None 139 | auto_shrink: bool = None 140 | gn_sync: GroupNormSync = None 141 | skip_infer: bool = None 142 | 143 | zigzag_dir: bool = True # False: ->, True: <- 144 | zigzag_to_cpu: bool = False # stash to 'cpu' and wait for apply `custom_group_norm` 145 | sync_approx: bool = False # True: apply, False: collect 146 | sync_approx_pc: int = 0 # program cpunter of 'sync_approx_plan' execution 147 | sync_approx_plan: List[Tuple[Var, Mean]] = [] 148 | skip_infer_plan = { 149 | Encoder: { 150 | 'down0.block0': False, 151 | 'down0.block1': False, 152 | 'down1.block0': False, 153 | 'down1.block1': False, 154 | 'down2.block0': False, 155 | 'down2.block1': False, 156 | 'down3.block0': False, 157 | 'down3.block1': False, 158 | 'mid.block_1': False, 159 | 'mid.attn_1': False, 160 | 'mid.block_2': False, 161 | }, 162 | Decoder: { 163 | 'mid.block_1': False, 164 | 'mid.attn_1': False, 165 | 'mid.block_2': False, 166 | 'up3.block0': False, 167 | 'up3.block1': False, 168 | 'up3.block2': False, 169 | 'up2.block0': False, 170 | 'up2.block1': False, 171 | 'up2.block2': False, 172 | 'up1.block0': False, 173 | 'up1.block1': False, 174 | 'up1.block2': False, 175 | 'up0.block0': False, 176 | 'up0.block1': False, 177 | 'up0.block2': False, 178 | } 179 | } 180 | skip_infer_plan_dummy = defaultdict(lambda: False) 181 | 182 | 183 | def _dbg_tensor(X:Tensor, name:str) -> None: 184 | var, mean = torch.var_mean(X) 185 | print(f'{name}: {list(X.shape)}, {X.max().item():.4f}, {X.min().item():.4f}, {mean.item():.4f}, {var.item():.4f}') 186 | 187 | def _dbg_to_image(X:Tensor, name:str) -> None: 188 | import numpy as np 189 | from PIL import Image 190 | 191 | im = X.permute([1, 2, 0]) 192 | im = (im + 1) / 2 193 | im = im.clamp_(0, 1) 194 | im = im.cpu().numpy() 195 | im = (im * 255).astype(np.uint8) 196 | img = Image.fromarray(im) 197 | img.save(Path(os.environ['TEMP']) / (name + '.png')) 198 | 199 | 200 | # ↓↓↓ modified from 'ldm/modules/diffusionmodules/model.py' ↓↓↓ 201 | 202 | def nonlinearity(x:Tensor) -> Tensor: 203 | return F.silu(x, inplace=True) 204 | 205 | def GroupNorm_forward(gn:GroupNorm, h:Tensor) -> TaskGen: 206 | if gn_sync == GroupNormSync.SYNC: 207 | var, mean = get_var_mean(h, gn.num_groups, gn.eps) 208 | if zigzag_to_cpu: h = h.cpu() 209 | yield gn, h, (var, mean) 210 | h = h.to(devices.device) 211 | elif gn_sync == GroupNormSync.APPROX: 212 | if sync_approx: # apply 213 | global sync_approx_pc 214 | var, mean = sync_approx_plan[sync_approx_pc] 215 | h = custom_group_norm(h, gn.num_groups, mean, var, gn.weight, gn.bias, gn.eps) 216 | sync_approx_pc = (sync_approx_pc + 1) % len(sync_approx_plan) 217 | else: # collect 218 | var, mean = get_var_mean(h, gn.num_groups, gn.eps) 219 | sync_approx_plan.append((var, mean)) 220 | h = gn(h) 221 | elif gn_sync == GroupNormSync.UNSYNC: 222 | h = gn(h) 223 | yield h 224 | 225 | def Resblock_forward(self:ResnetBlock, x:Tensor) -> TaskGen: 226 | h = x.clone() if (gn_sync == GroupNormSync.SYNC and not zigzag_to_cpu) else x 227 | 228 | for item in GroupNorm_forward(self.norm1, h): 229 | if isinstance(item, Tensor): h = item 230 | else: yield item 231 | 232 | h = nonlinearity(h) 233 | h: Tensor = self.conv1(h) 234 | 235 | for item in GroupNorm_forward(self.norm2, h): 236 | if isinstance(item, Tensor): h = item 237 | else: yield item 238 | 239 | h = nonlinearity(h) 240 | #h = self.dropout(h) 241 | h = self.conv2(h) 242 | 243 | if self.in_channels != self.out_channels: 244 | if self.use_conv_shortcut: 245 | x = self.conv_shortcut(x) 246 | else: 247 | x = self.nin_shortcut(x) 248 | yield x + h 249 | 250 | def AttnBlock_forward(self:AttnBlock, x:Tensor) -> TaskGen: 251 | h = x.clone() if (gn_sync == GroupNormSync.SYNC and not zigzag_to_cpu) else x 252 | 253 | for item in GroupNorm_forward(self.norm, h): 254 | if isinstance(item, Tensor): h = item 255 | else: yield item 256 | 257 | q = self.q(h) 258 | k = self.k(h) 259 | v = self.v(h) 260 | 261 | # compute attention 262 | B, C, H, W = q.shape 263 | q = q.reshape(B, C, H * W) 264 | q = q.permute(0, 2, 1) # b,hw,c 265 | k = k.reshape(B, C, H * W) # b,c,hw 266 | w = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 267 | w = w * (int(C)**(-0.5)) 268 | w = torch.nn.functional.softmax(w, dim=2) 269 | 270 | # attend to values 271 | v = v.reshape(B, C, H * W) 272 | w = w.permute(0,2,1) # b,hw,hw (first hw of k, second of q) 273 | h = torch.bmm(v, w) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 274 | h = h.reshape(B, C, H, W) 275 | 276 | h: Tensor = self.proj_out(h) 277 | yield x + h 278 | 279 | def _mid_forward(self:Net, x:Tensor, skip_plan:Dict[str, bool]) -> TaskGen: 280 | if not skip_plan['mid.block_1']: 281 | for item in Resblock_forward(self.mid.block_1, x): 282 | if isinstance(item, Tensor): x = item 283 | else: yield item 284 | if DEBUG_SHAPE: print('block_1:', x.shape) 285 | if not skip_plan['mid.attn_1']: 286 | for item in AttnBlock_forward(self.mid.attn_1, x): 287 | if isinstance(item, Tensor): x = item 288 | else: yield item 289 | if DEBUG_SHAPE: print('attn_1:', x.shape) 290 | if not skip_plan['mid.block_2']: 291 | for item in Resblock_forward(self.mid.block_2, x): 292 | if isinstance(item, Tensor): x = item 293 | else: yield item 294 | if DEBUG_SHAPE: print('block_2:', x.shape) 295 | yield x 296 | 297 | def Encoder_forward(self:Encoder, x:Tensor) -> TaskGen: 298 | # prenet 299 | x = self.conv_in(x) 300 | if DEBUG_SHAPE: print('conv_in:', x.shape) 301 | 302 | skip_enc = skip_infer_plan[Encoder] if skip_infer else skip_infer_plan_dummy 303 | 304 | # downsampling 305 | for i_level in range(self.num_resolutions): 306 | for i_block in range(self.num_res_blocks): 307 | if not skip_enc[f'down{i_level}.block{i_block}']: 308 | for item in Resblock_forward(self.down[i_level].block[i_block], x): 309 | if isinstance(item, Tensor): x = item 310 | else: yield item 311 | if DEBUG_SHAPE: print(f'down[{i_level}].block[{i_block}]:', x.shape) 312 | assert not len(self.down[i_level].attn) 313 | if i_level != self.num_resolutions-1: 314 | x = self.down[i_level].downsample(x) 315 | 316 | # middle 317 | for item in _mid_forward(self, x, skip_enc): 318 | if isinstance(item, Tensor): x = item 319 | else: yield item 320 | 321 | # end 322 | for item in GroupNorm_forward(self.norm_out, x): 323 | if isinstance(item, Tensor): x = item 324 | else: yield item 325 | 326 | x = nonlinearity(x) 327 | x = self.conv_out(x) 328 | yield x.cpu() 329 | 330 | def Decoder_forward(self:Decoder, x:Tensor) -> TaskGen: 331 | # prenet 332 | x = self.conv_in(x) # [B, C=4, H, W] => [B, C=512, H, W] 333 | if DEBUG_SHAPE: print('conv_in:', x.shape) 334 | 335 | skip_dec = skip_infer_plan[Decoder] if skip_infer else skip_infer_plan_dummy 336 | 337 | # middle 338 | for item in _mid_forward(self, x, skip_dec): 339 | if isinstance(item, Tensor): x = item 340 | else: yield item 341 | 342 | # upsampling 343 | for i_level in reversed(range(self.num_resolutions)): 344 | for i_block in range(self.num_res_blocks+1): 345 | if not skip_dec[f'up{i_level}.block{i_block}']: 346 | for item in Resblock_forward(self.up[i_level].block[i_block], x): 347 | if isinstance(item, Tensor): x = item 348 | else: yield item 349 | if DEBUG_SHAPE: print(f'up[{i_level}].block[{i_block}]:', x.shape) 350 | assert not len(self.up[i_level].attn) 351 | if i_level != 0: 352 | x = self.up[i_level].upsample(x) 353 | if DEBUG_SHAPE: print(f'up[{i_level}].upsample:', x.shape) 354 | 355 | # end 356 | if self.give_pre_end: yield x.cpu() 357 | 358 | for item in GroupNorm_forward(self.norm_out, x): 359 | if isinstance(item, Tensor): x = item 360 | else: yield item 361 | 362 | x = nonlinearity(x) 363 | x = self.conv_out(x) 364 | if DEBUG_SHAPE: print(f'conv_out:', x.shape) 365 | if self.tanh_out: x = torch.tanh(x) 366 | yield x.cpu() 367 | 368 | # ↑↑↑ modified from 'ldm/modules/diffusionmodules/model.py' ↑↑↑ 369 | 370 | 371 | def get_real_tile_config(z:Tensor, tile_size:int, is_decoder:bool) -> Tuple[int, int, int, int]: 372 | global gn_sync 373 | 374 | B, C, H, W = z.shape 375 | 376 | if auto_shrink: 377 | def auto_tile_size(low:int, high:int) -> int: 378 | ''' VRAM saving when close to low, GPU warp friendy when close to high ''' 379 | align_size = 64 if is_decoder else 512 380 | while low < high: 381 | r = low % align_size 382 | if low + r > high: 383 | align_size //= 2 384 | else: 385 | return low + r 386 | return high 387 | 388 | n_tiles_H = math.ceil(H / tile_size) 389 | n_tiles_W = math.ceil(W / tile_size) 390 | tile_size_H = auto_tile_size(math.ceil(H / n_tiles_H), math.ceil(H / (n_tiles_H - 0.15))) # assure last tile fill 72.25% 391 | tile_size_W = auto_tile_size(math.ceil(W / n_tiles_W), math.ceil(W / (n_tiles_W - 0.15))) 392 | else: 393 | tile_size_H = tile_size_W = tile_size 394 | 395 | n_tiles_H = math.ceil(H / tile_size_H) 396 | n_tiles_W = math.ceil(W / tile_size_W) 397 | n_tiles = n_tiles_H * n_tiles_W 398 | if n_tiles <= 1: gn_sync == GroupNormSync.UNSYNC # trick: force unsync when signle tile 399 | fill_ratio = H * W / (n_tiles * tile_size_H * tile_size_W) 400 | 401 | suffix = '' 402 | if gn_sync == GroupNormSync.APPROX: suffix = '(apply)' if sync_approx else '(collect)' 403 | print(f'>> sync group norm: {gn_sync.value} {suffix}') 404 | print(f'>> input: {list(z.shape)}, {str(z.dtype)[len("torch."):]} on {z.device}') 405 | print(f'>> real tile size: {tile_size_H} x {tile_size_W}') 406 | print(f'>> split to {n_tiles_H} x {n_tiles_W} = {n_tiles} tiles (fill ratio: {fill_ratio:.3%})') 407 | 408 | return tile_size_H, tile_size_W, n_tiles_H, n_tiles_W 409 | 410 | def make_bbox(n_tiles_H:int, n_tiles_W:int, tile_size_H:int, tile_size_W:int, H:int, W:int, P:int, scaler:Union[int, float]) -> Tuple[List[BBox], List[BBox]]: 411 | bbox_inputs: List[BBox] = [] 412 | bbox_outputs: List[BBox] = [] 413 | 414 | x = 0 415 | for _ in range(n_tiles_H): 416 | y = 0 417 | for _ in range(n_tiles_W): 418 | bbox_inputs.append(( 419 | x, min(x + tile_size_H, H) + 2 * P, 420 | y, min(y + tile_size_W, W) + 2 * P, 421 | )) 422 | bbox_outputs.append(( 423 | int(x * scaler), int(min(x + tile_size_H, H) * scaler), 424 | int(y * scaler), int(min(y + tile_size_W, W) * scaler), 425 | )) 426 | y += tile_size_W 427 | x += tile_size_H 428 | 429 | if DEBUG_STAGE: 430 | print('bbox_inputs:') 431 | print(bbox_inputs) 432 | print('bbox_outputs:') 433 | print(bbox_outputs) 434 | 435 | return bbox_inputs, bbox_outputs 436 | 437 | def get_n_sync(net:Net, is_decoder:bool) -> int: 438 | if gn_sync != GroupNormSync.SYNC: return 1 439 | 440 | n_sync = 31 if is_decoder else 23 441 | if skip_infer: 442 | for block, skip in skip_infer_plan[type(net)].items(): 443 | if not skip: continue 444 | if 'attn' in block: n_sync -= 1 445 | elif 'block' in block: n_sync -= 2 446 | return n_sync 447 | 448 | def perfcount(fn): 449 | def wrapper(*args, **kwargs): 450 | device = devices.device 451 | if torch.cuda.is_available(): 452 | torch.cuda.reset_peak_memory_stats(device) 453 | devices.torch_gc() 454 | gc.collect() 455 | 456 | ts = time() 457 | try: return fn(*args, **kwargs) 458 | except: raise 459 | finally: 460 | te = time() 461 | if torch.cuda.is_available(): 462 | vram = torch.cuda.max_memory_allocated(device) / 2**20 463 | torch.cuda.reset_peak_memory_stats(device) 464 | print(f'Done in {te - ts:.3f}s, max VRAM alloc {vram:.3f} MB') 465 | else: 466 | print(f'Done in {te - ts:.3f}s') 467 | devices.torch_gc() 468 | gc.collect() 469 | return wrapper 470 | 471 | 472 | @torch.inference_mode() 473 | def VAE_forward_tile(self:Net, z:Tensor, tile_size:int, pad_size:int) -> Tensor: 474 | global gn_sync, zigzag_dir, zigzag_to_cpu, sync_approx_plan 475 | 476 | B, C, H, W = z.shape 477 | P = pad_size 478 | is_decoder = isinstance(self, Decoder) 479 | scaler = opt_f if is_decoder else 1/opt_f 480 | ch = 3 if is_decoder else 8 481 | result = None # z[:, :ch, :, :] if is_decoder else z[:, :ch, :H//opt_f, :W//opt_f] # very cheap tmp result 482 | 483 | # modified: gn_sync 484 | tile_size_H, tile_size_W, n_tiles_H, n_tiles_W = get_real_tile_config(z, tile_size, is_decoder) 485 | 486 | if 'estimate max tensor shape': 487 | if is_decoder: shape = torch.Size((B, 256, (tile_size_H+2*P)*opt_f, (tile_size_W+2*P)*opt_f)) 488 | else: shape = torch.Size((B, 128, tile_size_H+2*P, tile_size_W+2*P)) 489 | size_t = 2 if self.conv_in.weight.dtype == torch.float16 else 4 490 | print(f'>> max tensor shape: {list(shape)}, estimated vram size: {shape.numel() * size_t / 2**20:.3f} MB') 491 | 492 | ''' split tiles ''' 493 | if P != 0: z = F.pad(z, (P, P, P, P), mode='reflect') # [B, C, H+2*pad, W+2*pad] 494 | 495 | bbox_inputs, bbox_outputs = make_bbox(n_tiles_H, n_tiles_W, tile_size_H, tile_size_W, H, W, P, scaler) 496 | workers: List[TaskGen] = [] 497 | for bbox in bbox_inputs: 498 | Hs, He, Ws, We = bbox 499 | tile = z[:, :, Hs:He, Ws:We] 500 | workers.append(Decoder_forward(self, tile) if is_decoder else Encoder_forward(self, tile)) 501 | del z 502 | n_workers = len(workers) 503 | if n_workers >= 3: workers = workers[1:] + [workers[0]] # trick: put two largest tiles at end for full_sync zigzagging 504 | 505 | ''' start workers ''' 506 | steps = get_n_sync(self, is_decoder) * n_workers 507 | pbar = tqdm(total=steps, desc=f'VAE tile {"decoding" if is_decoder else "encoding"}') 508 | while True: 509 | if state.interrupted: return 510 | 511 | # run one round 512 | try: 513 | outputs: List[TaskRet] = [None] * n_workers 514 | for i in (reversed if zigzag_dir else iter)(range(n_workers)): 515 | if state.interrupted: return 516 | 517 | zigzag_to_cpu = (i != 0) if zigzag_dir else (i != n_workers - 1) 518 | outputs[i] = next(workers[i]) 519 | pbar.update() 520 | if isinstance(outputs[i], Tile): workers[i] = None # trick: release resource when done 521 | zigzag_dir = not zigzag_dir 522 | 523 | if not 'check outputs type consistency': 524 | ret_type = type(outputs[0]) 525 | else: 526 | ret_types = { type(o) for o in outputs } 527 | assert len(ret_types) == 1 528 | ret_type = ret_types.pop() 529 | except StopIteration: 530 | print_exc() 531 | raise ValueError('Error: workers stopped early !!') 532 | 533 | # handle intermediates 534 | if ret_type == tuple: # GroupNorm full sync barrier 535 | assert gn_sync == GroupNormSync.SYNC 536 | 537 | if not 'check gn object identity': 538 | gn: GroupNorm = outputs[0][0] 539 | else: 540 | gns = { gn for gn, _, _ in outputs } 541 | if len(gns) > 1: 542 | print(f'group_norms: {gns}') 543 | raise ValueError('Error: workers progressing states not synchronized !!') 544 | gn: GroupNorm = list(gns)[0] 545 | 546 | if DEBUG_STAGE: 547 | print('n_tiles:', len(outputs)) 548 | print('tile.shape:', outputs[0][1].shape) 549 | print('tile[0].device:', outputs[0][1].device) # 'cpu' 550 | print('tile[-1].device:', outputs[-1][1].device) # 'cuda' 551 | 552 | var = torch.stack([var for _, _, (var, _) in outputs], dim=-1).mean(dim=-1) # [NG=32], float32, 'cuda' 553 | mean = torch.stack([mean for _, _, (_, mean) in outputs], dim=-1).mean(dim=-1) 554 | for _, tile, _ in outputs: 555 | if state.interrupted: return 556 | 557 | tile_n = custom_group_norm(tile.to(mean.device), gn.num_groups, mean, var, gn.weight, gn.bias, gn.eps) 558 | tile.data = tile_n.to(tile.device) 559 | 560 | elif ret_type == Tile: # final Tensor splits 561 | if DEBUG_STAGE: 562 | print('n_outputs:', len(outputs)) 563 | print('output.shape:', outputs[0].shape) 564 | print('output.device:', outputs[0].device) # 'cpu' 565 | assert len(bbox_outputs) == len(outputs), 'n_tiles != n_bbox_outputs' 566 | 567 | if n_workers >= 3: outputs = [outputs[-1]] + outputs[:-1] # trick: rev put two largest tiles at end for full_sync zigzagging 568 | 569 | result = torch.zeros([B, ch, int(H*scaler), int(W*scaler)], dtype=outputs[0].dtype) 570 | crop_pad = lambda x, P: x if P == 0 else x[:, :, P:-P, P:-P] 571 | for i, bbox in enumerate(bbox_outputs): 572 | Hs, He, Ws, We = bbox 573 | result[:, :, Hs:He, Ws:We] += crop_pad(outputs[i], int(P * scaler)) 574 | 575 | pbar.close() 576 | break # we're done! 577 | 578 | else: 579 | raise ValueError(f'Error: unkown ret_type: {ret_type} !!') 580 | 581 | ''' finish ''' 582 | if not is_decoder: result = result.to(devices.device) 583 | return result 584 | 585 | @perfcount 586 | def VAE_hijack(enabled:bool, self:Net, z:Tensor, tile_size:int, pad_size:int) -> Tensor: 587 | if not enabled: return self.original_forward(z) 588 | 589 | global gn_sync, sync_approx, sync_approx_plan 590 | 591 | B, C, H, W = z.shape 592 | if max(H, W) <= tile_size: 593 | if smart_ignore: 594 | return self.original_forward(z) 595 | if gn_sync == GroupNormSync.APPROX: 596 | print('<< ignore gn_sync=APPROX due to tensor to small ;)') 597 | gn_sync = GroupNormSync.UNSYNC 598 | 599 | cached_gn_sync = gn_sync 600 | if gn_sync == GroupNormSync.APPROX and isinstance(self, Encoder): # do not allow approx on encoder 601 | gn_sync = GroupNormSync.SYNC 602 | 603 | if gn_sync == GroupNormSync.APPROX: 604 | # collect 605 | sync_approx = False 606 | sync_approx_plan.clear() 607 | 608 | z_hat: Tensor = F.interpolate(z, size=(tile_size, tile_size), mode='nearest') # NOTE: do NOT interp in order to keep stats 609 | if DEBUG_APPROX: 610 | _dbg_tensor(z, 'z') 611 | _dbg_tensor(z_hat, 'z_hat') 612 | _dbg_to_image(cheap_approximation(z_hat[0].float()), 'z_hat') 613 | 614 | if 'stats shift': 615 | std_src, mean_src = torch.std_mean(z_hat, dim=[0, 2, 3], keepdim=True) 616 | std_tgt, mean_tgt = torch.std_mean(z, dim=[0, 2, 3], keepdim=True) 617 | z_hat = (z_hat - mean_src) / std_src 618 | z_hat = z_hat * std_tgt + mean_tgt 619 | z_hat = z_hat.clamp_(z.min(), z.max()) 620 | if DEBUG_APPROX: 621 | _dbg_tensor(z_hat, 'z_hat_shift') 622 | _dbg_to_image(cheap_approximation(z_hat[0].float()), 'z_hat_shift') 623 | del std_src, mean_src, std_tgt, mean_tgt 624 | 625 | x_hat = VAE_forward_tile(self, z_hat, tile_size, pad_size) 626 | if DEBUG_APPROX: 627 | _dbg_to_image(x_hat[0].float(), 'z_approx') 628 | del z_hat, x_hat 629 | 630 | # apply 631 | sync_approx = True 632 | 633 | try: 634 | return VAE_forward_tile(self, z, tile_size, pad_size) 635 | except: 636 | print_exc() 637 | return torch.stack([cheap_approximation(sample.float()).to(sample) for sample in z], dim=0) 638 | finally: 639 | sync_approx_plan.clear() 640 | gn_sync = cached_gn_sync 641 | 642 | 643 | class Script(Script): 644 | 645 | def title(self): 646 | return "Yet Another VAE Tiling" 647 | 648 | def show(self, is_img2img): 649 | return AlwaysVisible 650 | 651 | def ui(self, is_img2img): 652 | with gr.Accordion('Yet Another VAE Tiling', open=DEFAULT_OPEN): 653 | with gr.Row(variant='compact').style(equal_height=True): 654 | enabled = gr.Checkbox(label='Enabled', value=lambda: DEFAULT_ENABLED) 655 | reset = gr.Button(value='↻', variant='tool') 656 | 657 | with gr.Row(variant='compact').style(equal_height=True): 658 | encoder_tile_size = gr.Slider(label='Encoder tile size', minimum=512, maximum=4096, step=32, value=lambda: DEFAULT_ENCODER_TILE_SIZE) 659 | decoder_tile_size = gr.Slider(label='Decoder tile size', minimum=32, maximum=256, step=8, value=lambda: DEFAULT_DECODER_TILE_SIZE) 660 | 661 | with gr.Row(variant='compact').style(equal_height=True): 662 | encoder_pad_size = gr.Slider(label='Encoder pad size', minimum=0, maximum=64, step=8, value=lambda: DEFAULT_ENCODER_PAD_SIZE) 663 | decoder_pad_size = gr.Slider(label='Decoder pad size', minimum=0, maximum=8, step=1, value=lambda: DEFAULT_DECODER_PAD_SIZE) 664 | 665 | reset.click(fn=lambda: [DEFAULT_ENCODER_TILE_SIZE, DEFAULT_ENCODER_PAD_SIZE, DEFAULT_DECODER_TILE_SIZE, DEFAULT_DECODER_PAD_SIZE], 666 | outputs=[encoder_tile_size, encoder_pad_size, decoder_tile_size, decoder_pad_size]) 667 | 668 | with gr.Row(variant='compact').style(equal_height=True): 669 | ext_smart_ignore = gr.Checkbox(label='Do not process small images', value=lambda: DEFAULT_SMART_IGNORE) 670 | ext_auto_shrink = gr.Checkbox(label='Auto adjust real tile size', value=lambda: DEFAULT_AUTO_SHRINK) 671 | ext_gn_sync = gr.Dropdown(label='GroupNorm sync', value=lambda: DEFAULT_GN_SYNC, choices=[e.value for e in GroupNormSync]) 672 | ext_skip_infer = gr.Checkbox(label='Skip infer (experimental)', value=lambda: DEFAULT_SKIP_INFER) 673 | 674 | with gr.Group(visible=DEFAULT_SKIP_INFER) as tab_skip_infer: 675 | with gr.Tab(label='Encoder skip infer'): 676 | with gr.Row(variant='compact'): 677 | skip_enc_down0_block0 = gr.Checkbox(label='down0.block0') 678 | skip_enc_down0_block1 = gr.Checkbox(label='down0.block1') 679 | skip_enc_down1_block0 = gr.Checkbox(label='down1.block0') 680 | skip_enc_down1_block1 = gr.Checkbox(label='down1.block1') 681 | with gr.Row(variant='compact'): 682 | skip_enc_down2_block0 = gr.Checkbox(label='down2.block0') 683 | skip_enc_down2_block1 = gr.Checkbox(label='down2.block1') 684 | skip_enc_down3_block0 = gr.Checkbox(label='down3.block0') 685 | skip_enc_down3_block1 = gr.Checkbox(label='down3.block1') 686 | with gr.Row(variant='compact'): 687 | skip_enc_mid_block_1 = gr.Checkbox(label='mid.block_1') 688 | skip_enc_mid_attn_1 = gr.Checkbox(label='mid.attn_1') 689 | skip_enc_mid_block_2 = gr.Checkbox(label='mid.block_2') 690 | 691 | with gr.Tab(label='Decoder skip infer'): 692 | with gr.Row(variant='compact'): 693 | skip_dec_mid_block_1 = gr.Checkbox(label='mid.block_1') 694 | skip_dec_mid_attn_1 = gr.Checkbox(label='mid.attn_1') 695 | skip_dec_mid_block_2 = gr.Checkbox(label='mid.block_2') 696 | with gr.Row(variant='compact'): 697 | skip_dec_up3_block0 = gr.Checkbox(label='up3.block0') 698 | skip_dec_up3_block1 = gr.Checkbox(label='up3.block1') 699 | skip_dec_up3_block2 = gr.Checkbox(label='up3.block2') 700 | skip_dec_up2_block0 = gr.Checkbox(label='up2.block0') 701 | skip_dec_up2_block1 = gr.Checkbox(label='up2.block1') 702 | skip_dec_up2_block2 = gr.Checkbox(label='up2.block2') 703 | with gr.Row(variant='compact'): 704 | skip_dec_up1_block0 = gr.Checkbox(label='up1.block0 (no skip)', value=False, interactive=False) 705 | skip_dec_up1_block1 = gr.Checkbox(label='up1.block1') 706 | skip_dec_up1_block2 = gr.Checkbox(label='up1.block2') 707 | skip_dec_up0_block0 = gr.Checkbox(label='up0.block0 (no skip)', value=False, interactive=False) 708 | skip_dec_up0_block1 = gr.Checkbox(label='up0.block1') 709 | skip_dec_up0_block2 = gr.Checkbox(label='up0.block2') 710 | 711 | with gr.Row(): 712 | gr.HTML('
=> see "img/VAE_arch.md" for model arch reference
') 713 | 714 | ext_skip_infer.change(fn=lambda x: gr_show(x), inputs=ext_skip_infer, outputs=tab_skip_infer, show_progress=False) 715 | 716 | return [ 717 | enabled, 718 | encoder_tile_size, encoder_pad_size, 719 | decoder_tile_size, decoder_pad_size, 720 | ext_smart_ignore, 721 | ext_auto_shrink, 722 | ext_gn_sync, 723 | ext_skip_infer, 724 | skip_enc_down0_block0, 725 | skip_enc_down0_block1, 726 | skip_enc_down1_block0, 727 | skip_enc_down1_block1, 728 | skip_enc_down2_block0, 729 | skip_enc_down2_block1, 730 | skip_enc_down3_block0, 731 | skip_enc_down3_block1, 732 | skip_enc_mid_block_1, 733 | skip_enc_mid_attn_1, 734 | skip_enc_mid_block_2, 735 | skip_dec_mid_block_1, 736 | skip_dec_mid_attn_1, 737 | skip_dec_mid_block_2, 738 | skip_dec_up3_block0, 739 | skip_dec_up3_block1, 740 | skip_dec_up3_block2, 741 | skip_dec_up2_block0, 742 | skip_dec_up2_block1, 743 | skip_dec_up2_block2, 744 | skip_dec_up1_block1, 745 | skip_dec_up1_block2, 746 | skip_dec_up0_block1, 747 | skip_dec_up0_block2, 748 | ] 749 | 750 | def process(self, p:StableDiffusionProcessing, 751 | enabled:bool, 752 | encoder_tile_size:int, encoder_pad_size:int, 753 | decoder_tile_size:int, decoder_pad_size:int, 754 | ext_smart_ignore:bool, 755 | ext_auto_shrink:bool, 756 | ext_gn_sync:str, 757 | ext_skip_infer:bool, 758 | skip_enc_down0_block0:bool, 759 | skip_enc_down0_block1:bool, 760 | skip_enc_down1_block0:bool, 761 | skip_enc_down1_block1:bool, 762 | skip_enc_down2_block0:bool, 763 | skip_enc_down2_block1:bool, 764 | skip_enc_down3_block0:bool, 765 | skip_enc_down3_block1:bool, 766 | skip_enc_mid_block_1:bool, 767 | skip_enc_mid_attn_1:bool, 768 | skip_enc_mid_block_2:bool, 769 | skip_dec_mid_block_1:bool, 770 | skip_dec_mid_attn_1:bool, 771 | skip_dec_mid_block_2:bool, 772 | skip_dec_up3_block0:bool, 773 | skip_dec_up3_block1:bool, 774 | skip_dec_up3_block2:bool, 775 | skip_dec_up2_block0:bool, 776 | skip_dec_up2_block1:bool, 777 | skip_dec_up2_block2:bool, 778 | skip_dec_up1_block1:bool, 779 | skip_dec_up1_block2:bool, 780 | skip_dec_up0_block1:bool, 781 | skip_dec_up0_block2:bool, 782 | ): 783 | 784 | vae: AutoencoderKL = p.sd_model.first_stage_model 785 | if vae.device == torch.device('cpu'): return 786 | 787 | encoder: Encoder = vae.encoder 788 | decoder: Decoder = vae.decoder 789 | 790 | # save original forward (only once) 791 | if not hasattr(encoder, 'original_forward'): encoder.original_forward = encoder.forward 792 | if not hasattr(decoder, 'original_forward'): decoder.original_forward = decoder.forward 793 | 794 | # undo hijack 795 | if not enabled: 796 | from inspect import isfunction, getfullargspec 797 | if isfunction(encoder.forward) and getfullargspec(encoder.forward).args[0] == 'x': 798 | encoder.forward = encoder.original_forward 799 | if isfunction(decoder.forward) and getfullargspec(decoder.forward).args[0] == 'x': 800 | decoder.forward = decoder.original_forward 801 | return 802 | 803 | # extras parameters 804 | if enabled: 805 | global smart_ignore, auto_shrink, gn_sync, skip_infer, skip_infer_plan 806 | 807 | # store setting to globals 808 | smart_ignore = ext_smart_ignore 809 | auto_shrink = ext_auto_shrink 810 | gn_sync = GroupNormSync(ext_gn_sync) 811 | skip_infer = ext_skip_infer 812 | 813 | if ext_skip_infer: 814 | skip_infer_plan[Encoder].update({ 815 | 'down0.block0': skip_enc_down0_block0, 816 | 'down0.block1': skip_enc_down0_block1, 817 | 'down1.block0': skip_enc_down1_block0, 818 | 'down1.block1': skip_enc_down1_block1, 819 | 'down2.block0': skip_enc_down2_block0, 820 | 'down2.block1': skip_enc_down2_block1, 821 | 'down3.block0': skip_enc_down3_block0, 822 | 'down3.block1': skip_enc_down3_block1, 823 | 'mid.block_1': skip_enc_mid_block_1, 824 | 'mid.attn_1': skip_enc_mid_attn_1, 825 | 'mid.block_2': skip_enc_mid_block_2, 826 | }) 827 | skip_infer_plan[Decoder].update({ 828 | 'mid.block_1': skip_dec_mid_block_1, 829 | 'mid.attn_1': skip_dec_mid_attn_1, 830 | 'mid.block_2': skip_dec_mid_block_2, 831 | 'up3.block0': skip_dec_up3_block0, 832 | 'up3.block1': skip_dec_up3_block1, 833 | 'up3.block2': skip_dec_up3_block2, 834 | 'up2.block0': skip_dec_up2_block0, 835 | 'up2.block1': skip_dec_up2_block1, 836 | 'up2.block2': skip_dec_up2_block2, 837 | 'up1.block1': skip_dec_up1_block1, 838 | 'up1.block2': skip_dec_up1_block2, 839 | 'up0.block1': skip_dec_up0_block1, 840 | 'up0.block2': skip_dec_up0_block2, 841 | }) 842 | 843 | # apply hijack 844 | encoder.forward = lambda x: VAE_hijack(enabled, encoder, x, encoder_tile_size, encoder_pad_size) 845 | decoder.forward = lambda x: VAE_hijack(enabled, decoder, x, decoder_tile_size, decoder_pad_size) 846 | --------------------------------------------------------------------------------