├── .github └── workflows │ └── publish.yml ├── LICENSE.md ├── README.md ├── __init__.py ├── aura_sr.py ├── nodes.py ├── nodes_preview ├── pv1.png └── pv2.png ├── pyproject.toml └── utils.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | jobs: 12 | publish-node: 13 | name: Publish Custom Node to registry 14 | runs-on: ubuntu-latest 15 | # if this is a forked repository. Skipping the workflow. 16 | if: github.event.repository.fork == false 17 | steps: 18 | - name: Check out code 19 | uses: actions/checkout@v4 20 | - name: Publish Custom Node 21 | uses: Comfy-Org/publish-node-action@main 22 | with: 23 | ## Add your own personal access token to your Github Repository secrets and reference it here. 24 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 25 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # Creative Commons Attribution-ShareAlike 4.0 International 2 | 3 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 4 | 5 | **Using Creative Commons Public Licenses** 6 | 7 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 8 | 9 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 10 | 11 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 12 | 13 | ## Creative Commons Attribution-ShareAlike 4.0 International Public License 14 | 15 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 16 | 17 | ### Section 1 – Definitions. 18 | 19 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 20 | 21 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 22 | 23 | c. __BY-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License. 24 | 25 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 26 | 27 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 28 | 29 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 30 | 31 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution and ShareAlike. 32 | 33 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 34 | 35 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 36 | 37 | j. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 38 | 39 | k. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 40 | 41 | l. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 42 | 43 | m. __You__ means the individual or entity exercising the Licensed Rights under this Public License. __Your__ has a corresponding meaning. 44 | 45 | ### Section 2 – Scope. 46 | 47 | a. ___License grant.___ 48 | 49 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 50 | 51 | A. reproduce and Share the Licensed Material, in whole or in part; and 52 | 53 | B. produce, reproduce, and Share Adapted Material. 54 | 55 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 56 | 57 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 58 | 59 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 60 | 61 | 5. __Downstream recipients.__ 62 | 63 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 64 | 65 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply. 66 | 67 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 68 | 69 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 70 | 71 | b. ___Other rights.___ 72 | 73 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 74 | 75 | 2. Patent and trademark rights are not licensed under this Public License. 76 | 77 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties. 78 | 79 | ### Section 3 – License Conditions. 80 | 81 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 82 | 83 | a. ___Attribution.___ 84 | 85 | 1. If You Share the Licensed Material (including in modified form), You must: 86 | 87 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 88 | 89 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 90 | 91 | ii. a copyright notice; 92 | 93 | iii. a notice that refers to this Public License; 94 | 95 | iv. a notice that refers to the disclaimer of warranties; 96 | 97 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 98 | 99 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 100 | 101 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 102 | 103 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 104 | 105 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 106 | 107 | b. ___ShareAlike.___ 108 | 109 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply. 110 | 111 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-SA Compatible License. 112 | 113 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material. 114 | 115 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply. 116 | 117 | ### Section 4 – Sui Generis Database Rights. 118 | 119 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 120 | 121 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database; 122 | 123 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and 124 | 125 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 126 | 127 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 128 | 129 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 130 | 131 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 132 | 133 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 134 | 135 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 136 | 137 | ### Section 6 – Term and Termination. 138 | 139 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 140 | 141 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 142 | 143 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 144 | 145 | 2. upon express reinstatement by the Licensor. 146 | 147 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 148 | 149 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 150 | 151 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 152 | 153 | ### Section 7 – Other Terms and Conditions. 154 | 155 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 156 | 157 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 158 | 159 | ### Section 8 – Interpretation. 160 | 161 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 162 | 163 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 164 | 165 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 166 | 167 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 168 | 169 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” The text of the Creative Commons public licenses is dedicated to the public domain under the [CC0 Public Domain Dedication](https://creativecommons.org/publicdomain/zero/1.0/legalcode). Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 170 | > 171 | > Creative Commons may be contacted at creativecommons.org. 172 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AuraSR-ComfyUI 2 | ComfyUI implementation of [Aura-SR](https://github.com/fal-ai/aura-sr). 3 | 4 | Also supports [AuraSR v2](https://huggingface.co/fal/AuraSR-v2) 5 | 6 | ![Interface](nodes_preview/pv1.png) 7 | 8 | 9 | ### ATTENTION: 10 | 11 | AuraSR v1 (model) is ultra sensitive to ANY kind of image compression and when given such image the output will probably be terrible. It is highly recommended that you feed it images straight out of SD (prior to any saving) - unlike the example above - which shows some of the common artifacts introduced on compressed images. 12 | 13 | TIP: If you are loading an already saved image (especially if its a .JPEG) - you can try to use a 'compression artifact-removal' model such as [DeJPG_OmniSR](https://openmodeldb.info/models/1x-DeJPG-OmniSR) before passing the image to AuraSR. Check these links to judge the results yourself: [imgsli](https://imgsli.com/Mjc1NzYw/0/2) and [imgur](https://imgur.com/a/pwFwnwF). 14 | 15 | Example workflow with DeJPG: 16 | 17 | ![Interface](nodes_preview/pv2.png) 18 | 19 | ### NOTE: AuraSR v2 seems to no longer suffer from this issue and in fact, if you use a 'compression artifact-removal' model beforehand with it - the results may lose quality! So the recommendation above is ONLY meant for the first version of the model! 20 | 21 | 22 | # Instructions: 23 | - Create a folder named 'Aura-SR' inside '\models'. 24 | - Alternatively, you can specify a (single) custom model location using ComfyUI's 'extra_model_paths.yaml' file with an entry exactly named as 'aura-sr'. 25 | - Download the .safetensors AND config.json files from [HuggingFace](https://huggingface.co/fal/AuraSR/tree/main) and place them in '\models\Aura-SR' 26 | - V2 version of the model is available here: [link](https://huggingface.co/fal/AuraSR-v2/tree/main) (seems better in some cases and much worse at others - do not use DeJPG (and similar models) with it! I'll personally just stick with V1 for now). 27 | - (Optional) Rename the model to whatever you want and rename the config file to the same name as the model (this allows for future, multiple models with their own unique configs). 28 | - Install with ComfyUI Manager, restart then reload the browser's page. 29 | - Add Node > AuraSR > AuraSR Upscaler 30 | - All of the node's parameters are self explanatory apart for 'transparency_mask' and 'reapply_transparency': 31 | - transparency_mask: (Optional) A mask obtained from loading a RGBA image (with transparent pixels). Can be directly connected to the 'Load Image' native node. 32 | - reapply_transparency: When given a valid mask AND/OR a single RGBA image - it will attempt to reapply the transparency of the original image to the upscaled one. Keep in mind that the 'Load Image' native node auto-converts the input image to RGB (no transparency) before sending it to another node. Therefore if you are not passing a valid 'transparency_mask' then you need a specialized node capable of loading and outputing in RGBA mode. This feature is internally disabled whenever you send a batch of images to the node. 33 | 34 | 35 | 36 | 37 | # Changelog 38 | ### v3.0.0: 39 | - Batch Image input is now supported. 40 | - Reapply_transparency is automatically, internally set to False when receiving batches of images. 41 | 42 | - Added ability to add a aura-sr entry to extra_model_paths.yaml. This allows you to set a custom location for your AuraSR models. Example ('aura-sr' is case sensitive): 43 | 44 | ``` 45 | somethingHere: 46 | aura-sr: your/path/to/aurasrFolder 47 | ``` 48 | 49 | ### v2.1.0: 50 | - Added support for AuraSR v0.4.0 (code) 51 | - Which introduces 2 new upscaling methods: '4x_overlapped_checkboard' and '4x_overlapped_constant'. [Comparison](https://imgsli.com/MjgxMzgx) 52 | - These new methods take at least twice the amount of time as the original but may offer better results. 53 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] 4 | -------------------------------------------------------------------------------- /aura_sr.py: -------------------------------------------------------------------------------- 1 | # AuraSR: GAN-based Super-Resolution for real-world, a reproduction of the GigaGAN* paper. Implementation is 2 | # based on the unofficial lucidrains/gigagan-pytorch repository. Heavily modified from there. 3 | # 4 | # https://mingukkang.github.io/GigaGAN/ 5 | from math import log2, ceil 6 | from functools import partial 7 | from typing import Any, Optional, List, Iterable 8 | 9 | import torch 10 | from torchvision import transforms 11 | from PIL import Image 12 | from torch import nn, einsum, Tensor 13 | import torch.nn.functional as F 14 | 15 | from einops import rearrange, repeat, reduce 16 | from einops.layers.torch import Rearrange 17 | from torchvision.utils import save_image 18 | import math 19 | 20 | 21 | def get_same_padding(size, kernel, dilation, stride): 22 | return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2 23 | 24 | 25 | class AdaptiveConv2DMod(nn.Module): 26 | def __init__( 27 | self, 28 | dim, 29 | dim_out, 30 | kernel, 31 | *, 32 | demod=True, 33 | stride=1, 34 | dilation=1, 35 | eps=1e-8, 36 | num_conv_kernels=1, # set this to be greater than 1 for adaptive 37 | ): 38 | super().__init__() 39 | self.eps = eps 40 | 41 | self.dim_out = dim_out 42 | 43 | self.kernel = kernel 44 | self.stride = stride 45 | self.dilation = dilation 46 | self.adaptive = num_conv_kernels > 1 47 | 48 | self.weights = nn.Parameter( 49 | torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel)) 50 | ) 51 | 52 | self.demod = demod 53 | 54 | nn.init.kaiming_normal_( 55 | self.weights, a=0, mode="fan_in", nonlinearity="leaky_relu" 56 | ) 57 | 58 | def forward( 59 | self, fmap, mod: Optional[Tensor] = None, kernel_mod: Optional[Tensor] = None 60 | ): 61 | """ 62 | notation 63 | 64 | b - batch 65 | n - convs 66 | o - output 67 | i - input 68 | k - kernel 69 | """ 70 | 71 | b, h = fmap.shape[0], fmap.shape[-2] 72 | 73 | # account for feature map that has been expanded by the scale in the first dimension 74 | # due to multiscale inputs and outputs 75 | 76 | if mod.shape[0] != b: 77 | mod = repeat(mod, "b ... -> (s b) ...", s=b // mod.shape[0]) 78 | 79 | if exists(kernel_mod): 80 | kernel_mod_has_el = kernel_mod.numel() > 0 81 | 82 | assert self.adaptive or not kernel_mod_has_el 83 | 84 | if kernel_mod_has_el and kernel_mod.shape[0] != b: 85 | kernel_mod = repeat( 86 | kernel_mod, "b ... -> (s b) ...", s=b // kernel_mod.shape[0] 87 | ) 88 | 89 | # prepare weights for modulation 90 | 91 | weights = self.weights 92 | 93 | if self.adaptive: 94 | weights = repeat(weights, "... -> b ...", b=b) 95 | 96 | # determine an adaptive weight and 'select' the kernel to use with softmax 97 | 98 | assert exists(kernel_mod) and kernel_mod.numel() > 0 99 | 100 | kernel_attn = kernel_mod.softmax(dim=-1) 101 | kernel_attn = rearrange(kernel_attn, "b n -> b n 1 1 1 1") 102 | 103 | weights = reduce(weights * kernel_attn, "b n ... -> b ...", "sum") 104 | 105 | # do the modulation, demodulation, as done in stylegan2 106 | 107 | mod = rearrange(mod, "b i -> b 1 i 1 1") 108 | 109 | weights = weights * (mod + 1) 110 | 111 | if self.demod: 112 | inv_norm = ( 113 | reduce(weights**2, "b o i k1 k2 -> b o 1 1 1", "sum") 114 | .clamp(min=self.eps) 115 | .rsqrt() 116 | ) 117 | weights = weights * inv_norm 118 | 119 | fmap = rearrange(fmap, "b c h w -> 1 (b c) h w") 120 | 121 | weights = rearrange(weights, "b o ... -> (b o) ...") 122 | 123 | padding = get_same_padding(h, self.kernel, self.dilation, self.stride) 124 | fmap = F.conv2d(fmap, weights, padding=padding, groups=b) 125 | 126 | return rearrange(fmap, "1 (b o) ... -> b o ...", b=b) 127 | 128 | 129 | class Attend(nn.Module): 130 | def __init__(self, dropout=0.0, flash=False): 131 | super().__init__() 132 | self.dropout = dropout 133 | self.attn_dropout = nn.Dropout(dropout) 134 | self.scale = nn.Parameter(torch.randn(1)) 135 | self.flash = flash 136 | 137 | def flash_attn(self, q, k, v): 138 | q, k, v = map(lambda t: t.contiguous(), (q, k, v)) 139 | out = F.scaled_dot_product_attention( 140 | q, k, v, dropout_p=self.dropout if self.training else 0.0 141 | ) 142 | return out 143 | 144 | def forward(self, q, k, v): 145 | if self.flash: 146 | return self.flash_attn(q, k, v) 147 | 148 | scale = q.shape[-1] ** -0.5 149 | 150 | # similarity 151 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale 152 | 153 | # attention 154 | attn = sim.softmax(dim=-1) 155 | attn = self.attn_dropout(attn) 156 | 157 | # aggregate values 158 | out = einsum("b h i j, b h j d -> b h i d", attn, v) 159 | 160 | return out 161 | 162 | 163 | def exists(x): 164 | return x is not None 165 | 166 | 167 | def default(val, d): 168 | if exists(val): 169 | return val 170 | return d() if callable(d) else d 171 | 172 | 173 | def cast_tuple(t, length=1): 174 | if isinstance(t, tuple): 175 | return t 176 | return (t,) * length 177 | 178 | 179 | def identity(t, *args, **kwargs): 180 | return t 181 | 182 | 183 | def is_power_of_two(n): 184 | return log2(n).is_integer() 185 | 186 | 187 | def null_iterator(): 188 | while True: 189 | yield None 190 | 191 | def Downsample(dim, dim_out=None): 192 | return nn.Sequential( 193 | Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), 194 | nn.Conv2d(dim * 4, default(dim_out, dim), 1), 195 | ) 196 | 197 | 198 | class RMSNorm(nn.Module): 199 | def __init__(self, dim): 200 | super().__init__() 201 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 202 | self.eps = 1e-4 203 | 204 | def forward(self, x): 205 | return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5) 206 | 207 | 208 | # building block modules 209 | 210 | 211 | class Block(nn.Module): 212 | def __init__(self, dim, dim_out, groups=8, num_conv_kernels=0): 213 | super().__init__() 214 | self.proj = AdaptiveConv2DMod( 215 | dim, dim_out, kernel=3, num_conv_kernels=num_conv_kernels 216 | ) 217 | self.kernel = 3 218 | self.dilation = 1 219 | self.stride = 1 220 | 221 | self.act = nn.SiLU() 222 | 223 | def forward(self, x, conv_mods_iter: Optional[Iterable] = None): 224 | conv_mods_iter = default(conv_mods_iter, null_iterator()) 225 | 226 | x = self.proj(x, mod=next(conv_mods_iter), kernel_mod=next(conv_mods_iter)) 227 | 228 | x = self.act(x) 229 | return x 230 | 231 | 232 | class ResnetBlock(nn.Module): 233 | def __init__( 234 | self, dim, dim_out, *, groups=8, num_conv_kernels=0, style_dims: List = [] 235 | ): 236 | super().__init__() 237 | style_dims.extend([dim, num_conv_kernels, dim_out, num_conv_kernels]) 238 | 239 | self.block1 = Block( 240 | dim, dim_out, groups=groups, num_conv_kernels=num_conv_kernels 241 | ) 242 | self.block2 = Block( 243 | dim_out, dim_out, groups=groups, num_conv_kernels=num_conv_kernels 244 | ) 245 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 246 | 247 | def forward(self, x, conv_mods_iter: Optional[Iterable] = None): 248 | h = self.block1(x, conv_mods_iter=conv_mods_iter) 249 | h = self.block2(h, conv_mods_iter=conv_mods_iter) 250 | 251 | return h + self.res_conv(x) 252 | 253 | 254 | class LinearAttention(nn.Module): 255 | def __init__(self, dim, heads=4, dim_head=32): 256 | super().__init__() 257 | self.scale = dim_head**-0.5 258 | self.heads = heads 259 | hidden_dim = dim_head * heads 260 | 261 | self.norm = RMSNorm(dim) 262 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 263 | 264 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), RMSNorm(dim)) 265 | 266 | def forward(self, x): 267 | b, c, h, w = x.shape 268 | 269 | x = self.norm(x) 270 | 271 | qkv = self.to_qkv(x).chunk(3, dim=1) 272 | q, k, v = map( 273 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 274 | ) 275 | 276 | q = q.softmax(dim=-2) 277 | k = k.softmax(dim=-1) 278 | 279 | q = q * self.scale 280 | 281 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v) 282 | 283 | out = torch.einsum("b h d e, b h d n -> b h e n", context, q) 284 | out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) 285 | return self.to_out(out) 286 | 287 | 288 | class Attention(nn.Module): 289 | def __init__(self, dim, heads=4, dim_head=32, flash=False): 290 | super().__init__() 291 | self.heads = heads 292 | hidden_dim = dim_head * heads 293 | 294 | self.norm = RMSNorm(dim) 295 | 296 | self.attend = Attend(flash=flash) 297 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 298 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 299 | 300 | def forward(self, x): 301 | b, c, h, w = x.shape 302 | x = self.norm(x) 303 | qkv = self.to_qkv(x).chunk(3, dim=1) 304 | 305 | q, k, v = map( 306 | lambda t: rearrange(t, "b (h c) x y -> b h (x y) c", h=self.heads), qkv 307 | ) 308 | 309 | out = self.attend(q, k, v) 310 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) 311 | 312 | return self.to_out(out) 313 | 314 | 315 | # feedforward 316 | def FeedForward(dim, mult=4): 317 | return nn.Sequential( 318 | RMSNorm(dim), 319 | nn.Conv2d(dim, dim * mult, 1), 320 | nn.GELU(), 321 | nn.Conv2d(dim * mult, dim, 1), 322 | ) 323 | 324 | 325 | # transformers 326 | class Transformer(nn.Module): 327 | def __init__(self, dim, dim_head=64, heads=8, depth=1, flash_attn=True, ff_mult=4): 328 | super().__init__() 329 | self.layers = nn.ModuleList([]) 330 | 331 | for _ in range(depth): 332 | self.layers.append( 333 | nn.ModuleList( 334 | [ 335 | Attention( 336 | dim=dim, dim_head=dim_head, heads=heads, flash=flash_attn 337 | ), 338 | FeedForward(dim=dim, mult=ff_mult), 339 | ] 340 | ) 341 | ) 342 | 343 | def forward(self, x): 344 | for attn, ff in self.layers: 345 | x = attn(x) + x 346 | x = ff(x) + x 347 | 348 | return x 349 | 350 | 351 | class LinearTransformer(nn.Module): 352 | def __init__(self, dim, dim_head=64, heads=8, depth=1, ff_mult=4): 353 | super().__init__() 354 | self.layers = nn.ModuleList([]) 355 | 356 | for _ in range(depth): 357 | self.layers.append( 358 | nn.ModuleList( 359 | [ 360 | LinearAttention(dim=dim, dim_head=dim_head, heads=heads), 361 | FeedForward(dim=dim, mult=ff_mult), 362 | ] 363 | ) 364 | ) 365 | 366 | def forward(self, x): 367 | for attn, ff in self.layers: 368 | x = attn(x) + x 369 | x = ff(x) + x 370 | 371 | return x 372 | 373 | 374 | class NearestNeighborhoodUpsample(nn.Module): 375 | def __init__(self, dim, dim_out=None): 376 | super().__init__() 377 | dim_out = default(dim_out, dim) 378 | self.conv = nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1) 379 | 380 | def forward(self, x): 381 | 382 | if x.shape[0] >= 64: 383 | x = x.contiguous() 384 | 385 | x = F.interpolate(x, scale_factor=2.0, mode="nearest") 386 | x = self.conv(x) 387 | 388 | return x 389 | 390 | class EqualLinear(nn.Module): 391 | def __init__(self, dim, dim_out, lr_mul=1, bias=True): 392 | super().__init__() 393 | self.weight = nn.Parameter(torch.randn(dim_out, dim)) 394 | if bias: 395 | self.bias = nn.Parameter(torch.zeros(dim_out)) 396 | 397 | self.lr_mul = lr_mul 398 | 399 | def forward(self, input): 400 | return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul) 401 | 402 | 403 | class StyleGanNetwork(nn.Module): 404 | def __init__(self, dim_in=128, dim_out=512, depth=8, lr_mul=0.1, dim_text_latent=0): 405 | super().__init__() 406 | self.dim_in = dim_in 407 | self.dim_out = dim_out 408 | self.dim_text_latent = dim_text_latent 409 | 410 | layers = [] 411 | for i in range(depth): 412 | is_first = i == 0 413 | 414 | if is_first: 415 | dim_in_layer = dim_in + dim_text_latent 416 | else: 417 | dim_in_layer = dim_out 418 | 419 | dim_out_layer = dim_out 420 | 421 | layers.extend( 422 | [EqualLinear(dim_in_layer, dim_out_layer, lr_mul), nn.LeakyReLU(0.2)] 423 | ) 424 | 425 | self.net = nn.Sequential(*layers) 426 | 427 | def forward(self, x, text_latent=None): 428 | x = F.normalize(x, dim=1) 429 | if self.dim_text_latent > 0: 430 | assert exists(text_latent) 431 | x = torch.cat((x, text_latent), dim=-1) 432 | return self.net(x) 433 | 434 | 435 | class UnetUpsampler(torch.nn.Module): 436 | 437 | def __init__( 438 | self, 439 | dim: int, 440 | *, 441 | image_size: int, 442 | input_image_size: int, 443 | init_dim: Optional[int] = None, 444 | out_dim: Optional[int] = None, 445 | style_network: Optional[dict] = None, 446 | up_dim_mults: tuple = (1, 2, 4, 8, 16), 447 | down_dim_mults: tuple = (4, 8, 16), 448 | channels: int = 3, 449 | resnet_block_groups: int = 8, 450 | full_attn: tuple = (False, False, False, True, True), 451 | flash_attn: bool = True, 452 | self_attn_dim_head: int = 64, 453 | self_attn_heads: int = 8, 454 | attn_depths: tuple = (2, 2, 2, 2, 4), 455 | mid_attn_depth: int = 4, 456 | num_conv_kernels: int = 4, 457 | resize_mode: str = "bilinear", 458 | unconditional: bool = True, 459 | skip_connect_scale: Optional[float] = None, 460 | ): 461 | super().__init__() 462 | self.style_network = style_network = StyleGanNetwork(**style_network) 463 | self.unconditional = unconditional 464 | assert not ( 465 | unconditional 466 | and exists(style_network) 467 | and style_network.dim_text_latent > 0 468 | ) 469 | 470 | assert is_power_of_two(image_size) and is_power_of_two( 471 | input_image_size 472 | ), "both output image size and input image size must be power of 2" 473 | assert ( 474 | input_image_size < image_size 475 | ), "input image size must be smaller than the output image size, thus upsampling" 476 | 477 | self.image_size = image_size 478 | self.input_image_size = input_image_size 479 | 480 | style_embed_split_dims = [] 481 | 482 | self.channels = channels 483 | input_channels = channels 484 | 485 | init_dim = default(init_dim, dim) 486 | 487 | up_dims = [init_dim, *map(lambda m: dim * m, up_dim_mults)] 488 | init_down_dim = up_dims[len(up_dim_mults) - len(down_dim_mults)] 489 | down_dims = [init_down_dim, *map(lambda m: dim * m, down_dim_mults)] 490 | self.init_conv = nn.Conv2d(input_channels, init_down_dim, 7, padding=3) 491 | 492 | up_in_out = list(zip(up_dims[:-1], up_dims[1:])) 493 | down_in_out = list(zip(down_dims[:-1], down_dims[1:])) 494 | 495 | block_klass = partial( 496 | ResnetBlock, 497 | groups=resnet_block_groups, 498 | num_conv_kernels=num_conv_kernels, 499 | style_dims=style_embed_split_dims, 500 | ) 501 | 502 | FullAttention = partial(Transformer, flash_attn=flash_attn) 503 | *_, mid_dim = up_dims 504 | 505 | self.skip_connect_scale = default(skip_connect_scale, 2**-0.5) 506 | 507 | self.downs = nn.ModuleList([]) 508 | self.ups = nn.ModuleList([]) 509 | 510 | block_count = 6 511 | 512 | for ind, ( 513 | (dim_in, dim_out), 514 | layer_full_attn, 515 | layer_attn_depth, 516 | ) in enumerate(zip(down_in_out, full_attn, attn_depths)): 517 | attn_klass = FullAttention if layer_full_attn else LinearTransformer 518 | 519 | blocks = [] 520 | for i in range(block_count): 521 | blocks.append(block_klass(dim_in, dim_in)) 522 | 523 | self.downs.append( 524 | nn.ModuleList( 525 | [ 526 | nn.ModuleList(blocks), 527 | nn.ModuleList( 528 | [ 529 | ( 530 | attn_klass( 531 | dim_in, 532 | dim_head=self_attn_dim_head, 533 | heads=self_attn_heads, 534 | depth=layer_attn_depth, 535 | ) 536 | if layer_full_attn 537 | else None 538 | ), 539 | nn.Conv2d( 540 | dim_in, dim_out, kernel_size=3, stride=2, padding=1 541 | ), 542 | ] 543 | ), 544 | ] 545 | ) 546 | ) 547 | 548 | self.mid_block1 = block_klass(mid_dim, mid_dim) 549 | self.mid_attn = FullAttention( 550 | mid_dim, 551 | dim_head=self_attn_dim_head, 552 | heads=self_attn_heads, 553 | depth=mid_attn_depth, 554 | ) 555 | self.mid_block2 = block_klass(mid_dim, mid_dim) 556 | 557 | *_, last_dim = up_dims 558 | 559 | for ind, ( 560 | (dim_in, dim_out), 561 | layer_full_attn, 562 | layer_attn_depth, 563 | ) in enumerate( 564 | zip( 565 | reversed(up_in_out), 566 | reversed(full_attn), 567 | reversed(attn_depths), 568 | ) 569 | ): 570 | attn_klass = FullAttention if layer_full_attn else LinearTransformer 571 | 572 | blocks = [] 573 | input_dim = dim_in * 2 if ind < len(down_in_out) else dim_in 574 | for i in range(block_count): 575 | blocks.append(block_klass(input_dim, dim_in)) 576 | 577 | self.ups.append( 578 | nn.ModuleList( 579 | [ 580 | nn.ModuleList(blocks), 581 | nn.ModuleList( 582 | [ 583 | NearestNeighborhoodUpsample( 584 | last_dim if ind == 0 else dim_out, 585 | dim_in, 586 | ), 587 | ( 588 | attn_klass( 589 | dim_in, 590 | dim_head=self_attn_dim_head, 591 | heads=self_attn_heads, 592 | depth=layer_attn_depth, 593 | ) 594 | if layer_full_attn 595 | else None 596 | ), 597 | ] 598 | ), 599 | ] 600 | ) 601 | ) 602 | 603 | self.out_dim = default(out_dim, channels) 604 | self.final_res_block = block_klass(dim, dim) 605 | self.final_to_rgb = nn.Conv2d(dim, channels, 1) 606 | self.resize_mode = resize_mode 607 | self.style_to_conv_modulations = nn.Linear( 608 | style_network.dim_out, sum(style_embed_split_dims) 609 | ) 610 | self.style_embed_split_dims = style_embed_split_dims 611 | 612 | @property 613 | def allowable_rgb_resolutions(self): 614 | input_res_base = int(log2(self.input_image_size)) 615 | output_res_base = int(log2(self.image_size)) 616 | allowed_rgb_res_base = list(range(input_res_base, output_res_base)) 617 | return [*map(lambda p: 2**p, allowed_rgb_res_base)] 618 | 619 | @property 620 | def device(self): 621 | return next(self.parameters()).device 622 | 623 | @property 624 | def total_params(self): 625 | return sum([p.numel() for p in self.parameters()]) 626 | 627 | def resize_image_to(self, x, size): 628 | return F.interpolate(x, (size, size), mode=self.resize_mode) 629 | 630 | def forward( 631 | self, 632 | lowres_image: torch.Tensor, 633 | styles: Optional[torch.Tensor] = None, 634 | noise: Optional[torch.Tensor] = None, 635 | global_text_tokens: Optional[torch.Tensor] = None, 636 | return_all_rgbs: bool = False, 637 | ): 638 | x = lowres_image 639 | 640 | noise_scale = 0.001 # Adjust the scale of the noise as needed 641 | noise_aug = torch.randn_like(x) * noise_scale 642 | x = x + noise_aug 643 | x = x.clamp(0, 1) 644 | 645 | shape = x.shape 646 | batch_size = shape[0] 647 | 648 | assert shape[-2:] == ((self.input_image_size,) * 2) 649 | 650 | # styles 651 | if not exists(styles): 652 | assert exists(self.style_network) 653 | 654 | noise = default( 655 | noise, 656 | torch.randn( 657 | (batch_size, self.style_network.dim_in), device=self.device 658 | ), 659 | ) 660 | styles = self.style_network(noise, global_text_tokens) 661 | 662 | # project styles to conv modulations 663 | conv_mods = self.style_to_conv_modulations(styles) 664 | conv_mods = conv_mods.split(self.style_embed_split_dims, dim=-1) 665 | conv_mods = iter(conv_mods) 666 | 667 | x = self.init_conv(x) 668 | 669 | h = [] 670 | for blocks, (attn, downsample) in self.downs: 671 | for block in blocks: 672 | x = block(x, conv_mods_iter=conv_mods) 673 | h.append(x) 674 | 675 | if attn is not None: 676 | x = attn(x) 677 | 678 | x = downsample(x) 679 | 680 | x = self.mid_block1(x, conv_mods_iter=conv_mods) 681 | x = self.mid_attn(x) 682 | x = self.mid_block2(x, conv_mods_iter=conv_mods) 683 | 684 | for ( 685 | blocks, 686 | ( 687 | upsample, 688 | attn, 689 | ), 690 | ) in self.ups: 691 | x = upsample(x) 692 | for block in blocks: 693 | if h != []: 694 | res = h.pop() 695 | res = res * self.skip_connect_scale 696 | x = torch.cat((x, res), dim=1) 697 | 698 | x = block(x, conv_mods_iter=conv_mods) 699 | 700 | if attn is not None: 701 | x = attn(x) 702 | 703 | x = self.final_res_block(x, conv_mods_iter=conv_mods) 704 | rgb = self.final_to_rgb(x) 705 | 706 | if not return_all_rgbs: 707 | return rgb 708 | 709 | return rgb, [] 710 | 711 | 712 | def tile_image(image, chunk_size=64): 713 | c, h, w = image.shape 714 | h_chunks = ceil(h / chunk_size) 715 | w_chunks = ceil(w / chunk_size) 716 | tiles = [] 717 | for i in range(h_chunks): 718 | for j in range(w_chunks): 719 | tile = image[:, i * chunk_size:(i + 1) * chunk_size, j * chunk_size:(j + 1) * chunk_size] 720 | tiles.append(tile) 721 | return tiles, h_chunks, w_chunks 722 | 723 | # This helps create a checkboard pattern with some edge blending 724 | def create_checkerboard_weights(tile_size): 725 | x = torch.linspace(-1, 1, tile_size) 726 | y = torch.linspace(-1, 1, tile_size) 727 | 728 | x, y = torch.meshgrid(x, y, indexing='ij') 729 | d = torch.sqrt(x*x + y*y) 730 | sigma, mu = 0.5, 0.0 731 | weights = torch.exp(-((d-mu)**2 / (2.0 * sigma**2))) 732 | 733 | # saturate the values to sure get high weights in the center 734 | weights = weights**8 735 | 736 | return weights / weights.max() # Normalize to [0, 1] 737 | 738 | def repeat_weights(weights, image_size): 739 | tile_size = weights.shape[0] 740 | repeats = (math.ceil(image_size[0] / tile_size), math.ceil(image_size[1] / tile_size)) 741 | return weights.repeat(repeats)[:image_size[0], :image_size[1]] 742 | 743 | def create_offset_weights(weights, image_size): 744 | tile_size = weights.shape[0] 745 | offset = tile_size // 2 746 | full_weights = repeat_weights(weights, (image_size[0] + offset, image_size[1] + offset)) 747 | return full_weights[offset:, offset:] 748 | 749 | def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64): 750 | # Determine the shape of the output tensor 751 | c = tiles[0].shape[0] 752 | h = h_chunks * chunk_size 753 | w = w_chunks * chunk_size 754 | 755 | # Create an empty tensor to hold the merged image 756 | merged = torch.zeros((c, h, w), dtype=tiles[0].dtype) 757 | 758 | # Iterate over the tiles and place them in the correct position 759 | for idx, tile in enumerate(tiles): 760 | i = idx // w_chunks 761 | j = idx % w_chunks 762 | 763 | h_start = i * chunk_size 764 | w_start = j * chunk_size 765 | 766 | tile_h, tile_w = tile.shape[1:] 767 | merged[:, h_start:h_start+tile_h, w_start:w_start+tile_w] = tile 768 | 769 | return merged 770 | 771 | class AuraSR: 772 | def __init__(self, config: dict[str, Any], device: str = "cuda"): 773 | self.upsampler = UnetUpsampler(**config).to(device) 774 | self.input_image_size = config["input_image_size"] 775 | 776 | ## Disabled from_pretrained because it imports huggingface_hub and its never really used by the Node 777 | #@classmethod 778 | #def from_pretrained(cls, model_id: str = "fal-ai/AuraSR", use_safetensors: bool = True): 779 | # import json 780 | # import torch 781 | # from pathlib import Path 782 | # from huggingface_hub import snapshot_download 783 | # 784 | # # Check if model_id is a local file 785 | # if Path(model_id).is_file(): 786 | # local_file = Path(model_id) 787 | # if local_file.suffix == '.safetensors': 788 | # use_safetensors = True 789 | # elif local_file.suffix == '.ckpt': 790 | # use_safetensors = False 791 | # else: 792 | # raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.") 793 | # 794 | # # For local files, we need to provide the config separately 795 | # config_path = local_file.with_name('config.json') 796 | # if not config_path.exists(): 797 | # raise FileNotFoundError( 798 | # f"Config file not found: {config_path}. " 799 | # f"When loading from a local file, ensure that 'config.json' " 800 | # f"is present in the same directory as '{local_file.name}'. " 801 | # f"If you're trying to load a model from Hugging Face, " 802 | # f"please provide the model ID instead of a file path." 803 | # ) 804 | # 805 | # config = json.loads(config_path.read_text()) 806 | # hf_model_path = local_file.parent 807 | # else: 808 | # hf_model_path = Path(snapshot_download(model_id)) 809 | # config = json.loads((hf_model_path / "config.json").read_text()) 810 | # 811 | # model = cls(config) 812 | # 813 | # if use_safetensors: 814 | # try: 815 | # from safetensors.torch import load_file 816 | # checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id) 817 | # except ImportError: 818 | # raise ImportError( 819 | # "The safetensors library is not installed. " 820 | # "Please install it with `pip install safetensors` " 821 | # "or use `use_safetensors=False` to load the model with PyTorch." 822 | # ) 823 | # else: 824 | # checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id) 825 | # 826 | # model.upsampler.load_state_dict(checkpoint, strict=True) 827 | # return model 828 | 829 | @torch.no_grad() 830 | def upscale_4x(self, image: Image.Image, max_batch_size=8) -> Image.Image: 831 | tensor_transform = transforms.ToTensor() 832 | device = self.upsampler.device 833 | 834 | image_tensor = tensor_transform(image).unsqueeze(0) 835 | _, _, h, w = image_tensor.shape 836 | pad_h = (self.input_image_size - h % self.input_image_size) % self.input_image_size 837 | pad_w = (self.input_image_size - w % self.input_image_size) % self.input_image_size 838 | 839 | # Pad the image 840 | image_tensor = torch.nn.functional.pad(image_tensor, (0, pad_w, 0, pad_h), mode='reflect').squeeze(0) 841 | tiles, h_chunks, w_chunks = tile_image(image_tensor, self.input_image_size) 842 | 843 | # Batch processing of tiles 844 | num_tiles = len(tiles) 845 | batches = [tiles[i:i + max_batch_size] for i in range(0, num_tiles, max_batch_size)] 846 | reconstructed_tiles = [] 847 | 848 | for batch in batches: 849 | model_input = torch.stack(batch).to(device) 850 | generator_output = self.upsampler( 851 | lowres_image=model_input, 852 | noise=torch.randn(model_input.shape[0], 128, device=device) 853 | ) 854 | reconstructed_tiles.extend(list(generator_output.clamp_(0, 1).detach().cpu())) 855 | 856 | merged_tensor = merge_tiles(reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4) 857 | unpadded = merged_tensor[:, :h * 4, :w * 4] 858 | 859 | to_pil = transforms.ToPILImage() 860 | return to_pil(unpadded) 861 | 862 | # Tiled 4x upscaling with overlapping tiles to reduce seam artifacts 863 | # weights options are 'checkboard' and 'constant' 864 | @torch.no_grad() 865 | def upscale_4x_overlapped(self, image, max_batch_size=8, weight_type='checkboard'): 866 | tensor_transform = transforms.ToTensor() 867 | device = self.upsampler.device 868 | 869 | image_tensor = tensor_transform(image).unsqueeze(0) 870 | _, _, h, w = image_tensor.shape 871 | 872 | # Calculate paddings 873 | pad_h = ( 874 | self.input_image_size - h % self.input_image_size 875 | ) % self.input_image_size 876 | pad_w = ( 877 | self.input_image_size - w % self.input_image_size 878 | ) % self.input_image_size 879 | 880 | # Pad the image 881 | image_tensor = torch.nn.functional.pad( 882 | image_tensor, (0, pad_w, 0, pad_h), mode="reflect" 883 | ).squeeze(0) 884 | 885 | # Function to process tiles 886 | def process_tiles(tiles, h_chunks, w_chunks): 887 | num_tiles = len(tiles) 888 | batches = [ 889 | tiles[i : i + max_batch_size] 890 | for i in range(0, num_tiles, max_batch_size) 891 | ] 892 | reconstructed_tiles = [] 893 | 894 | for batch in batches: 895 | model_input = torch.stack(batch).to(device) 896 | generator_output = self.upsampler( 897 | lowres_image=model_input, 898 | noise=torch.randn(model_input.shape[0], 128, device=device), 899 | ) 900 | reconstructed_tiles.extend( 901 | list(generator_output.clamp_(0, 1).detach().cpu()) 902 | ) 903 | 904 | return merge_tiles( 905 | reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4 906 | ) 907 | 908 | # First pass 909 | tiles1, h_chunks1, w_chunks1 = tile_image(image_tensor, self.input_image_size) 910 | result1 = process_tiles(tiles1, h_chunks1, w_chunks1) 911 | 912 | # Second pass with offset 913 | offset = self.input_image_size // 2 914 | image_tensor_offset = torch.nn.functional.pad(image_tensor, (offset, offset, offset, offset), mode='reflect').squeeze(0) 915 | 916 | tiles2, h_chunks2, w_chunks2 = tile_image( 917 | image_tensor_offset, self.input_image_size 918 | ) 919 | result2 = process_tiles(tiles2, h_chunks2, w_chunks2) 920 | 921 | # unpad 922 | offset_4x = offset * 4 923 | result2_interior = result2[:, offset_4x:-offset_4x, offset_4x:-offset_4x] 924 | 925 | if weight_type == 'checkboard': 926 | weight_tile = create_checkerboard_weights(self.input_image_size * 4) 927 | 928 | weight_shape = result2_interior.shape[1:] 929 | weights_1 = create_offset_weights(weight_tile, weight_shape) 930 | weights_2 = repeat_weights(weight_tile, weight_shape) 931 | 932 | normalizer = weights_1 + weights_2 933 | weights_1 = weights_1 / normalizer 934 | weights_2 = weights_2 / normalizer 935 | 936 | weights_1 = weights_1.unsqueeze(0).repeat(3, 1, 1) 937 | weights_2 = weights_2.unsqueeze(0).repeat(3, 1, 1) 938 | elif weight_type == 'constant': 939 | weights_1 = torch.ones_like(result2_interior) * 0.5 940 | weights_2 = weights_1 941 | else: 942 | raise ValueError("weight_type should be either 'gaussian' or 'constant' but got", weight_type) 943 | 944 | result1 = result1 * weights_2 945 | result2 = result2_interior * weights_1 946 | 947 | # Average the overlapping region 948 | result1 = ( 949 | result1 + result2 950 | ) 951 | 952 | # Remove padding 953 | unpadded = result1[:, : h * 4, : w * 4] 954 | 955 | to_pil = transforms.ToPILImage() 956 | return to_pil(unpadded) 957 | -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import json 4 | import folder_paths 5 | from comfy import model_management 6 | import comfy.utils 7 | from .aura_sr import AuraSR 8 | from .utils import * 9 | 10 | 11 | if "aura-sr" not in folder_paths.folder_names_and_paths: 12 | aurasr_folders = [p for p in os.listdir(folder_paths.models_dir) if os.path.isdir(os.path.join(folder_paths.models_dir, p)) and (p.lower() == "aura-sr" or p.lower() == "aurasr" or p.lower() == "aura_sr")] 13 | aurasr_fullpath = os.path.join(folder_paths.models_dir, aurasr_folders[0]) if len(aurasr_folders) > 0 else os.path.join(folder_paths.models_dir, "Aura-SR") 14 | if not os.path.isdir(aurasr_fullpath): 15 | os.mkdir(aurasr_fullpath) 16 | 17 | folder_paths.folder_names_and_paths["aura-sr"] = ([aurasr_fullpath], folder_paths.supported_pt_extensions) 18 | else: 19 | aurasr_fullpath = folder_paths.folder_names_and_paths["aura-sr"][0][0] 20 | folder_paths.folder_names_and_paths.pop('aura-sr', None) 21 | folder_paths.folder_names_and_paths["aura-sr"] = ([aurasr_fullpath], folder_paths.supported_pt_extensions) 22 | 23 | 24 | 25 | AuraSRUpscalers = [] 26 | 27 | 28 | def get_config(model_path): 29 | configs = [f for f in Path(aurasr_fullpath).rglob('*') if f.is_file() and f.name.lower().endswith(".json")] 30 | # picking rules by priority (exit immediately when picked): 31 | # 0 - if a config file is at the same location of the model AND has the same name (without ext) 32 | # 1 - if a config file is at the same location of the model AND is named 'config' (without ext) 33 | # 2 - if a config file is at aurasr_fullpath and is named 'config' (without ext) 34 | # -- notify user for potential invalid config.json when case #2 35 | rule = 0 36 | while (rule < 3): 37 | for c in configs: 38 | if rule == 0 and c.parent == Path(model_path).parent and c.stem.lower() == Path(model_path).stem.lower(): 39 | return json.loads(c.read_text()) 40 | if rule == 1 and c.parent == Path(model_path).parent and c.stem.lower() == "config": 41 | return json.loads(c.read_text()) 42 | if rule == 2 and str(c.parent) == aurasr_fullpath and c.stem.lower() == "config": 43 | print(f"\n[AuraSR-ComfyUI] WARNING:\n\tCould not find a config named 'config.json'/modelname.json for model: '\\{c.parent.name}\\{Path(model_path).parent.name}\\{Path(model_path).name}'") 44 | print(f"\tUsing '\\{c.parent.name}\\{c.name}' instead.") 45 | print("\tIf this configuration is not intended for this model then it can cause errors or quality loss in the output!!\n") 46 | return json.loads(c.read_text()) 47 | rule += 1 48 | return None 49 | 50 | 51 | def getAuraClassFromMemory(model_name): 52 | i = 0 53 | while (i < len(AuraSRUpscalers)): 54 | if model_name == AuraSRUpscalers[i].model_name: 55 | if not AuraSRUpscalers[i].loaded: # remove if model not loaded 56 | AuraSRUpscalers[i].unload() 57 | AuraSRUpscalers.pop(i) 58 | else: 59 | return AuraSRUpscalers[i] 60 | i += 1 61 | return None 62 | 63 | 64 | class AuraSRUpscaler: 65 | @classmethod 66 | def INPUT_TYPES(s): 67 | return {"required": {"model_name": (folder_paths.get_filename_list("aura-sr"),), 68 | "image": ("IMAGE",), 69 | "mode": (["4x", "4x_overlapped_checkboard", "4x_overlapped_constant"],), 70 | "reapply_transparency": ("BOOLEAN", {"default": True}), 71 | "tile_batch_size": ("INT", {"default": 8, "min": 1, "max": 32}), 72 | "device": (["default", "cpu"],), 73 | "offload_to_cpu": ("BOOLEAN", {"default": False}), 74 | }, 75 | "optional": { 76 | "transparency_mask": ("MASK",), 77 | }, 78 | } 79 | 80 | RETURN_TYPES = ("IMAGE",) 81 | FUNCTION = "main" 82 | 83 | CATEGORY = "AuraSR" 84 | 85 | def __init__(self): 86 | self.loaded = False 87 | self.model_name = "" 88 | self.aura_sr = None 89 | self.upscaling_factor = 4 90 | self.device_warned = False 91 | self.config = None 92 | self.device = "cpu" 93 | 94 | 95 | def unload(self): 96 | if self.aura_sr is not None: 97 | self.aura_sr.upsampler = None # I don't know if this is the best way to unload a model but it should work 98 | self.aura_sr = None 99 | self.loaded = False 100 | self.model_name = "" 101 | self.upscaling_factor = 4 102 | self.config = None 103 | self.device = "cpu" 104 | 105 | 106 | def load(self, model_name, device): 107 | model_path = folder_paths.get_full_path("aura-sr", model_name) 108 | self.config = get_config(model_path) 109 | if self.config is None: 110 | return 111 | 112 | try: 113 | self.upscaling_factor = int(self.config["image_size"] / self.config["input_image_size"]) 114 | except: 115 | print(f"[AuraSR-ComfyUI] Failed to calculate {model_name}'s upscaling factor. Defaulting to 4.") 116 | self.upscaling_factor = 4 117 | 118 | checkpoint = comfy.utils.load_torch_file(model_path, safe_load=True) 119 | 120 | self.aura_sr = AuraSR(config=self.config, device=device) 121 | self.aura_sr.upsampler.load_state_dict(checkpoint, strict=True) 122 | 123 | self.loaded = True 124 | self.model_name = model_name 125 | self.device = device 126 | 127 | 128 | def load_from_memory(self, cl, device): 129 | self.loaded = True 130 | self.model_name = cl.model_name 131 | self.aura_sr = cl.aura_sr 132 | self.upscaling_factor = cl.upscaling_factor 133 | self.device_warned = cl.device_warned 134 | self.config = cl.config 135 | if device != cl.device: 136 | self.aura_sr.upsampler.to(device) 137 | cl.device = device 138 | self.device = device 139 | 140 | 141 | 142 | def main(self, model_name, image, mode, reapply_transparency, tile_batch_size, device, offload_to_cpu, transparency_mask=None): 143 | 144 | # set device 145 | torch_device = model_management.get_torch_device() 146 | if model_management.directml_enabled: 147 | if device == "default" and not self.device_warned: 148 | print("[AuraSR-ComfyUI] Cannot run AuraSR on DirectML device. Using CPU instead (this will be VERY SLOW!)") 149 | self.device_warned = True 150 | device = "cpu" 151 | else: 152 | device = torch_device if device == "default" else "cpu" 153 | device = device if str(device).lower() != "cpu" else "cpu" # force device to be "cpu" when using CPU in default mode 154 | 155 | # load/unload model 156 | class_in_memory = getAuraClassFromMemory(model_name) 157 | if not self.loaded or self.model_name != model_name: 158 | 159 | if class_in_memory is None: 160 | self.unload() 161 | self.load(model_name, device) 162 | AuraSRUpscalers.append(self) 163 | else: 164 | self.load_from_memory(class_in_memory, device) 165 | 166 | if self.config is None: 167 | print("[AuraSR-ComfyUI] Could not find a config/ModelName .json file! Please download it from the model's HF page and place it according to the instructions (https://github.com/GreenLandisaLie/AuraSR-ComfyUI?tab=readme-ov-file#instructions).\nReturning original image.") 168 | return (image, ) 169 | else: 170 | if self.device != device: 171 | self.aura_sr.upsampler.to(device) 172 | self.device = device 173 | if class_in_memory is not None: 174 | class_in_memory.device = device 175 | 176 | 177 | # iterate through images input 178 | upscaled_images = [] 179 | reapply_transparency = reapply_transparency if len(image) == 1 else False 180 | for tensor_image in image: 181 | 182 | # prepare input image and resized_alpha 183 | input_image, resized_alpha = prepare_input(tensor_image if len(image) != 1 else image, transparency_mask, reapply_transparency, self.upscaling_factor) 184 | 185 | # upscale 186 | inference_failed = False 187 | try: 188 | if mode == "4x": 189 | upscaled_image = self.aura_sr.upscale_4x(image=input_image, max_batch_size=tile_batch_size) 190 | elif mode == "4x_overlapped_checkboard": 191 | upscaled_image = self.aura_sr.upscale_4x_overlapped(image=input_image, max_batch_size=tile_batch_size, weight_type='checkboard') 192 | else: 193 | upscaled_image = self.aura_sr.upscale_4x_overlapped(image=input_image, max_batch_size=tile_batch_size, weight_type='constant') 194 | except: 195 | inference_failed = True 196 | print("[AuraSR-ComfyUI] Failed to upscale with AuraSR. Returning original image.") 197 | upscaled_image = input_image 198 | 199 | # apply resized_alpha 200 | if reapply_transparency and resized_alpha is not None: 201 | try: 202 | upscaled_image = paste_alpha(upscaled_image, resized_alpha) 203 | except: 204 | print("[AuraSR-ComfyUI] Failed to apply alpha layer.") 205 | 206 | # back to tensor and add to list 207 | upscaled_images.append(pil2tensor(upscaled_image)) 208 | 209 | 210 | # create output tensor from list of tensors 211 | output = torch.cat(upscaled_images, dim=0) 212 | 213 | # offload to cpu 214 | if offload_to_cpu: 215 | self.aura_sr.upsampler.to("cpu") 216 | self.device = "cpu" 217 | if class_in_memory is not None: 218 | class_in_memory.device = "cpu" 219 | 220 | # force unload when inference fails 221 | if inference_failed: 222 | self.unload() 223 | 224 | return (output, ) 225 | 226 | 227 | 228 | 229 | NODE_CLASS_MAPPINGS = { 230 | "AuraSR.AuraSRUpscaler": AuraSRUpscaler 231 | } 232 | 233 | NODE_DISPLAY_NAME_MAPPINGS = { 234 | "AuraSR.AuraSRUpscaler": "AuraSR Upscaler" 235 | } 236 | -------------------------------------------------------------------------------- /nodes_preview/pv1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenLandisaLie/AuraSR-ComfyUI/4ceef9234a232f6b2729fe8679a2a82acc60b9f6/nodes_preview/pv1.png -------------------------------------------------------------------------------- /nodes_preview/pv2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenLandisaLie/AuraSR-ComfyUI/4ceef9234a232f6b2729fe8679a2a82acc60b9f6/nodes_preview/pv2.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "aurasr-comfyui" 3 | description = "ComfyUI implementation of AuraSR" 4 | version = "3.0.1" 5 | license = { file = "LICENSE.md" } 6 | 7 | [project.urls] 8 | Repository = "https://github.com/GreenLandisaLie/AuraSR-ComfyUI" 9 | # Used by Comfy Registry https://comfyregistry.org 10 | 11 | [tool.comfy] 12 | PublisherId = "greenlandisalie" 13 | DisplayName = "AuraSR-ComfyUI" 14 | Icon = "" 15 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | def pil2tensor(image): 7 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) 8 | 9 | def tensor2pil(image): 10 | return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) 11 | 12 | def numpy2pil(image): 13 | return Image.fromarray(np.clip(255. * image.squeeze(), 0, 255).astype(np.uint8)) 14 | 15 | def to_pil(image): 16 | if isinstance(image, Image.Image): 17 | return image 18 | if isinstance(image, torch.Tensor): 19 | return tensor2pil(image) 20 | if isinstance(image, np.ndarray): 21 | return numpy2pil(image) 22 | raise ValueError(f"Cannot convert {type(image)} to PIL.Image") 23 | 24 | def has_transparency(image): 25 | if isinstance(image, Image.Image): 26 | if image.info.get("transparency", None) is not None: 27 | return True 28 | if image.mode == "P": 29 | transparent = image.info.get("transparency", -1) 30 | for _, index in image.getcolors(): 31 | if index == transparent: 32 | return True 33 | elif image.mode == "RGBA": 34 | extrema = image.getextrema() 35 | if extrema[3][0] < 255: 36 | return True 37 | 38 | if isinstance(image, torch.Tensor) or isinstance(image, np.ndarray): 39 | return True if image.shape[-1] == 4 else False 40 | 41 | return False 42 | 43 | def copy_image(image): 44 | if isinstance(image, torch.Tensor): 45 | return image.clone().detach() 46 | return image.copy() # works for both numpy and pil 47 | 48 | def get_resized_alpha(image, transparency_mask, upscaling_factor): 49 | try: 50 | if transparency_mask is not None: 51 | if image.shape[:3] != transparency_mask.shape[:3] and len(transparency_mask.shape) != len(image.shape) + 1: 52 | # Invalid mask. Attempt with original image 53 | if has_transparency(image): 54 | img = copy_image(image) 55 | else: 56 | return None 57 | else: 58 | img = transparency_mask 59 | img = 1.0 - img # invert 60 | img = img.reshape((-1, 1, img.shape[-2], img.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) # mask -> image 61 | else: 62 | if has_transparency(image): 63 | img = copy_image(image) 64 | else: 65 | return None 66 | 67 | if isinstance(img, torch.Tensor): 68 | img = img.cpu().numpy() 69 | if isinstance(img, np.ndarray): 70 | mode = 'RGBA' if transparency_mask is None else 'RGB' 71 | img = Image.fromarray(np.clip(255. * img.squeeze(), 0, 255).astype(np.uint8), mode=mode) 72 | if not img.getbbox(): # some RGB images return fully black masks with the 'Load Image' node - cannot apply masking if so 73 | return None 74 | 75 | resized_alpha = img.resize((img.width * upscaling_factor, img.height * upscaling_factor)).split()[-1] 76 | except: 77 | return None 78 | return resized_alpha 79 | 80 | def paste_alpha(image, alpha): 81 | image = image.convert("RGBA") 82 | image.putalpha(alpha) 83 | return image 84 | 85 | 86 | def prepare_input(image, transparency_mask, reapply_transparency, upscaling_factor): 87 | resized_alpha = get_resized_alpha(image, transparency_mask, upscaling_factor) if reapply_transparency else None 88 | image = to_pil(image).convert("RGB") 89 | return image, resized_alpha 90 | 91 | --------------------------------------------------------------------------------