├── .gitignore ├── LICENSE ├── README.md ├── model_data └── README.md ├── nets ├── __init__.py ├── attention.py ├── pipeline.py └── transformer_2d.py ├── predict.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore map, miou, datasets 2 | map_out/ 3 | miou_out/ 4 | VOCdevkit/ 5 | datasets/ 6 | Medical_Datasets/ 7 | lfw/ 8 | logs/ 9 | model_data/ 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Scalable Diffusion Models with Transformers (DiT) 在Pytorch当中的实现 2 | --- 3 | 4 | ## 目录 5 | 1. [仓库更新 Top News](#仓库更新) 6 | 2. [所需环境 Environment](#所需环境) 7 | 3. [文件下载 Download](#文件下载) 8 | 4. [预测步骤 How2predict](#预测步骤) 9 | 5. [参考资料 Reference](#Reference) 10 | 11 | ## Top News 12 | **`2024-01`**:**创建仓库,支持简单预测,并且将DiT网络相关代码从diffusers中扒出,方便学习。** 13 | 14 | ### 所需环境 15 | torch==1.7.1以上 16 | 17 | ### 文件下载 18 | 预测所需的权重可以在百度网盘下载。 19 | 链接: https://pan.baidu.com/s/1xy_mujfbTs7gEITNQWf5Kw?pwd=kx7r 20 | 提取码: kx7r 21 | 22 | ## 预测步骤 23 | ### a、使用预训练权重 24 | 1. 下载完库后解压,在百度网盘下载权值,放入model_data 25 | 2. 运行predict.py. 26 | 27 | ### Reference 28 | https://github.com/facebookresearch/DiT 29 | https://github.com/huggingface/diffusers -------------------------------------------------------------------------------- /model_data/README.md: -------------------------------------------------------------------------------- 1 | Please put weights here. -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bubbliiiing/DiT-pytorch/4a8758a2dd8f11b9d82ed4c94682ddf01dcec99f/nets/__init__.py -------------------------------------------------------------------------------- /nets/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Dict, Optional 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch import nn 19 | 20 | from diffusers.utils import maybe_allow_in_graph 21 | from diffusers.models.activations import get_activation 22 | from diffusers.models.attention_processor import Attention 23 | from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings 24 | 25 | 26 | @maybe_allow_in_graph 27 | class BasicTransformerBlock(nn.Module): 28 | r""" 29 | A basic Transformer block. 30 | 31 | Parameters: 32 | dim (`int`): The number of channels in the input and output. 33 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 34 | attention_head_dim (`int`): The number of channels in each head. 35 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 36 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 37 | only_cross_attention (`bool`, *optional*): 38 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 39 | double_self_attention (`bool`, *optional*): 40 | Whether to use two self-attention layers. In this case no cross attention layers are used. 41 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 42 | num_embeds_ada_norm (: 43 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 44 | attention_bias (: 45 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | dim: int, 51 | num_attention_heads: int, 52 | attention_head_dim: int, 53 | dropout=0.0, 54 | cross_attention_dim: Optional[int] = None, 55 | activation_fn: str = "geglu", 56 | num_embeds_ada_norm: Optional[int] = None, 57 | attention_bias: bool = False, 58 | only_cross_attention: bool = False, 59 | double_self_attention: bool = False, 60 | upcast_attention: bool = False, 61 | norm_elementwise_affine: bool = True, 62 | norm_type: str = "layer_norm", 63 | final_dropout: bool = False, 64 | ): 65 | super().__init__() 66 | self.only_cross_attention = only_cross_attention 67 | 68 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 69 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 70 | 71 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 72 | raise ValueError( 73 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 74 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 75 | ) 76 | 77 | # Define 3 blocks. Each block has its own normalization layer. 78 | # 1. Self-Attn 79 | if self.use_ada_layer_norm: 80 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 81 | elif self.use_ada_layer_norm_zero: 82 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 83 | else: 84 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 85 | self.attn1 = Attention( 86 | query_dim=dim, 87 | heads=num_attention_heads, 88 | dim_head=attention_head_dim, 89 | dropout=dropout, 90 | bias=attention_bias, 91 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 92 | upcast_attention=upcast_attention, 93 | ) 94 | 95 | # 2. Cross-Attn 96 | if cross_attention_dim is not None or double_self_attention: 97 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 98 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 99 | # the second cross attention block. 100 | self.norm2 = ( 101 | AdaLayerNorm(dim, num_embeds_ada_norm) 102 | if self.use_ada_layer_norm 103 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 104 | ) 105 | self.attn2 = Attention( 106 | query_dim=dim, 107 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 108 | heads=num_attention_heads, 109 | dim_head=attention_head_dim, 110 | dropout=dropout, 111 | bias=attention_bias, 112 | upcast_attention=upcast_attention, 113 | ) # is self-attn if encoder_hidden_states is none 114 | else: 115 | self.norm2 = None 116 | self.attn2 = None 117 | 118 | # 3. Feed-forward 119 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 120 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) 121 | 122 | # let chunk size default to None 123 | self._chunk_size = None 124 | self._chunk_dim = 0 125 | 126 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): 127 | # Sets chunk feed-forward 128 | self._chunk_size = chunk_size 129 | self._chunk_dim = dim 130 | 131 | def forward( 132 | self, 133 | hidden_states: torch.FloatTensor, 134 | attention_mask: Optional[torch.FloatTensor] = None, 135 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 136 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 137 | timestep: Optional[torch.LongTensor] = None, 138 | cross_attention_kwargs: Dict[str, Any] = None, 139 | class_labels: Optional[torch.LongTensor] = None, 140 | ): 141 | # Notice that normalization is always applied before the real computation in the following blocks. 142 | # 1. Self-Attention 143 | # 在Self-Attention前先施加norm 144 | if self.use_ada_layer_norm: 145 | norm_hidden_states = self.norm1(hidden_states, timestep) 146 | elif self.use_ada_layer_norm_zero: 147 | # 在norm1中,已经进行了输入特征的缩放与偏置 148 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 149 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 150 | ) 151 | else: 152 | norm_hidden_states = self.norm1(hidden_states) 153 | 154 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} 155 | 156 | # 然后施加Self-Attention 157 | attn_output = self.attn1( 158 | norm_hidden_states, 159 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 160 | attention_mask=attention_mask, 161 | **cross_attention_kwargs, 162 | ) 163 | # 在Self-Attention后,再次进行了特征的缩放(gate) 164 | if self.use_ada_layer_norm_zero: 165 | attn_output = gate_msa.unsqueeze(1) * attn_output 166 | hidden_states = attn_output + hidden_states 167 | 168 | # 2. Cross-Attention 169 | if self.attn2 is not None: 170 | norm_hidden_states = ( 171 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 172 | ) 173 | 174 | attn_output = self.attn2( 175 | norm_hidden_states, 176 | encoder_hidden_states=encoder_hidden_states, 177 | attention_mask=encoder_attention_mask, 178 | **cross_attention_kwargs, 179 | ) 180 | hidden_states = attn_output + hidden_states 181 | 182 | # 3. Feed-forward 183 | norm_hidden_states = self.norm3(hidden_states) 184 | 185 | # 在mlp前,进行了输入特征的缩放与偏置 186 | if self.use_ada_layer_norm_zero: 187 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 188 | 189 | # 然后施加全连接层 190 | if self._chunk_size is not None: 191 | # "feed_forward_chunk_size" can be used to save memory 192 | if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: 193 | raise ValueError( 194 | f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." 195 | ) 196 | 197 | num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size 198 | ff_output = torch.cat( 199 | [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], 200 | dim=self._chunk_dim, 201 | ) 202 | else: 203 | ff_output = self.ff(norm_hidden_states) 204 | 205 | # 在mlp后,再次进行了特征的缩放(gate) 206 | if self.use_ada_layer_norm_zero: 207 | ff_output = gate_mlp.unsqueeze(1) * ff_output 208 | 209 | hidden_states = ff_output + hidden_states 210 | 211 | return hidden_states 212 | 213 | 214 | class FeedForward(nn.Module): 215 | r""" 216 | A feed-forward layer. 217 | 218 | Parameters: 219 | dim (`int`): The number of channels in the input. 220 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 221 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 222 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 223 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 224 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. 225 | """ 226 | 227 | def __init__( 228 | self, 229 | dim: int, 230 | dim_out: Optional[int] = None, 231 | mult: int = 4, 232 | dropout: float = 0.0, 233 | activation_fn: str = "geglu", 234 | final_dropout: bool = False, 235 | ): 236 | super().__init__() 237 | inner_dim = int(dim * mult) 238 | dim_out = dim_out if dim_out is not None else dim 239 | 240 | if activation_fn == "gelu": 241 | act_fn = GELU(dim, inner_dim) 242 | if activation_fn == "gelu-approximate": 243 | act_fn = GELU(dim, inner_dim, approximate="tanh") 244 | elif activation_fn == "geglu": 245 | act_fn = GEGLU(dim, inner_dim) 246 | elif activation_fn == "geglu-approximate": 247 | act_fn = ApproximateGELU(dim, inner_dim) 248 | 249 | self.net = nn.ModuleList([]) 250 | # project in 251 | self.net.append(act_fn) 252 | # project dropout 253 | self.net.append(nn.Dropout(dropout)) 254 | # project out 255 | self.net.append(nn.Linear(inner_dim, dim_out)) 256 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 257 | if final_dropout: 258 | self.net.append(nn.Dropout(dropout)) 259 | 260 | def forward(self, hidden_states): 261 | for module in self.net: 262 | hidden_states = module(hidden_states) 263 | return hidden_states 264 | 265 | 266 | class GELU(nn.Module): 267 | r""" 268 | GELU activation function with tanh approximation support with `approximate="tanh"`. 269 | """ 270 | 271 | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): 272 | super().__init__() 273 | self.proj = nn.Linear(dim_in, dim_out) 274 | self.approximate = approximate 275 | 276 | def gelu(self, gate): 277 | if gate.device.type != "mps": 278 | return F.gelu(gate, approximate=self.approximate) 279 | # mps: gelu is not implemented for float16 280 | return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) 281 | 282 | def forward(self, hidden_states): 283 | hidden_states = self.proj(hidden_states) 284 | hidden_states = self.gelu(hidden_states) 285 | return hidden_states 286 | 287 | 288 | class GEGLU(nn.Module): 289 | r""" 290 | A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. 291 | 292 | Parameters: 293 | dim_in (`int`): The number of channels in the input. 294 | dim_out (`int`): The number of channels in the output. 295 | """ 296 | 297 | def __init__(self, dim_in: int, dim_out: int): 298 | super().__init__() 299 | self.proj = nn.Linear(dim_in, dim_out * 2) 300 | 301 | def gelu(self, gate): 302 | if gate.device.type != "mps": 303 | return F.gelu(gate) 304 | # mps: gelu is not implemented for float16 305 | return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) 306 | 307 | def forward(self, hidden_states): 308 | hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) 309 | return hidden_states * self.gelu(gate) 310 | 311 | 312 | class ApproximateGELU(nn.Module): 313 | """ 314 | The approximate form of Gaussian Error Linear Unit (GELU) 315 | 316 | For more details, see section 2: https://arxiv.org/abs/1606.08415 317 | """ 318 | 319 | def __init__(self, dim_in: int, dim_out: int): 320 | super().__init__() 321 | self.proj = nn.Linear(dim_in, dim_out) 322 | 323 | def forward(self, x): 324 | x = self.proj(x) 325 | return x * torch.sigmoid(1.702 * x) 326 | 327 | 328 | class AdaLayerNorm(nn.Module): 329 | """ 330 | Norm layer modified to incorporate timestep embeddings. 331 | """ 332 | 333 | def __init__(self, embedding_dim, num_embeddings): 334 | super().__init__() 335 | self.emb = nn.Embedding(num_embeddings, embedding_dim) 336 | self.silu = nn.SiLU() 337 | self.linear = nn.Linear(embedding_dim, embedding_dim * 2) 338 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) 339 | 340 | def forward(self, x, timestep): 341 | emb = self.linear(self.silu(self.emb(timestep))) 342 | scale, shift = torch.chunk(emb, 2) 343 | x = self.norm(x) * (1 + scale) + shift 344 | return x 345 | 346 | 347 | class AdaLayerNormZero(nn.Module): 348 | """ 349 | Norm layer adaptive layer norm zero (adaLN-Zero). 350 | """ 351 | 352 | def __init__(self, embedding_dim, num_embeddings): 353 | super().__init__() 354 | 355 | self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) 356 | 357 | self.silu = nn.SiLU() 358 | self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) 359 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) 360 | 361 | def forward(self, x, timestep, class_labels, hidden_dtype=None): 362 | emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) 363 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) 364 | x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] 365 | return x, gate_msa, shift_mlp, scale_mlp, gate_mlp 366 | 367 | 368 | class AdaGroupNorm(nn.Module): 369 | """ 370 | GroupNorm layer modified to incorporate timestep embeddings. 371 | """ 372 | 373 | def __init__( 374 | self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 375 | ): 376 | super().__init__() 377 | self.num_groups = num_groups 378 | self.eps = eps 379 | 380 | if act_fn is None: 381 | self.act = None 382 | else: 383 | self.act = get_activation(act_fn) 384 | 385 | self.linear = nn.Linear(embedding_dim, out_dim * 2) 386 | 387 | def forward(self, x, emb): 388 | if self.act: 389 | emb = self.act(emb) 390 | emb = self.linear(emb) 391 | emb = emb[:, :, None, None] 392 | scale, shift = emb.chunk(2, dim=1) 393 | 394 | x = F.group_norm(x, self.num_groups, eps=self.eps) 395 | x = x * (1 + scale) + shift 396 | return x 397 | -------------------------------------------------------------------------------- /nets/pipeline.py: -------------------------------------------------------------------------------- 1 | # Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) 2 | # William Peebles and Saining Xie 3 | # 4 | # Copyright (c) 2021 OpenAI 5 | # MIT License 6 | # 7 | # Copyright 2023 The HuggingFace Team. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | 21 | from typing import Dict, List, Optional, Tuple, Union 22 | 23 | import torch 24 | 25 | from diffusers.models import AutoencoderKL 26 | from diffusers.schedulers import KarrasDiffusionSchedulers 27 | from diffusers.utils import randn_tensor 28 | from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput 29 | from diffusers import logging 30 | from .transformer_2d import Transformer2DModel 31 | 32 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 33 | 34 | def randn_tensor( 35 | shape: Union[Tuple, List], 36 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 37 | device: Optional["torch.device"] = None, 38 | dtype: Optional["torch.dtype"] = None, 39 | layout: Optional["torch.layout"] = None, 40 | ): 41 | """A helper function to create random tensors on the desired `device` with the desired `dtype`. When 42 | passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor 43 | is always created on the CPU. 44 | """ 45 | # device on which tensor is created defaults to device 46 | rand_device = device 47 | batch_size = shape[0] 48 | 49 | layout = layout or torch.strided 50 | device = device or torch.device("cpu") 51 | 52 | if generator is not None: 53 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type 54 | if gen_device_type != device.type and gen_device_type == "cpu": 55 | rand_device = "cpu" 56 | if device != "mps": 57 | logger.info( 58 | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." 59 | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" 60 | f" slighly speed up this function by passing a generator that was created on the {device} device." 61 | ) 62 | elif gen_device_type != device.type and gen_device_type == "cuda": 63 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") 64 | 65 | if isinstance(generator, list): 66 | shape = (1,) + shape[1:] 67 | latents = [ 68 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) 69 | for i in range(batch_size) 70 | ] 71 | latents = torch.cat(latents, dim=0).to(device) 72 | else: 73 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) 74 | 75 | return latents 76 | 77 | class DiTPipeline(DiffusionPipeline): 78 | r""" 79 | Pipeline for image generation based on a Transformer backbone instead of a UNet. 80 | 81 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 82 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 83 | 84 | Parameters: 85 | transformer ([`Transformer2DModel`]): 86 | A class conditioned `Transformer2DModel` to denoise the encoded image latents. 87 | vae ([`AutoencoderKL`]): 88 | Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. 89 | scheduler ([`DDIMScheduler`]): 90 | A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 91 | """ 92 | 93 | def __init__( 94 | self, 95 | transformer: Transformer2DModel, 96 | vae: AutoencoderKL, 97 | scheduler: KarrasDiffusionSchedulers, 98 | id2label: Optional[Dict[int, str]] = None, 99 | ): 100 | super().__init__() 101 | self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) 102 | 103 | # create a imagenet -> id dictionary for easier use 104 | self.labels = {} 105 | if id2label is not None: 106 | for key, value in id2label.items(): 107 | for label in value.split(","): 108 | self.labels[label.lstrip().rstrip()] = int(key) 109 | self.labels = dict(sorted(self.labels.items())) 110 | 111 | def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: 112 | r""" 113 | 114 | Map label strings from ImageNet to corresponding class ids. 115 | 116 | Parameters: 117 | label (`str` or `dict` of `str`): 118 | Label strings to be mapped to class ids. 119 | 120 | Returns: 121 | `list` of `int`: 122 | Class ids to be processed by pipeline. 123 | """ 124 | 125 | if not isinstance(label, list): 126 | label = list(label) 127 | 128 | for l in label: 129 | if l not in self.labels: 130 | raise ValueError( 131 | f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}." 132 | ) 133 | 134 | return [self.labels[l] for l in label] 135 | 136 | @torch.no_grad() 137 | def __call__( 138 | self, 139 | class_labels: List[int], 140 | guidance_scale: float = 4.0, 141 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 142 | num_inference_steps: int = 50, 143 | output_type: Optional[str] = "pil", 144 | return_dict: bool = True, 145 | ) -> Union[ImagePipelineOutput, Tuple]: 146 | # 生成的批次大小、隐含层的大小与隐含层通道 147 | batch_size = len(class_labels) 148 | latent_size = self.transformer.config.sample_size 149 | latent_channels = self.transformer.config.in_channels 150 | 151 | # --------------------------------- # 152 | # 前处理 153 | # --------------------------------- # 154 | # 生成latent 155 | latents = randn_tensor( 156 | shape=(batch_size, latent_channels, latent_size, latent_size), 157 | generator=generator, 158 | device=self.device, 159 | dtype=self.transformer.dtype, 160 | ) 161 | latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents 162 | 163 | # 将输入的label 与 null label进行concat,null label是负向提示类。 164 | class_labels = torch.tensor(class_labels, device=self.device).reshape(-1) 165 | class_null = torch.tensor([1000] * batch_size, device=self.device) 166 | class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels 167 | 168 | # 设置生成的步数 169 | self.scheduler.set_timesteps(num_inference_steps) 170 | 171 | # --------------------------------- # 172 | # 扩散生成 173 | # --------------------------------- # 174 | # 开始N步扩散的循环 175 | for t in self.progress_bar(self.scheduler.timesteps): 176 | if guidance_scale > 1: 177 | half = latent_model_input[: len(latent_model_input) // 2] 178 | latent_model_input = torch.cat([half, half], dim=0) 179 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 180 | 181 | # 处理timesteps 182 | timesteps = t 183 | if not torch.is_tensor(timesteps): 184 | is_mps = latent_model_input.device.type == "mps" 185 | if isinstance(timesteps, float): 186 | dtype = torch.float32 if is_mps else torch.float64 187 | else: 188 | dtype = torch.int32 if is_mps else torch.int64 189 | timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device) 190 | elif len(timesteps.shape) == 0: 191 | timesteps = timesteps[None].to(latent_model_input.device) 192 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 193 | timesteps = timesteps.expand(latent_model_input.shape[0]) 194 | 195 | # 将隐含层特征、时间步和种类输入传入到transformers中 196 | noise_pred = self.transformer( 197 | latent_model_input, timestep=timesteps, class_labels=class_labels_input 198 | ).sample 199 | 200 | # perform guidance 201 | if guidance_scale > 1: 202 | # 在通道上做分割,取出生图部分的通道 203 | eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:] 204 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 205 | 206 | half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) 207 | eps = torch.cat([half_eps, half_eps], dim=0) 208 | 209 | noise_pred = torch.cat([eps, rest], dim=1) 210 | 211 | # 对结果进行分割,取出生图部分的通道 212 | if self.transformer.config.out_channels // 2 == latent_channels: 213 | model_output, _ = torch.split(noise_pred, latent_channels, dim=1) 214 | else: 215 | model_output = noise_pred 216 | 217 | # 通过采样器将这一步噪声施加到隐含层 218 | latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample 219 | 220 | if guidance_scale > 1: 221 | latents, _ = latent_model_input.chunk(2, dim=0) 222 | else: 223 | latents = latent_model_input 224 | 225 | # --------------------------------- # 226 | # 后处理 227 | # --------------------------------- # 228 | # 通过vae进行解码 229 | latents = 1 / self.vae.config.scaling_factor * latents 230 | samples = self.vae.decode(latents).sample 231 | 232 | samples = (samples / 2 + 0.5).clamp(0, 1) 233 | 234 | # 转化为float32类别 235 | samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() 236 | 237 | if output_type == "pil": 238 | samples = self.numpy_to_pil(samples) 239 | 240 | if not return_dict: 241 | return (samples,) 242 | 243 | return ImagePipelineOutput(images=samples) 244 | -------------------------------------------------------------------------------- /nets/transformer_2d.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Any, Dict, Optional 16 | 17 | import numpy as np 18 | import torch 19 | import torch.nn.functional as F 20 | from diffusers.configuration_utils import ConfigMixin, register_to_config 21 | from diffusers.models.modeling_utils import ModelMixin 22 | from diffusers.utils import BaseOutput, deprecate 23 | from torch import nn 24 | 25 | from .attention import BasicTransformerBlock 26 | 27 | 28 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 29 | """ 30 | grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or 31 | [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 32 | """ 33 | grid_h = np.arange(grid_size, dtype=np.float32) 34 | grid_w = np.arange(grid_size, dtype=np.float32) 35 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 36 | grid = np.stack(grid, axis=0) 37 | 38 | grid = grid.reshape([2, 1, grid_size, grid_size]) 39 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 40 | if cls_token and extra_tokens > 0: 41 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 42 | return pos_embed 43 | 44 | 45 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 46 | if embed_dim % 2 != 0: 47 | raise ValueError("embed_dim must be divisible by 2") 48 | 49 | # use half of dimensions to encode grid_h 50 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 51 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 52 | 53 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 54 | return emb 55 | 56 | 57 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 58 | """ 59 | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) 60 | """ 61 | if embed_dim % 2 != 0: 62 | raise ValueError("embed_dim must be divisible by 2") 63 | 64 | omega = np.arange(embed_dim // 2, dtype=np.float64) 65 | omega /= embed_dim / 2.0 66 | omega = 1.0 / 10000**omega # (D/2,) 67 | 68 | pos = pos.reshape(-1) # (M,) 69 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 70 | 71 | emb_sin = np.sin(out) # (M, D/2) 72 | emb_cos = np.cos(out) # (M, D/2) 73 | 74 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 75 | return emb 76 | 77 | 78 | class PatchEmbed(nn.Module): 79 | """2D Image to Patch Embedding""" 80 | 81 | def __init__( 82 | self, 83 | height=224, 84 | width=224, 85 | patch_size=16, 86 | in_channels=3, 87 | embed_dim=768, 88 | layer_norm=False, 89 | flatten=True, 90 | bias=True, 91 | ): 92 | super().__init__() 93 | 94 | num_patches = (height // patch_size) * (width // patch_size) 95 | self.flatten = flatten 96 | self.layer_norm = layer_norm 97 | 98 | self.proj = nn.Conv2d( 99 | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias 100 | ) 101 | if layer_norm: 102 | self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) 103 | else: 104 | self.norm = None 105 | 106 | pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) 107 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) 108 | 109 | def forward(self, latent): 110 | latent = self.proj(latent) 111 | if self.flatten: 112 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC 113 | if self.layer_norm: 114 | latent = self.norm(latent) 115 | return latent + self.pos_embed 116 | 117 | @dataclass 118 | class Transformer2DModelOutput(BaseOutput): 119 | """ 120 | The output of [`Transformer2DModel`]. 121 | 122 | Args: 123 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): 124 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability 125 | distributions for the unnoised latent pixels. 126 | """ 127 | 128 | sample: torch.FloatTensor 129 | 130 | 131 | class Transformer2DModel(ModelMixin, ConfigMixin): 132 | """ 133 | A 2D Transformer model for image-like data. 134 | 135 | Parameters: 136 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 137 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 138 | in_channels (`int`, *optional*): 139 | The number of channels in the input and output (specify if the input is **continuous**). 140 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 141 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 142 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 143 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 144 | This is fixed during training since it is used to learn a number of position embeddings. 145 | num_vector_embeds (`int`, *optional*): 146 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). 147 | Includes the class for the masked latent pixel. 148 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. 149 | num_embeds_ada_norm ( `int`, *optional*): 150 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is 151 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are 152 | added to the hidden states. 153 | 154 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. 155 | attention_bias (`bool`, *optional*): 156 | Configure if the `TransformerBlocks` attention should contain a bias parameter. 157 | """ 158 | 159 | @register_to_config 160 | def __init__( 161 | self, 162 | num_attention_heads: int = 16, 163 | attention_head_dim: int = 88, 164 | in_channels: Optional[int] = None, 165 | out_channels: Optional[int] = None, 166 | num_layers: int = 1, 167 | dropout: float = 0.0, 168 | norm_num_groups: int = 32, 169 | cross_attention_dim: Optional[int] = None, 170 | attention_bias: bool = False, 171 | sample_size: Optional[int] = None, 172 | num_vector_embeds: Optional[int] = None, 173 | patch_size: Optional[int] = None, 174 | activation_fn: str = "geglu", 175 | num_embeds_ada_norm: Optional[int] = None, 176 | use_linear_projection: bool = False, 177 | only_cross_attention: bool = False, 178 | upcast_attention: bool = False, 179 | norm_type: str = "layer_norm", 180 | norm_elementwise_affine: bool = True, 181 | ): 182 | super().__init__() 183 | self.use_linear_projection = use_linear_projection 184 | self.num_attention_heads = num_attention_heads 185 | self.attention_head_dim = attention_head_dim 186 | inner_dim = num_attention_heads * attention_head_dim 187 | 188 | # 2. Define input layers 189 | assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" 190 | 191 | self.height = sample_size 192 | self.width = sample_size 193 | 194 | self.patch_size = patch_size 195 | self.pos_embed = PatchEmbed( 196 | height=sample_size, 197 | width=sample_size, 198 | patch_size=patch_size, 199 | in_channels=in_channels, 200 | embed_dim=inner_dim, 201 | ) 202 | 203 | # 3. Define transformers blocks 204 | self.transformer_blocks = nn.ModuleList( 205 | [ 206 | BasicTransformerBlock( 207 | inner_dim, 208 | num_attention_heads, 209 | attention_head_dim, 210 | dropout=dropout, 211 | cross_attention_dim=cross_attention_dim, 212 | activation_fn=activation_fn, 213 | num_embeds_ada_norm=num_embeds_ada_norm, 214 | attention_bias=attention_bias, 215 | only_cross_attention=only_cross_attention, 216 | upcast_attention=upcast_attention, 217 | norm_type=norm_type, 218 | norm_elementwise_affine=norm_elementwise_affine, 219 | ) 220 | for d in range(num_layers) 221 | ] 222 | ) 223 | 224 | # 4. Define output layers 225 | self.out_channels = in_channels if out_channels is None else out_channels 226 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 227 | self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) 228 | self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 229 | 230 | def forward( 231 | self, 232 | hidden_states: torch.Tensor, 233 | encoder_hidden_states: Optional[torch.Tensor] = None, 234 | timestep: Optional[torch.LongTensor] = None, 235 | class_labels: Optional[torch.LongTensor] = None, 236 | cross_attention_kwargs: Dict[str, Any] = None, 237 | attention_mask: Optional[torch.Tensor] = None, 238 | encoder_attention_mask: Optional[torch.Tensor] = None, 239 | return_dict: bool = True, 240 | ): 241 | if attention_mask is not None and attention_mask.ndim == 2: 242 | # assume that mask is expressed as: 243 | # (1 = keep, 0 = discard) 244 | # convert mask into a bias that can be added to attention scores: 245 | # (keep = +0, discard = -10000.0) 246 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 247 | attention_mask = attention_mask.unsqueeze(1) 248 | 249 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 250 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 251 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 252 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 253 | 254 | # 1. Input 255 | hidden_states = self.pos_embed(hidden_states) 256 | 257 | # 2. Blocks 258 | for block in self.transformer_blocks: 259 | hidden_states = block( 260 | hidden_states, 261 | attention_mask=attention_mask, 262 | encoder_hidden_states=encoder_hidden_states, 263 | encoder_attention_mask=encoder_attention_mask, 264 | timestep=timestep, 265 | cross_attention_kwargs=cross_attention_kwargs, 266 | class_labels=class_labels, 267 | ) 268 | 269 | # 3. Output 270 | conditioning = self.transformer_blocks[0].norm1.emb( 271 | timestep, class_labels, hidden_dtype=hidden_states.dtype 272 | ) 273 | shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) 274 | hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] 275 | hidden_states = self.proj_out_2(hidden_states) 276 | 277 | # unpatchify 278 | height = width = int(hidden_states.shape[1] ** 0.5) 279 | hidden_states = hidden_states.reshape( 280 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) 281 | ) 282 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 283 | output = hidden_states.reshape( 284 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) 285 | ) 286 | 287 | if not return_dict: 288 | return (output,) 289 | 290 | return Transformer2DModelOutput(sample=output) 291 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import json 4 | import os 5 | from diffusers import DPMSolverMultistepScheduler, AutoencoderKL 6 | 7 | from nets.transformer_2d import Transformer2DModel 8 | from nets.pipeline import DiTPipeline 9 | 10 | # 模型路径 11 | model_path = "model_data/DiT-XL-2-256" 12 | 13 | # 初始化DiT的各个组件 14 | scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder="scheduler") 15 | transformer = Transformer2DModel.from_pretrained(model_path, subfolder="transformer") 16 | vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae") 17 | id2label = json.load(open(os.path.join(model_path, "model_index.json"), "r"))['id2label'] 18 | 19 | # 初始化DiT的Pipeline 20 | pipe = DiTPipeline(scheduler=scheduler, transformer=transformer, vae=vae, id2label=id2label) 21 | pipe = pipe.to("cuda") 22 | 23 | # imagenet种类 对应的 名称 24 | words = ["white shark", "umbrella"] 25 | # 获得imagenet对应的ids 26 | class_ids = pipe.get_label_ids(words) 27 | # 设置seed 28 | generator = torch.manual_seed(42) 29 | 30 | # pipeline前传 31 | output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator) 32 | 33 | # 保存图片 34 | for index, image in enumerate(output.images): 35 | image.save(f"output-{index}.png") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | GitPython==3.1.32 2 | Pillow==9.5.0 3 | accelerate==0.21.0 4 | basicsr==1.4.2 5 | blendmodes==2022 6 | clean-fid==0.1.35 7 | einops==0.4.1 8 | fastapi==0.94.0 9 | gfpgan==1.3.8 10 | gradio==3.41.2 11 | httpcore==0.15 12 | inflection==0.5.1 13 | jsonmerge==1.8.0 14 | kornia==0.6.7 15 | lark==1.1.2 16 | numpy==1.23.5 17 | omegaconf==2.2.3 18 | open-clip-torch==2.20.0 19 | piexif==1.1.3 20 | psutil==5.9.5 21 | pytorch_lightning==1.9.4 22 | realesrgan==0.3.0 23 | resize-right==0.0.2 24 | safetensors==0.3.1 25 | scikit-image==0.21.0 26 | timm==0.9.2 27 | tomesd==0.1.3 28 | torch 29 | diffusers==0.18.2 30 | torchdiffeq==0.2.3 31 | torchsde==0.2.5 32 | transformers==4.30.2 33 | --------------------------------------------------------------------------------