├── .gitignore ├── README.md ├── __init__.py ├── examples ├── stablesr_w_UltimateSDUpscale.json └── stablesr_w_color_fix.json ├── modules ├── attn.py ├── colorfix.py ├── spade.py ├── struct_cond.py └── util.py └── nodes.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # comfyui-stablsr插件 2 | 3 | ## 使用说明 4 | 从huggingface仓库 https://huggingface.co/Iceclear/StableSR/ 或百度盘 https://pan.baidu.com/s/10OAff20AtNhNFfJYgN1F6Q?pwd=ox5o 5 | 下载webui_786v_139.ckpt和stablesr_768v_000139.ckpt两个模型文件 6 | 7 | 下载后StableSR webui_786v_139.ckpt 放到 Comyfui/models/stablesr/ 8 | StableSR stablesr_768v_000139.ckpt 放到 Comyfui/models/checkpoints/ 9 | 文件夹不存在的话自行创建即可 10 | 11 | 如果你的comfyui模型目录设置了共享webui的模型目录,则放到webui对应目录下 12 | 13 | 14 | 15 | ## usage 16 | Download webui_786v_139.ckpt and stablesr_768v_000139.ckpt, from https://huggingface.co/Iceclear/StableSR/ or from Baidu Pan at https://pan.baidu.com/s/10OAff20AtNhNFfJYgN1F6Q?pwd=ox5o 17 | 18 | then put StableSR webui_786v_139.ckpt into Comyfui/models/stablesr/ 19 | and StableSR stablesr_768v_000139.ckpt into Comyfui/models/checkpoints/ 20 | If the folders do not exist, create them yourself. 21 | 22 | If you have set the shared webui model directory for your comfyui, place the files in the corresponding webui directory. 23 | 24 | There is a setup json in /examples/ to load the workflow into Comfyui 25 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: WSJUSA 3 | @title: StableSR 4 | @nickname: StableSR 5 | @description: This module enables StableSR in Comgfyui. Ported work of sd-webui-stablesr. Original work for Auotmaatic1111 version of this module and StableSR credit to LIightChaser and Jianyi Wang. 6 | """ 7 | 8 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 9 | 10 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] -------------------------------------------------------------------------------- /examples/stablesr_w_UltimateSDUpscale.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 42, 3 | "last_link_id": 99, 4 | "nodes": [ 5 | { 6 | "id": 12, 7 | "type": "LoadImage", 8 | "pos": [ 9 | 137, 10 | -235 11 | ], 12 | "size": { 13 | "0": 453.4217529296875, 14 | "1": 469.52587890625 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "IMAGE", 22 | "type": "IMAGE", 23 | "links": [ 24 | 54 25 | ], 26 | "shape": 3, 27 | "slot_index": 0 28 | }, 29 | { 30 | "name": "MASK", 31 | "type": "MASK", 32 | "links": null, 33 | "shape": 3 34 | } 35 | ], 36 | "properties": { 37 | "Node name for S&R": "LoadImage" 38 | }, 39 | "widgets_values": [ 40 | "example_lowres.png", 41 | "image" 42 | ] 43 | }, 44 | { 45 | "id": 4, 46 | "type": "CheckpointLoaderSimple", 47 | "pos": [ 48 | -284, 49 | 436 50 | ], 51 | "size": { 52 | "0": 315, 53 | "1": 98 54 | }, 55 | "flags": {}, 56 | "order": 1, 57 | "mode": 0, 58 | "outputs": [ 59 | { 60 | "name": "MODEL", 61 | "type": "MODEL", 62 | "links": [ 63 | 63 64 | ], 65 | "slot_index": 0 66 | }, 67 | { 68 | "name": "CLIP", 69 | "type": "CLIP", 70 | "links": [ 71 | 19, 72 | 22 73 | ], 74 | "slot_index": 1 75 | }, 76 | { 77 | "name": "VAE", 78 | "type": "VAE", 79 | "links": [ 80 | 94 81 | ], 82 | "slot_index": 2 83 | } 84 | ], 85 | "properties": { 86 | "Node name for S&R": "CheckpointLoaderSimple" 87 | }, 88 | "widgets_values": [ 89 | "stablesr_768v_000139.ckpt" 90 | ] 91 | }, 92 | { 93 | "id": 34, 94 | "type": "StableSRColorFix", 95 | "pos": [ 96 | 1034, 97 | 9 98 | ], 99 | "size": { 100 | "0": 315, 101 | "1": 78 102 | }, 103 | "flags": {}, 104 | "order": 7, 105 | "mode": 0, 106 | "inputs": [ 107 | { 108 | "name": "image", 109 | "type": "IMAGE", 110 | "link": 88 111 | }, 112 | { 113 | "name": "color_map_image", 114 | "type": "IMAGE", 115 | "link": 82 116 | } 117 | ], 118 | "outputs": [ 119 | { 120 | "name": "IMAGE", 121 | "type": "IMAGE", 122 | "links": [ 123 | 74 124 | ], 125 | "shape": 3, 126 | "slot_index": 0 127 | } 128 | ], 129 | "properties": { 130 | "Node name for S&R": "StableSRColorFix" 131 | }, 132 | "widgets_values": [ 133 | "Wavelet" 134 | ] 135 | }, 136 | { 137 | "id": 35, 138 | "type": "PreviewImage", 139 | "pos": [ 140 | 1001, 141 | 141 142 | ], 143 | "size": { 144 | "0": 423.5451965332031, 145 | "1": 530.3379516601562 146 | }, 147 | "flags": {}, 148 | "order": 8, 149 | "mode": 0, 150 | "inputs": [ 151 | { 152 | "name": "images", 153 | "type": "IMAGE", 154 | "link": 74 155 | } 156 | ], 157 | "properties": { 158 | "Node name for S&R": "PreviewImage" 159 | } 160 | }, 161 | { 162 | "id": 29, 163 | "type": "ImageScaleBy", 164 | "pos": [ 165 | 639, 166 | -41 167 | ], 168 | "size": { 169 | "0": 315, 170 | "1": 82 171 | }, 172 | "flags": {}, 173 | "order": 2, 174 | "mode": 0, 175 | "inputs": [ 176 | { 177 | "name": "image", 178 | "type": "IMAGE", 179 | "link": 54 180 | } 181 | ], 182 | "outputs": [ 183 | { 184 | "name": "IMAGE", 185 | "type": "IMAGE", 186 | "links": [ 187 | 82, 188 | 87 189 | ], 190 | "shape": 3, 191 | "slot_index": 0 192 | } 193 | ], 194 | "properties": { 195 | "Node name for S&R": "ImageScaleBy" 196 | }, 197 | "widgets_values": [ 198 | "lanczos", 199 | 2 200 | ] 201 | }, 202 | { 203 | "id": 14, 204 | "type": "CLIPTextEncode", 205 | "pos": [ 206 | 174, 207 | 457 208 | ], 209 | "size": { 210 | "0": 400, 211 | "1": 200 212 | }, 213 | "flags": {}, 214 | "order": 4, 215 | "mode": 0, 216 | "inputs": [ 217 | { 218 | "name": "clip", 219 | "type": "CLIP", 220 | "link": 19 221 | } 222 | ], 223 | "outputs": [ 224 | { 225 | "name": "CONDITIONING", 226 | "type": "CONDITIONING", 227 | "links": [ 228 | 93 229 | ], 230 | "shape": 3, 231 | "slot_index": 0 232 | } 233 | ], 234 | "properties": { 235 | "Node name for S&R": "CLIPTextEncode" 236 | }, 237 | "widgets_values": [ 238 | "(masterpiece), (best quality), (realistic),(very clear)" 239 | ] 240 | }, 241 | { 242 | "id": 15, 243 | "type": "CLIPTextEncode", 244 | "pos": [ 245 | 178, 246 | 700 247 | ], 248 | "size": { 249 | "0": 400, 250 | "1": 200 251 | }, 252 | "flags": {}, 253 | "order": 5, 254 | "mode": 0, 255 | "inputs": [ 256 | { 257 | "name": "clip", 258 | "type": "CLIP", 259 | "link": 22 260 | } 261 | ], 262 | "outputs": [ 263 | { 264 | "name": "CONDITIONING", 265 | "type": "CONDITIONING", 266 | "links": [ 267 | 92 268 | ], 269 | "shape": 3, 270 | "slot_index": 0 271 | } 272 | ], 273 | "properties": { 274 | "Node name for S&R": "CLIPTextEncode" 275 | }, 276 | "widgets_values": [ 277 | "3d, cartoon, anime, sketches, (worst quality), (low quality)" 278 | ] 279 | }, 280 | { 281 | "id": 39, 282 | "type": "UltimateSDUpscaleNoUpscale", 283 | "pos": [ 284 | 645, 285 | 111 286 | ], 287 | "size": { 288 | "0": 315, 289 | "1": 570 290 | }, 291 | "flags": {}, 292 | "order": 6, 293 | "mode": 0, 294 | "inputs": [ 295 | { 296 | "name": "upscaled_image", 297 | "type": "IMAGE", 298 | "link": 87 299 | }, 300 | { 301 | "name": "model", 302 | "type": "MODEL", 303 | "link": 99 304 | }, 305 | { 306 | "name": "positive", 307 | "type": "CONDITIONING", 308 | "link": 93 309 | }, 310 | { 311 | "name": "negative", 312 | "type": "CONDITIONING", 313 | "link": 92 314 | }, 315 | { 316 | "name": "vae", 317 | "type": "VAE", 318 | "link": 94 319 | } 320 | ], 321 | "outputs": [ 322 | { 323 | "name": "IMAGE", 324 | "type": "IMAGE", 325 | "links": [ 326 | 88 327 | ], 328 | "shape": 3, 329 | "slot_index": 0 330 | } 331 | ], 332 | "properties": { 333 | "Node name for S&R": "UltimateSDUpscaleNoUpscale" 334 | }, 335 | "widgets_values": [ 336 | 0, 337 | "fixed", 338 | 20, 339 | 7, 340 | "euler_ancestral", 341 | "karras", 342 | 1, 343 | "Linear", 344 | 704, 345 | 704, 346 | 32, 347 | 32, 348 | "None", 349 | 1, 350 | 64, 351 | 8, 352 | 16, 353 | true, 354 | false 355 | ] 356 | }, 357 | { 358 | "id": 31, 359 | "type": "ApplyStableSRUpscaler", 360 | "pos": [ 361 | 206, 362 | 301 363 | ], 364 | "size": { 365 | "0": 315, 366 | "1": 102 367 | }, 368 | "flags": {}, 369 | "order": 3, 370 | "mode": 0, 371 | "inputs": [ 372 | { 373 | "name": "model", 374 | "type": "MODEL", 375 | "link": 63 376 | }, 377 | { 378 | "name": "latent_image", 379 | "type": "LATENT", 380 | "link": null 381 | } 382 | ], 383 | "outputs": [ 384 | { 385 | "name": "MODEL", 386 | "type": "MODEL", 387 | "links": [ 388 | 99 389 | ], 390 | "shape": 3, 391 | "slot_index": 0 392 | } 393 | ], 394 | "properties": { 395 | "Node name for S&R": "ApplyStableSRUpscaler" 396 | }, 397 | "widgets_values": [ 398 | "webui_768v_139.ckpt" 399 | ] 400 | } 401 | ], 402 | "links": [ 403 | [ 404 | 19, 405 | 4, 406 | 1, 407 | 14, 408 | 0, 409 | "CLIP" 410 | ], 411 | [ 412 | 22, 413 | 4, 414 | 1, 415 | 15, 416 | 0, 417 | "CLIP" 418 | ], 419 | [ 420 | 54, 421 | 12, 422 | 0, 423 | 29, 424 | 0, 425 | "IMAGE" 426 | ], 427 | [ 428 | 63, 429 | 4, 430 | 0, 431 | 31, 432 | 0, 433 | "MODEL" 434 | ], 435 | [ 436 | 74, 437 | 34, 438 | 0, 439 | 35, 440 | 0, 441 | "IMAGE" 442 | ], 443 | [ 444 | 82, 445 | 29, 446 | 0, 447 | 34, 448 | 1, 449 | "IMAGE" 450 | ], 451 | [ 452 | 87, 453 | 29, 454 | 0, 455 | 39, 456 | 0, 457 | "IMAGE" 458 | ], 459 | [ 460 | 88, 461 | 39, 462 | 0, 463 | 34, 464 | 0, 465 | "IMAGE" 466 | ], 467 | [ 468 | 92, 469 | 15, 470 | 0, 471 | 39, 472 | 3, 473 | "CONDITIONING" 474 | ], 475 | [ 476 | 93, 477 | 14, 478 | 0, 479 | 39, 480 | 2, 481 | "CONDITIONING" 482 | ], 483 | [ 484 | 94, 485 | 4, 486 | 2, 487 | 39, 488 | 4, 489 | "VAE" 490 | ], 491 | [ 492 | 99, 493 | 31, 494 | 0, 495 | 39, 496 | 1, 497 | "MODEL" 498 | ] 499 | ], 500 | "groups": [], 501 | "config": {}, 502 | "extra": {}, 503 | "version": 0.4 504 | } -------------------------------------------------------------------------------- /examples/stablesr_w_color_fix.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 35, 3 | "last_link_id": 74, 4 | "nodes": [ 5 | { 6 | "id": 14, 7 | "type": "CLIPTextEncode", 8 | "pos": [ 9 | 57, 10 | 505 11 | ], 12 | "size": { 13 | "0": 400, 14 | "1": 200 15 | }, 16 | "flags": {}, 17 | "order": 2, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "clip", 22 | "type": "CLIP", 23 | "link": 19 24 | } 25 | ], 26 | "outputs": [ 27 | { 28 | "name": "CONDITIONING", 29 | "type": "CONDITIONING", 30 | "links": [ 31 | 65 32 | ], 33 | "shape": 3, 34 | "slot_index": 0 35 | } 36 | ], 37 | "properties": { 38 | "Node name for S&R": "CLIPTextEncode" 39 | }, 40 | "widgets_values": [ 41 | "(masterpiece), (best quality), (realistic),(very clear)" 42 | ] 43 | }, 44 | { 45 | "id": 15, 46 | "type": "CLIPTextEncode", 47 | "pos": [ 48 | 56, 49 | 777 50 | ], 51 | "size": { 52 | "0": 400, 53 | "1": 200 54 | }, 55 | "flags": {}, 56 | "order": 3, 57 | "mode": 0, 58 | "inputs": [ 59 | { 60 | "name": "clip", 61 | "type": "CLIP", 62 | "link": 22 63 | } 64 | ], 65 | "outputs": [ 66 | { 67 | "name": "CONDITIONING", 68 | "type": "CONDITIONING", 69 | "links": [ 70 | 66 71 | ], 72 | "shape": 3, 73 | "slot_index": 0 74 | } 75 | ], 76 | "properties": { 77 | "Node name for S&R": "CLIPTextEncode" 78 | }, 79 | "widgets_values": [ 80 | "3d, cartoon, anime, sketches, (worst quality), (low quality)" 81 | ] 82 | }, 83 | { 84 | "id": 31, 85 | "type": "ApplyStableSRUpscaler", 86 | "pos": [ 87 | 550, 88 | 236 89 | ], 90 | "size": { 91 | "0": 315, 92 | "1": 102 93 | }, 94 | "flags": {}, 95 | "order": 6, 96 | "mode": 0, 97 | "inputs": [ 98 | { 99 | "name": "model", 100 | "type": "MODEL", 101 | "link": 63 102 | }, 103 | { 104 | "name": "latent_image", 105 | "type": "LATENT", 106 | "link": 67 107 | } 108 | ], 109 | "outputs": [ 110 | { 111 | "name": "MODEL", 112 | "type": "MODEL", 113 | "links": [ 114 | 64 115 | ], 116 | "shape": 3, 117 | "slot_index": 0 118 | } 119 | ], 120 | "properties": { 121 | "Node name for S&R": "ApplyStableSRUpscaler" 122 | }, 123 | "widgets_values": [ 124 | "webui_768v_139.ckpt", 125 | true 126 | ] 127 | }, 128 | { 129 | "id": 26, 130 | "type": "PreviewImage", 131 | "pos": [ 132 | 1247, 133 | 498 134 | ], 135 | "size": { 136 | "0": 426.760009765625, 137 | "1": 541.3356323242188 138 | }, 139 | "flags": {}, 140 | "order": 10, 141 | "mode": 0, 142 | "inputs": [ 143 | { 144 | "name": "images", 145 | "type": "IMAGE", 146 | "link": 41 147 | } 148 | ], 149 | "properties": { 150 | "Node name for S&R": "PreviewImage" 151 | } 152 | }, 153 | { 154 | "id": 13, 155 | "type": "VAEEncode", 156 | "pos": [ 157 | 593, 158 | 110 159 | ], 160 | "size": { 161 | "0": 210, 162 | "1": 46 163 | }, 164 | "flags": {}, 165 | "order": 5, 166 | "mode": 0, 167 | "inputs": [ 168 | { 169 | "name": "pixels", 170 | "type": "IMAGE", 171 | "link": 56 172 | }, 173 | { 174 | "name": "vae", 175 | "type": "VAE", 176 | "link": 15 177 | } 178 | ], 179 | "outputs": [ 180 | { 181 | "name": "LATENT", 182 | "type": "LATENT", 183 | "links": [ 184 | 67, 185 | 68, 186 | 70 187 | ], 188 | "shape": 3, 189 | "slot_index": 0 190 | } 191 | ], 192 | "properties": { 193 | "Node name for S&R": "VAEEncode" 194 | } 195 | }, 196 | { 197 | "id": 4, 198 | "type": "CheckpointLoaderSimple", 199 | "pos": [ 200 | -406, 201 | 490 202 | ], 203 | "size": { 204 | "0": 315, 205 | "1": 98 206 | }, 207 | "flags": {}, 208 | "order": 0, 209 | "mode": 0, 210 | "outputs": [ 211 | { 212 | "name": "MODEL", 213 | "type": "MODEL", 214 | "links": [ 215 | 63 216 | ], 217 | "slot_index": 0 218 | }, 219 | { 220 | "name": "CLIP", 221 | "type": "CLIP", 222 | "links": [ 223 | 19, 224 | 22 225 | ], 226 | "slot_index": 1 227 | }, 228 | { 229 | "name": "VAE", 230 | "type": "VAE", 231 | "links": [ 232 | 8, 233 | 15, 234 | 71 235 | ], 236 | "slot_index": 2 237 | } 238 | ], 239 | "properties": { 240 | "Node name for S&R": "CheckpointLoaderSimple" 241 | }, 242 | "widgets_values": [ 243 | "stablesr_768v_000139.ckpt" 244 | ] 245 | }, 246 | { 247 | "id": 33, 248 | "type": "VAEDecode", 249 | "pos": [ 250 | 1467, 251 | 79 252 | ], 253 | "size": { 254 | "0": 210, 255 | "1": 46 256 | }, 257 | "flags": {}, 258 | "order": 7, 259 | "mode": 0, 260 | "inputs": [ 261 | { 262 | "name": "samples", 263 | "type": "LATENT", 264 | "link": 70 265 | }, 266 | { 267 | "name": "vae", 268 | "type": "VAE", 269 | "link": 71 270 | } 271 | ], 272 | "outputs": [ 273 | { 274 | "name": "IMAGE", 275 | "type": "IMAGE", 276 | "links": [ 277 | 72 278 | ], 279 | "shape": 3, 280 | "slot_index": 0 281 | } 282 | ], 283 | "properties": { 284 | "Node name for S&R": "VAEDecode" 285 | } 286 | }, 287 | { 288 | "id": 8, 289 | "type": "VAEDecode", 290 | "pos": [ 291 | 1468, 292 | 181 293 | ], 294 | "size": { 295 | "0": 210, 296 | "1": 46 297 | }, 298 | "flags": {}, 299 | "order": 9, 300 | "mode": 0, 301 | "inputs": [ 302 | { 303 | "name": "samples", 304 | "type": "LATENT", 305 | "link": 60 306 | }, 307 | { 308 | "name": "vae", 309 | "type": "VAE", 310 | "link": 8 311 | } 312 | ], 313 | "outputs": [ 314 | { 315 | "name": "IMAGE", 316 | "type": "IMAGE", 317 | "links": [ 318 | 41, 319 | 73 320 | ], 321 | "slot_index": 0 322 | } 323 | ], 324 | "properties": { 325 | "Node name for S&R": "VAEDecode" 326 | } 327 | }, 328 | { 329 | "id": 34, 330 | "type": "StableSRColorFix", 331 | "pos": [ 332 | 1732, 333 | 118 334 | ], 335 | "size": { 336 | "0": 315, 337 | "1": 78 338 | }, 339 | "flags": {}, 340 | "order": 11, 341 | "mode": 0, 342 | "inputs": [ 343 | { 344 | "name": "image", 345 | "type": "IMAGE", 346 | "link": 73 347 | }, 348 | { 349 | "name": "color_map_image", 350 | "type": "IMAGE", 351 | "link": 72 352 | } 353 | ], 354 | "outputs": [ 355 | { 356 | "name": "IMAGE", 357 | "type": "IMAGE", 358 | "links": [ 359 | 74 360 | ], 361 | "shape": 3, 362 | "slot_index": 0 363 | } 364 | ], 365 | "properties": { 366 | "Node name for S&R": "StableSRColorFix" 367 | }, 368 | "widgets_values": [ 369 | "Wavelet" 370 | ] 371 | }, 372 | { 373 | "id": 30, 374 | "type": "KSampler", 375 | "pos": [ 376 | 1012, 377 | 124 378 | ], 379 | "size": { 380 | "0": 315, 381 | "1": 262 382 | }, 383 | "flags": {}, 384 | "order": 8, 385 | "mode": 0, 386 | "inputs": [ 387 | { 388 | "name": "model", 389 | "type": "MODEL", 390 | "link": 64 391 | }, 392 | { 393 | "name": "positive", 394 | "type": "CONDITIONING", 395 | "link": 65 396 | }, 397 | { 398 | "name": "negative", 399 | "type": "CONDITIONING", 400 | "link": 66 401 | }, 402 | { 403 | "name": "latent_image", 404 | "type": "LATENT", 405 | "link": 68, 406 | "slot_index": 3 407 | } 408 | ], 409 | "outputs": [ 410 | { 411 | "name": "LATENT", 412 | "type": "LATENT", 413 | "links": [ 414 | 60 415 | ], 416 | "shape": 3, 417 | "slot_index": 0 418 | } 419 | ], 420 | "properties": { 421 | "Node name for S&R": "KSampler" 422 | }, 423 | "widgets_values": [ 424 | 175840193994180, 425 | "randomize", 426 | 20, 427 | 8, 428 | "euler", 429 | "normal", 430 | 1 431 | ] 432 | }, 433 | { 434 | "id": 29, 435 | "type": "ImageScaleBy", 436 | "pos": [ 437 | 197, 438 | 47 439 | ], 440 | "size": { 441 | "0": 315, 442 | "1": 82 443 | }, 444 | "flags": {}, 445 | "order": 4, 446 | "mode": 0, 447 | "inputs": [ 448 | { 449 | "name": "image", 450 | "type": "IMAGE", 451 | "link": 54 452 | } 453 | ], 454 | "outputs": [ 455 | { 456 | "name": "IMAGE", 457 | "type": "IMAGE", 458 | "links": [ 459 | 56 460 | ], 461 | "shape": 3, 462 | "slot_index": 0 463 | } 464 | ], 465 | "properties": { 466 | "Node name for S&R": "ImageScaleBy" 467 | }, 468 | "widgets_values": [ 469 | "lanczos", 470 | 1.9999993896484363 471 | ] 472 | }, 473 | { 474 | "id": 35, 475 | "type": "PreviewImage", 476 | "pos": [ 477 | 1740, 478 | 502 479 | ], 480 | "size": { 481 | "0": 423.5451965332031, 482 | "1": 530.3379516601562 483 | }, 484 | "flags": {}, 485 | "order": 12, 486 | "mode": 0, 487 | "inputs": [ 488 | { 489 | "name": "images", 490 | "type": "IMAGE", 491 | "link": 74 492 | } 493 | ], 494 | "properties": { 495 | "Node name for S&R": "PreviewImage" 496 | } 497 | }, 498 | { 499 | "id": 12, 500 | "type": "LoadImage", 501 | "pos": [ 502 | -382, 503 | -80 504 | ], 505 | "size": { 506 | "0": 453.4217529296875, 507 | "1": 469.52587890625 508 | }, 509 | "flags": {}, 510 | "order": 1, 511 | "mode": 0, 512 | "outputs": [ 513 | { 514 | "name": "IMAGE", 515 | "type": "IMAGE", 516 | "links": [ 517 | 54 518 | ], 519 | "shape": 3, 520 | "slot_index": 0 521 | }, 522 | { 523 | "name": "MASK", 524 | "type": "MASK", 525 | "links": null, 526 | "shape": 3 527 | } 528 | ], 529 | "properties": { 530 | "Node name for S&R": "LoadImage" 531 | }, 532 | "widgets_values": [ 533 | "1111.jpg", 534 | "image" 535 | ] 536 | } 537 | ], 538 | "links": [ 539 | [ 540 | 8, 541 | 4, 542 | 2, 543 | 8, 544 | 1, 545 | "VAE" 546 | ], 547 | [ 548 | 15, 549 | 4, 550 | 2, 551 | 13, 552 | 1, 553 | "VAE" 554 | ], 555 | [ 556 | 19, 557 | 4, 558 | 1, 559 | 14, 560 | 0, 561 | "CLIP" 562 | ], 563 | [ 564 | 22, 565 | 4, 566 | 1, 567 | 15, 568 | 0, 569 | "CLIP" 570 | ], 571 | [ 572 | 41, 573 | 8, 574 | 0, 575 | 26, 576 | 0, 577 | "IMAGE" 578 | ], 579 | [ 580 | 54, 581 | 12, 582 | 0, 583 | 29, 584 | 0, 585 | "IMAGE" 586 | ], 587 | [ 588 | 56, 589 | 29, 590 | 0, 591 | 13, 592 | 0, 593 | "IMAGE" 594 | ], 595 | [ 596 | 60, 597 | 30, 598 | 0, 599 | 8, 600 | 0, 601 | "LATENT" 602 | ], 603 | [ 604 | 63, 605 | 4, 606 | 0, 607 | 31, 608 | 0, 609 | "MODEL" 610 | ], 611 | [ 612 | 64, 613 | 31, 614 | 0, 615 | 30, 616 | 0, 617 | "MODEL" 618 | ], 619 | [ 620 | 65, 621 | 14, 622 | 0, 623 | 30, 624 | 1, 625 | "CONDITIONING" 626 | ], 627 | [ 628 | 66, 629 | 15, 630 | 0, 631 | 30, 632 | 2, 633 | "CONDITIONING" 634 | ], 635 | [ 636 | 67, 637 | 13, 638 | 0, 639 | 31, 640 | 1, 641 | "LATENT" 642 | ], 643 | [ 644 | 68, 645 | 13, 646 | 0, 647 | 30, 648 | 3, 649 | "LATENT" 650 | ], 651 | [ 652 | 70, 653 | 13, 654 | 0, 655 | 33, 656 | 0, 657 | "LATENT" 658 | ], 659 | [ 660 | 71, 661 | 4, 662 | 2, 663 | 33, 664 | 1, 665 | "VAE" 666 | ], 667 | [ 668 | 72, 669 | 33, 670 | 0, 671 | 34, 672 | 1, 673 | "IMAGE" 674 | ], 675 | [ 676 | 73, 677 | 8, 678 | 0, 679 | 34, 680 | 0, 681 | "IMAGE" 682 | ], 683 | [ 684 | 74, 685 | 34, 686 | 0, 687 | 35, 688 | 0, 689 | "IMAGE" 690 | ] 691 | ], 692 | "groups": [], 693 | "config": {}, 694 | "extra": {}, 695 | "version": 0.4 696 | } -------------------------------------------------------------------------------- /modules/attn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file is modified from the multidiffusion-upscaler-for-automatic1111 TiledVAE attn.py, so that the StableSR can save much VRAM. 3 | Add modified further for Comfyui 4 | ''' 5 | import math 6 | import torch 7 | import comfy.model_management 8 | 9 | ''' 10 | # DELETE from sd-webui-stablesr version 11 | from modules import shared, sd_hijack 12 | from modules.sd_hijack_optimizations import get_available_vram, get_xformers_flash_attention_op, sub_quad_attention 13 | ''' 14 | 15 | try: 16 | import xformers 17 | import xformers.ops 18 | except ImportError: 19 | pass 20 | 21 | # Simplified version of attn_forward switch -- only xformers for now -- see get_attn_forward original below to add capability 22 | def sr_get_attn_func(): 23 | if comfy.model_management.xformers_enabled(): 24 | return sr_xformers_attnblock_forward 25 | 26 | return sr_attn_forward 27 | 28 | # The following functions are all copied from stable-diffusion-webui modules.sd_hijack_optimizations 29 | # However, the residual & normalization are removed and computed separately. 30 | # And addapted for Comfyui 31 | 32 | # Note no change from original attn_function 33 | def sr_attn_forward(q, k, v): 34 | # compute attention 35 | # q: b,hw,c 36 | k = k.permute(0, 2, 1) # b,c,hw 37 | c = k.shape[1] 38 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 39 | w_ = w_ * (int(c)**(-0.5)) 40 | w_ = torch.nn.functional.softmax(w_, dim=2) 41 | 42 | # attend to values 43 | v = v.permute(0, 2, 1) # b,c,hw 44 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 45 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 46 | h_ = torch.bmm(v, w_) 47 | 48 | return h_.permute(0, 2, 1) 49 | 50 | def sr_xformers_attnblock_forward(q, k, v): 51 | return xformers.ops.memory_efficient_attention(q, k, v, op=sr_get_xformers_flash_attention_op(q, k, v)) 52 | 53 | # barrow get_xformers_flash_attention_op from auto1111 54 | def sr_get_xformers_flash_attention_op(q, k, v): 55 | ''' 56 | # this may be a problem 57 | if not shared.cmd_opts.xformers_flash_attention: 58 | return None 59 | ''' 60 | try: 61 | flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp 62 | fw, bw = flash_attention_op 63 | if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): 64 | return flash_attention_op 65 | except Exception as e: 66 | print(f'[StableSR] Error getting Flash attention handler: {e}') 67 | return None 68 | 69 | 70 | ''' 71 | The original sd-web-ui get_attn_func() and mapped handlers for reference 72 | def get_attn_func(): 73 | method = sd_hijack.model_hijack.optimization_method 74 | print(f"[StableSR] in get_attn_function - optimization method: {method}") 75 | if method is None: 76 | return attn_forward 77 | method = method.lower() 78 | # The method should be one of the following: 79 | # ['none', 'sdp-no-mem', 'sdp', 'xformers', ''sub-quadratic', 'v1', 'invokeai', 'doggettx'] 80 | if method not in ['none', 'sdp-no-mem', 'sdp', 'xformers', 'sub-quadratic', 'v1', 'invokeai', 'doggettx']: 81 | print(f"[StableSR] Warning: Unknown attention optimization method {method}. Please try to update the extension.") 82 | return attn_forward 83 | 84 | if method == 'none': 85 | return attn_forward 86 | elif method == 'xformers': 87 | return xformers_attnblock_forward 88 | elif method == 'sdp-no-mem': 89 | return sdp_no_mem_attnblock_forward 90 | elif method == 'sdp': 91 | return sdp_attnblock_forward 92 | elif method == 'sub-quadratic': 93 | return sub_quad_attnblock_forward 94 | elif method == 'doggettx': 95 | return cross_attention_attnblock_forward 96 | 97 | return attn_forward 98 | 99 | 100 | # The following functions are all copied from modules.sd_hijack_optimizations 101 | # However, the residual & normalization are removed and computed separately. 102 | 103 | def attn_forward(q, k, v): 104 | # compute attention 105 | # q: b,hw,c 106 | k = k.permute(0, 2, 1) # b,c,hw 107 | c = k.shape[1] 108 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 109 | w_ = w_ * (int(c)**(-0.5)) 110 | w_ = torch.nn.functional.softmax(w_, dim=2) 111 | 112 | # attend to values 113 | v = v.permute(0, 2, 1) # b,c,hw 114 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 115 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 116 | h_ = torch.bmm(v, w_) 117 | 118 | return h_.permute(0, 2, 1) 119 | 120 | def xformers_attnblock_forward(q, k, v): 121 | return xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) 122 | 123 | 124 | def cross_attention_attnblock_forward(q, k, v): 125 | # compute attention 126 | k = k.permute(0, 2, 1)# b,c,hw 127 | v = v.permute(0, 2, 1)# b,c,hw 128 | c = k.shape[1] 129 | h_ = torch.zeros_like(k, device=q.device) 130 | 131 | mem_free_total = get_available_vram() 132 | 133 | tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() 134 | mem_required = tensor_size * 2.5 135 | steps = 1 136 | 137 | if mem_required > mem_free_total: 138 | steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) 139 | 140 | slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] 141 | for i in range(0, q.shape[1], slice_size): 142 | end = i + slice_size 143 | 144 | w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 145 | w2 = w1 * (int(c)**(-0.5)) 146 | del w1 147 | w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) 148 | del w2 149 | 150 | # attend to values 151 | w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 152 | del w3 153 | 154 | h_[:, :, i:end] = torch.bmm(v, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 155 | del w4 156 | 157 | return h_.permute(0, 2, 1) 158 | 159 | def sdp_no_mem_attnblock_forward(q, k, v): 160 | with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): 161 | return sdp_attnblock_forward(q, k, v) 162 | 163 | def sdp_attnblock_forward(q, k, v): 164 | return torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) 165 | 166 | def sub_quad_attnblock_forward(q, k, v): 167 | return sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=True) 168 | ''' -------------------------------------------------------------------------------- /modules/colorfix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from torch import Tensor 4 | from torch.nn import functional as F 5 | 6 | from torchvision.transforms import ToTensor, ToPILImage 7 | 8 | def adain_color_fix(target: Image, source: Image): 9 | # Convert images to tensors 10 | to_tensor = ToTensor() 11 | target_tensor = to_tensor(target).unsqueeze(0) 12 | source_tensor = to_tensor(source).unsqueeze(0) 13 | 14 | # Apply adaptive instance normalization 15 | result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) 16 | 17 | # Convert tensor back to image 18 | to_image = ToPILImage() 19 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) 20 | 21 | return result_image 22 | 23 | def wavelet_color_fix(target: Image, source: Image): 24 | # Convert images to tensors 25 | to_tensor = ToTensor() 26 | target_tensor = to_tensor(target).unsqueeze(0) 27 | source_tensor = to_tensor(source).unsqueeze(0) 28 | 29 | # Apply wavelet reconstruction 30 | result_tensor = wavelet_reconstruction(target_tensor, source_tensor) 31 | 32 | # Convert tensor back to image 33 | to_image = ToPILImage() 34 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) 35 | 36 | return result_image 37 | 38 | def calc_mean_std(feat: Tensor, eps=1e-5): 39 | """Calculate mean and std for adaptive_instance_normalization. 40 | Args: 41 | feat (Tensor): 4D tensor. 42 | eps (float): A small value added to the variance to avoid 43 | divide-by-zero. Default: 1e-5. 44 | """ 45 | size = feat.size() 46 | assert len(size) == 4, 'The input feature should be 4D tensor.' 47 | b, c = size[:2] 48 | feat_var = feat.view(b, c, -1).var(dim=2) + eps 49 | feat_std = feat_var.sqrt().view(b, c, 1, 1) 50 | feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) 51 | return feat_mean, feat_std 52 | 53 | def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): 54 | """Adaptive instance normalization. 55 | Adjust the reference features to have the similar color and illuminations 56 | as those in the degradate features. 57 | Args: 58 | content_feat (Tensor): The reference feature. 59 | style_feat (Tensor): The degradate features. 60 | """ 61 | size = content_feat.size() 62 | style_mean, style_std = calc_mean_std(style_feat) 63 | content_mean, content_std = calc_mean_std(content_feat) 64 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 65 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 66 | 67 | def wavelet_blur(image: Tensor, radius: int): 68 | """ 69 | Apply wavelet blur to the input tensor. 70 | """ 71 | # input shape: (1, 3, H, W) 72 | # convolution kernel 73 | kernel_vals = [ 74 | [0.0625, 0.125, 0.0625], 75 | [0.125, 0.25, 0.125], 76 | [0.0625, 0.125, 0.0625], 77 | ] 78 | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) 79 | # add channel dimensions to the kernel to make it a 4D tensor 80 | kernel = kernel[None, None] 81 | # repeat the kernel across all input channels 82 | kernel = kernel.repeat(3, 1, 1, 1) 83 | image = F.pad(image, (radius, radius, radius, radius), mode='replicate') 84 | # apply convolution 85 | output = F.conv2d(image, kernel, groups=3, dilation=radius) 86 | return output 87 | 88 | def wavelet_decomposition(image: Tensor, levels=5): 89 | """ 90 | Apply wavelet decomposition to the input tensor. 91 | This function only returns the low frequency & the high frequency. 92 | """ 93 | high_freq = torch.zeros_like(image) 94 | for i in range(levels): 95 | radius = 2 ** i 96 | low_freq = wavelet_blur(image, radius) 97 | high_freq += (image - low_freq) 98 | image = low_freq 99 | 100 | return high_freq, low_freq 101 | 102 | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): 103 | """ 104 | Apply wavelet decomposition, so that the content will have the same color as the style. 105 | """ 106 | # calculate the wavelet decomposition of the content feature 107 | content_high_freq, content_low_freq = wavelet_decomposition(content_feat) 108 | del content_low_freq 109 | # calculate the wavelet decomposition of the style feature 110 | style_high_freq, style_low_freq = wavelet_decomposition(style_feat) 111 | del style_high_freq 112 | # reconstruct the content feature with the style's high frequency 113 | return content_high_freq + style_low_freq 114 | 115 | -------------------------------------------------------------------------------- /modules/spade.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import re 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel 12 | from comfy.ldm.modules.diffusionmodules.util import checkpoint 13 | 14 | from .util import normalization 15 | 16 | 17 | class SPADE(nn.Module): 18 | def __init__(self, norm_nc, label_nc=256, config_text='spadeinstance3x3'): 19 | super().__init__() 20 | assert config_text.startswith('spade') 21 | parsed = re.search('spade(\D+)(\d)x\d', config_text) 22 | ks = int(parsed.group(2)) 23 | self.param_free_norm = normalization(norm_nc) 24 | 25 | # The dimension of the intermediate embedding space. Yes, hardcoded. 26 | nhidden = 128 27 | 28 | pw = ks // 2 29 | self.mlp_shared = nn.Sequential( 30 | nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), 31 | nn.ReLU() 32 | ) 33 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 34 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 35 | 36 | def forward(self, x_dic, segmap_dic): 37 | return checkpoint( 38 | self._forward, (x_dic, segmap_dic), self.parameters(), True 39 | ) 40 | 41 | def _forward(self, x_dic, segmap_dic): 42 | segmap = segmap_dic[str(x_dic.size(-1))] 43 | x = x_dic 44 | 45 | # Part 1. generate parameter-free normalized activations 46 | normalized = self.param_free_norm(x) 47 | 48 | # Part 2. produce scaling and bias conditioned on semantic map 49 | # segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') 50 | actv = self.mlp_shared(segmap) 51 | 52 | repeat_factor = normalized.shape[0]//segmap.shape[0] 53 | if repeat_factor > 1: 54 | out = normalized 55 | out *= (1 + self.mlp_gamma(actv).repeat_interleave(repeat_factor, dim=0)) 56 | out += self.mlp_beta(actv).repeat_interleave(repeat_factor, dim=0) 57 | else: 58 | out = normalized 59 | out *= (1 + self.mlp_gamma(actv)) 60 | out += self.mlp_beta(actv) 61 | return out 62 | 63 | def dual_resblock_forward(self: ResBlock, x, emb, spade: SPADE, get_struct_cond): 64 | if self.updown: 65 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 66 | h = in_rest(x) 67 | h = self.h_upd(h) 68 | x = self.x_upd(x) 69 | h = in_conv(h) 70 | else: 71 | h = self.in_layers(x) 72 | emb_out = self.emb_layers(emb).type(h.dtype) 73 | while len(emb_out.shape) < len(h.shape): 74 | emb_out = emb_out[..., None] 75 | if self.use_scale_shift_norm: 76 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 77 | scale, shift = torch.chunk(emb_out, 2, dim=1) 78 | h = out_norm(h) * (1 + scale) + shift 79 | h = out_rest(h) 80 | else: 81 | h = h + emb_out 82 | h = self.out_layers(h) 83 | h = spade(h, get_struct_cond()) 84 | return self.skip_connection(x) + h 85 | 86 | 87 | class SPADELayers(nn.Module): 88 | def __init__(self): 89 | ''' 90 | A container class for fast SPADE layer loading. 91 | params inferred from the official checkpoint 92 | ''' 93 | super().__init__() 94 | self.input_blocks = nn.ModuleList([ 95 | nn.Identity(), 96 | SPADE(320), 97 | SPADE(320), 98 | nn.Identity(), 99 | SPADE(640), 100 | SPADE(640), 101 | nn.Identity(), 102 | SPADE(1280), 103 | SPADE(1280), 104 | nn.Identity(), 105 | SPADE(1280), 106 | SPADE(1280), 107 | ]) 108 | self.middle_block = nn.ModuleList([ 109 | SPADE(1280), 110 | nn.Identity(), 111 | SPADE(1280), 112 | ]) 113 | self.output_blocks = nn.ModuleList([ 114 | SPADE(1280), 115 | SPADE(1280), 116 | SPADE(1280), 117 | SPADE(1280), 118 | SPADE(1280), 119 | SPADE(1280), 120 | SPADE(640), 121 | SPADE(640), 122 | SPADE(640), 123 | SPADE(320), 124 | SPADE(320), 125 | SPADE(320), 126 | ]) 127 | self.input_ids = [1,2,4,5,7,8,10,11] 128 | self.output_ids = list(range(12)) 129 | self.mid_ids = [0,2] 130 | self.forward_cache_name = 'org_forward_stablesr' 131 | self.unet = None 132 | 133 | 134 | def hook(self, unet: UNetModel, get_struct_cond): 135 | # hook all resblocks 136 | self.unet = unet 137 | resblock: ResBlock = None 138 | for i in self.input_ids: 139 | resblock = unet.input_blocks[i][0] 140 | # debug 141 | # assert isinstance(resblock, ResBlock) 142 | if not hasattr(resblock, self.forward_cache_name): 143 | setattr(resblock, self.forward_cache_name, resblock._forward) 144 | resblock._forward = lambda x, timesteps, resblock=resblock, spade=self.input_blocks[i]: dual_resblock_forward(resblock, x, timesteps, spade, get_struct_cond) 145 | 146 | for i in self.output_ids: 147 | resblock = unet.output_blocks[i][0] 148 | # debug 149 | # assert isinstance(resblock, ResBlock) 150 | if not hasattr(resblock, self.forward_cache_name): 151 | setattr(resblock, self.forward_cache_name, resblock._forward) 152 | resblock._forward = lambda x, timesteps, resblock=resblock, spade=self.output_blocks[i]: dual_resblock_forward(resblock, x, timesteps, spade, get_struct_cond) 153 | 154 | for i in self.mid_ids: 155 | resblock = unet.middle_block[i] 156 | # debug 157 | # assert isinstance(resblock, ResBlock) 158 | if not hasattr(resblock, self.forward_cache_name): 159 | setattr(resblock, self.forward_cache_name, resblock._forward) 160 | resblock._forward = lambda x, timesteps, resblock=resblock, spade=self.middle_block[i]: dual_resblock_forward(resblock, x, timesteps, spade, get_struct_cond) 161 | 162 | def unhook(self): 163 | unet = self.unet 164 | if unet is None: return 165 | resblock: ResBlock = None 166 | for i in self.input_ids: 167 | resblock = unet.input_blocks[i][0] 168 | if hasattr(resblock, self.forward_cache_name): 169 | resblock._forward = getattr(resblock, self.forward_cache_name) 170 | delattr(resblock, self.forward_cache_name) 171 | 172 | for i in self.output_ids: 173 | resblock = unet.output_blocks[i][0] 174 | if hasattr(resblock, self.forward_cache_name): 175 | resblock._forward = getattr(resblock, self.forward_cache_name) 176 | delattr(resblock, self.forward_cache_name) 177 | 178 | for i in self.mid_ids: 179 | resblock = unet.middle_block[i] 180 | if hasattr(resblock, self.forward_cache_name): 181 | resblock._forward = getattr(resblock, self.forward_cache_name) 182 | delattr(resblock, self.forward_cache_name) 183 | self.unet = None 184 | 185 | 186 | def load_from_dict(self, state_dict): 187 | """ 188 | Load model weights from a dictionary. 189 | :param state_dict: a dict of parameters. 190 | """ 191 | filtered_dict = {} 192 | for k, v in state_dict.items(): 193 | if k.startswith("model.diffusion_model."): 194 | key = k[len("model.diffusion_model.") :] 195 | # remove the '.0.spade' within the key 196 | if 'middle_block' not in key: 197 | key = key.replace('.0.spade', '') 198 | else: 199 | key = key.replace('.spade', '') 200 | filtered_dict[key] = v 201 | self.load_state_dict(filtered_dict) 202 | 203 | 204 | if __name__ == '__main__': 205 | path = '../models/stablesr_sd21.ckpt' 206 | state_dict = torch.load(path) 207 | model = SPADELayers() 208 | model.load_from_dict(state_dict) 209 | print(model) -------------------------------------------------------------------------------- /modules/struct_cond.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from comfy.ldm.modules.diffusionmodules.openaimodel import ( 5 | Downsample, ResBlock, TimestepEmbedSequential) 6 | from comfy.ldm.modules.diffusionmodules.util import (checkpoint, 7 | timestep_embedding, 8 | zero_module) 9 | 10 | # NOTE only change in file for Comyfui 11 | from .attn import sr_get_attn_func as get_attn_func 12 | from .util import conv_nd, linear, normalization 13 | 14 | attn_func = None 15 | 16 | 17 | class QKVAttentionLegacy(nn.Module): 18 | """ 19 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 20 | """ 21 | 22 | def __init__(self, n_heads): 23 | super().__init__() 24 | self.n_heads = n_heads 25 | 26 | def forward(self, qkv): 27 | """ 28 | Apply QKV attention. 29 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 30 | :return: an [N x (H * C) x T] tensor after attention. 31 | """ 32 | bs, width, length = qkv.shape 33 | assert width % (3 * self.n_heads) == 0 34 | ch = width // (3 * self.n_heads) 35 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 36 | # Legacy Attention 37 | # scale = 1 / math.sqrt(math.sqrt(ch)) 38 | # weight = torch.einsum( 39 | # "bct,bcs->bts", q * scale, k * scale 40 | # ) # More stable with f16 than dividing afterwards 41 | # weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 42 | # a = torch.einsum("bts,bcs->bct", weight, v) 43 | # a = a.reshape(bs, -1, length) 44 | q, k, v = map( 45 | lambda t:t.permute(0,2,1) 46 | .contiguous(), 47 | (q, k, v), 48 | ) 49 | global attn_func 50 | a = attn_func(q, k, v) 51 | a = ( 52 | a.permute(0,2,1) 53 | .reshape(bs, -1, length) 54 | ) 55 | return a 56 | 57 | class AttentionBlock(nn.Module): 58 | """ 59 | An attention block that allows spatial positions to attend to each other. 60 | Originally ported from here, but adapted to the N-d case. 61 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 62 | """ 63 | 64 | def __init__( 65 | self, 66 | channels, 67 | num_heads=1, 68 | num_head_channels=-1, 69 | use_checkpoint=False, 70 | use_new_attention_order=False, 71 | ): 72 | super().__init__() 73 | self.channels = channels 74 | if num_head_channels == -1: 75 | self.num_heads = num_heads 76 | else: 77 | assert ( 78 | channels % num_head_channels == 0 79 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 80 | self.num_heads = channels // num_head_channels 81 | self.norm = normalization(channels) 82 | self.qkv = conv_nd(1, channels, channels * 3, 1) 83 | self.attention = QKVAttentionLegacy(self.num_heads) 84 | 85 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 86 | 87 | def forward(self, x): 88 | return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! 89 | 90 | def _forward(self, x): 91 | b, c, *spatial = x.shape 92 | x = x.reshape(b, c, -1) 93 | qkv = self.qkv(self.norm(x)) 94 | h = self.attention(qkv) 95 | h = self.proj_out(h) 96 | return (x + h).reshape(b, c, *spatial) 97 | 98 | 99 | class EncoderUNetModelWT(nn.Module): 100 | """ 101 | The half UNet model with attention and timestep embedding. 102 | For usage, see UNet. 103 | """ 104 | 105 | def __init__( 106 | self, 107 | in_channels, 108 | model_channels, 109 | out_channels, 110 | num_res_blocks, 111 | attention_resolutions, 112 | dropout=0, 113 | channel_mult=(1, 2, 4, 8), 114 | conv_resample=True, 115 | dims=2, 116 | use_checkpoint=False, 117 | use_fp16=False, 118 | num_heads=4, 119 | num_head_channels=-1, 120 | num_heads_upsample=-1, 121 | use_scale_shift_norm=False, 122 | resblock_updown=False, 123 | use_new_attention_order=False, 124 | ): 125 | super().__init__() 126 | 127 | if num_heads_upsample == -1: 128 | num_heads_upsample = num_heads 129 | 130 | self.in_channels = in_channels 131 | self.model_channels = model_channels 132 | self.out_channels = out_channels 133 | self.num_res_blocks = num_res_blocks 134 | self.attention_resolutions = attention_resolutions 135 | self.dropout = dropout 136 | self.channel_mult = channel_mult 137 | self.conv_resample = conv_resample 138 | self.use_checkpoint = use_checkpoint 139 | self.dtype = torch.float16 if use_fp16 else torch.float32 140 | self.num_heads = num_heads 141 | self.num_head_channels = num_head_channels 142 | self.num_heads_upsample = num_heads_upsample 143 | 144 | time_embed_dim = model_channels * 4 145 | self.time_embed = nn.Sequential( 146 | linear(model_channels, time_embed_dim), 147 | nn.SiLU(), 148 | linear(time_embed_dim, time_embed_dim), 149 | ) 150 | 151 | self.input_blocks = nn.ModuleList( 152 | [ 153 | TimestepEmbedSequential( 154 | conv_nd(dims, in_channels, model_channels, 3, padding=1) 155 | ) 156 | ] 157 | ) 158 | self._feature_size = model_channels 159 | input_block_chans = [] 160 | ch = model_channels 161 | ds = 1 162 | for level, mult in enumerate(channel_mult): 163 | for _ in range(num_res_blocks): 164 | layers = [ 165 | ResBlock( 166 | ch, 167 | time_embed_dim, 168 | dropout, 169 | out_channels=mult * model_channels, 170 | dims=dims, 171 | use_checkpoint=use_checkpoint, 172 | use_scale_shift_norm=use_scale_shift_norm, 173 | ) 174 | ] 175 | ch = mult * model_channels 176 | if ds in attention_resolutions: 177 | layers.append( 178 | AttentionBlock( 179 | ch, 180 | use_checkpoint=use_checkpoint, 181 | num_heads=num_heads, 182 | num_head_channels=num_head_channels, 183 | use_new_attention_order=use_new_attention_order, 184 | ) 185 | ) 186 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 187 | self._feature_size += ch 188 | if level != len(channel_mult) - 1: 189 | out_ch = ch 190 | self.input_blocks.append( 191 | TimestepEmbedSequential( 192 | ResBlock( 193 | ch, 194 | time_embed_dim, 195 | dropout, 196 | out_channels=out_ch, 197 | dims=dims, 198 | use_checkpoint=use_checkpoint, 199 | use_scale_shift_norm=use_scale_shift_norm, 200 | down=True, 201 | ) 202 | if resblock_updown 203 | else Downsample( 204 | ch, conv_resample, dims=dims, out_channels=out_ch 205 | ) 206 | ) 207 | ) 208 | ch = out_ch 209 | input_block_chans.append(ch) 210 | ds *= 2 211 | self._feature_size += ch 212 | 213 | self.middle_block = TimestepEmbedSequential( 214 | ResBlock( 215 | ch, 216 | time_embed_dim, 217 | dropout, 218 | dims=dims, 219 | use_checkpoint=use_checkpoint, 220 | use_scale_shift_norm=use_scale_shift_norm, 221 | ), 222 | AttentionBlock( 223 | ch, 224 | use_checkpoint=use_checkpoint, 225 | num_heads=num_heads, 226 | num_head_channels=num_head_channels, 227 | use_new_attention_order=use_new_attention_order, 228 | ), 229 | ResBlock( 230 | ch, 231 | time_embed_dim, 232 | dropout, 233 | dims=dims, 234 | use_checkpoint=use_checkpoint, 235 | use_scale_shift_norm=use_scale_shift_norm, 236 | ), 237 | ) 238 | input_block_chans.append(ch) 239 | self._feature_size += ch 240 | self.input_block_chans = input_block_chans 241 | 242 | self.fea_tran = nn.ModuleList([]) 243 | 244 | for i in range(len(input_block_chans)): 245 | self.fea_tran.append( 246 | ResBlock( 247 | input_block_chans[i], 248 | time_embed_dim, 249 | dropout, 250 | out_channels=out_channels, 251 | dims=dims, 252 | use_checkpoint=use_checkpoint, 253 | use_scale_shift_norm=use_scale_shift_norm, 254 | ) 255 | ) 256 | 257 | @torch.no_grad() 258 | def forward(self, x, timesteps): 259 | """ 260 | Apply the model to an input batch. 261 | :param x: an [N x C x ...] Tensor of inputs. 262 | :param timesteps: a 1-D batch of timesteps. 263 | :return: an [N x K] Tensor of outputs. 264 | """ 265 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels).to(x.dtype)) 266 | 267 | result_list = [] 268 | results = {} 269 | h = x.type(self.dtype) 270 | for module in self.input_blocks: 271 | last_h = h 272 | h = module(h, emb) 273 | if h.size(-1) != last_h.size(-1): 274 | result_list.append(last_h) 275 | h = self.middle_block(h, emb) 276 | result_list.append(h) 277 | 278 | assert len(result_list) == len(self.fea_tran) 279 | 280 | for i in range(len(result_list)): 281 | results[str(result_list[i].size(-1))] = self.fea_tran[i](result_list[i], emb) 282 | 283 | return results 284 | 285 | def load_from_dict(self, state_dict): 286 | """ 287 | Load model weights from a dictionary. 288 | :param state_dict: a dict of parameters. 289 | """ 290 | filtered_dict = {} 291 | for k, v in state_dict.items(): 292 | if k.startswith("structcond_stage_model."): 293 | filtered_dict[k[len("structcond_stage_model.") :]] = v 294 | self.load_state_dict(filtered_dict) 295 | 296 | 297 | def build_unetwt(use_fp16=False) -> EncoderUNetModelWT: 298 | """ 299 | Build a model from a state dict. 300 | :param state_dict: a dict of parameters. 301 | :return: a nn.Module. 302 | """ 303 | # The settings is from official setting yaml file. 304 | # https://github.com/IceClear/StableSR/blob/main/configs/stableSRNew/v2-finetune_text_T_512.yaml 305 | 306 | model = EncoderUNetModelWT( 307 | in_channels=4, 308 | model_channels=256, 309 | out_channels=256, 310 | num_res_blocks=2, 311 | attention_resolutions=[ 4, 2, 1 ], 312 | dropout=0.0, 313 | channel_mult=[1, 1, 2, 2], 314 | conv_resample=True, 315 | dims=2, 316 | use_checkpoint=False, 317 | use_fp16=use_fp16, 318 | num_heads=4, 319 | num_head_channels=-1, 320 | num_heads_upsample=-1, 321 | use_scale_shift_norm=False, 322 | resblock_updown=False, 323 | use_new_attention_order=False, 324 | ) 325 | global attn_func 326 | attn_func = get_attn_func() 327 | return model 328 | 329 | 330 | if __name__ == "__main__": 331 | ''' 332 | Test the lr encoder model. 333 | ''' 334 | path = '../models/stablesr_sd21.ckpt' 335 | state_dict = torch.load(path) 336 | for key in state_dict.keys(): 337 | print(key) 338 | model = build_unetwt() 339 | model.load_from_dict(state_dict) 340 | model = model.cuda() 341 | test_latent = torch.zeros(1, 4, 64, 64).half().cuda() 342 | test_timesteps = torch.tensor([0]).half().cuda() 343 | with torch.no_grad(): 344 | test_result = model(test_latent, test_timesteps) 345 | print(test_result) -------------------------------------------------------------------------------- /modules/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import PIL.Image as Image 3 | import torch 4 | 5 | 6 | def pil2tensor(image): 7 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) 8 | 9 | 10 | def tensor2pil(image): 11 | return Image.fromarray( 12 | np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) 13 | ) 14 | 15 | 16 | def conv_nd(dims, *args, **kwargs): 17 | """ 18 | Create a 1D, 2D, or 3D convolution module. 19 | """ 20 | if dims == 1: 21 | return torch.nn.Conv1d(*args, **kwargs) 22 | elif dims == 2: 23 | return torch.nn.Conv2d(*args, **kwargs) 24 | elif dims == 3: 25 | return torch.nn.Conv3d(*args, **kwargs) 26 | 27 | 28 | def linear(*args, **kwargs): 29 | return torch.nn.Linear(*args, **kwargs) 30 | 31 | 32 | def normalization(channels): 33 | return torch.nn.GroupNorm(32, channels) 34 | -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | import comfy.model_management 7 | import comfy.sample 8 | import folder_paths 9 | 10 | from .modules.colorfix import adain_color_fix, wavelet_color_fix 11 | from .modules.spade import SPADELayers 12 | from .modules.struct_cond import EncoderUNetModelWT, build_unetwt 13 | from .modules.util import pil2tensor, tensor2pil 14 | 15 | model_path = folder_paths.models_dir 16 | folder_name = "stablesr" 17 | folder_path = os.path.join( 18 | model_path, "stablesr" 19 | ) # set a default path for the common comfyui model path 20 | if folder_name in folder_paths.folder_names_and_paths: 21 | folder_path = folder_paths.folder_names_and_paths[folder_name][0][ 22 | 0 23 | ] # if a custom path was set in extra_model_paths.yaml then use it 24 | folder_paths.folder_names_and_paths["stablesr"] = ( 25 | [folder_path], 26 | folder_paths.supported_pt_extensions, 27 | ) 28 | 29 | 30 | class StableSRColorFix: 31 | @classmethod 32 | def INPUT_TYPES(s): 33 | return { 34 | "required": { 35 | "image": ("IMAGE",), 36 | "color_map_image": ("IMAGE",), 37 | "color_fix": ( 38 | [ 39 | "Wavelet", 40 | "AdaIN", 41 | ], 42 | ), 43 | }, 44 | } 45 | 46 | RETURN_TYPES = ("IMAGE",) 47 | FUNCTION = "fix_color" 48 | CATEGORY = "image" 49 | 50 | def fix_color(self, image, color_map_image, color_fix): 51 | print(f"[StableSR] fix_color") 52 | try: 53 | color_fix_func = ( 54 | wavelet_color_fix if color_fix == "Wavelet" else adain_color_fix 55 | ) 56 | result_image = color_fix_func( 57 | tensor2pil(image), tensor2pil(color_map_image) 58 | ) 59 | refined_image = pil2tensor(result_image) 60 | return (refined_image,) 61 | except Exception as e: 62 | print(f"[StableSR] Error fix_color: {e}") 63 | 64 | 65 | original_sample = comfy.sample.sample 66 | SAMPLE_X = None 67 | 68 | 69 | def hook_sample(*args, **kwargs): 70 | global SAMPLE_X 71 | if len(args) >= 9: 72 | SAMPLE_X = args[8] 73 | elif "latent_image" in kwargs: 74 | SAMPLE_X = kwargs["latent_image"] 75 | return original_sample(*args, **kwargs) 76 | 77 | 78 | comfy.sample.sample = hook_sample 79 | 80 | 81 | class StableSR: 82 | """ 83 | Initializes a StableSR model. 84 | 85 | Args: 86 | path: The path to the StableSR checkpoint file. 87 | dtype: The data type of the model. If not specified, the default data type will be used. 88 | device: The device to run the model on. If not specified, the default device will be used. 89 | """ 90 | 91 | def __init__(self, stable_sr_model_path, dtype, device): 92 | print(f"[StbaleSR] in StableSR init - dtype: {dtype}, device: {device}") 93 | state_dict = comfy.utils.load_torch_file(stable_sr_model_path) 94 | 95 | self.struct_cond_model: EncoderUNetModelWT = build_unetwt( 96 | use_fp16=dtype == torch.float16 97 | ) 98 | self.spade_layers: SPADELayers = SPADELayers() 99 | self.struct_cond_model.load_from_dict(state_dict) 100 | self.spade_layers.load_from_dict(state_dict) 101 | del state_dict 102 | 103 | self.dtype = dtype 104 | self.struct_cond_model.apply(lambda x: x.to(dtype=dtype, device=device)) 105 | self.spade_layers.apply(lambda x: x.to(dtype=dtype, device=device)) 106 | self.latent_image: Tensor = None 107 | self.set_image_hooks = {} 108 | self.struct_cond: Tensor = None 109 | 110 | self.auto_set_latent = False 111 | self.last_t = 0.0 112 | 113 | def set_latent_image(self, latent_image): 114 | self.latent_image = latent_image 115 | 116 | def set_auto_set_latent(self, auto_set_latent): 117 | self.auto_set_latent = auto_set_latent 118 | 119 | def __call__(self, model_function, params): 120 | # explode packed args 121 | input_x = params.get("input") 122 | timestep = params.get("timestep") 123 | c = params.get("c") 124 | 125 | t = model_function.__self__.model_sampling.timestep(timestep) 126 | 127 | if self.auto_set_latent: 128 | tt = float(t[0]) 129 | if self.last_t <= tt: 130 | latent_image = model_function.__self__.process_latent_in(SAMPLE_X) 131 | self.set_latent_image(latent_image) 132 | self.last_t = tt 133 | 134 | # set latent image to device 135 | device = input_x.device 136 | latent_image = self.latent_image.to(dtype=self.dtype, device=device) 137 | 138 | # Ensure the device of all modules layers is the same as the unet 139 | # This will fix the issue when user use --medvram or --lowvram 140 | self.spade_layers.to(device) 141 | self.struct_cond_model.to(device) 142 | 143 | self.struct_cond = None # mitigate vram peak 144 | self.struct_cond = self.struct_cond_model( 145 | latent_image, t[: latent_image.shape[0]] 146 | ) 147 | 148 | self.spade_layers.hook( 149 | model_function.__self__.diffusion_model, lambda: self.struct_cond 150 | ) 151 | 152 | # Call the model_function with the provided arguments 153 | result = model_function(input_x, timestep, **c) 154 | 155 | self.spade_layers.unhook() 156 | 157 | # Return the result 158 | return result 159 | 160 | def to(self, device): 161 | if type(device) == torch.device: 162 | self.struct_cond_model.apply(lambda x: x.to(device=device)) 163 | self.spade_layers.apply(lambda x: x.to(device=device)) 164 | return self 165 | 166 | 167 | class ApplyStableSRUpscaler: 168 | @classmethod 169 | def INPUT_TYPES(s): 170 | return { 171 | "required": { 172 | "model": ("MODEL",), 173 | "stablesr_model": (folder_paths.get_filename_list("stablesr"),), 174 | }, 175 | "optional": { 176 | "latent_image": ("LATENT",), 177 | }, 178 | } 179 | 180 | RETURN_TYPES = ("MODEL",) 181 | 182 | FUNCTION = "apply_stable_sr_upscaler" 183 | CATEGORY = "image/upscaling" 184 | 185 | def apply_stable_sr_upscaler(self, model, stablesr_model, latent_image=None): 186 | stablesr_model_path = folder_paths.get_full_path("stablesr", stablesr_model) 187 | if not os.path.isfile(stablesr_model_path): 188 | raise Exception(f"[StableSR] Invalid StableSR model reference") 189 | 190 | upscaler = StableSR( 191 | stablesr_model_path, dtype=comfy.model_management.unet_dtype(), device="cpu" 192 | ) 193 | if latent_image != None: 194 | latent_image = model.model.process_latent_in(latent_image["samples"]) 195 | upscaler.set_latent_image(latent_image) 196 | else: 197 | upscaler.set_auto_set_latent(True) 198 | 199 | model_sr = model.clone() 200 | model_sr.set_model_unet_function_wrapper(upscaler) 201 | return (model_sr,) 202 | 203 | 204 | NODE_CLASS_MAPPINGS = { 205 | "StableSRColorFix": StableSRColorFix, 206 | "ApplyStableSRUpscaler": ApplyStableSRUpscaler, 207 | } 208 | 209 | NODE_DISPLAY_NAME_MAPPINGS = { 210 | "StableSRColorFix": "StableSRColorFix", 211 | "ApplyStableSRUpscaler": "ApplyStableSRUpscaler", 212 | } 213 | --------------------------------------------------------------------------------